def _read_pretrained_hdf5_format_embedding_file(embeddings_filename: str, # pylint: disable=invalid-name
embedding_dim: int,
vocab: Vocabulary,
namespace: str = "tokens") -> torch.FloatTensor:
"""
Reads from a hdf5 formatted file. The embedding matrix is assumed to
be keyed by 'embedding' and of size ``(num_tokens, embedding_dim)``.
"""
with h5py.File(embeddings_filename, 'r') as fin:
embeddings = fin['embedding'][...]
if list(embeddings.shape) != [vocab.get_vocab_size(namespace), embedding_dim]:
raise ConfigurationError(
"Read shape {0} embeddings from the file, but expected {1}".format(
list(embeddings.shape), [vocab.get_vocab_size(namespace), embedding_dim]))
return torch.FloatTensor(embeddings)
评论列表
文章目录