def get_last_checkpoint(model_dir):
list_of_models = glob.glob1(model_dir, '*.hdf5')
ckpt_epochs = [int(x.split('-')[-2]) for x in list_of_models]
print ckpt_epochs
latest_model_name = list_of_models[np.argsort(ckpt_epochs)[-1]]
epoch_num = int(latest_model_name.split('-')[-2])
print 'last snapshot:', latest_model_name
print 'last epoch=', epoch_num
return os.path.join(model_dir, latest_model_name), epoch_num
评论列表
文章目录