pixelcnn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:eva 作者: israelg99 项目源码 文件源码
def PixelCNN(input_shape, filters, depth, build=True):
    height, width, channels = input_shape
    palette = 256 # TODO: Make it scalable to any amount of palette.

    input_img = Input(shape=input_shape, name=str(channels)+'_channels_'+str(palette)+'_palette')

    model = MaskedConvolution2D(filters, 7, 7, mask='A', border_mode='same', name='masked2d_A')(input_img)

    model = ResidualBlockList(filters, depth)(model)
    model = PReLU()(model)

    for _ in range(2):
        model = MaskedConvolution2D(filters, 1, 1, border_mode='valid')(model)
        model = PReLU()(model)

    outs = OutChannels(*input_shape, masked=True, palette=palette)(model)

    if build:
        model = Model(input=input_img, output=outs)
        model.compile(optimizer=Nadam(), loss='binary_crossentropy' if channels == 1 else 'sparse_categorical_crossentropy')

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号