embedding.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def _read_pretrained_hdf5_format_embedding_file(embeddings_filename: str, # pylint: disable=invalid-name
                                                embedding_dim: int,
                                                vocab: Vocabulary,
                                                namespace: str = "tokens") -> torch.FloatTensor:
    """
    Reads from a hdf5 formatted file.  The embedding matrix is assumed to
    be keyed by 'embedding' and of size ``(num_tokens, embedding_dim)``.
    """
    with h5py.File(embeddings_filename, 'r') as fin:
        embeddings = fin['embedding'][...]

    if list(embeddings.shape) != [vocab.get_vocab_size(namespace), embedding_dim]:
        raise ConfigurationError(
                "Read shape {0} embeddings from the file, but expected {1}".format(
                        list(embeddings.shape), [vocab.get_vocab_size(namespace), embedding_dim]))

    return torch.FloatTensor(embeddings)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号