model_base.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:nmt_v2 作者: rpryzant 项目源码 文件源码
def create_or_load_model(model, model_dir, session, name):
    latest_ckpt = tf.train.latest_checkpoint(model_dir)

    if latest_ckpt:
        model = load_model(model, latest_ckpt, session, name)
    else:
        start_time = time.time()
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        print "  created %s model with fresh parameters, time %.2fs" % \
                        (name, time.time() - start_time)

    global_step = model.global_step.eval(session=session)
    return model, global_step
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号