def variable(value, dtype=None, name=None, constraint=None):
if isinstance(value, Tensor):
value = value.value
if isinstance(value, torch.autograd.Variable):
value = value.data
if 'torch' in str(type(value)):
value = value.numpy()
name = _prepare_name(name, 'variable')
if dtype is None:
dtype = keras.backend.floatx()
if value.dtype != dtype:
value = np.cast[dtype](value)
torch_tensor = torch.from_numpy(value)
torch_variable = torch.autograd.Variable(torch_tensor, requires_grad=True)
ktorch_variable = Variable(torch_variable, name=name)
ktorch_variable.constraint = None
make_keras_tensor(ktorch_variable)
return ktorch_variable
评论列表
文章目录