WeightedEuclidean.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def updateOutput(self, input):
        # lazy-initialize
        self._diagCov = self._diagCov or self.output.new()

        self._input = self._input or input.new()
        self._weight = self._weight or self.weight.new()
        self._expand = self._expand or self.output.new()
        self._expand2 = self._expand or self.output.new()
        self._expand3 = self._expand3 or self.output.new()
        self._repeat = self._repeat or self.output.new()
        self._repeat2 = self._repeat2 or self.output.new()
        self._repeat3 = self._repeat3 or self.output.new()

        inputSize, outputSize = self.weight.size(0), self.weight.size(1)

        # y_j = || c_j * (w_j - x) ||
        if input.dim() == 1:
            self._view(self._input, input, inputSize, 1)
            self._expand.expand_as(self._input, self.weight)
            self._repeat.resize_as_(self._expand).copy_(self._expand)
            self._repeat.add_(-1, self.weight)
            self._repeat.mul_(self.diagCov)
            torch.norm(self.output, self._repeat, 2, 0)
            self.output.resize_(outputSize)
        elif input.dim() == 2:
            batchSize = input.size(0)

            self._view(self._input, input, batchSize, inputSize, 1)
            self._expand = self._input.expand(batchSize, inputSize, outputSize)
            # make the expanded tensor contiguous (requires lots of memory)
            self._repeat.resize_as_(self._expand).copy_(self._expand)

            self._weight = self.weight.view(1, inputSize, outputSize)
            self._expand2 = self._weight.expand_as(self._repeat)

            self._diagCov = self.diagCov.view(1, inputSize, outputSize)
            self._expand3 = self._diagCov.expand_as(self._repeat)
            if input.type() == 'torch.cuda.FloatTensor':
                # TODO: this can be fixed with a custom allocator
                # requires lots of memory, but minimizes cudaMallocs and loops
                self._repeat2.resize_as_(self._expand2).copy_(self._expand2)
                self._repeat.add_(-1, self._repeat2)
                self._repeat3.resize_as_(self._expand3).copy_(self._expand3)
                self._repeat.mul_(self._repeat3)
            else:
                self._repeat.add_(-1, self._expand2)
                self._repeat.mul_(self._expand3)


            torch.norm(self.output, self._repeat, 2, 1)
            self.output.resize_(batchSize, outputSize)
        else:
           raise RuntimeError("1D or 2D input expected")

        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号