def testCrossColumn(self):
language = tf.contrib.layers.sparse_column_with_hash_bucket(
"language", hash_bucket_size=3)
country = tf.contrib.layers.sparse_column_with_hash_bucket(
"country", hash_bucket_size=5)
country_language = tf.contrib.layers.crossed_column(
[language, country], hash_bucket_size=15)
features = {
"language": tf.SparseTensor(values=["english", "spanish"],
indices=[[0, 0], [1, 0]],
shape=[2, 1]),
"country": tf.SparseTensor(values=["US", "SV"],
indices=[[0, 0], [1, 0]],
shape=[2, 1])
}
output = feature_column_ops._Transformer(features).transform(
country_language)
with self.test_session():
self.assertEqual(output.values.dtype, tf.int64)
self.assertTrue(all(x < 15 and x >= 0 for x in output.values.eval()))
评论列表
文章目录