def compute_exponential_averages(variables, decay):
"""Given a list of tensorflow scalar variables
create ops corresponding to their exponential
averages
Parameters
----------
variables: [tf.Tensor]
List of scalar tensors.
Returns
-------
averages: [tf.Tensor]
List of scalar tensors corresponding to averages
of al the `variables` (in order)
apply_op: tf.runnable
Op to be run to update the averages with current value
of variables.
"""
averager = tf.train.ExponentialMovingAverage(decay=decay)
apply_op = averager.apply(variables)
return [averager.average(v) for v in variables], apply_op
评论列表
文章目录