def test_read_hdf5_format_file(self):
vocab = Vocabulary()
vocab.add_token_to_namespace("word")
vocab.add_token_to_namespace("word2")
embeddings_filename = self.TEST_DIR + "embeddings.hdf5"
embeddings = numpy.random.rand(vocab.get_vocab_size(), 5)
with h5py.File(embeddings_filename, 'w') as fout:
_ = fout.create_dataset(
'embedding', embeddings.shape, dtype='float32', data=embeddings
)
params = Params({
'pretrained_file': embeddings_filename,
'embedding_dim': 5,
})
embedding_layer = Embedding.from_params(vocab, params)
assert numpy.allclose(embedding_layer.weight.data.numpy(), embeddings)
评论列表
文章目录