texttfrecords.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def generate_files(self, generator, output_filenames, max_cases=None):
        """Generate cases from a generator and save as TFRecord files.

        Generated cases are transformed to tf.Example protos and saved as TFRecords
        in sharded files named output_dir/output_name-00..N-of-00..M=num_shards.

        Args:
          generator: a generator yielding (string -> int/float/str list) dictionaries.
          output_filenames: List of output file paths.
          max_cases: maximum number of cases to get from the generator;
            if None (default), we use the generator until StopIteration is raised.
        """
        num_shards = len(output_filenames)
        writers = [tf.python_io.TFRecordWriter(
            fname) for fname in output_filenames]
        counter, shard = 0, 0
        for case in generator:
            if counter > 0 and counter % 100000 == 0:
                tf.logging.info("Generating case %d." % counter)
            counter += 1
            if max_cases and counter > max_cases:
                break
            sequence_example = self.to_example(case)
            writers[shard].write(sequence_example.SerializeToString())
            shard = (shard + 1) % num_shards

        for writer in writers:
            writer.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号