def u_intnet(self, h_v, m_v, opt):
if opt['x_v'].ndimension():
input_tensor = torch.cat([h_v, opt['x_v'], torch.squeeze(m_v)], 1)
else:
input_tensor = torch.cat([h_v, torch.squeeze(m_v)], 1)
return self.learn_modules[0](input_tensor)
评论列表
文章目录