Euclidean.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def updateOutput(self, input):
        # lazy initialize buffers
        self._input = self._input or input.new()
        self._weight = self._weight or self.weight.new()
        self._expand = self._expand or self.output.new()
        self._expand2 = self._expand2 or self.output.new()
        self._repeat = self._repeat or self.output.new()
        self._repeat2 = self._repeat2 or self.output.new()

        inputSize, outputSize = self.weight.size(0), self.weight.size(1)

        # y_j = || w_j - x || = || x - w_j ||
        assert input.dim() == 2

        batchSize = input.size(0)
        self._view(self._input, input, batchSize, inputSize, 1)
        self._expand = self._input.expand(batchSize, inputSize, outputSize)
        # make the expanded tensor contiguous (requires lots of memory)
        self._repeat.resize_as_(self._expand).copy_(self._expand)

        self._weight = self.weight.view(1, inputSize, outputSize)
        self._expand2 = self._weight.expand_as(self._repeat)

        if torch.typename(input) == 'torch.cuda.FloatTensor':
            # TODO: after adding new allocators this can be changed
            # requires lots of memory, but minimizes cudaMallocs and loops
            self._repeat2.resize_as_(self._expand2).copy_(self._expand2)
            self._repeat.add_(-1, self._repeat2)
        else:
            self._repeat.add_(-1, self._expand2)

        torch.norm(self.output, self._repeat, 2, 1)
        self.output.resize_(batchSize, outputSize)

        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号