run_seq2seq.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:DialogueBreakdownDetection2016 作者: icoxfog417 项目源码 文件源码
def classify():
    reader = DbdReader(DATA_DIR, TRAIN_PATH, target_for_vocabulary=TARGET_PATH, max_vocabulary_size=_vocab_size_, filter="140", threshold=0.6, clear_when_exit=False)
    reader.init()
    dataset, user_vocab, system_vocab = reader.get_dataset()

    labels = reader.get_labels()
    model = make_model(user_vocab, system_vocab)
    model_if = model.create_interface(_buckets_, TRAIN_DIR)

    train_x, test_x, train_t, test_t = train_test_split(dataset, labels, test_size=0.2, random_state=42)

    with tf.Session() as sess:
        detector = Detector(sess, model_if)
        detector.train(sess, train_x, train_t)
        y = [detector.predict(sess, p) for p in test_x]
        y = [lb for lb, prob in y]

    report = classification_report([lb.label for lb in test_t], y, target_names=DbdReader.get_label_names())
    print(report)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号