def testPredictAsIterable(self):
"""Tests predict() and predict_proba() call with as_iterable set to True."""
def _input_fn(num_epochs=None):
features = {
'age': tf.train.limit_epochs(tf.constant([[.9], [.1], [.1]]),
num_epochs=num_epochs),
'language': tf.SparseTensor(values=['en', 'fr', 'zh'],
indices=[[0, 0], [0, 1], [2, 0]],
shape=[3, 2])
}
target = tf.constant([[1], [0], [0]], dtype=tf.int64)
return features, target
sparse_column = tf.contrib.layers.sparse_column_with_hash_bucket(
'language', hash_bucket_size=20)
feature_columns = [
tf.contrib.layers.embedding_column(sparse_column, dimension=1),
tf.contrib.layers.real_valued_column('age')
]
classifier = dnn_sampled_softmax_classifier._DNNSampledSoftmaxClassifier(
n_classes=3,
n_samples=2,
n_labels=1,
feature_columns=feature_columns,
hidden_units=[4, 4])
classifier.fit(input_fn=_input_fn, steps=1)
predict_input_fn = functools.partial(_input_fn, num_epochs=1)
# Test the output of predict() and predict_proba() with as_iterable=True
predictions = list(
classifier.predict(input_fn=predict_input_fn, as_iterable=True))
predictions_proba = list(
classifier.predict_proba(input_fn=predict_input_fn, as_iterable=True))
self.assertTrue(np.array_equal(predictions,
np.argmax(predictions_proba, 1)))
dnn_sampled_softmax_classifier_test.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录