model_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:cxflow-tensorflow 作者: Cognexa 项目源码 文件源码
def test_restore_1(self):
        """Test restore from directory with one valid checkpoint."""

        # test model saving
        trainable_model = TrainableModel(dataset=None, log_dir=self.tmpdir, **_IO, optimizer=_OPTIMIZER)
        batch = {'input': [[1] * 10], 'target': [[0] * 10]}
        for _ in range(1000):
            trainable_model.run(batch, train=True)
        saved_var_value = trainable_model.var.eval(session=trainable_model.session)
        trainable_model.save('1')

        # test restoring
        restored_model = BaseModel(dataset=None, log_dir='', restore_from=self.tmpdir, **_IO, optimizer=_OPTIMIZER)

        var = restored_model.graph.get_tensor_by_name('var:0')
        var_value = var.eval(session=restored_model.session)
        self.assertTrue(np.allclose(saved_var_value, var_value))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号