torch_backend.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ktorch 作者: farizrahman4u 项目源码 文件源码
def constant(value, dtype=None, shape=None, name=None):
    value = np.array(value)
    name = _prepare_name(name, 'constant')
    if dtype is None:
        dtype = keras.backend.floatx()
    if value.dtype != dtype:
        value = np.cast[dtype](value)
    if value.shape == ():
        if shape is None:
            shape = ()
        value = np.ones(shape) * value
    torch_tensor = torch.from_numpy(value)
    torch_variable = torch.autograd.Variable(torch_tensor, requires_grad=False)
    ktorch_variable = Variable(torch_variable, name=name)
    make_keras_tensor(ktorch_variable)
    return ktorch_variable
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号