embeddings.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ikelos 作者: braingineer 项目源码 文件源码
def from_vocab(igor, vocab):    
    print("using vocab and glove file to generate embedding matrix")
    remaining_vocab = set(vocab.keys())
    embeddings = np.zeros((len(vocab), igor.embedding_size))
    print("{} words to convert".format(len(remaining_vocab)))


    if igor.save_dir[-1] != "/":
        igor.save_dir += "/"
    if not path.exists(igor.save_dir):
        makedirs(igor.save_dir)

    if igor.from_url:
        assert hasattr(glove_urls, igor.target_glove), "You need to specify one of the glove variables"
        url = urlopen(getattr(glove_urls, igor.target_glove))
        fileiter = ZipFile(StringIO(url.read())).open(file).readlines()
    else:
        assert os.path.exists(igor.target_glove), "You need to specify a real file"
        fileiter = open(igor.target_glove).readlines()

    count=0
    for line in tqdm(fileiter):
        line = line.replace("\n","").split(" ")
        try:
            word, nums = line[0], [float(x.strip()) for x in line[1:]]
            if word in remaining_vocab:
                embeddings[vocab[word]]  = np.array(nums)
                remaining_vocab.remove(word)
        except Exception as e:
            print("{} broke. exception: {}. line: {}.".format(word, e, x))
        count+=1


    print("{} words were not in glove; saving to oov.txt".format(len(remaining_vocab)))
    with open(path.join(igor.save_dir, "oov.txt"), "w") as fp:
        fp.write("\n".join(remaining_vocab))

    for word in tqdm(remaining_vocab):
        embeddings[vocab[word]] = np.asarray(glorot_uniform((igor.embedding_size,)).eval())



    vocab.save('embedding.vocab')
    with open(path.join(igor.save_dir, "embedding.npy"), "wb") as fp:
        np.save(fp, embeddings)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号