def forward(self, inp):
#if inp.dim() > 2:
# inp = inp.permute(0, 2, 1)
#inp = inp.contiguous().view(-1, self.L)
if not (type(inp) == Variable):
inp = Variable(inp[0])
if hasattr(self.arguments, 'pack_num'):
N = inp.size(0)
Ncut = int(N/self.arguments.pack_num)
split = torch.split(inp, Ncut, dim=0)
inp = torch.cat(split, dim=1)
h1 = F.tanh((self.l1(inp)))
#h2 = F.tanh(self.l2_bn(self.l2(h1)))
if self.arguments.tr_method == 'adversarial_wasserstein':
output = (self.l3(h1))
else:
output = F.sigmoid(self.l3(h1))
return output, h1
评论列表
文章目录