LookupTable.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def type(self, type=None, tensorCache=None):
        if not type:
            return self._type
        super(LookupTable, self).type(type, tensorCache)

        if type == 'torch.cuda.FloatTensor':
            # CUDA uses _sorted and _indices temporary tensors
            self._sorted = torch.cuda.LongTensor()
            self._indices = torch.cuda.LongTensor()
            self._count = torch.cuda.LongTensor()
            self._input = torch.cuda.LongTensor()
        else:
            # self._count and self._input should only be converted if using Cuda
            self._count = torch.IntTensor()
            self._input = torch.LongTensor()

        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号