generate.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:opt-mmd 作者: dougalsutherland 项目源码 文件源码
def _sample_trained_minibatch_gan(params_file, n, batch_size, rs):
    import lasagne
    from lasagne.init import Normal
    import lasagne.layers as ll
    import theano as th
    from theano.sandbox.rng_mrg import MRG_RandomStreams
    import theano.tensor as T

    import nn

    theano_rng = MRG_RandomStreams(rs.randint(2 ** 15))
    lasagne.random.set_rng(np.random.RandomState(rs.randint(2 ** 15)))

    noise_dim = (batch_size, 100)
    noise = theano_rng.uniform(size=noise_dim)
    ls = [ll.InputLayer(shape=noise_dim, input_var=noise)]
    ls.append(nn.batch_norm(
        ll.DenseLayer(ls[-1], num_units=4*4*512, W=Normal(0.05),
                      nonlinearity=nn.relu),
        g=None))
    ls.append(ll.ReshapeLayer(ls[-1], (batch_size,512,4,4)))
    ls.append(nn.batch_norm(
        nn.Deconv2DLayer(ls[-1], (batch_size,256,8,8), (5,5), W=Normal(0.05),
                         nonlinearity=nn.relu),
        g=None)) # 4 -> 8
    ls.append(nn.batch_norm(
        nn.Deconv2DLayer(ls[-1], (batch_size,128,16,16), (5,5), W=Normal(0.05),
                         nonlinearity=nn.relu),
        g=None)) # 8 -> 16
    ls.append(nn.weight_norm(
        nn.Deconv2DLayer(ls[-1], (batch_size,3,32,32), (5,5), W=Normal(0.05),
                         nonlinearity=T.tanh),
        train_g=True, init_stdv=0.1)) # 16 -> 32
    gen_dat = ll.get_output(ls[-1])

    with np.load(params_file) as d:
        params = [d['arr_{}'.format(i)] for i in range(9)]
    ll.set_all_param_values(ls[-1], params, trainable=True)

    sample_batch = th.function(inputs=[], outputs=gen_dat)
    samps = []
    while len(samps) < n:
        samps.extend(sample_batch())
    samps = np.array(samps[:n])
    return samps
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号