model_factorizer.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:spatial-reasoning 作者: JannerM 项目源码 文件源码
def forward(self, inds):
        state_inp = self.state_inp.index_select(0, inds)
        state_out = self.state_model.forward(state_inp)
        goal_out = self.goal_model.forward(self.goal_inp)

        recon = torch.mm(state_out, goal_out.t())
        mask_select = self.mask.index_select(0, inds)
        true_select = self.mat.index_select(0, inds)

        # pdb.set_trace()

        diff = torch.pow(recon - true_select, 2)

        mse = diff.sum()

        return mse
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号