def load_embedding(data, embedding_file, binary=True, prefix=None, file_name='embedding.pkl'):
"""
:param data:
:param embedding_file:
:param binary:
:param prefix: if prefix is None, then write to file_name, else load from prefix
:param file_name:
:return:
"""
if prefix == None:
vocab = sorted(reduce(lambda x, y: x | y, (set(sentence) for sentence, _ in data)))
word_idx = dict((c, i) for i, c in enumerate(vocab))
vocab_size = len(word_idx) + 1 # +1 for nil word
# "/home/junfeng/word2vec/GoogleNews-vectors-negative300.bin"
model = word2vec.Word2Vec.load_word2vec_format(embedding_file, binary=binary)
embedding = []
for c in word_idx:
if c in model:
embedding.append(model[c])
else:
embedding.append(np.random.uniform(0.1, 0.1, 300))
embedding = np.array(embedding, dtype=np.float32)
with open(file_name, 'wb') as f:
pickle.dump(embedding, f)
pickle.dump(vocab_size, f)
pickle.dump(word_idx, f)
else:
with open(prefix, 'rb') as f:
embedding = pickle.load(f)
vocab_size = pickle.load(f)
word_idx = pickle.load(f)
return vocab_size, word_idx, embedding
评论列表
文章目录