def testCrossWithCrossedColumn(self):
price_bucket = tf.contrib.layers.bucketized_column(
tf.contrib.layers.real_valued_column("price"),
boundaries=[0., 10., 100.])
language = tf.contrib.layers.sparse_column_with_hash_bucket(
"language", hash_bucket_size=3)
country = tf.contrib.layers.sparse_column_with_hash_bucket(
"country", hash_bucket_size=5)
country_language = tf.contrib.layers.crossed_column(
[language, country], hash_bucket_size=10)
country_language_price = tf.contrib.layers.crossed_column(
set([country_language, price_bucket]),
hash_bucket_size=15)
with tf.Graph().as_default():
features = {
"price": tf.constant([[20.]]),
"country": tf.SparseTensor(values=["US", "SV"],
indices=[[0, 0], [0, 1]],
shape=[1, 2]),
"language": tf.SparseTensor(values=["english", "spanish"],
indices=[[0, 0], [0, 1]],
shape=[1, 2])
}
output, column_to_variable, _ = (
tf.contrib.layers.weighted_sum_from_feature_columns(
features, [country_language_price],
num_outputs=1))
with self.test_session() as sess:
tf.initialize_all_variables().run()
tf.initialize_all_tables().run()
weights = column_to_variable[country_language_price][0]
sess.run(weights.assign(weights + 0.4))
# There are two crosses each with 0.4 weight.
# score = 0.4 + 0.4 + 0.4 + 0.4
self.assertAllClose(output.eval(), [[1.6]])
评论列表
文章目录