def forward(cls, ctx, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse=False):
ctx.padding_idx = padding_idx
ctx.scale_grad_by_freq = scale_grad_by_freq
ctx._indices = None
ctx.sparse = sparse
assert indices.dim() <= 2
assert not ctx.needs_input_grad[0], "Embedding doesn't " \
"compute the gradient w.r.t. the indices"
ctx._backend = type2backend[type(weight)]
ctx._weight_size = weight.size()
if not indices.is_contiguous():
ctx._indices = indices.contiguous()
indices = ctx._indices
else:
ctx.save_for_backward(indices)
output = weight.new()
if max_norm is not None:
cls._renorm(ctx, indices, weight, max_norm, norm_type)
if indices.dim() == 1:
output = torch.index_select(weight, 0, indices)
else:
output = torch.index_select(weight, 0, indices.view(-1))
output = output.view(indices.size(0), indices.size(1), weight.size(1))
return output
评论列表
文章目录