def backward(self, grad_output):
if self._indices is not None:
indices = self._indices
else:
indices, = self.saved_tensors
grad_output = grad_output.contiguous()
if not self.sparse:
if indices.dim() == 2:
indices = indices.view(-1)
with torch.cuda.device_of(grad_output):
if grad_output.is_cuda:
_sorted = torch.cuda.LongTensor()
_indices = torch.cuda.LongTensor()
_count = torch.cuda.LongTensor()
else:
_count = torch.IntTensor()
_sorted = _indices = None
grad_weight = grad_output.new(self._weight_size).zero_()
self._backend.LookupTable_accGradParameters(
self._backend.library_state,
indices,
grad_output,
grad_weight,
_count,
_sorted,
_indices,
self.scale_grad_by_freq,
self.padding_idx,
1
)
else:
tensor_type = type(grad_output).__name__
if grad_output.is_cuda:
SparseTensor = getattr(torch.cuda.sparse, tensor_type)
else:
SparseTensor = getattr(torch.sparse, tensor_type)
grad_weight = SparseTensor(
indices.view(1, -1),
grad_output.view(-1, self._weight_size[1]),
self._weight_size,
)
return None, grad_weight
python类sparse()的实例源码
def backward(ctx, grad_output):
if ctx._indices is not None:
indices = ctx._indices
else:
indices, = ctx.saved_tensors
grad_output = grad_output.contiguous()
if not ctx.sparse:
if indices.dim() == 2:
indices = indices.view(-1)
with torch.cuda.device_of(grad_output):
if grad_output.is_cuda:
_sorted = torch.cuda.LongTensor()
_indices = torch.cuda.LongTensor()
_count = torch.cuda.LongTensor()
else:
_count = torch.IntTensor()
_sorted = _indices = None
grad_weight = grad_output.new(ctx._weight_size).zero_()
# Doesn't support Variable grad_output
ctx._backend.LookupTable_accGradParameters(
ctx._backend.library_state,
indices,
grad_output,
grad_weight,
_count,
_sorted,
_indices,
ctx.scale_grad_by_freq,
ctx.padding_idx,
1
)
else:
tensor_type = type(grad_output).__name__
if grad_output.is_cuda:
SparseTensor = getattr(torch.cuda.sparse, tensor_type)
else:
SparseTensor = getattr(torch.sparse, tensor_type)
grad_weight = SparseTensor(
indices.view(1, -1),
grad_output.view(-1, ctx._weight_size[1]),
ctx._weight_size,
)
return None, grad_weight, None, None, None, None, None
def sparse_getitem(sparse, idxs):
if not isinstance(idxs, tuple):
idxs = idxs,
if not sparse.ndimension() <= 2:
raise RuntimeError('Must be a 1d or 2d sparse tensor')
if len(idxs) > sparse.ndimension():
raise RuntimeError('Invalid index for %d-order tensor' % sparse.ndimension())
indices = sparse._indices()
values = sparse._values()
size = list(sparse.size())
for i, idx in list(enumerate(idxs))[::-1]:
if isinstance(idx, int):
del size[i]
mask = indices[i].eq(idx)
if sum(mask):
new_indices = indices.new().resize_(indices.size(0) - 1, sum(mask)).zero_()
for j in range(indices.size(0)):
if i > j:
new_indices[j].copy_(indices[j][mask])
elif i < j:
new_indices[j - 1].copy_(indices[j][mask])
indices = new_indices
values = values[mask]
else:
indices.resize_(indices.size(0) - 1, 1).zero_()
values.resize_(1).zero_()
if not len(size):
return sum(values)
elif isinstance(idx, slice):
start, stop, step = idx.indices(size[i])
size = list(size[:i]) + [stop - start] + list(size[i + 1:])
if step != 1:
raise RuntimeError('Slicing with step is not supported')
mask = indices[i].lt(stop) * indices[i].ge(start)
if sum(mask):
new_indices = indices.new().resize_(indices.size(0), sum(mask)).zero_()
for j in range(indices.size(0)):
new_indices[j].copy_(indices[j][mask])
new_indices[i].sub_(start)
indices = new_indices
values = values[mask]
else:
indices.resize_(indices.size(0), 1).zero_()
values.resize_(1).zero_()
else:
raise RuntimeError('Unknown index type')
return sparse.__class__(indices, values, torch.Size(size))
def backward(ctx, grad_output):
if ctx._indices is not None:
indices = ctx._indices
else:
indices, = ctx.saved_tensors
grad_output = grad_output.contiguous()
if not ctx.sparse:
if indices.dim() == 2:
indices = indices.view(-1)
with torch.cuda.device_of(grad_output):
if grad_output.is_cuda:
_sorted = torch.cuda.LongTensor()
_indices = torch.cuda.LongTensor()
_count = torch.cuda.LongTensor()
else:
_count = torch.IntTensor()
_sorted = _indices = None
grad_weight = grad_output.new(ctx._weight_size).zero_()
# Doesn't support Variable grad_output
ctx._backend.LookupTable_accGradParameters(
ctx._backend.library_state,
indices,
grad_output,
grad_weight,
_count,
_sorted,
_indices,
ctx.scale_grad_by_freq,
ctx.padding_idx,
1
)
else:
tensor_type = type(grad_output).__name__
if grad_output.is_cuda:
SparseTensor = getattr(torch.cuda.sparse, tensor_type)
else:
SparseTensor = getattr(torch.sparse, tensor_type)
padding_idx = ctx.padding_idx
indices = indices.view(1, -1)
grad_output = grad_output.view(-1, ctx._weight_size[1])
if padding_idx is not None:
nonpadding_indices_indices = (indices.view(-1) != padding_idx).nonzero().view(-1)
indices = indices.index_select(1, nonpadding_indices_indices)
grad_output = grad_output.index_select(0, nonpadding_indices_indices)
grad_weight = SparseTensor(
indices,
grad_output,
ctx._weight_size,
)
return None, grad_weight, None, None, None, None, None