def forward(self, input, indices=None):
"""
Shape:
- target_batch :math:`(N, E, 1+N_r)`where `N = length, E = embedding size, N_r = noise ratio`
"""
if indices is None:
return super(IndexLinear, self).forward(input)
# the pytorch's [] operator BP can't correctly
input = input.unsqueeze(1)
target_batch = self.weight.index_select(0, indices.view(-1)).view(indices.size(0), indices.size(1), -1).transpose(1,2)
bias = self.bias.index_select(0, indices.view(-1)).view(indices.size(0), 1, indices.size(1))
out = torch.baddbmm(1, bias, 1, input, target_batch)
return out.squeeze()
评论列表
文章目录