rl_tuner.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def restore_from_directory(self, directory=None, checkpoint_name=None,
                             reward_file_name=None):
    """Restores this model from a saved checkpoint.

    Args:
      directory: Path to directory where checkpoint is located. If
        None, defaults to self.output_dir.
      checkpoint_name: The name of the checkpoint within the
        directory.
      reward_file_name: The name of the .npz file where the stored
        rewards are saved. If None, will not attempt to load stored
        rewards.
    """
    if directory is None:
      directory = self.output_dir

    if checkpoint_name is not None:
      checkpoint_file = os.path.join(directory, checkpoint_name)
    else:
      tf.logging.info('Directory %s.', directory)
      checkpoint_file = tf.train.latest_checkpoint(directory)

    if checkpoint_file is None:
      tf.logging.fatal('Error! Cannot locate checkpoint in the directory')
      return
    # TODO(natashamjaques): Remove print statement once tf.logging outputs
    # to Jupyter notebooks (once the following issue is resolved:
    # https://github.com/tensorflow/tensorflow/issues/3047)
    print('Attempting to restore from checkpoint', checkpoint_file)
    tf.logging.info('Attempting to restore from checkpoint %s', checkpoint_file)

    self.saver.restore(self.session, checkpoint_file)

    if reward_file_name is not None:
      npz_file_name = os.path.join(directory, reward_file_name)
      # TODO(natashamjaques): Remove print statement once tf.logging outputs
      # to Jupyter notebooks (once the following issue is resolved:
      # https://github.com/tensorflow/tensorflow/issues/3047)
      print('Attempting to load saved reward values from file', npz_file_name)
      tf.logging.info('Attempting to load saved reward values from file %s',
                      npz_file_name)
      npz_file = np.load(npz_file_name)

      self.rewards_batched = npz_file['train_rewards']
      self.music_theory_rewards_batched = npz_file['train_music_theory_rewards']
      self.note_rnn_rewards_batched = npz_file['train_note_rnn_rewards']
      self.eval_avg_reward = npz_file['eval_rewards']
      self.eval_avg_music_theory_reward = npz_file['eval_music_theory_rewards']
      self.eval_avg_note_rnn_reward = npz_file['eval_note_rnn_rewards']
      self.target_val_list = npz_file['target_val_list']
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号