def g(self, tilde_z_l, u_l):
if self.use_cuda:
ones = Parameter(torch.ones(tilde_z_l.size()[0], 1).cuda())
else:
ones = Parameter(torch.ones(tilde_z_l.size()[0], 1))
b_a1 = ones.mm(self.a1)
b_a2 = ones.mm(self.a2)
b_a3 = ones.mm(self.a3)
b_a4 = ones.mm(self.a4)
b_a5 = ones.mm(self.a5)
b_a6 = ones.mm(self.a6)
b_a7 = ones.mm(self.a7)
b_a8 = ones.mm(self.a8)
b_a9 = ones.mm(self.a9)
b_a10 = ones.mm(self.a10)
mu_l = torch.mul(b_a1, torch.sigmoid(torch.mul(b_a2, u_l) + b_a3)) + \
torch.mul(b_a4, u_l) + \
b_a5
v_l = torch.mul(b_a6, torch.sigmoid(torch.mul(b_a7, u_l) + b_a8)) + \
torch.mul(b_a9, u_l) + \
b_a10
hat_z_l = torch.mul(tilde_z_l - mu_l, v_l) + mu_l
return hat_z_l
评论列表
文章目录