def updateOutput(self, input):
self.output.set_(self.network.forward([input, self.partition]))
if self.bias:
self.output.add_(torch.index_select(self.bias, 1, self.partition).expand_as(self.output))
self.addBuffer = self.addBuffer or input.new()
if self.addBuffer.nelement() != input.size(0):
self.addBuffer.resize_(input.size(0)).fill_(1)
return self.output
评论列表
文章目录