def __init__(self, sess, env_name, model_dir, variables, max_update_per_step, max_to_keep=20):
self.sess = sess
self.env_name = env_name
self.max_update_per_step = max_update_per_step
self.reset()
self.max_avg_r = None
with tf.variable_scope('t'):
self.t_op = tf.Variable(0, trainable=False, name='t')
self.t_add_op = self.t_op.assign_add(1)
self.model_dir = model_dir
self.saver = tf.train.Saver(variables + [self.t_op], max_to_keep=max_to_keep)
self.writer = tf.train.SummaryWriter('./logs/%s' % self.model_dir, self.sess.graph)
with tf.variable_scope('summary'):
scalar_summary_tags = ['total r', 'avg r', 'avg q', 'avg v', 'avg a', 'avg l']
self.summary_placeholders = {}
self.summary_ops = {}
for tag in scalar_summary_tags:
self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_'))
self.summary_ops[tag] = tf.scalar_summary('%s/%s' % (self.env_name, tag), self.summary_placeholders[tag])
评论列表
文章目录