def testInferFeatureSchema(self):
d = tf.placeholder(tf.int64, None)
tensors = {
'a': tf.placeholder(tf.float32, (None,)),
'b': tf.placeholder(tf.string, (1, 2, 3)),
'c': tf.placeholder(tf.int64, None),
'd': d
}
d_column_schema = sch.ColumnSchema(tf.int64, [1, 2, 3],
sch.FixedColumnRepresentation())
api.set_column_schema(d, d_column_schema)
schema = impl_helper.infer_feature_schema(tf.get_default_graph(), tensors)
expected_schema = sch.Schema(column_schemas={
'a': sch.ColumnSchema(tf.float32, [],
sch.FixedColumnRepresentation()),
'b': sch.ColumnSchema(tf.string, [2, 3],
sch.FixedColumnRepresentation()),
'c': sch.ColumnSchema(tf.int64, None,
sch.FixedColumnRepresentation()),
'd': sch.ColumnSchema(tf.int64, [1, 2, 3],
sch.FixedColumnRepresentation())
})
self.assertEqual(schema, expected_schema)
评论列表
文章目录