def _prepare(self, grads_and_vars):
""""""
if self._lr is None:
sTy = 0
sTs = 0
yTy = 0
for g_t, x_tm1 in grads_and_vars:
if g_t is None:
continue
with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
if isinstance(g_t, ops.Tensor):
g_tm1 = self.get_slot(x_tm1, 'g')
s_tm1 = self.get_slot(x_tm1, 's')
y_t = (g_t-g_tm1)
sTy += math_ops.reduce_sum(s_tm1*y_t)
sTs += math_ops.reduce_sum(s_tm1**2)
yTy += math_ops.reduce_sum(y_t**2)
else:
idxs, idxs_ = array_ops.unique(g_t.indices)
g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
g_tm1 = self.get_slot(x_tm1, 'g')
g_tm1_ = array_ops.gather(g_tm1, idxs)
s_tm1 = self.get_slot(x_tm1, 's')
s_tm1_ = array_ops.gather(s_tm1, idxs)
y_t_ = (g_t_-g_tm1_)
sTy += math_ops.reduce_sum(s_tm1_*y_t_)
sTs += math_ops.reduce_sum(s_tm1_**2)
yTy += math_ops.reduce_sum(y_t_**2)
sTy = math_ops.abs(sTy)
self._lr = sTs / (sTy + self._eps)
#=============================================================
评论列表
文章目录