def get_embeddings_index(embedding_type='glove.42B.300d'):
"""Retrieves embeddings index from embedding name. Will automatically download and cache as needed.
Args:
embedding_type: The embedding type to load.
Returns:
The embeddings indexed by word.
"""
embeddings_index = _EMBEDDINGS_CACHE.get(embedding_type)
if embeddings_index is not None:
return embeddings_index
data_obj = _EMBEDDING_TYPES.get(embedding_type)
if data_obj is None:
raise ValueError("Embedding name should be one of '{}'".format(_EMBEDDING_TYPES.keys()))
cache_dir = os.path.expanduser(os.path.join('~', '.keras-text'))
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
file_path = get_file(embedding_type, origin=data_obj['url'], extract=True,
cache_dir=cache_dir, cache_subdir='embeddings')
file_path = os.path.join(os.path.dirname(file_path), data_obj['file'])
embeddings_index = _build_embeddings_index(file_path)
_EMBEDDINGS_CACHE[embedding_type] = embeddings_index
return embeddings_index
评论列表
文章目录