def prime_internal_model(self, model):
"""Prime an internal model such as the q_network based on priming mode.
Args:
model: The internal model that should be primed.
Returns:
The first observation to feed into the model.
"""
model.state_value = model.get_zero_state()
if self.priming_mode == 'random_midi':
priming_idx = np.random.randint(0, len(self.priming_states))
model.state_value = np.reshape(
self.priming_states[priming_idx, :],
(1, model.cell.state_size))
priming_note = self.priming_notes[priming_idx]
next_obs = np.array(
rl_tuner_ops.make_onehot([priming_note], self.num_actions)).flatten()
tf.logging.debug(
'Feeding priming state for midi file %s and corresponding note %s',
priming_idx, priming_note)
elif self.priming_mode == 'single_midi':
model.prime_model()
next_obs = model.priming_note
elif self.priming_mode == 'random_note':
next_obs = self.get_random_note()
else:
tf.logging.warn('Error! Invalid priming mode. Priming with random note')
next_obs = self.get_random_note()
return next_obs
评论列表
文章目录