model_factorizer.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:spatial-reasoning 作者: JannerM 项目源码 文件源码
def train(self, lr, iters, batch_size = 256):
        optimizer = optim.Adam(self.parameters(), lr=lr)

        t = trange(iters)
        for i in t:
            optimizer.zero_grad()
            inds = torch.floor(torch.rand(batch_size) * self.M).long().cuda()
            # bug: floor(rand()) sometimes gives 1
            inds[inds >= self.M] = self.M - 1
            inds = Variable(inds)

            loss = self.forward(inds)
            # print loss.data[0]
            t.set_description( str(loss.data[0]) )
            loss.backward()
            optimizer.step()

        return self.state_model, self.goal_model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号