def cbow_train(self):
print("CBOW Training......")
self.cbow_model.save_embedding(self.data.id2word, 'cbow_begin_embedding.txt')
pos_all_pairs = self.data.get_cbow_batch_all_pairs(self.batch_size, self.context_size)
pair_count = len(pos_all_pairs)
process_bar = tqdm(range(int(pair_count / self.batch_size)))
for _ in process_bar:
pos_pairs = self.data.get_cbow_batch_pairs(self.batch_size, self.window_size)
if self.using_hs:
pos_pairs, neg_pairs = self.data.get_cbow_pairs_by_huffman(pos_pairs)
else:
pos_pairs, neg_pairs = self.data.get_cbow_pairs_by_neg_sampling(pos_pairs, self.context_size)
pos_u = [pair[0] for pair in pos_pairs]
pos_v = [int(pair[1]) for pair in pos_pairs]
neg_u = [pair[0] for pair in neg_pairs]
neg_v = [int(pair[1]) for pair in neg_pairs]
self.optimizer.zero_grad()
loss = self.cbow_model.forward(pos_u, pos_v, neg_u, neg_v)
loss.backward()
self.optimizer.step()
print("CBOW Trained and Saving File......")
self.cbow_model.save_embedding(self.data.id2word, self.output_file_name)
print("CBOW Trained and Saved File.")
评论列表
文章目录