def test_example_proto_coder_default_value(self):
input_schema = dataset_schema.from_feature_spec({
'scalar_feature_3':
tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=1.0),
'scalar_feature_4':
tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=0.0),
'1d_vector_feature':
tf.FixedLenFeature(
shape=[1], dtype=tf.float32, default_value=[2.0]),
'2d_vector_feature':
tf.FixedLenFeature(
shape=[2, 2],
dtype=tf.float32,
default_value=[[1.0, 2.0], [3.0, 4.0]]),
})
coder = example_proto_coder.ExampleProtoCoder(input_schema)
# Python types.
example_proto_text = """
features {
}
"""
example = tf.train.Example()
text_format.Merge(example_proto_text, example)
data = example.SerializeToString()
# Assert the data is decoded into the expected format.
expected_decoded = {
'scalar_feature_3': 1.0,
'scalar_feature_4': 0.0,
'1d_vector_feature': [2.0],
'2d_vector_feature': [[1.0, 2.0], [3.0, 4.0]],
}
decoded = coder.decode(data)
np.testing.assert_equal(expected_decoded, decoded)
评论列表
文章目录