def __init__(self, tag, x, summary_fn=tf.summary.scalar, summary_args=(), scope=None):
"""
Initializes an Average.
Arguments
x: Tensor to be averaged over multiple runs.
tag: Tag for the summary.
summary_fn: Function used for creating a summary.
summary_args: Arguments passed to the summary function.
"""
with tf.variable_scope(scope or type(self).__name__):
counter = tf.Variable(name="counter", initial_value=tf.constant(0),
dtype=tf.int32, trainable=False)
running_sum = tf.Variable(name="running_sum", initial_value=tf.constant(0.),
dtype=tf.float32, trainable=False)
self._running_average = running_sum / tf.cast(counter, tf.float32)
self._summary = summary_fn(tag or x.name + '_avg', self._running_average, **summary_args)
self._update_op = tf.group(counter.assign_add(1), running_sum.assign_add(x))
self._reset_op = tf.group(counter.assign(0), running_sum.assign(0.))
评论列表
文章目录