def multi_perspective_expand_for_2D(in_tensor, decompose_params):
"""
Return: [batch_size, decompse_dim, dim]
"""
in_tensor = in_tensor.unsqueeze(1) #[batch_size, 'x', dim]
decompose_params = decompose_params.unsqueeze(0) # [1, decompse_dim, dim]
return torch.mul(in_tensor, decompose_params)
评论列表
文章目录