def _train(self, model, optimizer, train_iter, log_interval, logger, start_time):
model.train()
for iteration, batch in enumerate(tqdm(train_iter, desc='this epoch'), 1):
image, pose, visibility = Variable(batch[0]), Variable(batch[1]), Variable(batch[2])
if self.gpu:
image, pose, visibility = image.cuda(), pose.cuda(), visibility.cuda()
optimizer.zero_grad()
output = model(image)
loss = mean_squared_error(output, pose, visibility, self.use_visibility)
loss.backward()
optimizer.step()
if iteration % log_interval == 0:
log = 'elapsed_time: {0}, loss: {1}'.format(time.time() - start_time, loss.data[0])
logger.write(log)
评论列表
文章目录