def convert_episode_to_tf_records(base_directory, new_directory, dataloader, path):
episode = frame.load_episode(path)
features, labels = dataloader.load_features_and_labels_episode(episode)
assert path.rfind(base_directory) > -1
new_path = path[path.rfind(base_directory) + len(base_directory) + 1:]
new_path = os.path.splitext(new_path)[0]
new_path = os.path.splitext(new_path)[0]
new_path = os.path.join(new_directory, new_path + ".tfrecord")
options = tf.python_io.TFRecordOptions(
compression_type=tf.python_io.TFRecordCompressionType.GZIP)
os.makedirs(new_path, exist_ok=True)
for i, f in enumerate(episode.frames):
writer = tf.python_io.TFRecordWriter(
os.path.join(new_path, "{}.tfrecord".format(i)), options=options)
example = tf.train.Example(features=tf.train.Features(feature={
'action': _int64_feature([f.action]),
'label': _int64_feature([f.label] if f.label is not None else []),
'observation': _float_feature(f.observation.reshape(-1)),
'observation_shape': _int64_feature(f.observation.shape),
'image': _bytes_feature([f.image.tobytes()]),
'image_shape': _int64_feature(f.image.shape),
}))
writer.write(example.SerializeToString())
writer.close()
return new_path
评论列表
文章目录