def _test_pipeline(self, mode, params=None):
"""Helper function to test the full model pipeline.
"""
# Create source and target example
source_len = self.sequence_length + 5
target_len = self.sequence_length + 10
source = " ".join(np.random.choice(self.vocab_list, source_len))
target = " ".join(np.random.choice(self.vocab_list, target_len))
sources_file, targets_file = test_utils.create_temp_parallel_data(
sources=[source], targets=[target])
# Build model graph
model = self.create_model(mode, params)
input_pipeline_ = input_pipeline.ParallelTextInputPipeline(
params={
"source_files": [sources_file.name],
"target_files": [targets_file.name]
},
mode=mode)
input_fn = training_utils.create_input_fn(
pipeline=input_pipeline_, batch_size=self.batch_size)
features, labels = input_fn()
fetches = model(features, labels, None)
fetches = [_ for _ in fetches if _ is not None]
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
fetches_ = sess.run(fetches)
sources_file.close()
targets_file.close()
return model, fetches_
评论列表
文章目录