model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:textfool 作者: bogdan-kulynych 项目源码 文件源码
def build_model(max_length=1000,
                nb_filters=64,
                kernel_size=3,
                pool_size=2,
                regularization=0.01,
                weight_constraint=2.,
                dropout_prob=0.4,
                clear_session=True):
    if clear_session:
        K.clear_session()

    model = Sequential()
    model.add(Embedding(
        embeddings.shape[0],
        embeddings.shape[1],
        input_length=max_length,
        trainable=False,
        weights=[embeddings]))

    model.add(Conv1D(nb_filters, kernel_size, activation='relu'))
    model.add(Conv1D(nb_filters, kernel_size, activation='relu'))
    model.add(MaxPooling1D(pool_size))

    model.add(Dropout(dropout_prob))

    model.add(Conv1D(nb_filters * 2, kernel_size, activation='relu'))
    model.add(Conv1D(nb_filters * 2, kernel_size, activation='relu'))
    model.add(MaxPooling1D(pool_size))

    model.add(Dropout(dropout_prob))

    model.add(GlobalAveragePooling1D())
    model.add(Dense(1,
        kernel_regularizer=l2(regularization),
        kernel_constraint=maxnorm(weight_constraint),
        activation='sigmoid'))

    model.compile(
        loss='binary_crossentropy',
        optimizer='rmsprop',
        metrics=['accuracy'])

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号