def resnet(repetition=2, k=1):
'''Wide Residual Network (with a slight modification)
depth == repetition*6 + 2
'''
from keras.models import Model
from keras.layers import Input, Dense, Flatten, AveragePooling2D
from keras.regularizers import l2
input_shape = (1, _img_len, _img_len)
output_dim = len(_columns)
x = Input(shape=input_shape)
z = conv2d(nb_filter=8, k_size=5, downsample=True)(x) # out_shape == 8, _img_len/ 2, _img_len/ 2
z = bn_lrelu(0.01)(z)
z = residual_block(nb_filter=k*16, repetition=repetition)(z) # out_shape == k*16, _img_len/ 4, _img_len/ 4
z = residual_block(nb_filter=k*32, repetition=repetition)(z) # out_shape == k*32, _img_len/ 8, _img_len/ 8
z = residual_block(nb_filter=k*64, repetition=repetition)(z) # out_shape == k*64, _img_len/16, _img_len/16
z = AveragePooling2D((_img_len/16, _img_len/16))(z)
z = Flatten()(z)
z = Dense(output_dim=output_dim, activation='sigmoid', W_regularizer=l2(_Wreg_l2), init='zero')(z)
return Model(input=x, output=z)
trainer.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录