nn.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:event_chain 作者: wangzq870305 项目源码 文件源码
def cnn_train(X_train,y_train,vocab_size):

    X_train = sequence.pad_sequences(X_train, maxlen=MAX_LEN)

    print('Build model...')
    model = Sequential()
    model.add(Embedding(vocab_size, EMBED_SIZE, input_length=MAX_LEN))

    model.add(Dropout(0.25))

    # we add a Convolution1D, which will learn nb_filter
    # word group filters of size filter_length:
    model.add(Convolution1D(nb_filter=nb_filter,
                            filter_length=filter_length,
                            border_mode='valid',
                            activation='relu',
                            subsample_length=1))
    # we use standard max pooling (halving the output of the previous layer):
    model.add(MaxPooling1D(pool_length=2))

    # We flatten the output of the conv layer,
    # so that we can add a vanilla dense layer:
    model.add(Flatten())

    # We add a vanilla hidden layer:
    model.add(Dense(HIDDEN_SIZE))
    model.add(Dropout(0.25))
    model.add(Activation('relu'))

    # We project onto a single unit output layer, and squash it with a sigmoid:
    model.add(Dense(1))
    model.add(Activation('sigmoid'))

    model.compile(loss='binary_crossentropy', optimizer='rmsprop')
    model.fit(X_train, y_train, batch_size=BATCH_SIZE, nb_epoch=EPOCHS, show_accuracy=True)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号