def curvature_range(self):
# set up the curvature window
self._curv_win = \
tf.Variable(np.zeros( [self._curv_win_width, ] ), dtype=tf.float32, name="curv_win", trainable=False)
self._curv_win = tf.scatter_update(self._curv_win,
self._global_step % self._curv_win_width, self._grad_norm_squared)
# note here the iterations start from iteration 0
valid_window = tf.slice(self._curv_win, tf.constant( [0, ] ),
tf.expand_dims(tf.minimum(tf.constant(self._curv_win_width), self._global_step + 1), dim=0) )
self._h_min_t = tf.reduce_min(valid_window)
self._h_max_t = tf.reduce_max(valid_window)
curv_range_ops = []
with tf.control_dependencies([self._h_min_t, self._h_max_t] ):
avg_op = self._moving_averager.apply([self._h_min_t, self._h_max_t] )
with tf.control_dependencies([avg_op] ):
self._h_min = tf.identity(self._moving_averager.average(self._h_min_t) )
self._h_max = tf.identity(self._moving_averager.average(self._h_max_t) )
curv_range_ops.append(avg_op)
return curv_range_ops
评论列表
文章目录