def get_data_paths(paths, file_pattern=DEFAULT_TFRECORDS_GLOB_PATTERN):
if not isinstance(paths, list):
assert isstring(paths)
paths = [paths]
if not isinstance(file_pattern, list):
assert isstring(file_pattern)
file_patterns = [file_pattern] * len(paths)
else:
file_patterns = file_pattern
assert len(file_patterns) == len(paths), (file_patterns, paths)
datasources = []
for path, file_pattern in zip(paths, file_patterns):
if os.path.isdir(path):
tfrecord_pattern = os.path.join(path, file_pattern)
datasource = tf.gfile.Glob(tfrecord_pattern)
datasource.sort()
datasources.append(datasource)
else:
datasources.append([path])
dl = map(len, datasources)
assert all([dl[0] == d for d in dl[1:]]), dl
return datasources
评论列表
文章目录