classification_sw.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:deep_share 作者: luyongxi 项目源码 文件源码
def _print_train_val(self):
        """ Print training and validation information """

        ClassificationSW._print_train_val(self)

        cur_iter = self._cur_iter
        cur_round = self._cur_round

        # display training errors
        if cur_iter % cfg.TRAIN.TRAIN_FREQ == 0:
            err_train = self._err_mean
            print 'Round {}, Iteration {}: training error = {}'.format(cur_round, cur_iter, err_train.mean())
            # if self._model_params.model is not None:
            #     print 'err_corr: {}'.format(self._err_corr)

        # display validation errors
        if cur_iter % cfg.TRAIN.VAL_FREQ == 0:
            # perform validation
            err_val = np.zeros((cfg.TRAIN.VAL_SIZE * cfg.TRAIN.IMS_PER_BATCH, self._num_classes))
            for i in xrange(cfg.TRAIN.VAL_SIZE * cfg.TRAIN.IMS_PER_BATCH):
                self._solver.test_nets[0].forward()
                err_val[i,:] = (self._solver.test_nets[0].blobs['error'].data > 0.5)
            err_val = np.nanmean(err_val, axis=0)
            print 'Round {}, Iteration {}: validation error = {}'.format(cur_round, cur_iter, np.nanmean(err_val))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号