def test_read_hdf5_raises_on_invalid_shape(self):
vocab = Vocabulary()
vocab.add_token_to_namespace("word")
embeddings_filename = self.TEST_DIR + "embeddings.hdf5"
embeddings = numpy.random.rand(vocab.get_vocab_size(), 10)
with h5py.File(embeddings_filename, 'w') as fout:
_ = fout.create_dataset(
'embedding', embeddings.shape, dtype='float32', data=embeddings
)
params = Params({
'pretrained_file': embeddings_filename,
'embedding_dim': 5,
})
with pytest.raises(ConfigurationError):
_ = Embedding.from_params(vocab, params)
评论列表
文章目录