def updateOutput(self, input):
self._assertInput(input)
# set up buffer:
if self.buff2 is None:
self.buff2 = input[0].new()
self.buff2.resize_as_(input[1])
# compute output scores:
self.output.resize_(input[0].size(0), self.weight.size(0))
for k in range(self.weight.size(0)):
torch.mm(input[0], self.weight[k], out=self.buff2)
self.buff2.mul_(input[1])
torch.sum(self.buff2, 1, True, out=self.output.narrow(1, k, 1))
if self.bias is not None:
self.output.add_(self.bias.view(1, self.bias.nelement()).expand_as(self.output))
return self.output
评论列表
文章目录