def test_with_extra_args(self):
pipeline_def = yaml.load("""
class: ParallelTextInputPipeline
params:
source_files: ["file1"]
target_files: ["file2"]
num_epochs: 1
shuffle: True
""")
pipeline = input_pipeline.make_input_pipeline_from_def(
def_dict=pipeline_def,
mode=tf.contrib.learn.ModeKeys.TRAIN,
num_epochs=5,
shuffle=False)
self.assertIsInstance(pipeline, input_pipeline.ParallelTextInputPipeline)
#pylint: disable=W0212
self.assertEqual(pipeline.params["source_files"], ["file1"])
self.assertEqual(pipeline.params["target_files"], ["file2"])
self.assertEqual(pipeline.params["num_epochs"], 5)
self.assertEqual(pipeline.params["shuffle"], False)
评论列表
文章目录