def forward(self, indices, weight):
assert indices.dim() <= 2
assert not self.needs_input_grad[0], "Embedding doesn't " \
"compute the gradient w.r.t. the indices"
self._backend = type2backend[type(weight)]
self._weight_size = weight.size()
if not indices.is_contiguous():
self._indices = indices.contiguous()
indices = self._indices
else:
self.save_for_backward(indices)
if self.max_norm is not None:
self._renorm(indices, weight)
output = weight.new()
if indices.dim() == 1:
torch.index_select(output, weight, 0, indices)
else:
torch.index_select(output, weight, 0, indices.view(-1))
output = output.view(indices.size(0), indices.size(1), weight.size(1))
return output
评论列表
文章目录