archival_test.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def test_archiving(self):
        # copy params, since they'll get consumed during training
        params_copy = copy.deepcopy(self.params.as_dict())

        # `train_model` should create an archive
        model = train_model(self.params, serialization_dir=self.TEST_DIR)

        archive_path = os.path.join(self.TEST_DIR, "model.tar.gz")

        # load from the archive
        archive = load_archive(archive_path)
        model2 = archive.model

        # check that model weights are the same
        keys = set(model.state_dict().keys())
        keys2 = set(model2.state_dict().keys())

        assert keys == keys2

        for key in keys:
            assert torch.equal(model.state_dict()[key], model2.state_dict()[key])

        # check that vocabularies are the same
        vocab = model.vocab
        vocab2 = model2.vocab

        assert vocab._token_to_index == vocab2._token_to_index  # pylint: disable=protected-access
        assert vocab._index_to_token == vocab2._index_to_token  # pylint: disable=protected-access

        # check that params are the same
        params2 = archive.config
        assert params2.as_dict() == params_copy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号