def update_hyper_param(self):
assign_hyper_ops = []
self._mu = tf.identity(tf.cond(
self._do_tune, lambda: self.get_mu_tensor(),
lambda: self._mu_var))
with tf.control_dependencies([self._mu]):
self._lr = tf.identity(tf.cond(
self._do_tune, lambda: self.get_lr_tensor(),
lambda: self._lr_var))
with tf.control_dependencies([self._mu, self._lr]):
if self._use_unsmoothed_lr_mu:
assign_hyper_ops.append(tf.assign(self._mu_var, self._mu) )
assign_hyper_ops.append(tf.assign(self._lr_var, self._lr) )
else:
self._mu = self._beta * self._mu_var + (1 - self._beta) * self._mu
self._lr = self._beta * self._lr_var + (1 - self._beta) * self._lr
with tf.control_dependencies([self._mu, self._lr] ):
assign_hyper_ops.append(tf.assign(self._mu_var, self._mu) )
assign_hyper_ops.append(tf.assign(self._lr_var, self._lr) )
assign_hyper_op = tf.group(*assign_hyper_ops)
return assign_hyper_op
评论列表
文章目录