def type(self, type, tensorCache=None):
# torch.min expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor.
if type == 'torch.cuda.FloatTensor':
indices, self._indices = self._indices, None
super(Min, self).type(type, tensorCache)
self._indices = indices.type('torch.cuda.LongTensor') if indices is not None else None
else:
# self._indices must be a LongTensor. Setting it to nil temporarily avoids
# unnecessary memory allocations.
indices, self._indices = self._indices, None
super(Min, self).type(type, tensorCache)
self._indices = indices.long() if indices is not None else None
return self
评论列表
文章目录