model_dir.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def restore_checkpoint(self, sess, var_list=None, load_ema=True):
        """
        Restores either the best weights or the most recent checkpoint, assuming the correct
        variables have already been added to the tf default graph e.g., .get_prediction()
        has been called the model stored in `self`.
        Automatically detects if EMA weights exists, and if they do loads them instead
        """
        checkpoint = self.get_best_weights()
        if checkpoint is None:
            print("Loading most recent checkpoint")
            checkpoint = self.get_latest_checkpoint()
        else:
            print("Loading best weights")

        if load_ema:
            if var_list is None:
                # Same default used by `Saver`
                var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + \
                           tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS)

            # Automatically check if there are EMA variables, if so use those
            reader = tf.train.NewCheckpointReader(checkpoint)
            ema = tf.train.ExponentialMovingAverage(0)
            ema_names = {ema.average_name(x): x for x in var_list
                         if reader.has_tensor(ema.average_name(x))}
            if len(ema_names) > 0:
                print("Found EMA weights, loading them")
                ema_vars = set(x for x in ema_names.values())
                var_list = {v.op.name: v for v in var_list if v not in ema_vars}
                var_list.update(ema_names)

        saver = tf.train.Saver(var_list)
        saver.restore(sess, checkpoint)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号