embeddings.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:keras-text 作者: raghakot 项目源码 文件源码
def get_embeddings_index(embedding_type='glove.42B.300d'):
    """Retrieves embeddings index from embedding name. Will automatically download and cache as needed.

    Args:
        embedding_type: The embedding type to load.

    Returns:
        The embeddings indexed by word.
    """

    embeddings_index = _EMBEDDINGS_CACHE.get(embedding_type)
    if embeddings_index is not None:
        return embeddings_index

    data_obj = _EMBEDDING_TYPES.get(embedding_type)
    if data_obj is None:
        raise ValueError("Embedding name should be one of '{}'".format(_EMBEDDING_TYPES.keys()))

    cache_dir = os.path.expanduser(os.path.join('~', '.keras-text'))
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)

    file_path = get_file(embedding_type, origin=data_obj['url'], extract=True,
                         cache_dir=cache_dir, cache_subdir='embeddings')
    file_path = os.path.join(os.path.dirname(file_path), data_obj['file'])

    embeddings_index = _build_embeddings_index(file_path)
    _EMBEDDINGS_CACHE[embedding_type] = embeddings_index
    return embeddings_index
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号