encoding.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:PyTorch-Encoding 作者: zhanghang1989 项目源码 文件源码
def forward(self, X):
        if isinstance(X, tuple) or isinstance(X, list):
            # for self-parallel mode, please see encoding.nn
            return my_data_parallel(self, X)
        elif not isinstance(X, Variable):
            raise RuntimeError('unknown input type')
        # input X is a 4D tensor
        assert(X.size(1)==self.D)
        if X.dim() == 3:
            # BxDxN
            B, N, K, D = X.size(0), X.size(2), self.K, self.D
            X = X.transpose(1,2).contiguous()
        elif X.dim() == 4:
            # BxDxHxW
            B, N, K, D = X.size(0), X.size(2)*X.size(3), self.K, self.D
            X = X.view(B,D,-1).transpose(1,2).contiguous()
        else:
            raise RuntimeError('Encoding Layer unknown input dims!')
        # assignment weights NxKxD
        A = F.softmax(scaledL2(X, self.codewords, self.scale), dim=1)
        # aggregate
        E = aggregate(A, X, self.codewords)
        return E
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号