def updateOutput(self, input):
assert input.dim() == 2
inputSize = self.weight.size(1)
outputSize = self.weight.size(0)
self._weightNorm = self._weightNorm or self.weight.new()
self._inputNorm = self._inputNorm or self.weight.new()
# y_j = (w_j * x) / ( || w_j || * || x || )
torch.norm(self._weightNorm, self.weight, 2, 1).add_(1e-12)
batchSize = input.size(0)
nelement = self.output.nelement()
self.output.resize_(batchSize, outputSize)
if self.output.nelement() != nelement:
self.output.zero_()
self.output.addmm_(0., 1., input, self.weight.t())
torch.norm(self._inputNorm, input, 2, 1).add_(1e-12)
self.output.div_(self._weightNorm.view(1, outputSize).expand_as(self.output))
self.output.div_(self._inputNorm.expand_as(self.output))
return self.output
评论列表
文章目录