def __init__(self,
screen_width,
screen_height,
num_channels,
num_actions,
metrics_directory,
batched_forward_pass_size,
hyperparameters=QNetworkHyperparameters()):
self.logger = logging.getLogger(__name__)
self.screen_width = screen_width
self.screen_height = screen_height
self.num_channels = num_channels
self.num_actions = num_actions
self.batched_forward_pass_size = batched_forward_pass_size
self.hyperparameters = hyperparameters
self.tf_graph = tf.Graph()
self.tf_graph_forward_pass_bundle_single = self._build_graph_forward_pass_bundle(self.tf_graph, 1)
self.tf_graph_forward_pass_bundle_batched = self._build_graph_forward_pass_bundle(self.tf_graph, batched_forward_pass_size)
self.tf_graph_train_bundle = self._build_graph_train_bundle(self.tf_graph)
self.tf_session = tf.Session(graph=self.tf_graph)
with self.tf_graph.as_default():
self.tf_all_summaries = tf.merge_all_summaries()
self.tf_summary_writer = tf.train.SummaryWriter(logdir=metrics_directory, graph=self.tf_graph)
self.tf_saver = tf.train.Saver()
tf.initialize_all_variables().run(session=self.tf_session)
self.assigns_train_to_forward_pass_variables = self._build_assigns_train_to_forward_pass_variables()
评论列表
文章目录