def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.):
updateOps = []
# obtain the stats var list
for stats_var in statsUpdates:
stats_new = statsUpdates[stats_var]
if accumulate:
# simple superbatch averaging
update_op = tf.assign_add(
stats_var, accumulateCoeff * stats_new, use_locking=True)
else:
# exponential running averaging
update_op = tf.assign(
stats_var, stats_var * self._stats_decay, use_locking=True)
update_op = tf.assign_add(
update_op, (1. - self._stats_decay) * stats_new, use_locking=True)
updateOps.append(update_op)
with tf.control_dependencies(updateOps):
stats_step_op = tf.assign_add(self.stats_step, 1)
if KFAC_DEBUG:
stats_step_op = (tf.Print(stats_step_op,
[tf.convert_to_tensor('step:'),
self.global_step,
tf.convert_to_tensor('fac step:'),
self.factor_step,
tf.convert_to_tensor('sgd step:'),
self.sgd_step,
tf.convert_to_tensor('Accum:'),
tf.convert_to_tensor(accumulate),
tf.convert_to_tensor('Accum coeff:'),
tf.convert_to_tensor(accumulateCoeff),
tf.convert_to_tensor('stat step:'),
self.stats_step, updateOps[0], updateOps[1]]))
return [stats_step_op, ]
评论列表
文章目录