def testSparseFeatures(self):
"""Tests SVM classifier with (hashed) sparse features."""
def input_fn():
return {
'example_id': tf.constant(['1', '2', '3']),
'price': tf.constant([[0.8], [0.6], [0.3]]),
'country': tf.SparseTensor(
values=['IT', 'US', 'GB'],
indices=[[0, 0], [1, 0], [2, 0]],
shape=[3, 1]),
}, tf.constant([[0], [1], [1]])
price = tf.contrib.layers.real_valued_column('price')
country = tf.contrib.layers.sparse_column_with_hash_bucket(
'country', hash_bucket_size=5)
svm_classifier = tf.contrib.learn.SVM(feature_columns=[price, country],
example_id_column='example_id',
l1_regularization=0.0,
l2_regularization=1.0)
svm_classifier.fit(input_fn=input_fn, steps=30)
accuracy = svm_classifier.evaluate(input_fn=input_fn, steps=1)['accuracy']
self.assertAlmostEqual(accuracy, 1.0, places=3)
评论列表
文章目录