pytorch_word2vec.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pytorch_word2vec 作者: bamtercelboo 项目源码 文件源码
def cbow_train(self):
        print("CBOW Training......")
        self.cbow_model.save_embedding(self.data.id2word, 'cbow_begin_embedding.txt')
        pos_all_pairs = self.data.get_cbow_batch_all_pairs(self.batch_size, self.context_size)
        pair_count = len(pos_all_pairs)
        process_bar = tqdm(range(int(pair_count / self.batch_size)))
        for _ in process_bar:
            pos_pairs = self.data.get_cbow_batch_pairs(self.batch_size, self.window_size)
            if self.using_hs:
                pos_pairs, neg_pairs = self.data.get_cbow_pairs_by_huffman(pos_pairs)
            else:
                pos_pairs, neg_pairs = self.data.get_cbow_pairs_by_neg_sampling(pos_pairs, self.context_size)

            pos_u = [pair[0] for pair in pos_pairs]
            pos_v = [int(pair[1]) for pair in pos_pairs]
            neg_u = [pair[0] for pair in neg_pairs]
            neg_v = [int(pair[1]) for pair in neg_pairs]

            self.optimizer.zero_grad()
            loss = self.cbow_model.forward(pos_u, pos_v, neg_u, neg_v)
            loss.backward()
            self.optimizer.step()
        print("CBOW Trained and Saving File......")
        self.cbow_model.save_embedding(self.data.id2word, self.output_file_name)
        print("CBOW Trained and Saved File.")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号