dnn_sampled_softmax_classifier_test.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testPredictAsIterable(self):
    """Tests predict() and predict_proba() call with as_iterable set to True."""
    def _input_fn(num_epochs=None):
      features = {
          'age': tf.train.limit_epochs(tf.constant([[.9], [.1], [.1]]),
                                       num_epochs=num_epochs),
          'language': tf.SparseTensor(values=['en', 'fr', 'zh'],
                                      indices=[[0, 0], [0, 1], [2, 0]],
                                      shape=[3, 2])
      }
      target = tf.constant([[1], [0], [0]], dtype=tf.int64)
      return features, target

    sparse_column = tf.contrib.layers.sparse_column_with_hash_bucket(
        'language', hash_bucket_size=20)
    feature_columns = [
        tf.contrib.layers.embedding_column(sparse_column, dimension=1),
        tf.contrib.layers.real_valued_column('age')
    ]

    classifier = dnn_sampled_softmax_classifier._DNNSampledSoftmaxClassifier(
        n_classes=3,
        n_samples=2,
        n_labels=1,
        feature_columns=feature_columns,
        hidden_units=[4, 4])

    classifier.fit(input_fn=_input_fn, steps=1)

    predict_input_fn = functools.partial(_input_fn, num_epochs=1)
    # Test the output of predict() and predict_proba() with as_iterable=True
    predictions = list(
        classifier.predict(input_fn=predict_input_fn, as_iterable=True))
    predictions_proba = list(
        classifier.predict_proba(input_fn=predict_input_fn, as_iterable=True))
    self.assertTrue(np.array_equal(predictions,
                                   np.argmax(predictions_proba, 1)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号