def u_duvenaud(self, h_v, m_v, opt):
param_sz = self.learn_args[0][opt['deg']].size()
parameter_mat = torch.t(self.learn_args[0][opt['deg']])[None, ...].expand(m_v.size(0), param_sz[1], param_sz[0])
aux = torch.bmm(parameter_mat, torch.transpose(m_v, 1, 2))
return torch.transpose(torch.nn.Sigmoid()(aux), 1, 2)
评论列表
文章目录