def backward(ctx, grad_output):
if ctx._indices is not None:
indices = ctx._indices
else:
indices, = ctx.saved_tensors
grad_output = grad_output.contiguous()
if not ctx.sparse:
if indices.dim() == 2:
indices = indices.view(-1)
with torch.cuda.device_of(grad_output):
if grad_output.is_cuda:
_sorted = torch.cuda.LongTensor()
_indices = torch.cuda.LongTensor()
_count = torch.cuda.LongTensor()
else:
_count = torch.IntTensor()
_sorted = _indices = None
grad_weight = grad_output.new(ctx._weight_size).zero_()
# Doesn't support Variable grad_output
ctx._backend.LookupTable_accGradParameters(
ctx._backend.library_state,
indices,
grad_output,
grad_weight,
_count,
_sorted,
_indices,
ctx.scale_grad_by_freq,
ctx.padding_idx,
1
)
else:
tensor_type = type(grad_output).__name__
if grad_output.is_cuda:
SparseTensor = getattr(torch.cuda.sparse, tensor_type)
else:
SparseTensor = getattr(torch.sparse, tensor_type)
grad_weight = SparseTensor(
indices.view(1, -1),
grad_output.view(-1, ctx._weight_size[1]),
ctx._weight_size,
)
return None, grad_weight, None, None, None, None, None
评论列表
文章目录