def forward(self, input1, input2, weight, bias=None):
self.save_for_backward(input1, input2, weight, bias)
output = input1.new(input1.size(0), weight.size(0))
buff = input1.new()
# compute output scores:
for k, w in enumerate(weight):
torch.mm(input1, w, out=buff)
buff.mul_(input2)
torch.sum(buff, 1, out=output.narrow(1, k, 1))
if bias is not None:
output.add_(bias.expand_as(output))
return output
评论列表
文章目录