classifier_tf.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def save_checkpoint(self, checkpoint_name):
        tf.get_collection_ref("threshold")[:] = [float(self.threshold)]
        tf.get_collection_ref("features")[:] = self.features.values()
        tf.get_collection_ref("loss")[:] = [self.loss]
        tf.get_collection_ref("prediction")[:] = [self.prediction]

        os.makedirs(os.path.dirname(checkpoint_name), exist_ok=True)
        saver = tf.train.Saver()
        saver.save(tf.get_default_session(), checkpoint_name)

        with open(os.path.join(os.path.dirname(checkpoint_name), "hparams.txt"), "w") as f:
            f.write(repr(self.hparams.__dict__))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号