grid_search_test.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testIrisDNN(self):
    if HAS_SKLEARN:
      random.seed(42)
      iris = datasets.load_iris()
      feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
      classifier = learn.DNNClassifier(
          feature_columns=feature_columns,
          hidden_units=[10, 20, 10],
          n_classes=3)
      grid_search = GridSearchCV(
          classifier, {'hidden_units': [[5, 5], [10, 10]]},
          scoring='accuracy',
          fit_params={'steps': [50]})
      grid_search.fit(iris.data, iris.target)
      score = accuracy_score(iris.target, grid_search.predict(iris.data))
      self.assertGreater(score, 0.5, 'Failed with score = {0}'.format(score))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号