paper_plots.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:gconv_experiments 作者: tscohen 项目源码 文件源码
def testplot_p4m(im=None, m=0, r=0):

    if im is None:
        im = np.zeros((5, 5), dtype='float32')
        im[0:5, 1] = 1.
        im[0, 1:4] = 1.
        im[2, 1:3] = 1.

    from groupy.gfunc.z2func_array import Z2FuncArray
    from groupy.garray.D4_array import D4Array
    def rotate_flip_z2_func(im, flip, theta_index):
        imf = Z2FuncArray(im)
        rot = D4Array([flip, theta_index], 'int')
        rot_imf = rot * imf
        return rot_imf.v
    im = rotate_flip_z2_func(im, m, r)

    filter_e = np.array([[-1., -4., 1.],
                         [-2., 0., 2.],
                         [-1., 0., 1.]])

    from groupy.gconv.chainer_gconv.p4m_conv import P4MConvZ2
    from chainer import Variable
    from chainer import cuda

    print im.shape

    imv = Variable(cuda.to_gpu(im.astype('float32').reshape(1, 1, 5, 5)))
    conv = P4MConvZ2(in_channels=1, out_channels=1, ksize=3, pad=2, flat_channels=True, initialW=filter_e.reshape(1, 1, 1, 3, 3))
    conv.to_gpu()
    conv_imv = conv(imv)
    print im.shape, conv_imv.data.shape
    return im, cuda.to_cpu(conv_imv.data)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号