def reconstruct_batch(self, output, batch_id, chosen_labels=None):
""" Create the song associated with the network output
Args:
output (list[np.Array]): The ouput of the network (size batch_size*output_dim)
batch_id (int): The batch id
chosen_labels (list[np.Array[batch_size, int]]): the sampled class at each timestep (useful to reconstruct the generated song)
Return:
Song: The reconstructed song
"""
assert Relative.HAS_EMPTY == True
processed_song = Relative.RelativeSong()
processed_song.first_note = music.Note()
processed_song.first_note.note = 56 # TODO: Define what should be the first note
print('Reconstruct')
for i, note in enumerate(output):
relative = Relative.RelativeNote()
# Here if we did sample the output, we should get which has heen the selected output
if not chosen_labels or i == len(chosen_labels): # If chosen_labels, the last generated note has not been sampled
chosen_label = int(np.argmax(note[batch_id,:])) # Cast np.int64 to int to avoid compatibility with mido
else:
chosen_label = int(chosen_labels[i][batch_id])
print(chosen_label, end=' ') # TODO: Add a text output connector
if chosen_label == 0: # <next> token
relative.pitch_class = None
#relative.scale = # Note used
#relative.prev_tick =
else:
relative.pitch_class = chosen_label-1
#relative.scale =
#relative.prev_tick =
processed_song.notes.append(relative)
print()
return self.reconstruct_song(processed_song)
评论列表
文章目录