x2y_yz2x_xy2p_ssl_cifar10.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:triple-gan 作者: zhenxuan00 项目源码 文件源码
def build_network():
    conv_defs = {
        'W': lasagne.init.HeNormal('relu'),
        'b': lasagne.init.Constant(0.0),
        'filter_size': (3, 3),
        'stride': (1, 1),
        'nonlinearity': lasagne.nonlinearities.LeakyRectify(0.1)
    }

    nin_defs = {
        'W': lasagne.init.HeNormal('relu'),
        'b': lasagne.init.Constant(0.0),
        'nonlinearity': lasagne.nonlinearities.LeakyRectify(0.1)
    }

    dense_defs = {
        'W': lasagne.init.HeNormal(1.0),
        'b': lasagne.init.Constant(0.0),
        'nonlinearity': lasagne.nonlinearities.softmax
    }

    wn_defs = {
        'momentum': .999
    }

    net = InputLayer        (     name='input',    shape=(None, 3, 32, 32))
    net = GaussianNoiseLayer(net, name='noise',    sigma=.15)
    net = WN(Conv2DLayer    (net, name='conv1a',   num_filters=128, pad='same', **conv_defs), **wn_defs)
    net = WN(Conv2DLayer    (net, name='conv1b',   num_filters=128, pad='same', **conv_defs), **wn_defs)
    net = WN(Conv2DLayer    (net, name='conv1c',   num_filters=128, pad='same', **conv_defs), **wn_defs)
    net = MaxPool2DLayer    (net, name='pool1',    pool_size=(2, 2))
    net = DropoutLayer      (net, name='drop1',    p=.5)
    net = WN(Conv2DLayer    (net, name='conv2a',   num_filters=256, pad='same', **conv_defs), **wn_defs)
    net = WN(Conv2DLayer    (net, name='conv2b',   num_filters=256, pad='same', **conv_defs), **wn_defs)
    net = WN(Conv2DLayer    (net, name='conv2c',   num_filters=256, pad='same', **conv_defs), **wn_defs)
    net = MaxPool2DLayer    (net, name='pool2',    pool_size=(2, 2))
    net = DropoutLayer      (net, name='drop2',    p=.5)
    net = WN(Conv2DLayer    (net, name='conv3a',   num_filters=512, pad=0,      **conv_defs), **wn_defs)
    net = WN(NINLayer       (net, name='conv3b',   num_units=256,               **nin_defs),  **wn_defs)
    net = WN(NINLayer       (net, name='conv3c',   num_units=128,               **nin_defs),  **wn_defs)
    net = GlobalPoolLayer   (net, name='pool3')
    net = WN(DenseLayer     (net, name='dense',    num_units=10,       **dense_defs), **wn_defs)

    return net
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号