qrnn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:chainer-qrnn 作者: musyoku 项目源码 文件源码
def __call__(self, X, ht_enc):
        pad = self._kernel_size - 1
        WX = self.W(X)
        if pad > 0:
            WX = WX[..., :-pad]
        Vh = self.V(ht_enc)

        # copy Vh
        # e.g.
        # WX = [[[  0   1   2]
        #        [  3   4   5]
        #        [  6   7   8]
        # Vh = [[11, 12, 13]]
        # 
        # Vh, WX = F.broadcast(F.expand_dims(Vh, axis=2), WX)
        # 
        # WX = [[[  0   1   2]
        #        [  3   4   5]
        #        [  6   7   8]
        # Vh = [[[  11  11  11]
        #        [  12  12  12]
        #        [  13  13  13]
        Vh, WX = functions.broadcast(functions.expand_dims(Vh, axis=2), WX)

        return self.pool(functions.split_axis(WX + Vh, self.num_split, axis=1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号