def save(self):
global_step = self.sess.run(tf.train.get_global_step(self.graph))
if self.config['last_checkpoint'] == global_step:
if self.config['debug']:
print('Model has already been saved during the current global step.')
return
print('Saving to %s with global_step %d.' % (self.config['results_dir'], global_step))
self.saver.save(self.sess, os.path.join(self.config['results_dir'], 'checkpoint'), global_step)
self.config['last_checkpoint'] = global_step
# Also save the configuration
json_file = os.path.join(self.config['results_dir'], 'config.json')
with open(json_file, 'w') as f:
json.dump(self.config, f, cls=utilities.NumPyCompatibleJSONEncoder)
评论列表
文章目录