operations.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:Saliency_Detection_Convolutional_Autoencoder 作者: arthurmeyer 项目源码 文件源码
def save_weight_only(model, sess, log_path, step):
  """
  Save model but only weight (meaning no hidden variable)
  In practice use this to just transfer weights from one model to the other

  Args:
    model            :         model to save variable from
    sess             :         tensorflow session
    log_path         :         where to save
    step             :         number of step at time of saving
  """

  path = log_path + '/' + model.name + '_weight_only'
  if tf.gfile.Exists(path):
    tf.gfile.DeleteRecursively(path)
  tf.gfile.MakeDirs(path)

  variable_to_save = {}
  for i in range(30):
    name = 'conv_' + str(i)
    variable_to_save[name] = model.parameters_conv[i]
    if i in [2, 4] and model.concat:
      name = 'deconv_' + str(i)
      variable_to_save[name] = model.parameters_deconv[i][0]
      name = 'deconv_' + str(i) + '_bis'
      variable_to_save[name] = model.parameters_deconv[i][1]
    else:
      name = 'deconv_' + str(i)
      variable_to_save[name] = model.parameters_deconv[i]
    if i < 2:
      name = 'deconv_bis_' + str(i)
      variable_to_save[name] = model.deconv[i]
  saver = tf.train.Saver(variable_to_save)
  checkpoint_path = os.path.join(path, 'model.ckpt')
  saver.save(sess, checkpoint_path, global_step=step)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号