estimator_test.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testIrisAllDictionaryInput(self):
    iris = base.load_iris()
    est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
    iris_data = {'input': iris.data}
    iris_target = {'labels': iris.target}
    est.fit(iris_data, iris_target, steps=100)
    scores = est.evaluate(
        x=iris_data,
        y=iris_target,
        metrics={('accuracy', 'class'): metric_ops.streaming_accuracy})
    predictions = list(est.predict(x=iris_data))
    predictions_class = list(est.predict(x=iris_data, outputs=['class']))
    self.assertEqual(len(predictions), iris.target.shape[0])
    classes_batch = np.array([p['class'] for p in predictions])
    self.assertAllClose(classes_batch,
                        np.array([p['class'] for p in predictions_class]))
    self.assertAllClose(
        classes_batch,
        np.argmax(
            np.array([p['prob'] for p in predictions]), axis=1))
    other_score = _sklearn.accuracy_score(iris.target, classes_batch)
    self.assertAllClose(other_score, scores['accuracy'])
    self.assertTrue('global_step' in scores)
    self.assertEqual(scores['global_step'], 100)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号