qrnn.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:depccg 作者: masashi-y 项目源码 文件源码
def pre(self, x):
        dims = len(x.shape) - 1

        if self.kernel_size == 1:
            ret = self.W(x)
        elif self.kernel_size == 2:
            if dims == 2:
                xprev = Variable(
                    self.xp.zeros((self.batch_size, 1, self.in_size),
                                  dtype=np.float32), volatile='AUTO')
                xtminus1 = F.concat((xprev, x[:, :-1, :]), axis=1)
            else:
                xtminus1 = self.x
            ret = self.W(x) + self.V(xtminus1)
        else:
            ret = F.swapaxes(self.conv(
                F.swapaxes(x, 1, 2))[:, :, :x.shape[2]], 1, 2)

        if not self.attention:
            return ret

        if dims == 1:
            enc = self.encoding[:, -1, :]
        else:
            enc = self.encoding[:, -1:, :]
        return sum(F.broadcast(self.U(enc), ret))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号