yahoo.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:liveqa2017 作者: codekansas 项目源码 文件源码
def get_word_embeddings(num_dimensions=500,
                        cache_loc=EMBEDDINGS_FILE):
    """Generates word embeddings.

    Args:
        num_dimensions: int, number of embedding dimensions.
        cache_loc: str, where to cache the word embeddings.

    Returns:
        numpy array representing the embeddings, with shape (NUM_TOKENS,
            num_dimensions).
    """

    if os.path.exists(cache_loc):
        embeddings = np.load(cache_loc)
    else:
        class SentenceGenerator(object):
            def __iter__(self):
                iterable = itertools.islice(iterate_qa_pairs(), 1000000)
                for i, (question, answer) in enumerate(iterable, 1):
                    q, a, _, _ = tokenize(question=question, answer=answer,
                                          use_pad=False, include_rev=False)
                    yield [str(w) for w in q]
                    yield [str(w) for w in a]

                    del q, a, w

                    if i % 1000 == 0:
                        sys.stderr.write('\rprocessed %d' % i)
                        sys.stderr.flush()

                sys.stderr.write('\rprocessed %d\n' % i)
                sys.stderr.flush()

        # The default embeddings.
        embeddings = np.random.normal(size=(NUM_TOKENS, num_dimensions))

        sentences = SentenceGenerator()
        model = models.Word2Vec(sentences, size=num_dimensions)

        word_vectors = model.wv
        del model

        # Puts the Word2Vec weights into the right order.
        weights = word_vectors.syn0
        vocab = word_vectors.vocab
        for k, v in vocab.items():
            embeddings[int(k)] = weights[v.index]

        with open(cache_loc, 'wb') as f:
            np.save(f, embeddings)
            pass

    assert embeddings.shape == (NUM_TOKENS, num_dimensions)
    return embeddings
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号