hooks.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def begin(self):
    variables = tf.contrib.framework.get_variables(scope=self.params["prefix"])

    def varname_in_checkpoint(name):
      """Removes the prefix from the variable name.
      """
      prefix_parts = self.params["prefix"].split("/")
      checkpoint_prefix = "/".join(prefix_parts[:-1])
      return name.replace(checkpoint_prefix + "/", "")

    target_names = [varname_in_checkpoint(_.op.name) for _ in variables]
    restore_map = {k: v for k, v in zip(target_names, variables)}

    tf.logging.info("Restoring variables: \n%s",
                    yaml.dump({k: v.op.name
                               for k, v in restore_map.items()}))

    self._saver = tf.train.Saver(restore_map)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号