def _print_train_val(self):
""" Print training and validation information """
# evaluate training performance
ClassificationSW._print_train_val(self)
cur_iter = self._cur_iter
cur_round = self._cur_round
# display training errors
if cur_iter % cfg.TRAIN.TRAIN_FREQ == 0:
err_train = self._err_mean
print 'Round {}, Iteration {}: training error = {}'.format(cur_round, cur_iter, err_train.mean())
# display validation errors
if cur_iter % cfg.TRAIN.VAL_FREQ == 0:
# perform validation
err_val = np.zeros((cfg.TRAIN.VAL_SIZE * cfg.TRAIN.IMS_PER_BATCH, ))
for i in xrange(cfg.TRAIN.VAL_SIZE * cfg.TRAIN.IMS_PER_BATCH):
self._solver.test_nets[0].forward()
err_val[i,:] = 1.0 - self._solver.test_nets[0].blobs['acc'].data
err_val = np.nanmean(err_val, axis=0)
print 'Round {}, Iteration {}: validation error = {}'.format(cur_round, cur_iter, np.nanmean(err_val))
评论列表
文章目录