model_rnn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:botcycle 作者: D2KLab 项目源码 文件源码
def train_and_evaluate(train, test, intents_lookup, save=False):
    validation_data = None
    train_inputs, train_labels = prepare_inputs_and_outputs(train, intents_lookup)
    if test:
        test_inputs, test_labels = prepare_inputs_and_outputs(test, intents_lookup)
        validation_data = test_inputs, test_labels

    print('Number of sentences for each intent, train and test')
    print([key for key in intents_lookup])
    print(train_labels.sum(axis=0))
    if test:
        print(test_labels.sum(axis=0))

    model = create_model(len(intents_lookup))
    # first iteration
    # model.summary()
    # this requires graphviz binaries also
    #plot_model(model, to_file=MODEL_OUTPUT_FOLDER + '/model.png', show_shapes=True)

    history = model.fit(train_inputs, train_labels, validation_data=validation_data, epochs=MAX_ITERATIONS, batch_size=50)

    # keep only f1_scores
    history = {'train': np.array(history.history['f1_score']), 'test': np.array(history.history.get('val_f1_score', []))}


    # compute f1 score weighted by support
    y_pred_train = model.predict(train_inputs)
    f1_train = f1_score(train_labels.argmax(axis=1),
                y_pred_train.argmax(axis=1), average='weighted')
    if test:
        y_pred_test = model.predict(test_inputs)
        f1_test = f1_score(test_labels.argmax(axis=1),
                    y_pred_test.argmax(axis=1), average='weighted')
    else:
        f1_test = None

    # generate confusion matrix
    # confusion = utils.my_confusion_matrix(test_labels.argmax(
    #     axis=1), y_pred_test.argmax(axis=1), len(intents_lookup))

    print(f1_test, f1_train)
    if save:
        model.save(MODEL_OUTPUT_FOLDER + '/model.h5')
        stats = {}
        stats['model_name'] = MODEL_NAME
        stats['model'] = model.get_config()
        with open(MODEL_OUTPUT_FOLDER+'/stats.json', 'w+') as stats_file:
            json.dump(stats, stats_file)

    return history, f1_test, f1_train
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号