def type(self, type=None, tensorCache=None):
if not type:
return self._type
super(LookupTable, self).type(type, tensorCache)
if type == 'torch.cuda.FloatTensor':
# CUDA uses _sorted and _indices temporary tensors
self._sorted = torch.cuda.LongTensor()
self._indices = torch.cuda.LongTensor()
self._count = torch.cuda.LongTensor()
self._input = torch.cuda.LongTensor()
else:
# self._count and self._input should only be converted if using Cuda
self._count = torch.IntTensor()
self._input = torch.LongTensor()
return self
评论列表
文章目录