graph_handler.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:bi-att-flow 作者: allenai 项目源码 文件源码
def _load(self, sess):
        config = self.config
        vars_ = {var.name.split(":")[0]: var for var in tf.all_variables()}
        if config.load_ema:
            ema = self.model.var_ema
            for var in tf.trainable_variables():
                del vars_[var.name.split(":")[0]]
                vars_[ema.average_name(var)] = var
        saver = tf.train.Saver(vars_, max_to_keep=config.max_to_keep)

        if config.load_path:
            save_path = config.load_path
        elif config.load_step > 0:
            save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step))
        else:
            save_dir = config.save_dir
            checkpoint = tf.train.get_checkpoint_state(save_dir)
            assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir)
            save_path = checkpoint.model_checkpoint_path
        print("Loading saved model from {}".format(save_path))
        saver.restore(sess, save_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号