resnet.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:loss-correction 作者: giorgiop 项目源码 文件源码
def cifar10_resnet(depth, cifar10model, decay, loss):

    # how many layers this is going to create?
    # 2 + 6 * depth

    model = cifar10model
    input = Input(shape=(model.img_rows, model.img_cols, model.img_channels))

    # 1 conv + BN + relu
    b = Conv2D(filters=32, kernel_size=(model.num_conv, model.num_conv),
               kernel_initializer="he_normal", padding="same",
               kernel_regularizer=l2(decay), bias_regularizer=l2(0))(input)
    b = BatchNormalization(axis=BN_AXIS)(b)
    b = Activation("relu")(b)

    # 1 res, no striding
    b = residual(model, decay, first=True)(b)  # 2 layers inside
    for _ in np.arange(1, depth):  # start from 1 => 2 * depth in total
        b = residual(model, decay)(b)

    # 2 res, with striding
    b = residual(model, decay, more_filters=True)(b)
    for _ in np.arange(1, depth):
        b = residual(model, decay)(b)

    # 3 res, with striding
    b = residual(model, decay, more_filters=True)(b)
    for _ in np.arange(1, depth):
        b = residual(model, decay)(b)

    b = BatchNormalization(axis=BN_AXIS)(b)
    b = Activation("relu")(b)

    b = AveragePooling2D(pool_size=(8, 8), strides=(1, 1),
                         padding="valid")(b)

    out = Flatten()(b)
    if loss in yes_softmax:
        dense = Dense(units=model.classes, kernel_initializer="he_normal",
                      activation="softmax",
                      kernel_regularizer=l2(decay), bias_regularizer=l2(0))(out)
    elif loss in yes_bound:
        dense = Dense(units=model.classes, kernel_initializer="he_normal",
                      kernel_regularizer=l2(decay), bias_regularizer=l2(0))(out)
        dense = BatchNormalization(axis=BN_AXIS)(dense)
    else:
        dense = Dense(units=model.classes, kernel_initializer="he_normal",
                      kernel_regularizer=l2(decay), bias_regularizer=l2(0))(out)

    return Model(inputs=input, outputs=dense)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号