def fit(self, tokens):
# get most frequent items for plotting:
tokens = [t.lower() for t in tokens]
self.mfi = [t for t,_ in Counter(tokens).most_common(self.nb_mfi)]
self.sentence_iterator = SentenceIterator(tokens=tokens)
# train embeddings:
self.w2v_model = Word2Vec(self.sentence_iterator,
window=self.window,
min_count=self.minimum_count,
size=self.size,
workers=self.nb_workers,
negative=self.nb_negative)
self.plot_mfi()
self.most_similar()
# build an index of the train tokens
# which occur at least min_count times:
self.token_idx = {'<UNK>': 0}
for k, v in Counter(tokens).items():
if v >= self.minimum_count:
self.token_idx[k] = len(self.token_idx)
# create an ordered vocab:
self.train_token_vocab = [k for k, v in sorted(self.token_idx.items(),\
key=itemgetter(1))]
self.pretrained_embeddings = self.get_weights(self.train_token_vocab)
return self
评论列表
文章目录