def testIrisDNN(self):
if HAS_SKLEARN:
random.seed(42)
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3)
grid_search = GridSearchCV(
classifier, {'hidden_units': [[5, 5], [10, 10]]},
scoring='accuracy',
fit_params={'steps': [50]})
grid_search.fit(iris.data, iris.target)
score = accuracy_score(iris.target, grid_search.predict(iris.data))
self.assertGreater(score, 0.5, 'Failed with score = {0}'.format(score))
grid_search_test.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录