decoder.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:MusicGenerator 作者: Conchylicultor 项目源码 文件源码
def get_cell(self, prev_keyboard, prev_state_enco):
        """ a RNN decoder
        See parent class for arguments details
        """

        axis = 1  # The first dimension is the batch, we split the keys
        assert prev_keyboard.get_shape()[axis].value == music.NB_NOTES
        inputs = tf.split(axis, music.NB_NOTES, prev_keyboard)

        outputs, final_state = tf.nn.seq2seq.rnn_decoder(
            decoder_inputs=inputs,
            initial_state=prev_state_enco,
            cell=self.rnn_cell
            # TODO: Which loop function (should use prediction) ? : Should take the previous generated input/ground truth (as the global model loop_fct). Need to add a new bool placeholder
        )

        # Is it better to do the projection before or after the packing ?
        next_keys = []
        for output in outputs:
            next_keys.append(self.project_key(output))

        next_keyboard = tf.concat(axis, next_keys)

        return next_keyboard, final_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号