def __init__(self, checkpoint_file):
checkpoint_dir = os.path.dirname(checkpoint_file)
hparams_file = os.path.join(checkpoint_dir, "hparams.txt")
hparams_dict = {}
if os.path.isfile(hparams_file):
with open(hparams_file) as f:
hparams_dict = ast.literal_eval(f.read())
self.hparams = TensorflowClassifierHparams(**hparams_dict)
self.graph = tf.Graph()
with self.graph.as_default():
print("loading from file {}".format(checkpoint_file))
config = tf.ConfigProto(
device_count={'GPU': 0}, )
config.gpu_options.visible_device_list = ""
self.session = tf.Session(config=config)
new_saver = tf.train.import_meta_graph(checkpoint_file + ".meta", clear_devices=True)
new_saver.restore(self.session, checkpoint_file)
self.features = {}
if self.hparams.use_image:
self.features["image"] = self.graph.get_tensor_by_name("image:0")
if self.hparams.use_observation:
self.features["observation"] = self.graph.get_tensor_by_name("observation:0")
if self.hparams.use_action:
self.features["action"] = self.graph.get_tensor_by_name("action:0")
self.prediction = tf.get_collection('prediction')[0]
self.loss = tf.get_collection('loss')[0]
self.threshold = tf.get_collection('threshold')[0]
评论列表
文章目录