train_utils.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:FastFPN 作者: wuzheng-sjtu 项目源码 文件源码
def get_var_list_to_restore():
  """Choose which vars to restore, ignore vars by setting --checkpoint_exclude_scopes """

  variables_to_restore = []
  if FLAGS.checkpoint_exclude_scopes is not None:
    exclusions = [scope.strip()
                  for scope in FLAGS.checkpoint_exclude_scopes.split(',')]

    # build restore list
    for var in tf.model_variables():
      for exclusion in exclusions:
        if var.name.startswith(exclusion):
          break
      else:
        variables_to_restore.append(var)
  else:
    variables_to_restore = tf.model_variables()

  variables_to_restore_final = []
  if FLAGS.checkpoint_include_scopes is not None:
      includes = [
              scope.strip()
              for scope in FLAGS.checkpoint_include_scopes.split(',')
              ]
      for var in variables_to_restore:
          for include in includes:
              if var.name.startswith(include):
                  variables_to_restore_final.append(var)
                  break
  else:
      variables_to_restore_final = variables_to_restore

  return variables_to_restore_final
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号