linear_test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testSdcaOptimizerSparseFeatures(self):
    """Tests LinearClasssifier with SDCAOptimizer and sparse features."""

    def input_fn():
      return {
          'example_id': tf.constant(['1', '2', '3']),
          'price': tf.constant([[0.4], [0.6], [0.3]]),
          'country': tf.SparseTensor(values=['IT', 'US', 'GB'],
                                     indices=[[0, 0], [1, 3], [2, 1]],
                                     shape=[3, 5]),
          'weights': tf.constant([[1.0], [1.0], [1.0]])
      }, tf.constant([[1], [0], [1]])

    price = tf.contrib.layers.real_valued_column('price')
    country = tf.contrib.layers.sparse_column_with_hash_bucket(
        'country', hash_bucket_size=5)
    sdca_optimizer = tf.contrib.linear_optimizer.SDCAOptimizer(
        example_id_column='example_id')
    classifier = tf.contrib.learn.LinearClassifier(
        feature_columns=[price, country],
        weight_column_name='weights',
        optimizer=sdca_optimizer)
    classifier.fit(input_fn=input_fn, steps=50)
    scores = classifier.evaluate(input_fn=input_fn, steps=1)
    self.assertGreater(scores['accuracy'], 0.9)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号