def restore_from_directory(self, directory=None, checkpoint_name=None,
reward_file_name=None):
"""Restores this model from a saved checkpoint.
Args:
directory: Path to directory where checkpoint is located. If
None, defaults to self.output_dir.
checkpoint_name: The name of the checkpoint within the
directory.
reward_file_name: The name of the .npz file where the stored
rewards are saved. If None, will not attempt to load stored
rewards.
"""
if directory is None:
directory = self.output_dir
if checkpoint_name is not None:
checkpoint_file = os.path.join(directory, checkpoint_name)
else:
tf.logging.info('Directory %s.', directory)
checkpoint_file = tf.train.latest_checkpoint(directory)
if checkpoint_file is None:
tf.logging.fatal('Error! Cannot locate checkpoint in the directory')
return
# TODO(natashamjaques): Remove print statement once tf.logging outputs
# to Jupyter notebooks (once the following issue is resolved:
# https://github.com/tensorflow/tensorflow/issues/3047)
print('Attempting to restore from checkpoint', checkpoint_file)
tf.logging.info('Attempting to restore from checkpoint %s', checkpoint_file)
self.saver.restore(self.session, checkpoint_file)
if reward_file_name is not None:
npz_file_name = os.path.join(directory, reward_file_name)
# TODO(natashamjaques): Remove print statement once tf.logging outputs
# to Jupyter notebooks (once the following issue is resolved:
# https://github.com/tensorflow/tensorflow/issues/3047)
print('Attempting to load saved reward values from file', npz_file_name)
tf.logging.info('Attempting to load saved reward values from file %s',
npz_file_name)
npz_file = np.load(npz_file_name)
self.rewards_batched = npz_file['train_rewards']
self.music_theory_rewards_batched = npz_file['train_music_theory_rewards']
self.note_rnn_rewards_batched = npz_file['train_note_rnn_rewards']
self.eval_avg_reward = npz_file['eval_rewards']
self.eval_avg_music_theory_reward = npz_file['eval_music_theory_rewards']
self.eval_avg_note_rnn_reward = npz_file['eval_note_rnn_rewards']
self.target_val_list = npz_file['target_val_list']
评论列表
文章目录