classifier.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:quantulum 作者: marcolagi 项目源码 文件源码
def train_classifier(download=True, parameters=None, ngram_range=(1, 1)):
    """Train the intent classifier."""
    if download:
        download_wiki()

    path = os.path.join(l.TOPDIR, 'train.json')
    training_set = json.load(open(path))
    path = os.path.join(l.TOPDIR, 'wiki.json')
    wiki_set = json.load(open(path))

    target_names = list(set([i['unit'] for i in training_set + wiki_set]))
    train_data, train_target = [], []
    for example in training_set + wiki_set:
        train_data.append(clean_text(example['text']))
        train_target.append(target_names.index(example['unit']))

    tfidf_model = TfidfVectorizer(sublinear_tf=True,
                                  ngram_range=ngram_range,
                                  stop_words='english')

    matrix = tfidf_model.fit_transform(train_data)

    if parameters is None:
        parameters = {'loss': 'log', 'penalty': 'l2', 'n_iter': 50,
                      'alpha': 0.00001, 'fit_intercept': True}

    clf = SGDClassifier(**parameters).fit(matrix, train_target)
    obj = {'tfidf_model': tfidf_model,
           'clf': clf,
           'target_names': target_names}
    path = os.path.join(l.TOPDIR, 'clf.pickle')
    pickle.dump(obj, open(path, 'w'))


###############################################################################
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号