def updateOutput(self, input):
self._assertInput(input)
# set up buffer:
self.buff2 = self.buff2 or input[0].new()
self.buff2.resize_as_(input[1])
# compute output scores:
self.output.resize_(input[0].size(0), self.weight.size(0))
for k in range(self.weight.size(0)):
torch.mm(self.buff2, input[0], self.weight[k])
self.buff2.mul_(input[1])
torch.sum(self.output.narrow(1, k, 1), self.buff2, 1)
if self.bias:
self.output.add_(self.bias.view(1, self.bias.nelement()).expand_as(self.output))
return self.output
评论列表
文章目录