dnn_sampled_softmax_classifier_test.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testCustomMetrics(self):
    """Tests the use of custom metric."""
    def _input_fn():
      features = {
          'language': tf.SparseTensor(values=['en', 'fr', 'zh'],
                                      indices=[[0, 0], [0, 1], [2, 0]],
                                      shape=[3, 2])
      }
      target = tf.constant([[1], [0], [0]], dtype=tf.int64)
      return features, target

    def _my_metric_op(predictions, targets):
      """Simply multiplies predictions and targets to return [1, 0 , 0]."""
      prediction_classes = math_ops.argmax(predictions, 1)
      return tf.mul(prediction_classes, tf.reshape(targets, [-1]))

    sparse_column = tf.contrib.layers.sparse_column_with_hash_bucket(
        'language', hash_bucket_size=20)
    embedding_features = [
        tf.contrib.layers.embedding_column(sparse_column, dimension=1)
    ]

    classifier = dnn_sampled_softmax_classifier._DNNSampledSoftmaxClassifier(
        n_classes=3,
        n_samples=2,
        n_labels=1,
        feature_columns=embedding_features,
        hidden_units=[4, 4],
        optimizer=tf.train.AdamOptimizer(learning_rate=0.01),
        config=tf.contrib.learn.RunConfig(tf_random_seed=5))

    # Test that the model actually trains.
    classifier.fit(input_fn=_input_fn, steps=50)
    metrics = {('my_metric', 'probabilities'): _my_metric_op}
    evaluate_output = classifier.evaluate(input_fn=_input_fn, steps=1,
                                          metrics=metrics)
    self.assertListEqual([1, 0, 0], list(evaluate_output['my_metric']))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号