def train_model(self, max_iters):
"""Network training loop."""
last_snapshot_iter = -1
timer = Timer()
model_paths = []
while self.solver.iter < max_iters:
# Make one SGD update
timer.tic()
self.solver.step(1)
timer.toc()
if self.solver.iter % (10 * self.solver_param.display) == 0:
sys.stderr.write('rank: {} iteration: {} speed: {:.3f}s / iter\n'.format(self.rank, self.solver.iter, timer.average_time))
if self.rank == 0 and self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
model_paths.append(self.snapshot())
if self.rank == 0 and last_snapshot_iter != self.solver.iter:
model_paths.append(self.snapshot())
return model_paths
train.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录