dnn_sampled_softmax_classifier_test.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testMultiLabelTopKWithCustomMetrics(self):
    """Tests the cases where n_labels>1 top_k>1 and custom metrics on top_k."""
    def _input_fn():
      features = {
          'language': tf.SparseTensor(values=['en', 'fr', 'zh'],
                                      indices=[[0, 0], [0, 1], [2, 0]],
                                      shape=[3, 2])
      }
      target = tf.constant([[0, 1], [0, 1], [0, 1]], dtype=tf.int64)
      return features, target

    def _my_metric_op(predictions, targets):
      """Simply adds the predictions and targets."""
      return tf.add(math_ops.to_int64(predictions), targets)

    sparse_column = tf.contrib.layers.sparse_column_with_hash_bucket(
        'language', hash_bucket_size=20)
    embedding_features = [
        tf.contrib.layers.embedding_column(sparse_column, dimension=1)
    ]

    classifier = dnn_sampled_softmax_classifier._DNNSampledSoftmaxClassifier(
        n_classes=3,
        n_samples=2,
        n_labels=2,
        top_k=2,
        feature_columns=embedding_features,
        hidden_units=[4, 4],
        optimizer=tf.train.AdamOptimizer(learning_rate=0.01),
        config=tf.contrib.learn.RunConfig(tf_random_seed=5))

    classifier.fit(input_fn=_input_fn, steps=50)
    # evaluate() without custom metrics.
    evaluate_output = classifier.evaluate(input_fn=_input_fn, steps=1)
    self.assertGreater(evaluate_output['precision_at_1'], 0.4)
    self.assertGreater(evaluate_output['recall_at_1'], 0.4)
    self.assertGreater(evaluate_output['precision_at_2'], 0.4)
    self.assertGreater(evaluate_output['recall_at_2'], 0.4)

    # evaluate() with custom metrics.
    metrics = {('my_metric', 'top_k'): _my_metric_op}
    evaluate_output = classifier.evaluate(input_fn=_input_fn, steps=1,
                                          metrics=metrics)
    # This test's output is flaky so just testing that 'my_metric' is indeed
    # part of the evaluate_output.
    self.assertTrue('my_metric' in evaluate_output)

    # predict() with top_k.
    predict_output = classifier.predict(input_fn=_input_fn, get_top_k=True)
    self.assertListEqual([3, 2], list(predict_output.shape))
    # TODO(dnivara): Setup this test such that it is not flaky and predict() and
    # evaluate() outputs can be tested.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号