def testIrisAll(self):
iris = base.load_iris()
est = estimator.SKCompat(
estimator.Estimator(model_fn=logistic_model_no_mode_fn))
est.fit(iris.data, iris.target, steps=100)
scores = est.score(
x=iris.data,
y=iris.target,
metrics={('accuracy', 'class'): metric_ops.streaming_accuracy})
predictions = est.predict(x=iris.data)
predictions_class = est.predict(x=iris.data, outputs=['class'])['class']
self.assertEqual(predictions['prob'].shape[0], iris.target.shape[0])
self.assertAllClose(predictions['class'], predictions_class)
self.assertAllClose(
predictions['class'], np.argmax(
predictions['prob'], axis=1))
other_score = _sklearn.accuracy_score(iris.target, predictions['class'])
self.assertAllClose(scores['accuracy'], other_score)
self.assertTrue('global_step' in scores)
self.assertEqual(100, scores['global_step'])
estimator_test.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录