def fit(self, train_data, eval_data, eval_metric='acc', **kargs):
snapshot = kargs.pop('snapshot')
self.clf.fit(*self._get_data_label(train_data))
jb.dump(self.clf, snapshot + '-0001.params')
if not isinstance(eval_metric, mx.metric.EvalMetric):
eval_metric = mx.metric.create(eval_metric)
data, label = self._get_data_label(eval_data)
pred = self.clf.predict(data).astype(np.int64)
prob = np.zeros((len(pred), pred.max() + 1))
prob[np.arange(len(prob)), pred] = 1
eval_metric.update([mx.nd.array(label)], [mx.nd.array(prob)])
for name, val in eval_metric.get_name_value():
logger.info('Epoch[0] Validation-{}={}', name, val)
评论列表
文章目录