sklearn_module.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:srep 作者: Answeror 项目源码 文件源码
def fit(self, train_data, eval_data, eval_metric='acc', **kargs):
        snapshot = kargs.pop('snapshot')
        self.clf.fit(*self._get_data_label(train_data))
        jb.dump(self.clf, snapshot + '-0001.params')

        if not isinstance(eval_metric, mx.metric.EvalMetric):
            eval_metric = mx.metric.create(eval_metric)
        data, label = self._get_data_label(eval_data)
        pred = self.clf.predict(data).astype(np.int64)
        prob = np.zeros((len(pred), pred.max() + 1))
        prob[np.arange(len(prob)), pred] = 1
        eval_metric.update([mx.nd.array(label)], [mx.nd.array(prob)])
        for name, val in eval_metric.get_name_value():
            logger.info('Epoch[0] Validation-{}={}', name, val)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号