build_training_data_tfrecord.py 文件源码

python
阅读 49 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def convert_episode_to_tf_records(base_directory, new_directory, dataloader, path):
    episode = frame.load_episode(path)
    features, labels = dataloader.load_features_and_labels_episode(episode)
    assert path.rfind(base_directory) > -1
    new_path = path[path.rfind(base_directory) + len(base_directory) + 1:]
    new_path = os.path.splitext(new_path)[0]
    new_path = os.path.splitext(new_path)[0]
    new_path = os.path.join(new_directory, new_path + ".tfrecord")
    options = tf.python_io.TFRecordOptions(
        compression_type=tf.python_io.TFRecordCompressionType.GZIP)
    os.makedirs(new_path, exist_ok=True)
    for i, f in enumerate(episode.frames):
        writer = tf.python_io.TFRecordWriter(
            os.path.join(new_path, "{}.tfrecord".format(i)), options=options)
        example = tf.train.Example(features=tf.train.Features(feature={
            'action': _int64_feature([f.action]),
            'label': _int64_feature([f.label] if f.label is not None else []),
            'observation': _float_feature(f.observation.reshape(-1)),
            'observation_shape': _int64_feature(f.observation.shape),
            'image': _bytes_feature([f.image.tobytes()]),
            'image_shape': _int64_feature(f.image.shape),
        }))
        writer.write(example.SerializeToString())
        writer.close()
    return new_path
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号