def create_model(session, restore_only=False):
# with bidirectional encoder, decoder state size should be
# 2x encoder state size
is_training = tf.placeholder(dtype=tf.bool, name='is_training')
encoder_cell = LSTMCell(64)
encoder_cell = MultiRNNCell([encoder_cell]*5)
decoder_cell = LSTMCell(128)
decoder_cell = MultiRNNCell([decoder_cell]*5)
model = Seq2SeqModel(encoder_cell=encoder_cell,
decoder_cell=decoder_cell,
vocab_size=wiki.vocab_size,
embedding_size=300,
attention=True,
bidirectional=True,
is_training=is_training,
device=args.device,
debug=False)
saver = tf.train.Saver(tf.global_variables(), keep_checkpoint_every_n_hours=1)
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
if checkpoint:
print("Reading model parameters from %s" % checkpoint.model_checkpoint_path)
saver.restore(session, checkpoint.model_checkpoint_path)
elif restore_only:
raise FileNotFoundError("Cannot restore model")
else:
print("Created model with fresh parameters")
session.run(tf.global_variables_initializer())
tf.get_default_graph().finalize()
return model, saver
评论列表
文章目录