def __init__(self, alpha=0.9, graph_path='', checkpoint_path='', metagraph_path=''):
if graph_path:
assert os.path.isfile(graph_path)
else:
assert os.path.isfile(checkpoint_path) and os.path.isfile(metagraph_path)
self.graph = tf.Graph()
with self.graph.as_default():
if graph_path:
# load a graph with weights frozen as constants
graph_def = tf.GraphDef()
with open(graph_path, "rb") as f:
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name="")
self.session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
else:
# load a meta-graph and initialize variables form checkpoint
saver = tf.train.import_meta_graph(metagraph_path)
self.session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver.restore(self.session, checkpoint_path)
self.model_input = self.session.graph.get_tensor_by_name("input_placeholder:0")
self.model_output = self.session.graph.get_tensor_by_name("output_steer:0")
self.last_steering_angle = 0 # None
self.alpha = alpha
评论列表
文章目录