def process_glove(glove_dir, glove_dim, vocab_dir, save_path, random_init=True):
"""
:param vocab_list: [vocab]
:return:
"""
save_path = save_path + '.{}'.format(glove_dim)
if not os.path.isfile(save_path + ".npz"):
# read vocabulary
with open(vocab_dir + '/vocabulary.pickle', 'rb') as f:
vocab_map = cPickle.load(f)
f.close()
vocab_list = list(zip(*vocab_map)[0])
glove_path = os.path.join(glove_dir, "glove.6B.{}d.txt".format(glove_dim))
if random_init:
glove = np.random.uniform(-0.25, 0.25, (len(vocab_list), glove_dim))
else:
glove = np.zeros((len(vocab_list), glove_dim))
found = 0
with open(glove_path, 'r') as fh:
for line in fh.readlines():
array = line.lstrip().rstrip().split(" ")
word = array[0]
vector = list(map(float, array[1:]))
if word in vocab_list:
idx = vocab_list.index(word)
glove[idx, :] = vector
found += 1
if word.capitalize() in vocab_list:
idx = vocab_list.index(word.capitalize())
glove[idx, :] = vector
found += 1
if word.upper() in vocab_list:
idx = vocab_list.index(word.upper())
glove[idx, :] = vector
found += 1
print("{}/{} of word vocab have corresponding vectors in {}".format(found, len(vocab_list), glove_path))
np.savez_compressed(save_path, glove=glove)
print("saved trimmed glove matrix at: {}".format(save_path))
评论列表
文章目录