model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:what-celebrity 作者: dansbecker 项目源码 文件源码
def model_from_thumbnails(train_x, train_y, val_x, val_y):
    n_obs, n_channels, n_rows, n_cols = train_x.shape
    n_classes = y.shape[1]

    model = Sequential()
    model.add(Convolution2D(32, 2, 2, border_mode='valid',
                            activation='relu',
                            input_shape=(n_channels, n_rows, n_cols)))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Convolution2D(64, 2, 2, border_mode='valid',
                            activation='relu'))
    model.add(Convolution2D(64, 2, 2, border_mode='valid',
                            activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Convolution2D(64, 2, 2, border_mode='valid',
                            activation='relu'))

    model.add(Flatten())
    model.add(Dropout(0.5))
    model.add(Dense(100, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(100, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(n_classes, activation='softmax'))
    optimizer = Adam()
    model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    stopper = EarlyStopping(monitor='val_loss', patience=15, verbose=0, mode='auto')

    model.fit(train_x, train_y, shuffle=True,
                        nb_epoch=100, validation_data=(val_x, val_y),
                        callbacks = [stopper])
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号