def restore_model(model, sess, log_path):
"""
Restore model (including hidden variable)
In practice use to resume the training of the same model
Args
model : model to restore variable to
sess : tensorflow session
log_path : where to save
Returns:
step_b : the step number at which training ended
"""
path = log_path + '/' + model.name
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(path)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
return ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
else:
print('------------------------------------------------------')
print('No checkpoint file found')
print('------------------------------------------------------ \n')
exit()
operations.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录