svm_test.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testSparseFeatures(self):
    """Tests SVM classifier with (hashed) sparse features."""

    def input_fn():
      return {
          'example_id': tf.constant(['1', '2', '3']),
          'price': tf.constant([[0.8], [0.6], [0.3]]),
          'country': tf.SparseTensor(
              values=['IT', 'US', 'GB'],
              indices=[[0, 0], [1, 0], [2, 0]],
              shape=[3, 1]),
      }, tf.constant([[0], [1], [1]])

    price = tf.contrib.layers.real_valued_column('price')
    country = tf.contrib.layers.sparse_column_with_hash_bucket(
        'country', hash_bucket_size=5)
    svm_classifier = tf.contrib.learn.SVM(feature_columns=[price, country],
                                          example_id_column='example_id',
                                          l1_regularization=0.0,
                                          l2_regularization=1.0)
    svm_classifier.fit(input_fn=input_fn, steps=30)
    accuracy = svm_classifier.evaluate(input_fn=input_fn, steps=1)['accuracy']
    self.assertAlmostEqual(accuracy, 1.0, places=3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号