pretraining.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pandora 作者: mikekestemont 项目源码 文件源码
def fit(self, tokens):
        # get most frequent items for plotting:
        tokens = [t.lower() for t in tokens]
        self.mfi = [t for t,_ in Counter(tokens).most_common(self.nb_mfi)]
        self.sentence_iterator = SentenceIterator(tokens=tokens)
        # train embeddings:
        self.w2v_model = Word2Vec(self.sentence_iterator,
                             window=self.window,
                             min_count=self.minimum_count,
                             size=self.size,
                             workers=self.nb_workers,
                             negative=self.nb_negative)
        self.plot_mfi()
        self.most_similar()

        # build an index of the train tokens
        # which occur at least min_count times:
        self.token_idx = {'<UNK>': 0}
        for k, v in Counter(tokens).items():
            if v >= self.minimum_count:
                self.token_idx[k] = len(self.token_idx)

        # create an ordered vocab:
        self.train_token_vocab = [k for k, v in sorted(self.token_idx.items(),\
                        key=itemgetter(1))]
        self.pretrained_embeddings = self.get_weights(self.train_token_vocab)

        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号