def test_restore_and_train(self):
"""Test model training after restoring."""
# save a model that is not trained
trainable_model = TrainableModel(dataset=None, log_dir=self.tmpdir, **_IO, optimizer=_OPTIMIZER)
trainable_model.save('')
# restored the model
restored_model = BaseModel(dataset=None, log_dir='', restore_from=self.tmpdir, **_IO)
# test whether it can be trained
batch = {'input': [[1] * 10], 'target': [[0] * 10]}
for _ in range(1000):
restored_model.run(batch, train=True)
after_value = restored_model.graph.get_tensor_by_name('var:0').eval(session=restored_model.session)
self.assertTrue(np.allclose([0]*10, after_value))
评论列表
文章目录