def save_weight_only(model, sess, log_path, step):
"""
Save model but only weight (meaning no hidden variable)
In practice use this to just transfer weights from one model to the other
Args:
model : model to save variable from
sess : tensorflow session
log_path : where to save
step : number of step at time of saving
"""
path = log_path + '/' + model.name + '_weight_only'
if tf.gfile.Exists(path):
tf.gfile.DeleteRecursively(path)
tf.gfile.MakeDirs(path)
variable_to_save = {}
for i in range(30):
name = 'conv_' + str(i)
variable_to_save[name] = model.parameters_conv[i]
if i in [2, 4] and model.concat:
name = 'deconv_' + str(i)
variable_to_save[name] = model.parameters_deconv[i][0]
name = 'deconv_' + str(i) + '_bis'
variable_to_save[name] = model.parameters_deconv[i][1]
else:
name = 'deconv_' + str(i)
variable_to_save[name] = model.parameters_deconv[i]
if i < 2:
name = 'deconv_bis_' + str(i)
variable_to_save[name] = model.deconv[i]
saver = tf.train.Saver(variable_to_save)
checkpoint_path = os.path.join(path, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
operations.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录