nn.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:event_chain 作者: wangzq870305 项目源码 文件源码
def lstm_train(X_train,y_train,vocab_size):

    X_train = sequence.pad_sequences(X_train, maxlen=MAX_LEN)

    main_input = Input(shape=(MAX_LEN,), dtype='int32')

    x = Embedding(output_dim=EMBED_SIZE, input_dim=vocab_size, input_length=MAX_LEN)(main_input)

    lstm_out = LSTM(HIDDEN_SIZE)(x)

    main_loss = Dense(1, activation='sigmoid', name='main_output')(lstm_out)

    model = Model(input=main_input, output=main_loss)

    model.compile(loss='binary_crossentropy', optimizer='rmsprop')
    model.fit(X_train, y_train, batch_size=BATCH_SIZE, nb_epoch=EPOCHS)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号