mnist_cond_lsgan.py 文件源码

python
阅读 14 收藏 0 点赞 0 评论 0

项目:WGAN_mnist 作者: rajeswar18 项目源码 文件源码
def discriminator(input_var,Y):
    yb = Y.dimshuffle(0, 1, 'x', 'x')

    D_1 = lasagne.layers.InputLayer(shape=(None, 1, 28, 28),
                                        input_var=input_var)
    D_2 = lasagne.layers.InputLayer(shape=(None, 10),input_var=Y)
    network=D_1
    network_yb=D_2
    network = CondConvConcatLayer([network,network_yb])

    network = ll.DropoutLayer(network, p=0.4)

    network = conv_layer(network, 3, 32, 1, 'same', nonlinearity=lrelu)
    network = CondConvConcatLayer([network,network_yb])

    network = conv_layer(network, 3, 64, 2, 'same', nonlinearity=lrelu)
    network = CondConvConcatLayer([network,network_yb])

    network = conv_layer(network, 3, 64, 2, 'same', nonlinearity=lrelu)
    #network = batch_norm(conv_layer(network, 3, 128, 1, 'same', nonlinearity=lrelu))
    #network = ll.DropoutLayer(network, p=0.2)

    network = conv_layer(network, 3, 128, 2, 'same', nonlinearity=lrelu)
    network = CondConvConcatLayer([network,network_yb])

    network = batch_norm(conv_layer(network, 4, 128, 1, 'valid', nonlinearity=lrelu))
    network = CondConvConcatLayer([network,network_yb])

    #network= DropoutLayer(network, p=0.5)
    network =conv_layer(network, 1, 1, 1, 'valid', nonlinearity=None)

    return network, D_1,D_2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号