def prune(self, question, paragraphs: List[ExtractedParagraph]):
if not self.filter_dist_one and len(paragraphs) == 1:
return paragraphs
tfidf = TfidfVectorizer(strip_accents="unicode", stop_words=self.stop.words)
text = []
for para in paragraphs:
text.append(" ".join(" ".join(s) for s in para.text))
try:
para_features = tfidf.fit_transform(text)
q_features = tfidf.transform([" ".join(question)])
except ValueError:
return []
dists = pairwise_distances(q_features, para_features, "cosine").ravel()
sorted_ix = np.lexsort(([x.start for x in paragraphs], dists)) # in case of ties, use the earlier paragraph
if self.filter_dist_one:
return [paragraphs[i] for i in sorted_ix[:self.n_to_select] if dists[i] < 1.0]
else:
return [paragraphs[i] for i in sorted_ix[:self.n_to_select]]
评论列表
文章目录