qrnn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainer-qrnn 作者: musyoku 项目源码 文件源码
def pool(self, WX, skip_mask=None):
        Z, F, O, I = None, None, None, None

        # f-pooling
        if len(self._pooling) == 1:
            assert len(WX) == 2
            Z, F = WX
            Z = functions.tanh(Z)
            F = self.zoneout(F)

        # fo-pooling
        if len(self._pooling) == 2:
            assert len(WX) == 3
            Z, F, O = WX
            Z = functions.tanh(Z)
            F = self.zoneout(F)
            O = functions.sigmoid(O)

        # ifo-pooling
        if len(self._pooling) == 3:
            assert len(WX) == 4
            Z, F, O, I = WX
            Z = functions.tanh(Z)
            F = self.zoneout(F)
            O = functions.sigmoid(O)
            I = functions.sigmoid(I)

        assert Z is not None
        assert F is not None

        T = Z.shape[2]
        for t in xrange(T):
            zt = Z[..., t]
            ft = F[..., t]
            ot = 1 if O is None else O[..., t]
            it = 1 - ft if I is None else I[..., t]
            xt = 1 if skip_mask is None else skip_mask[:, t, None]  # will be used for seq2seq to skip PAD

            if self.ct is None:
                self.ct = (1 - ft) * zt * xt
            else:
                self.ct = ft * self.ct + it * zt * xt
            self.ht = self.ct if O is None else ot * self.ct

            if self.H is None:
                self.H = functions.expand_dims(self.ht, 2)
            else:
                self.H = functions.concat((self.H, functions.expand_dims(self.ht, 2)), axis=2)

        return self.H
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号