def _finish(self, caches):
""""""
if self.clip > 0:
S_t = [cache['s_t'] for cache in caches]
S_t, _ = tf.clip_by_global_norm(S_t, self.clip)
for cache, s_t in zip(caches, S_t):
cache['s_t'] = s_t
for cache in caches:
x_tm1 = cache['x_tm1']
s_t = cache['s_t']
updates = cache['updates']
with tf.name_scope('update_' + x_tm1.op.name), tf.device(x_tm1.device):
if 'idxs' in cache:
idxs = cache['idxs']
x_t = tf.scatter_sub(x_tm1, idxs, s_t)
if self.chi > 0:
x_t_ = tf.gather(x_t, idxs)
x_bar_t, t_x_bar = self._sparse_moving_average(x_tm1, idxs, x_t_, 'x', beta=self.chi)
else:
x_t = tf.assign_sub(x_tm1, s_t)
if self.chi > 0:
x_bar_t, t_x_bar = self._dense_moving_average(x_tm1, x_t, 'x', beta=self.chi)
updates.append(x_t)
if self.chi > 0:
updates.extend([x_bar_t, t_x_bar])
update_ops = [tf.group(*cache['updates']) for cache in caches]
return tf.group(*update_ops, name='update')
#==============================================================
评论列表
文章目录