def get_priming_melodies(self):
"""Runs a batch of training data through MelodyRNN model.
If the priming mode is 'random_midi', priming the q-network requires a
random training melody. Therefore this function runs a batch of data from
the training directory through the internal model, and the resulting
internal states of the LSTM are stored in a list. The next note in each
training melody is also stored in a corresponding list called
'priming_notes'. Therefore, to prime the model with a random melody, it is
only necessary to select a random index from 0 to batch_size-1 and use the
hidden states and note at that index as input to the model.
"""
(next_note_softmax,
self.priming_states, lengths) = self.q_network.run_training_batch()
# Get the next note that was predicted for each priming melody to be used
# in priming.
self.priming_notes = [0] * len(lengths)
for i in range(len(lengths)):
# Each melody has TRAIN_SEQUENCE_LENGTH outputs, but the last note is
# actually stored at lengths[i]. The rest is padding.
start_i = i * TRAIN_SEQUENCE_LENGTH
end_i = start_i + lengths[i] - 1
end_softmax = next_note_softmax[end_i, :]
self.priming_notes[i] = np.argmax(end_softmax)
tf.logging.info('Stored priming notes: %s', self.priming_notes)
评论列表
文章目录