statistic.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:NAF-tensorflow 作者: carpedm20 项目源码 文件源码
def __init__(self, sess, env_name, model_dir, variables, max_update_per_step, max_to_keep=20):
    self.sess = sess
    self.env_name = env_name
    self.max_update_per_step = max_update_per_step

    self.reset()
    self.max_avg_r = None

    with tf.variable_scope('t'):
      self.t_op = tf.Variable(0, trainable=False, name='t')
      self.t_add_op = self.t_op.assign_add(1)

    self.model_dir = model_dir
    self.saver = tf.train.Saver(variables + [self.t_op], max_to_keep=max_to_keep)
    self.writer = tf.train.SummaryWriter('./logs/%s' % self.model_dir, self.sess.graph)

    with tf.variable_scope('summary'):
      scalar_summary_tags = ['total r', 'avg r', 'avg q', 'avg v', 'avg a', 'avg l']

      self.summary_placeholders = {}
      self.summary_ops = {}

      for tag in scalar_summary_tags:
        self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_'))
        self.summary_ops[tag]  = tf.scalar_summary('%s/%s' % (self.env_name, tag), self.summary_placeholders[tag])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号