def train_classifier(download=True, parameters=None, ngram_range=(1, 1)):
"""Train the intent classifier."""
if download:
download_wiki()
path = os.path.join(l.TOPDIR, 'train.json')
training_set = json.load(open(path))
path = os.path.join(l.TOPDIR, 'wiki.json')
wiki_set = json.load(open(path))
target_names = list(set([i['unit'] for i in training_set + wiki_set]))
train_data, train_target = [], []
for example in training_set + wiki_set:
train_data.append(clean_text(example['text']))
train_target.append(target_names.index(example['unit']))
tfidf_model = TfidfVectorizer(sublinear_tf=True,
ngram_range=ngram_range,
stop_words='english')
matrix = tfidf_model.fit_transform(train_data)
if parameters is None:
parameters = {'loss': 'log', 'penalty': 'l2', 'n_iter': 50,
'alpha': 0.00001, 'fit_intercept': True}
clf = SGDClassifier(**parameters).fit(matrix, train_target)
obj = {'tfidf_model': tfidf_model,
'clf': clf,
'target_names': target_names}
path = os.path.join(l.TOPDIR, 'clf.pickle')
pickle.dump(obj, open(path, 'w'))
###############################################################################
评论列表
文章目录