PartialLinear.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def __init__(self, inputsize, outputsize, bias=True):
        super(PartialLinear, self).__init__()

        # define the layer as a small network:
        pt = ParallelTable()
        pt.add(Identity()).add(LookupTable(outputsize, inputsize))
        self.network = Sequential().add(pt).add(MM(False, True))
        if bias:
            self.bias     = torch.zeros(1, outputsize)
            self.gradBias = torch.zeros(1, outputsize)
        else:
            self.bias = self.gradBias = None

        # set partition:
        self.inputsize  = inputsize
        self.outputsize = outputsize
        self.allcolumns = torch.range(0, self.outputsize-1).long()
        self.resetPartition()
        self.addBuffer = None
        self.buffer = None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号