imdb_lstm.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:mtl 作者: zhenhongChen 项目源码 文件源码
def imdb_run(index_embedding, dataset, num_words=5000, embedding_len=100, max_len=500):

    (x_train, y_train), (x_test, y_test) = ds.load_data(dataset, num_words)
    x_train = sequence.pad_sequences(x_train, maxlen=max_len)
    x_test = sequence.pad_sequences(x_test, maxlen=max_len)

    model = Sequential()
    model.add(Embedding(num_words, embedding_len, input_length=max_len, weights=[index_embedding]))
    model.add(LSTM(100, dropout=0.2, recurrent_dropout=0.2))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(loss='binary_crossentropy',
                    optimizer='adam',
                    metrics=['accuracy'])

    print(model.summary())
    model.fit(x_train, y_train, epochs=3, batch_size=64, verbose=2)
    score, acc = model.evaluate(x_test, y_test, verbose=0)

    print('Test score:', score)
    print('Test accuracy:', acc)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号