def train_model(self, max_iters):
"""Network training loop."""
last_snapshot_iter = -1
timer = Timer()
while self.solver.iter < max_iters:
# Make one SGD update
timer.tic()
self.solver.step(1)
timer.toc()
if self.solver.iter % (10 * self.solver_param.display) == 0:
print 'speed: {:.3f}s / iter'.format(timer.average_time)
if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
self.snapshot()
if last_snapshot_iter != self.solver.iter:
self.snapshot()
评论列表
文章目录