gated_pixelcnn.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:eva 作者: israelg99 项目源码 文件源码
def GatedPixelCNN(input_shape, filters, depth, latent=None, build=True):
    height, width, channels = input_shape
    palette = 256 # TODO: Make it scalable to any amount of palette.

    input_img = Input(shape=input_shape, name=str(channels)+'_channels_'+str(palette)+'_palette')

    latent_vector = None
    if latent is not None:
        latent_vector = Input(shape=(latent,), name='latent_vector')

    model = GatedCNNs(filters, depth, latent_vector)(*GatedCNN(filters, latent_vector)(input_img))

    for _ in range(2):
        model = Convolution2D(filters, 1, 1, border_mode='valid')(model)
        model = PReLU()(model)

    outs = OutChannels(*input_shape, masked=False, palette=palette)(model)

    if build:
        model = Model(input=[input_img, latent_vector] if latent is not None else input_img, output=outs)
        model.compile(optimizer=Nadam(), loss='binary_crossentropy' if channels == 1 else 'sparse_categorical_crossentropy')

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号