operations.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Saliency_Detection_Convolutional_Autoencoder 作者: arthurmeyer 项目源码 文件源码
def restore_weight_from(model, name, sess, log_path, copy_concat = False):
  """
  Restore model (excluding hidden variable)
  In practice use to train a model with the weight from another model. 
  As long as both model have architecture from the original model.py, then it works 
  Compatible w or w/o direct connections

  Args
    model            :         model to restore variable to
    name             :         name of model to copy
    sess             :         tensorflow session
    log_path         :         where to restore
    copy_concat      :         specify if the model to copy from also had direct connections

  Returns:
    step_b           :         the step number at which training ended
  """

  path = log_path + '/' + name + '_weight_only'

  variable_to_save = {}
  for i in range(30):
    name = 'conv_' + str(i)
    variable_to_save[name] = model.parameters_conv[i]
    if i < 2:
      if copy_concat == model.concat:
        name = 'deconv_' + str(i)
        variable_to_save[name] = model.parameters_deconv[i]
        name = 'deconv_bis_' + str(i)
        variable_to_save[name] = model.deconv[i]
    else:
      if i in [2, 4] and model.concat:
        name = 'deconv_' + str(i)
        variable_to_save[name] = model.parameters_deconv[i][0]
        if copy_concat:
          name = 'deconv_' + str(i) + '_bis'
          variable_to_save[name] = model.parameters_deconv[i][1]
      elif i in [2, 4] and not model.concat:
        name = 'deconv_' + str(i)
        variable_to_save[name] = model.parameters_deconv[i]
      else:
        name = 'deconv_' + str(i)
        variable_to_save[name] = model.parameters_deconv[i]

  saver = tf.train.Saver(variable_to_save)
  ckpt = tf.train.get_checkpoint_state(path)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
    return ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  else:
    print('------------------------------------------------------')
    print('No checkpoint file found')
    print('------------------------------------------------------ \n')
    exit()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号