def _createTestInferModel(
self, m_creator, hparams, sess, init_global_vars=False):
infer_mode = tf.contrib.learn.ModeKeys.INFER
infer_iterator, src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table = (
common_test_utils.create_test_iterator(hparams, infer_mode))
infer_m = m_creator(
hparams,
infer_mode,
infer_iterator,
src_vocab_table,
tgt_vocab_table,
reverse_tgt_vocab_table,
scope='dynamic_seq2seq')
if init_global_vars:
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
sess.run(infer_iterator.initializer)
return infer_m
评论列表
文章目录