def backward(self, grad_output):
if self._indices is not None:
indices = self._indices
else:
indices, = self.saved_tensors
grad_output = grad_output.contiguous()
if not self.sparse:
if indices.dim() == 2:
indices = indices.view(-1)
with torch.cuda.device_of(grad_output):
if grad_output.is_cuda:
_sorted = torch.cuda.LongTensor()
_indices = torch.cuda.LongTensor()
_count = torch.cuda.LongTensor()
else:
_count = torch.IntTensor()
_sorted = _indices = None
grad_weight = grad_output.new(self._weight_size).zero_()
self._backend.LookupTable_accGradParameters(
self._backend.library_state,
indices,
grad_output,
grad_weight,
_count,
_sorted,
_indices,
self.scale_grad_by_freq,
self.padding_idx,
1
)
else:
tensor_type = type(grad_output).__name__
if grad_output.is_cuda:
SparseTensor = getattr(torch.cuda.sparse, tensor_type)
else:
SparseTensor = getattr(torch.sparse, tensor_type)
grad_weight = SparseTensor(
indices.view(1, -1),
grad_output.view(-1, self._weight_size[1]),
self._weight_size,
)
return None, grad_weight
评论列表
文章目录