trainer.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:deepanalytics_compe26_benchmark 作者: takagiwa-ss 项目源码 文件源码
def resnet(repetition=2, k=1):
    '''Wide Residual Network (with a slight modification)
    depth == repetition*6 + 2
    '''
    from keras.models import Model
    from keras.layers import Input, Dense, Flatten, AveragePooling2D
    from keras.regularizers import l2

    input_shape = (1, _img_len, _img_len)
    output_dim = len(_columns)

    x = Input(shape=input_shape)

    z = conv2d(nb_filter=8, k_size=5, downsample=True)(x)        # out_shape ==    8, _img_len/ 2, _img_len/ 2
    z = bn_lrelu(0.01)(z)
    z = residual_block(nb_filter=k*16, repetition=repetition)(z) # out_shape == k*16, _img_len/ 4, _img_len/ 4
    z = residual_block(nb_filter=k*32, repetition=repetition)(z) # out_shape == k*32, _img_len/ 8, _img_len/ 8
    z = residual_block(nb_filter=k*64, repetition=repetition)(z) # out_shape == k*64, _img_len/16, _img_len/16
    z = AveragePooling2D((_img_len/16, _img_len/16))(z)
    z = Flatten()(z)
    z = Dense(output_dim=output_dim, activation='sigmoid', W_regularizer=l2(_Wreg_l2), init='zero')(z)

    return Model(input=x, output=z)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号