def create_model(sess, forward_only):
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
model = seq2seq_model.Seq2SeqModel(
FLAGS.form_vocab_size,
FLAGS.to_vocab_size,
_buckets,
FLAGS.size,
FLAGS.num_layers,
FALGS.max_gradinet_norm,
FLAGS.batch_size,
FALGS.learning_rate,
FALGS.learning_rate_decay_factor,
forward_only=forward_only,
dtype=dtype)
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and tf.train.checkpoint_exits(ckpt.model_checkpoint_path):
print("Reading model params from %s" % ckpt.model_checkpoint_path)
model.saver.restore(sess, ckpt.model_checkpoint_path)
else:
print("Created model with fresh params")
sess.run(tf.global_variables_initializer())
return model
评论列表
文章目录