qrnn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:chainer-qrnn 作者: musyoku 项目源码 文件源码
def __call__(self, X, ht_enc, H_enc, skip_mask=None):
        pad = self._kernel_size - 1
        WX = self.W(X)
        if pad > 0:
            WX = WX[:, :, :-pad]
        Vh = self.V(ht_enc)
        Vh, WX = functions.broadcast(functions.expand_dims(Vh, axis=2), WX)

        # f-pooling
        Z, F, O = functions.split_axis(WX + Vh, 3, axis=1)
        Z = functions.tanh(Z)
        F = self.zoneout(F)
        O = functions.sigmoid(O)
        T = Z.shape[2]

        # compute ungated hidden states
        self.contexts = []
        for t in xrange(T):
            z = Z[..., t]
            f = F[..., t]
            if t == 0:
                ct = (1 - f) * z
                self.contexts.append(ct)
            else:
                ct = f * self.contexts[-1] + (1 - f) * z
                self.contexts.append(ct)

        if skip_mask is not None:
            assert skip_mask.shape[1] == H_enc.shape[2]
            softmax_bias = (skip_mask == 0) * -1e6

        # compute attention weights (eq.8)
        H_enc = functions.swapaxes(H_enc, 1, 2)
        for t in xrange(T):
            ct = self.contexts[t]
            bias = 0 if skip_mask is None else softmax_bias[..., None]  # to skip PAD
            mask = 1 if skip_mask is None else skip_mask[..., None]     # to skip PAD
            alpha = functions.batch_matmul(H_enc, ct) + bias
            alpha = functions.softmax(alpha) * mask
            alpha = functions.broadcast_to(alpha, H_enc.shape)  # copy
            kt = functions.sum(alpha * H_enc, axis=1)
            ot = O[..., t]
            self.ht = ot * self.o(functions.concat((kt, ct), axis=1))

            if t == 0:
                self.H = functions.expand_dims(self.ht, 2)
            else:
                self.H = functions.concat((self.H, functions.expand_dims(self.ht, 2)), axis=2)

        return self.H
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号