def get_cell(self, prev_keyboard, prev_state):
""" a RNN encoder
See parent class for arguments details
"""
prev_state_enco, prev_state_deco = prev_state
axis = 1 # The first dimension is the batch, we split the keys
assert prev_keyboard.get_shape()[axis].value == music.NB_NOTES
inputs = tf.split(axis, music.NB_NOTES, prev_keyboard)
_, final_state = tf.nn.rnn(
self.rnn_cell,
inputs,
initial_state=prev_state_deco
)
return final_state
评论列表
文章目录