def forward(self, inp):
#if inp.dim() > 2:
# inp = inp.permute(0, 2, 1)
#inp = inp.contiguous().view(-1, self.L1)
if not (type(inp) == Variable):
inp = Variable(inp[0])
if self.arguments.tr_method in ['adversarial', 'adversarial_wasserstein']:
h = F.softplus((self.l1(inp)))
elif self.arguments.tr_method == 'ML':
h = F.softplus((self.l1(inp)))
else:
raise ValueError('Whaat method?')
output = F.softplus(self.l2(h))
if self.smooth_output:
output = output.view(-1, 1, int(np.sqrt(self.L2)), int(np.sqrt(self.L2)))
output = F.softplus(self.sml(output))
output = output.view(-1, self.L2)
return output
评论列表
文章目录