input_pipeline_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def test_without_extra_args(self):
    pipeline_def = yaml.load("""
      class: ParallelTextInputPipeline
      params:
        source_files: ["file1"]
        target_files: ["file2"]
        num_epochs: 1
        shuffle: True
    """)
    pipeline = input_pipeline.make_input_pipeline_from_def(
        pipeline_def, tf.contrib.learn.ModeKeys.TRAIN)
    self.assertIsInstance(pipeline, input_pipeline.ParallelTextInputPipeline)
    #pylint: disable=W0212
    self.assertEqual(pipeline.params["source_files"], ["file1"])
    self.assertEqual(pipeline.params["target_files"], ["file2"])
    self.assertEqual(pipeline.params["num_epochs"], 1)
    self.assertEqual(pipeline.params["shuffle"], True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号