def backward(self, grad_output):
if self._indices is not None:
indices = self._indices
else:
indices, = self.saved_tensors
if indices.dim() == 2:
indices = indices.view(-1)
grad_output = grad_output.contiguous()
if torch.typename(grad_output) == 'torch.cuda.FloatTensor':
_sorted = torch.cuda.LongTensor()
_indices = torch.cuda.LongTensor()
_count = torch.cuda.LongTensor()
else:
_count = torch.IntTensor()
_sorted = _indices = None
# TODO: sparse updates...
grad_weight = type(grad_output)(self._weight_size).zero_()
self._backend.LookupTable_accGradParameters(
self._backend.library_state,
indices,
grad_output,
grad_weight,
_count,
_sorted,
_indices,
self.scale_grad_by_freq,
self.padding_idx,
1
)
return None, grad_weight
评论列表
文章目录