SMASH.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:SMASH 作者: ajbrock 项目源码 文件源码
def forward(self,x):

        # Stem convolution
        out = self.conv1(x)

        # Allocate memory banks

        m = [[None for _ in range(d)] for d in self.D]
        module_index = 0
        for i,(incoming_channels,outgoing_channels,g_values, bs, trans) in enumerate(zip(
                self.incoming,self.outgoing, self.G, self.bank_sizes, [self.trans1,self.trans2,None])):

            # Write to initial memory banks
            for j in range(out.size(1) // (bs * self.N) ):
                m[i][j] = out[:, j * bs * self.N : (j + 1) * bs * self.N]

            for read,write,g in zip(incoming_channels,outgoing_channels,g_values):
                # Cat read tensors
                inp = torch.cat([m[i][index] for index in read], 1)

                # Apply module and increment op index
                out = self.mod[module_index](inp)
                module_index += 1

                for j, w in enumerate(write):
                    # Allocate dat memory if it's None
                    if m[i][w] is None:
                        m[i][w] = out[:, (j % (g // bs)) * (bs * self.N) : (j % (g // bs) + 1) * (bs * self.N)]
                    # Else, if already written, add to it. 
                    else:
                        m[i][w] = m[i][w] + out[:, (j % (g // bs)) * (bs * self.N) : (j % (g // bs) + 1) * (bs * self.N)]


            if trans is not None:
                out = trans(torch.cat(m[i], 1))
            else:
                out = torch.cat(m[i], 1)

        out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), out.size(2)))
        out = F.log_softmax(self.fc(out))
        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号