def save_checkpoint(self, checkpoint_name):
tf.get_collection_ref("threshold")[:] = [float(self.threshold)]
tf.get_collection_ref("features")[:] = self.features.values()
tf.get_collection_ref("loss")[:] = [self.loss]
tf.get_collection_ref("prediction")[:] = [self.prediction]
os.makedirs(os.path.dirname(checkpoint_name), exist_ok=True)
saver = tf.train.Saver()
saver.save(tf.get_default_session(), checkpoint_name)
with open(os.path.join(os.path.dirname(checkpoint_name), "hparams.txt"), "w") as f:
f.write(repr(self.hparams.__dict__))
评论列表
文章目录