steingan_lsun.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:SteinGAN 作者: DartML 项目源码 文件源码
def discrim(X):
    current_input = dropout(X, 0.3)
    ### encoder ###
    cv1 = relu(dnn_conv(current_input, aew1, subsample=(1,1), border_mode=(1,1)))
    cv2 = relu(batchnorm(dnn_conv(cv1, aew2, subsample=(4,4), border_mode=(2,2)), g=aeg2, b=aeb2))
    cv3 = relu(batchnorm(dnn_conv(cv2, aew3, subsample=(1,1), border_mode=(1,1)), g=aeg3, b=aeb3))
    cv4 = relu(batchnorm(dnn_conv(cv3, aew4, subsample=(4,4), border_mode=(2,2)), g=aeg4, b=aeb4))
    cv5 = relu(batchnorm(dnn_conv(cv4, aew5, subsample=(1,1), border_mode=(1,1)), g=aeg5, b=aeb5))
    cv6 = relu(batchnorm(dnn_conv(cv5, aew6, subsample=(4,4), border_mode=(0,0)), g=aeg6, b=aeb6))

    ### decoder ###
    dv6 = relu(batchnorm(deconv(cv6, aew6, subsample=(4,4), border_mode=(0,0)), g=aeg6t, b=aeb6t)) 
    dv5 = relu(batchnorm(deconv(dv6, aew5, subsample=(1,1), border_mode=(1,1)), g=aeg5t, b=aeb5t))
    dv4 = relu(batchnorm(deconv(dv5, aew4, subsample=(4,4), border_mode=(2,2)), g=aeg4t, b=aeb4t)) 
    dv3 = relu(batchnorm(deconv(dv4, aew3, subsample=(1,1), border_mode=(1,1)), g=aeg3t, b=aeb3t))
    dv2 = relu(batchnorm(deconv(dv3, aew2, subsample=(4,4), border_mode=(2,2)), g=aeg2t, b=aeb2t))
    dv1 = tanh(deconv(dv2, aew1, subsample=(1,1), border_mode=(1,1)))

    rX = dv1

    mse = T.sqrt(T.sum(T.abs_(T.flatten(X-rX, 2)),axis=1)) + T.sqrt(T.sum(T.flatten((X-rX)**2, 2), axis=1))
    return T.flatten(cv6, 2), rX, mse
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号