translate.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DeepLearning 作者: Wanwannodao 项目源码 文件源码
def create_model(sess, forward_only):
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    model = seq2seq_model.Seq2SeqModel(
        FLAGS.form_vocab_size,
        FLAGS.to_vocab_size,
        _buckets,
        FLAGS.size,
        FLAGS.num_layers,
        FALGS.max_gradinet_norm,
        FLAGS.batch_size,
        FALGS.learning_rate,
        FALGS.learning_rate_decay_factor,
        forward_only=forward_only,
        dtype=dtype)
    ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
    if ckpt and tf.train.checkpoint_exits(ckpt.model_checkpoint_path):
        print("Reading model params from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh params")
        sess.run(tf.global_variables_initializer())
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号