data.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:tfutils 作者: neuroailab 项目源码 文件源码
def get_data_paths(paths, file_pattern=DEFAULT_TFRECORDS_GLOB_PATTERN):
    if not isinstance(paths, list):
        assert isstring(paths)
        paths = [paths]
    if not isinstance(file_pattern, list):
        assert isstring(file_pattern)
        file_patterns = [file_pattern] * len(paths)
    else:
        file_patterns = file_pattern
    assert len(file_patterns) == len(paths), (file_patterns, paths)
    datasources = []
    for path, file_pattern in zip(paths, file_patterns):
        if os.path.isdir(path):
            tfrecord_pattern = os.path.join(path, file_pattern)
            datasource = tf.gfile.Glob(tfrecord_pattern)
            datasource.sort()
            datasources.append(datasource)
        else:
            datasources.append([path])
    dl = map(len, datasources)
    assert all([dl[0] == d for d in dl[1:]]), dl
    return datasources
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号