misc.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:ml-utils 作者: LinxiFan 项目源码 文件源码
def to_float_tensor(x, copy=True):
    """
    FloatTensor is the most used torch type, so we create a special method for it
    """
    if torch.is_tensor(x):
        assert isinstance(x, torch.FloatTensor)
        return x
    elif TC.is_variable(x):
        x = TC.to_tensor(x)
        assert isinstance(x, torch.FloatTensor)
        return x
    elif not TC.is_numpy(x):
        x = np.array(x, copy=False)
    x = np_cast(x, np.float32)
    if copy:
        return torch.FloatTensor(x)
    else:
        return torch.from_numpy(x)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号