train_dcganae.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:experiments 作者: tencia 项目源码 文件源码
def build_nets(input_var, channels=1, do_batchnorm=True, z_dim=100):

    def ns(shape):
        ret=list(shape)
        ret[0]=[0]
        return tuple(ret)

    ret = {}
    bn = batch_norm if do_batchnorm else lambda x:x
    ret['ae_in'] = layer = InputLayer(shape=(None,channels,28,28), input_var=input_var)
    ret['ae_conv1'] = layer = bn(Conv2DLayer(layer, num_filters=64, filter_size=5))
    ret['ae_pool1'] = layer = MaxPool2DLayer(layer, pool_size=2)
    ret['ae_conv2'] = layer = bn(Conv2DLayer(layer, num_filters=128, filter_size=3))
    ret['ae_pool2'] = layer = MaxPool2DLayer(layer, pool_size=2)
    ret['ae_enc'] = layer = DenseLayer(layer, num_units=z_dim,
            nonlinearity=nn.nonlinearities.tanh)
    ret['ae_unenc'] = layer = bn(nn.layers.DenseLayer(layer,
        num_units = np.product(nn.layers.get_output_shape(ret['ae_pool2'])[1:])))
    ret['ae_resh'] = layer = ReshapeLayer(layer,
            shape=ns(nn.layers.get_output_shape(ret['ae_pool2'])))
    ret['ae_depool2'] = layer = Upscale2DLayer(layer, scale_factor=2)
    ret['ae_deconv2'] = layer = bn(Conv2DLayer(layer, num_filters=64, filter_size=3,
        pad='full'))
    ret['ae_depool1'] = layer = Upscale2DLayer(layer, scale_factor=2)
    ret['ae_out'] = Conv2DLayer(layer, num_filters=1, filter_size=5, pad='full',
            nonlinearity=nn.nonlinearities.sigmoid)

    ret['disc_in'] = layer = InputLayer(shape=(None,channels,28,28), input_var=input_var)
    ret['disc_conv1'] = layer = bn(Conv2DLayer(layer, num_filters=64, filter_size=5))
    ret['disc_pool1'] = layer = MaxPool2DLayer(layer, pool_size=2)
    ret['disc_conv2'] = layer = bn(Conv2DLayer(layer, num_filters=128, filter_size=3))
    ret['disc_pool2'] = layer = MaxPool2DLayer(layer, pool_size=2)
    ret['disc_hid'] = layer = bn(DenseLayer(layer, num_units=100))
    ret['disc_out'] = DenseLayer(layer, num_units=1, nonlinearity=nn.nonlinearities.sigmoid)

    return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号