def _createTestTrainModel(self, m_creator, hparams, sess):
train_mode = tf.contrib.learn.ModeKeys.TRAIN
train_iterator, src_vocab_table, tgt_vocab_table = common_test_utils.create_test_iterator(
hparams, train_mode)
train_m = m_creator(
hparams,
train_mode,
train_iterator,
src_vocab_table,
tgt_vocab_table,
scope='dynamic_seq2seq')
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
sess.run(train_iterator.initializer)
return train_m
评论列表
文章目录