def testOneHotColumnFromSparseColumnWithHashBucket(self):
hash_buckets = 10
ids_tensor = tf.SparseTensor(
values=["c", "b",
"a", "c", "b",
"b"],
indices=[[0, 0, 0], [0, 1, 0],
[1, 0, 0], [1, 0, 1], [1, 1, 0],
[3, 2, 0]],
shape=[4, 3, 2])
hashed_ids_column = tf.contrib.layers.sparse_column_with_hash_bucket(
"ids", hash_buckets)
one_hot_column = tf.contrib.layers.one_hot_column(hashed_ids_column)
columns_to_tensors = {"ids": ids_tensor}
model_input_tensor = tf.contrib.layers.sequence_input_from_feature_columns(
columns_to_tensors, [one_hot_column])
with self.test_session() as sess:
tf.global_variables_initializer().run()
tf.initialize_all_tables().run()
model_input = sess.run(model_input_tensor)
expected_input_shape = np.array([4, 3, hash_buckets])
self.assertAllEqual(expected_input_shape, model_input.shape)
评论列表
文章目录