torch_backend.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ktorch 作者: farizrahman4u 项目源码 文件源码
def variable(value, dtype=None, name=None, constraint=None):
    if isinstance(value, Tensor):
        value = value.value
    if isinstance(value, torch.autograd.Variable):
        value = value.data
    if 'torch' in str(type(value)):
        value = value.numpy()
    name = _prepare_name(name, 'variable')
    if dtype is None:
        dtype = keras.backend.floatx()
    if value.dtype != dtype:
        value = np.cast[dtype](value)
    torch_tensor = torch.from_numpy(value)
    torch_variable = torch.autograd.Variable(torch_tensor, requires_grad=True)
    ktorch_variable = Variable(torch_variable, name=name)
    ktorch_variable.constraint = None
    make_keras_tensor(ktorch_variable)
    return ktorch_variable
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号