def to_tfrecord(data, file_dir):
for key, values in data.iteritems():
writer = tf.python_io.TFRecordWriter(os.path.join(file_dir, key + '.tfrecord'))
image = values['image']
ground_truth = values['ground_truth']
shape = np.array(image.shape).astype(np.int32)
# set precision of string printing to be float32
np.set_printoptions(precision=32)
example = tf.train.Example(features=tf.train.Features(feature={
'example_name': _bytes_feature(key),
'shape': _bytes_feature(shape.tostring()),
'img_raw': _bytes_feature(image.tostring()),
'gt_raw': _bytes_feature(ground_truth.tostring())}))
writer.write(example.SerializeToString())
writer.close()
评论列表
文章目录