q_network.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:agent-trainer 作者: lopespm 项目源码 文件源码
def __init__(self,
                 screen_width,
                 screen_height,
                 num_channels,
                 num_actions,
                 metrics_directory,
                 batched_forward_pass_size,
                 hyperparameters=QNetworkHyperparameters()):
        self.logger = logging.getLogger(__name__)
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.num_channels = num_channels
        self.num_actions = num_actions
        self.batched_forward_pass_size = batched_forward_pass_size
        self.hyperparameters = hyperparameters

        self.tf_graph = tf.Graph()
        self.tf_graph_forward_pass_bundle_single = self._build_graph_forward_pass_bundle(self.tf_graph, 1)
        self.tf_graph_forward_pass_bundle_batched = self._build_graph_forward_pass_bundle(self.tf_graph, batched_forward_pass_size)
        self.tf_graph_train_bundle = self._build_graph_train_bundle(self.tf_graph)

        self.tf_session = tf.Session(graph=self.tf_graph)

        with self.tf_graph.as_default():
            self.tf_all_summaries = tf.merge_all_summaries()
            self.tf_summary_writer = tf.train.SummaryWriter(logdir=metrics_directory, graph=self.tf_graph)
            self.tf_saver = tf.train.Saver()
            tf.initialize_all_variables().run(session=self.tf_session)

        self.assigns_train_to_forward_pass_variables = self._build_assigns_train_to_forward_pass_variables()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号