generator_utils.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def generate_files(generator, output_filenames, max_cases=None):
  """Generate cases from a generator and save as TFRecord files.

  Generated cases are transformed to tf.Example protos and saved as TFRecords
  in sharded files named output_dir/output_name-00..N-of-00..M=num_shards.

  Args:
    generator: a generator yielding (string -> int/float/str list) dictionaries.
    output_filenames: List of output file paths.
    max_cases: maximum number of cases to get from the generator;
      if None (default), we use the generator until StopIteration is raised.
  """
  if outputs_exist(output_filenames):
    tf.logging.info("Skipping generator because outputs files exist")
    return
  num_shards = len(output_filenames)
  writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames]
  counter, shard = 0, 0
  for case in generator:
    if counter > 0 and counter % 100000 == 0:
      tf.logging.info("Generating case %d." % counter)
    counter += 1
    if max_cases and counter > max_cases:
      break
    sequence_example = to_example(case)
    writers[shard].write(sequence_example.SerializeToString())
    shard = (shard + 1) % num_shards

  for writer in writers:
    writer.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号