def forward(self, v, z):
'''
:param v: batch_size (B) x latent_size (L)
:param z: batch_size (B) x latent_size (L)
:return: z_new = z - 2* v v_T / norm(v,2) * z
'''
# v * v_T
vvT = torch.bmm( v.unsqueeze(2), v.unsqueeze(1) ) # v * v_T : batch_dot( B x L x 1 * B x 1 x L ) = B x L x L
# v * v_T * z
vvTz = torch.bmm( vvT, z.unsqueeze(2) ).squeeze(2) # A * z : batchdot( B x L x L * B x L x 1 ).squeeze(2) = (B x L x 1).squeeze(2) = B x L
# calculate norm ||v||^2
norm_sq = torch.sum( v * v, 1 ) # calculate norm-2 for each row : B x 1
norm_sq = norm_sq.expand( norm_sq.size(0), v.size(1) ) # expand sizes : B x L
# calculate new z
z_new = z - 2 * vvTz / norm_sq # z - 2 * v * v_T * z / norm2(v)
return z_new
评论列表
文章目录