def testParseExample(self):
bucket = tf.contrib.layers.bucketized_column(
tf.contrib.layers.real_valued_column("price", dimension=3),
boundaries=[0., 10., 100.])
wire_cast = tf.contrib.layers.sparse_column_with_keys(
"wire_cast", ["marlo", "omar", "stringer"])
# buckets 2, 3, 0
data = tf.train.Example(features=tf.train.Features(feature={
"price": tf.train.Feature(float_list=tf.train.FloatList(value=[20., 110,
-3])),
"wire_cast": tf.train.Feature(bytes_list=tf.train.BytesList(value=[
b"stringer", b"marlo"
])),
}))
output = tf.contrib.layers.parse_feature_columns_from_examples(
serialized=[data.SerializeToString()],
feature_columns=[bucket, wire_cast])
self.assertIn(bucket, output)
self.assertIn(wire_cast, output)
with self.test_session():
tf.initialize_all_tables().run()
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
self.assertAllEqual(output[wire_cast].values.eval(), [2, 0])
评论列表
文章目录