def _build_update_ops(self, mean, variance, is_training):
"""Builds the moving average update ops when using moving variance.
Args:
mean: The mean value to update with.
variance: The variance value to update with.
is_training: Boolean Tensor to indicate if we're currently in
training mode.
Returns:
Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or
could be `True`. Returns `None` when `is_training=False`.
"""
def build_update_ops():
"""Builds the exponential moving average update ops."""
update_mean_op = moving_averages.assign_moving_average(
variable=self._moving_mean,
value=mean,
decay=self._decay_rate,
zero_debias=False,
name="update_moving_mean").op
update_variance_op = moving_averages.assign_moving_average(
variable=self._moving_variance,
value=variance,
decay=self._decay_rate,
zero_debias=False,
name="update_moving_variance").op
return update_mean_op, update_variance_op
def build_no_ops():
return (tf.no_op(), tf.no_op())
# Only make the ops if we know that `is_training=True`, or the value of
# `is_training` is unknown.
is_training_const = utils.constant_value(is_training)
if is_training_const is None or is_training_const:
update_mean_op, update_variance_op = utils.smart_cond(
is_training,
build_update_ops,
build_no_ops,
)
return (update_mean_op, update_variance_op)
else:
return None
评论列表
文章目录