model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainer-qrnn 作者: butsugiri 项目源码 文件源码
def pooling(self, c, xs, train):
        """
        implement fo-pooling
        (seemingly the best option when compared to ifo/f-pooling)
        """
        c_prev = c
        hs = []

        for x in xs:
            batch = x.shape[0]
            w0, w1, w2 = F.split_axis(x, 3, axis=1)
            z = F.tanh(w0)
            f = F.sigmoid(w1)
            o = F.sigmoid(w2)

            c_prev_rest = None
            if c_prev is None:
                c = (1 - f) * z
            else:
                # when sequence length differs within the minibatch
                if c_prev.shape[0] > batch:
                    c_prev, c_prev_rest = F.split_axis(c_prev, [batch], axis=0)
                # if train:
                #     zoneout_mask = (0.1 < self.xp.random.rand(*f.shape))
                #     c = f * c_prev + (1 - f) * z * zoneout_mask
                # else:
                #     c = f * c_prev + (1 - f) * z
                c = f * c_prev + (1 - f) * z
            h = o * c
            if c_prev_rest is not None:
                c = F.concat([c, c_prev_rest], axis=0)
            hs.append(h)
            c_prev = c
        return c, F.transpose_sequence(hs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号