theano_backend.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:keraflow 作者: ipod825 项目源码 文件源码
def pool(self, x, mode, pool_size, strides, padding=(0,0)):

        if strides is None:
            strides = pool_size
        assert len(strides)==len(pool_size)
        do2D = len(pool_size)==2

        if mode=='avg':
            mode='average_exc_pad'

        # theano requires symmetric padding
        # We pad the larger on when two sides' padding are unequal
        max_padding = list(padding)
        for i, p in enumerate(padding):
            if isinstance(p, tuple):
                assert p[1]==p[0]+1
                max_padding[i] = p[1]
            else:
                max_padding[i] = p

        if do2D:
            pool_out = pool.pool_2d(x, ds=pool_size, st=strides,
                                    ignore_border=True,
                                    padding=max_padding,
                                    mode=mode)
        else:
            # pool over HW
            pool_out = pool.pool_2d(x.dimshuffle(0,1,4,2,3),
                                    ds=pool_size[:2],
                                    st=strides[:2],
                                    ignore_border=True,
                                    padding=max_padding[:2],
                                    mode=mode)

            # pool over Z
            pool_out = pool.pool_2d(pool_out.dimshuffle(0,1,3,4,2),
                                    ds=(1,pool_size[2]),
                                    st=(1,strides[2]),
                                    ignore_border=True,
                                    padding=(0, max_padding[2]),
                                    mode=mode)

        # theano might output more than expected output shape (due to max padding). We truncate them here
        exp_l = []
        for i in range(len(strides)):
            l = T.ceil(self.cast(x.shape[i+2], _FLOATX)/strides[i])
            exp_l.append(self.cast(l, 'int32'))

        if do2D:
            return pool_out[:, :, :exp_l[0], :exp_l[1]]
        else:
            return pool_out[:, :, :exp_l[0], :exp_l[1], :exp_l[2]]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号