def test_archiving(self):
# copy params, since they'll get consumed during training
params_copy = copy.deepcopy(self.params.as_dict())
# `train_model` should create an archive
model = train_model(self.params, serialization_dir=self.TEST_DIR)
archive_path = os.path.join(self.TEST_DIR, "model.tar.gz")
# load from the archive
archive = load_archive(archive_path)
model2 = archive.model
# check that model weights are the same
keys = set(model.state_dict().keys())
keys2 = set(model2.state_dict().keys())
assert keys == keys2
for key in keys:
assert torch.equal(model.state_dict()[key], model2.state_dict()[key])
# check that vocabularies are the same
vocab = model.vocab
vocab2 = model2.vocab
assert vocab._token_to_index == vocab2._token_to_index # pylint: disable=protected-access
assert vocab._index_to_token == vocab2._index_to_token # pylint: disable=protected-access
# check that params are the same
params2 = archive.config
assert params2.as_dict() == params_copy
评论列表
文章目录