text_cnn.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:snowman 作者: keeganhines 项目源码 文件源码
def text_cnn(max_seq_index, max_seq_length):
    text_input = Input(shape = (max_seq_length,), name='text_input')
    x = Embedding(output_dim=15, 
            input_dim=max_seq_index, 
            input_length=max_seq_length)(text_input)

    conv_a = Conv1D(15,2, activation='relu')(x)
    conv_b = Conv1D(15,4, activation='relu')(x)
    conv_c = Conv1D(15,6, activation='relu')(x)

    pool_a = GlobalMaxPooling1D()(conv_a)
    pool_b = GlobalMaxPooling1D()(conv_b)
    pool_c = GlobalMaxPooling1D()(conv_c)

    flattened = concatenate(
        [pool_a, pool_b, pool_c])

    drop = Dropout(.2)(flattened)

    dense = Dense(1)(drop)
    out = Activation("sigmoid")(dense)

    model = Model(inputs=text_input, outputs=out)

    model.compile(loss='binary_crossentropy',
        optimizer='rmsprop',
        metrics=['accuracy'])

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号