tensorport.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def create_torch_variable(self, value, gpu=False):
        """Convenience method that produces a tensor given the value of the defined type.

        Returns: a torch tensor of same type.
        """
        if isinstance(value, torch.autograd.Variable):
            if gpu:
                value = value.cuda()
            return value
        if not torch.is_tensor(value):
            if not isinstance(value, np.ndarray):
                value = np.array(value, dtype=self.dtype.as_numpy_dtype)
            else:
                value = value.astype(self.dtype.as_numpy_dtype)
            if value.size == 0:
                return value
            allowed = [tf.int16, tf.int32, tf.int64, tf.float16, tf.float32, tf.float64, tf.int8]
            if self.dtype in allowed:
                value = torch.autograd.Variable(torch.from_numpy(value))
        else:
            value = torch.autograd.Variable(value)
        if gpu and isinstance(value, torch.autograd.Variable):
            value = value.cuda()
        return value
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号