def fit(self, docs, y):
assert len(docs) == len(y)
model = self.model
n_epochs = self.n_epochs
verbose = self.verbose
decay = (self.alpha - self.min_alpha) / n_epochs
X = [TaggedDocument(self.analyzer(doc), [label])
for doc, label in zip(docs, y)]
if verbose > 0:
print("First 3 tagged documents:\n", X[:3])
print("Training doc2vec model")
# d2v = Doc2Vec()
# d2v.build_vocab(X)
# if self.intersect is not None:
# d2v.intersect_word2vec_format(self.intersect)
model.build_vocab(X)
for epoch in range(n_epochs):
if verbose:
print("Doc2Vec: Epoch {} of {}.".format(epoch + 1, n_epochs))
model.train(X)
model.alpha -= decay # apply global decay
model.min_alpha = model.alpha # but no decay inside one epoch
if verbose > 0:
print("Finished.")
print("model:", self.model)
if self._matching:
self._matching.fit(docs)
else:
# if we dont do matching, its enough to fit a nearest neighbors on
# all centroids before query time
dvs = np.asarray([model.docvecs[tag] for tag in y])
self._neighbors.fit(dvs)
self._y = y
return self
评论列表
文章目录