fxnn.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:fxnn 作者: khaotik 项目源码 文件源码
def lyr_sconv_gen(
    name_, s_x_,
    idim_, odim_,
    **kwargs_):
    '''
    quick & dirty implementation of fxnn convolution layer
    '''
    global g_mdl
    dilation = kwargs_.get('dilation_')
    if dilation is None:
        dilation = 1
    init_scale = kwargs_.get('init_scale_')
    bias = kwargs_.get('bias_')
    op_conv = partial(
        T.nnet.conv2d,
        border_mode='half',
        filter_dilation = (dilation, dilation))
    ir = 0.5/sqrt(idim_*5+odim_)
    s_dims = T.shape(s_x_)
    s_x = T.reshape(s_x_, (s_dims[0]*idim_, 1, s_dims[2], s_dims[3]))
    s_x1 = T.reshape(op_conv(
        s_x, g_sconv_ker,
        filter_shape=(2, 1, 1, 3), **kwargs_),
        (s_dims[0]*idim_*2, 1, s_dims[2], s_dims[3]))
    s_x2 = T.reshape(op_conv(
        s_x1, g_sconv_ker.transpose(0,1,3,2),
        filter_shape=(2, 1, 3, 1),
    ), (s_dims[0], idim_*4, s_dims[2], s_dims[3]))
    s_y = T.join(1, s_x2, s_x_)
    return g_mdl.lyr_conv(
        name_, s_y, idim_*5, odim_, fsize_=1, init_scale_=ir, **kwargs_);
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号