def forward(self, L, z):
'''
:param L: batch_size (B) x latent_size^2 (L^2)
:param z: batch_size (B) x latent_size (L)
:return: z_new = L*z
'''
# L->tril(L)
L_matrix = L.view( -1, self.args.z1_size, self.args.z1_size ) # resize to get B x L x L
LTmask = torch.tril( torch.ones(self.args.z1_size, self.args.z1_size), k=-1 ) # lower-triangular mask matrix (1s in lower triangular part)
I = Variable( torch.eye(self.args.z1_size, self.args.z1_size).expand(L_matrix.size(0), self.args.z1_size, self.args.z1_size) )
if self.args.cuda:
LTmask = LTmask.cuda()
I = I.cuda()
LTmask = Variable(LTmask)
LTmask = LTmask.unsqueeze(0).expand( L_matrix.size(0), self.args.z1_size, self.args.z1_size ) # 1 x L x L -> B x L x L
LT = torch.mul( L_matrix, LTmask ) + I # here we get a batch of lower-triangular matrices with ones on diagonal
# z_new = L * z
z_new = torch.bmm( LT , z.unsqueeze(2) ).squeeze(2) # B x L x L * B x L x 1 -> B x L
return z_new
评论列表
文章目录