Cosine.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码
def updateOutput(self, input):
        assert input.dim() == 2

        inputSize = self.weight.size(1)
        outputSize = self.weight.size(0)

        if self._weightNorm is None:
            self._weightNorm = self.weight.new()
        if self._inputNorm is None:
            self._inputNorm = self.weight.new()

        # y_j = (w_j * x) / ( || w_j || * || x || )

        torch.norm(self.weight, 2, 1, out=self._weightNorm, keepdim=True).add_(1e-12)

        batchSize = input.size(0)
        nelement = self.output.nelement()
        self.output.resize_(batchSize, outputSize)
        if self.output.nelement() != nelement:
            self.output.zero_()

        self.output.addmm_(0., 1., input, self.weight.t())

        torch.norm(input, 2, 1, out=self._inputNorm, keepdim=True).add_(1e-12)
        self.output.div_(self._weightNorm.view(1, outputSize).expand_as(self.output))
        self.output.div_(self._inputNorm.expand_as(self.output))
        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号