cnn_layer.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:textGAN_public 作者: dreasysnail 项目源码 文件源码
def encoder(tparams, layer0_input, filter_shape, pool_size, options, prefix='cnn_d'):

    """ filter_shape: (number of filters, num input feature maps, filter height,
                        filter width)
        image_shape: (batch_size, num input feature maps, image height, image width)
    """

    conv_out = conv.conv2d(input=layer0_input, filters=tparams[_p(prefix,'W')], filter_shape=filter_shape)
    # conv_out_tanh = tensor.tanh(conv_out + tparams[_p(prefix,'b')].dimshuffle('x', 0, 'x', 'x'))
    # output = downsample.max_pool_2d(input=conv_out_tanh, ds=pool_size, ignore_border=False)

    if options['cnn_activation'] == 'tanh':
        conv_out_tanh = tensor.tanh(conv_out + tparams[_p(prefix,'b')].dimshuffle('x', 0, 'x', 'x'))
        output = downsample.max_pool_2d(input=conv_out_tanh, ds=pool_size, ignore_border=False)  # the ignore border is very important
    elif options['cnn_activation'] == 'linear':
        conv_out2 = conv_out + tparams[_p(prefix,'b')].dimshuffle('x', 0, 'x', 'x')
        output = downsample.max_pool_2d(input=conv_out2, ds=pool_size, ignore_border=False)  # the ignore border is very important
    else:
        print(' Wrong specification of activation function in CNN')

    return output.flatten(2)

    #output.flatten(2)
评论列表


问题


面经


文章

微信
公众号

扫码关注公众号