estimator_test.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testIrisAll(self):
    iris = base.load_iris()
    est = estimator.SKCompat(
        estimator.Estimator(model_fn=logistic_model_no_mode_fn))
    est.fit(iris.data, iris.target, steps=100)
    scores = est.score(
        x=iris.data,
        y=iris.target,
        metrics={('accuracy', 'class'): metric_ops.streaming_accuracy})
    predictions = est.predict(x=iris.data)
    predictions_class = est.predict(x=iris.data, outputs=['class'])['class']
    self.assertEqual(predictions['prob'].shape[0], iris.target.shape[0])
    self.assertAllClose(predictions['class'], predictions_class)
    self.assertAllClose(
        predictions['class'], np.argmax(
            predictions['prob'], axis=1))
    other_score = _sklearn.accuracy_score(iris.target, predictions['class'])
    self.assertAllClose(scores['accuracy'], other_score)
    self.assertTrue('global_step' in scores)
    self.assertEqual(100, scores['global_step'])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号