def process_glove(args, vocab_list, save_path, size=4e5, random_init=True):
"""
:param vocab_list: [vocab]
:return:
"""
if not gfile.Exists(save_path + ".npz"):
glove_path = os.path.join(args.glove_dir, "glove.6B.{}d.txt".format(args.glove_dim))
if random_init:
glove = np.random.randn(len(vocab_list), args.glove_dim)
else:
glove = np.zeros((len(vocab_list), args.glove_dim))
found = 0
with open(glove_path, 'r') as fh:
for line in tqdm(fh, total=size):
array = line.lstrip().rstrip().split(" ")
word = array[0]
vector = list(map(float, array[1:]))
if word in vocab_list:
idx = vocab_list.index(word)
glove[idx, :] = vector
found += 1
if word.capitalize() in vocab_list:
idx = vocab_list.index(word.capitalize())
glove[idx, :] = vector
found += 1
if word.upper() in vocab_list:
idx = vocab_list.index(word.upper())
glove[idx, :] = vector
found += 1
print("{}/{} of word vocab have corresponding vectors in {}".format(found, len(vocab_list), glove_path))
np.savez_compressed(save_path, glove=glove)
print("saved trimmed glove matrix at: {}".format(save_path))
评论列表
文章目录