def testIrisAllDictionaryInput(self):
iris = base.load_iris()
est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
iris_data = {'input': iris.data}
iris_target = {'labels': iris.target}
est.fit(iris_data, iris_target, steps=100)
scores = est.evaluate(
x=iris_data,
y=iris_target,
metrics={('accuracy', 'class'): metric_ops.streaming_accuracy})
predictions = list(est.predict(x=iris_data))
predictions_class = list(est.predict(x=iris_data, outputs=['class']))
self.assertEqual(len(predictions), iris.target.shape[0])
classes_batch = np.array([p['class'] for p in predictions])
self.assertAllClose(classes_batch,
np.array([p['class'] for p in predictions_class]))
self.assertAllClose(
classes_batch,
np.argmax(
np.array([p['prob'] for p in predictions]), axis=1))
other_score = _sklearn.accuracy_score(iris.target, classes_batch)
self.assertAllClose(other_score, scores['accuracy'])
self.assertTrue('global_step' in scores)
self.assertEqual(scores['global_step'], 100)
estimator_test.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录