def _apply_sparse(self, cache):
""""""
x_tm1, g_t, idxs = cache['x_tm1'], cache['g_t'], cache['idxs']
idxs, idxs_ = tf.unique(idxs)
g_t_ = tf.unsorted_segment_sum(g_t, idxs_, tf.size(idxs))
updates = cache['updates']
if self.mu > 0:
m_t, t_m = self._sparse_moving_average(x_tm1, idxs, g_t_, 'm', beta=self.mu)
m_t_ = tf.gather(m_t, idxs)
m_bar_t_ = (1-self.gamma) * m_t_ + self.gamma * g_t_
updates.extend([m_t, t_m])
else:
m_bar_t_ = g_t_
if self.nu > 0:
v_t, t_v = self._sparse_moving_average(x_tm1, idxs, g_t_**2, 'v', beta=self.nu)
v_t_ = tf.gather(v_t, idxs)
v_bar_t_ = tf.sqrt(v_t_ + self.epsilon)
updates.extend([v_t, t_v])
else:
v_bar_t_ = 1
s_t_ = self.learning_rate * m_bar_t_ / v_bar_t_
cache['s_t'] = s_t_
cache['g_t'] = g_t_
cache['idxs'] = idxs
return cache
评论列表
文章目录