def train(self, lr, iters, batch_size = 256):
optimizer = optim.Adam(self.parameters(), lr=lr)
t = trange(iters)
for i in t:
optimizer.zero_grad()
inds = torch.floor(torch.rand(batch_size) * self.M).long().cuda()
# bug: floor(rand()) sometimes gives 1
inds[inds >= self.M] = self.M - 1
inds = Variable(inds)
loss = self.forward(inds)
# print loss.data[0]
t.set_description( str(loss.data[0]) )
loss.backward()
optimizer.step()
return self.state_model, self.goal_model
评论列表
文章目录