def get_cell(self, prev_keyboard, prev_state_enco):
""" a RNN decoder
See parent class for arguments details
"""
axis = 1 # The first dimension is the batch, we split the keys
assert prev_keyboard.get_shape()[axis].value == music.NB_NOTES
inputs = tf.split(axis, music.NB_NOTES, prev_keyboard)
outputs, final_state = tf.nn.seq2seq.rnn_decoder(
decoder_inputs=inputs,
initial_state=prev_state_enco,
cell=self.rnn_cell
# TODO: Which loop function (should use prediction) ? : Should take the previous generated input/ground truth (as the global model loop_fct). Need to add a new bool placeholder
)
# Is it better to do the projection before or after the packing ?
next_keys = []
for output in outputs:
next_keys.append(self.project_key(output))
next_keyboard = tf.concat(axis, next_keys)
return next_keyboard, final_state
评论列表
文章目录