def create_merge_multiple(save_path, creators, shuffle=True):
n_sample_total = 0
creator_indices = []
for i, creator in enumerate(creators):
creator._read_list()
n_sample_total += creator.n_samples
creator_indices.append(np.full((creator.n_samples), i, dtype=np.int))
creator_indices = np.concatenate(creator_indices)
if shuffle:
np.random.shuffle(creator_indices)
print('Start creating dataset with {} examples. Output path: {}'.format(
n_sample_total, save_path))
writer = tf.python_io.TFRecordWriter(save_path)
count = 0
for i in range(n_sample_total):
creator = creators[creator_indices[i]]
example = creator._create_next_sample()
if example is not None:
writer.write(example.SerializeToString())
count += 1
if i > 0 and i % 100 == 0:
print('Progress %d / %d' % (i, n_sample_total))
print('Done creating %d samples' % count)
评论列表
文章目录