model.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:chainer-speech-recognition 作者: musyoku 项目源码 文件源码
def __call__(self, x, split_into_variables=True, discard_context=False):
        batchsize = x.shape[0]
        seq_length = x.shape[3]

        # conv
        out_data = self.conv_blocks(x)
        out_data = functions.reshape(out_data, (batchsize, -1, seq_length))

        # rnn
        for index, blocks in enumerate(self.rnn_blocks.blocks):
            sru = blocks[0]
            dropout = blocks[1] if len(blocks) == 2 else None
            hidden, cell, context = sru(out_data, self.contexts[index])
            if discard_context is False:
                self.contexts[index] = context
            if dropout is not None:
                out_data = dropout(out_data)

        # fc
        out_data = self.dense_blocks(out_data)
        assert out_data.shape[2] == seq_length

        # CTC???????RNN???????Variable????????
        if split_into_variables:
            out_data = F.swapaxes(out_data, 1, 2)
            out_data = F.reshape(out_data, (batchsize, -1))
            out_data = F.split_axis(out_data, seq_length, axis=1)
        else:
            out_data = F.swapaxes(out_data, 1, 2)
            out_data = F.squeeze(out_data, axis=2)

        return out_data
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号