def updateOutput(self, input):
assert input.dim() == 2
inputSize = self.weight.size(1)
outputSize = self.weight.size(0)
if self._weightNorm is None:
self._weightNorm = self.weight.new()
if self._inputNorm is None:
self._inputNorm = self.weight.new()
# y_j = (w_j * x) / ( || w_j || * || x || )
torch.norm(self.weight, 2, 1, out=self._weightNorm, keepdim=True).add_(1e-12)
batchSize = input.size(0)
nelement = self.output.nelement()
self.output.resize_(batchSize, outputSize)
if self.output.nelement() != nelement:
self.output.zero_()
self.output.addmm_(0., 1., input, self.weight.t())
torch.norm(input, 2, 1, out=self._inputNorm, keepdim=True).add_(1e-12)
self.output.div_(self._weightNorm.view(1, outputSize).expand_as(self.output))
self.output.div_(self._inputNorm.expand_as(self.output))
return self.output
评论列表
文章目录