def _print_train_val(self):
""" Print training and validation information """
ClassificationSW._print_train_val(self)
cur_iter = self._cur_iter
cur_round = self._cur_round
# display training errors
if cur_iter % cfg.TRAIN.TRAIN_FREQ == 0:
err_train = self._err_mean
print 'Round {}, Iteration {}: training error = {}'.format(cur_round, cur_iter, err_train.mean())
# if self._model_params.model is not None:
# print 'err_corr: {}'.format(self._err_corr)
# display validation errors
if cur_iter % cfg.TRAIN.VAL_FREQ == 0:
# perform validation
err_val = np.zeros((cfg.TRAIN.VAL_SIZE * cfg.TRAIN.IMS_PER_BATCH, self._num_classes))
for i in xrange(cfg.TRAIN.VAL_SIZE * cfg.TRAIN.IMS_PER_BATCH):
self._solver.test_nets[0].forward()
err_val[i,:] = (self._solver.test_nets[0].blobs['error'].data > 0.5)
err_val = np.nanmean(err_val, axis=0)
print 'Round {}, Iteration {}: validation error = {}'.format(cur_round, cur_iter, np.nanmean(err_val))
评论列表
文章目录