dual_encoder_lstm.py 文件源码

python
阅读 14 收藏 0 点赞 0 评论 0

项目:snli_dual_encoder_lstm 作者: hist0613 项目源码 文件源码
def load_model():
    if not os.path.exists(TRAINED_CLASSIFIER_PATH):
        print("No pre-trained model...")
        print("Start building model...")

        print("Now loading SNLI data...")
        X_train_1, X_train_2, Y_train, X_test_1, X_test_2, Y_test, X_dev_1, X_dev_2, Y_dev = load_data()

        print("Now loading embedding matrix...")
        embedding_matrix = load_embedding_matrix()

        print("Now building dual encoder lstm model...")
        # define lstm for sentence1
        branch1 = Sequential()
        branch1.add(Embedding(output_dim=EMBEDDING_DIM,
                              input_dim=MAX_NB_WORDS,
                              input_length=MAX_SEQUENCE_LENGTH,
                              weights=[embedding_matrix],
                              mask_zero=True,
                              trainable=False))
        branch1.add(LSTM(output_dim=LSTM_DIM))

        # define lstm for sentence2
        branch2 = Sequential()
        branch2.add(Embedding(output_dim=EMBEDDING_DIM,
                              input_dim=MAX_NB_WORDS,
                              input_length=MAX_SEQUENCE_LENGTH,
                              weights=[embedding_matrix],
                              mask_zero=True,
                              trainable=False))
        branch2.add(LSTM(output_dim=LSTM_DIM))

        # define classifier model
        model = Sequential()
        # Merge layer holds a weight matrix of itself
        model.add(Merge([branch1, branch2], mode='mul'))
        model.add(Dense(3))
        model.add(Activation('softmax'))

        model.compile(loss='categorical_crossentropy',
                      optimizer=OPTIMIZER,
                      metrics=['accuracy'])

        print("Now training the model...")
        print("\tbatch_size={}, nb_epoch={}".format(BATCH_SIZE, NB_EPOCH))
        model.fit([X_train_1, X_train_2], Y_train,
                  batch_size=BATCH_SIZE, nb_epoch=NB_EPOCH,
                  validation_data=([X_test_1, X_test_2], Y_test))

        print("Now saving the model... at {}".format(TRAINED_CLASSIFIER_PATH))
        model.save(TRAINED_CLASSIFIER_PATH)

    else:
        print("Found pre-trained model...")
        model = K_load_model(TRAINED_CLASSIFIER_PATH)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号