gan_things.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:sourceseparation_misc 作者: ycemsubakan 项目源码 文件源码
def forward(self, inp):
        #if inp.dim() > 2:
        #    inp = inp.permute(0, 2, 1)
        #inp = inp.contiguous().view(-1, self.L) 

        if not (type(inp) == Variable):
            inp = Variable(inp[0])

        if hasattr(self.arguments, 'pack_num'):
            N = inp.size(0)
            Ncut = int(N/self.arguments.pack_num)
            split = torch.split(inp, Ncut, dim=0)
            inp = torch.cat(split, dim=1)

        h1 = F.tanh((self.l1(inp)))

        #h2 = F.tanh(self.l2_bn(self.l2(h1)))

        if self.arguments.tr_method == 'adversarial_wasserstein':
            output = (self.l3(h1))
        else:
            output = F.sigmoid(self.l3(h1))

        return output, h1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号