wide_resnet.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:keras-contrib 作者: farizrahman4u 项目源码 文件源码
def __conv2_block(input, k=1, dropout=0.0):
    init = input

    channel_axis = 1 if K.image_dim_ordering() == 'th' else -1

    # Check if input number of filters is same as 16 * k, else create convolution2d for this input
    if K.image_dim_ordering() == 'th':
        if init._keras_shape[1] != 16 * k:
            init = Conv2D(16 * k, (1, 1), activation='linear', padding='same')(init)
    else:
        if init._keras_shape[-1] != 16 * k:
            init = Conv2D(16 * k, (1, 1), activation='linear', padding='same')(init)

    x = Conv2D(16 * k, (3, 3), padding='same')(input)
    x = BatchNormalization(axis=channel_axis)(x)
    x = Activation('relu')(x)

    if dropout > 0.0:
        x = Dropout(dropout)(x)

    x = Conv2D(16 * k, (3, 3), padding='same')(x)
    x = BatchNormalization(axis=channel_axis)(x)
    x = Activation('relu')(x)

    m = add([init, x])
    return m
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号