def forward(self, inds):
state_inp = self.state_inp.index_select(0, inds)
state_out = self.state_model.forward(state_inp)
goal_out = self.goal_model.forward(self.goal_inp)
recon = torch.mm(state_out, goal_out.t())
mask_select = self.mask.index_select(0, inds)
true_select = self.mat.index_select(0, inds)
# pdb.set_trace()
diff = torch.pow(recon - true_select, 2)
mse = diff.sum()
return mse
评论列表
文章目录