model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DeepClassificationBot 作者: AntreasAntoniou 项目源码 文件源码
def get_deep_anime_model(n_outputs=1000, input_size=128):
    '''The deep neural network used for deep anime bot'''
    conv = Sequential()

    conv.add(Convolution2D(64, 3, 3, activation='relu', input_shape=(3, input_size, input_size)))
    conv.add(ZeroPadding2D((1, 1)))
    conv.add(Convolution2D(64, 3, 3, activation='relu'))
    conv.add(MaxPooling2D((2, 2), strides=(2, 2)))
    conv.add(BatchNormalization())
    # conv.add(Dropout(0.5))

    conv.add(ZeroPadding2D((1, 1)))
    conv.add(Convolution2D(128, 3, 3, activation='relu'))
    # conv.add(ZeroPadding2D((1, 1)))
    conv.add(Convolution2D(128, 1, 1, activation='relu'))
    conv.add(MaxPooling2D((2, 2), strides=(2, 2)))
    conv.add(BatchNormalization())
    # conv.add(Dropout(0.5))

    conv.add(ZeroPadding2D((1, 1)))
    conv.add(Convolution2D(256, 3, 3, activation='relu'))
    conv.add(ZeroPadding2D((1, 1)))
    conv.add(Convolution2D(256, 3, 3, activation='relu'))
    # conv.add(ZeroPadding2D((1, 1)))
    conv.add(Convolution2D(256, 1, 1, activation='relu'))
    conv.add(MaxPooling2D((2, 2), strides=(2, 2)))
    conv.add(BatchNormalization())
    # conv.add(Dropout(0.5))

    conv.add(ZeroPadding2D((1, 1)))
    conv.add(Convolution2D(512, 3, 3, activation='relu'))
    conv.add(ZeroPadding2D((1, 1)))
    conv.add(Convolution2D(512, 3, 3, activation='relu'))
    # conv.add(ZeroPadding2D((1, 1)))
    conv.add(Convolution2D(512, 1, 1, activation='relu'))
    conv.add(AveragePooling2D((8, 8), strides=(2, 2)))
    conv.add(BatchNormalization())
    # conv.add(Dropout(0.5))

    # conv.add(ZeroPadding2D((1, 1)))
    # conv.add(Convolution2D(512, 3, 3, activation='relu'))
    # conv.add(ZeroPadding2D((1, 1)))
    # conv.add(Convolution2D(512, 3, 3, activation='relu'))
    # #conv.add(ZeroPadding2D((1, 1)))
    # conv.add(Convolution2D(512, 1, 1, activation='relu'))
    # conv.add(AveragePooling2D((4, 4)))

    # conv.add(BatchNormalization())
    conv.add(Flatten())
    conv.add(Dropout(0.5))
    conv.add(Dense(2048))
    conv.add(BatchNormalization())
    conv.add(Dropout(0.7))
    conv.add(Dense(2048))
    conv.add(BatchNormalization())
    conv.add(Dropout(0.7))
    conv.add(Dense(n_outputs))
    conv.add(Activation('softmax'))
    print(conv.summary())
    return conv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号