def test_restore_1(self):
"""Test restore from directory with one valid checkpoint."""
# test model saving
trainable_model = TrainableModel(dataset=None, log_dir=self.tmpdir, **_IO, optimizer=_OPTIMIZER)
batch = {'input': [[1] * 10], 'target': [[0] * 10]}
for _ in range(1000):
trainable_model.run(batch, train=True)
saved_var_value = trainable_model.var.eval(session=trainable_model.session)
trainable_model.save('1')
# test restoring
restored_model = BaseModel(dataset=None, log_dir='', restore_from=self.tmpdir, **_IO, optimizer=_OPTIMIZER)
var = restored_model.graph.get_tensor_by_name('var:0')
var_value = var.eval(session=restored_model.session)
self.assertTrue(np.allclose(saved_var_value, var_value))
评论列表
文章目录