def generate_files(self, generator, output_filenames, max_cases=None):
"""Generate cases from a generator and save as TFRecord files.
Generated cases are transformed to tf.Example protos and saved as TFRecords
in sharded files named output_dir/output_name-00..N-of-00..M=num_shards.
Args:
generator: a generator yielding (string -> int/float/str list) dictionaries.
output_filenames: List of output file paths.
max_cases: maximum number of cases to get from the generator;
if None (default), we use the generator until StopIteration is raised.
"""
num_shards = len(output_filenames)
writers = [tf.python_io.TFRecordWriter(
fname) for fname in output_filenames]
counter, shard = 0, 0
for case in generator:
if counter > 0 and counter % 100000 == 0:
tf.logging.info("Generating case %d." % counter)
counter += 1
if max_cases and counter > max_cases:
break
sequence_example = self.to_example(case)
writers[shard].write(sequence_example.SerializeToString())
shard = (shard + 1) % num_shards
for writer in writers:
writer.close()
评论列表
文章目录