def restore_weight_from(model, name, sess, log_path, copy_concat = False):
"""
Restore model (excluding hidden variable)
In practice use to train a model with the weight from another model.
As long as both model have architecture from the original model.py, then it works
Compatible w or w/o direct connections
Args
model : model to restore variable to
name : name of model to copy
sess : tensorflow session
log_path : where to restore
copy_concat : specify if the model to copy from also had direct connections
Returns:
step_b : the step number at which training ended
"""
path = log_path + '/' + name + '_weight_only'
variable_to_save = {}
for i in range(30):
name = 'conv_' + str(i)
variable_to_save[name] = model.parameters_conv[i]
if i < 2:
if copy_concat == model.concat:
name = 'deconv_' + str(i)
variable_to_save[name] = model.parameters_deconv[i]
name = 'deconv_bis_' + str(i)
variable_to_save[name] = model.deconv[i]
else:
if i in [2, 4] and model.concat:
name = 'deconv_' + str(i)
variable_to_save[name] = model.parameters_deconv[i][0]
if copy_concat:
name = 'deconv_' + str(i) + '_bis'
variable_to_save[name] = model.parameters_deconv[i][1]
elif i in [2, 4] and not model.concat:
name = 'deconv_' + str(i)
variable_to_save[name] = model.parameters_deconv[i]
else:
name = 'deconv_' + str(i)
variable_to_save[name] = model.parameters_deconv[i]
saver = tf.train.Saver(variable_to_save)
ckpt = tf.train.get_checkpoint_state(path)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
return ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
else:
print('------------------------------------------------------')
print('No checkpoint file found')
print('------------------------------------------------------ \n')
exit()
operations.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录