model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:deepascii 作者: awentzonline 项目源码 文件源码
def make_model(
        img_shape, charset_features, layer_name='block2_conv1',
        output_pool=2, pool_type='max'):
    if K.image_dim_ordering():
        num_chars, char_h, char_w, char_channels = charset_features.shape
        axis = -1
    else:
        num_chars, char_channels, char_h, char_w = charset_features.shape
        axis = 1
    vgg = vgg16.VGG16(input_shape=img_shape, include_top=False)
    layer = vgg.get_layer(layer_name)
    x = layer.output
    # TODO: theano dim order
    features_W = charset_features.transpose((1, 2, 3, 0)).astype(np.float32)
    features_W = features_W[::-1, ::-1, :, :] / np.sqrt(np.sum(np.square(features_W), axis=(0, 1), keepdims=True))
    x = BatchNormalization(mode=2)(x)
    x = Convolution2D(
        num_chars, char_h, char_w, border_mode='valid',
        weights=[features_W, np.zeros(num_chars)])(x)
    if output_pool > 1:
        pool_class = dict(max=MaxPooling2D, avg=AveragePooling2D)[pool_type]
        x = pool_class((output_pool, output_pool))(x)
    #x = Argmax(axis)(x)
    model = Model([vgg.input], [x])
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号