rl_tuner.py 文件源码

python
阅读 56 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def prime_internal_model(self, model):
    """Prime an internal model such as the q_network based on priming mode.

    Args:
      model: The internal model that should be primed.

    Returns:
      The first observation to feed into the model.
    """
    model.state_value = model.get_zero_state()

    if self.priming_mode == 'random_midi':
      priming_idx = np.random.randint(0, len(self.priming_states))
      model.state_value = np.reshape(
          self.priming_states[priming_idx, :],
          (1, model.cell.state_size))
      priming_note = self.priming_notes[priming_idx]
      next_obs = np.array(
          rl_tuner_ops.make_onehot([priming_note], self.num_actions)).flatten()
      tf.logging.debug(
          'Feeding priming state for midi file %s and corresponding note %s',
          priming_idx, priming_note)
    elif self.priming_mode == 'single_midi':
      model.prime_model()
      next_obs = model.priming_note
    elif self.priming_mode == 'random_note':
      next_obs = self.get_random_note()
    else:
      tf.logging.warn('Error! Invalid priming mode. Priming with random note')
      next_obs = self.get_random_note()

    return next_obs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号