def apply_updates(self, model, grads):
"""
Updates the model parameters based on the given gradients, using momentum
"""
update_ops = []
mom_ops = []
if isinstance(self._learning_rate, list):
lrs = self._learning_rate
print('d')
else:
lrs = [self._learning_rate for p in model.model_params]
with tf.name_scope('CDLearning/updates'):
for param, grad, mv, lr in zip(model.model_params, grads, self._momentum_vector, lrs):
mv = tf.assign(mv, self._momentum * mv + grad * lr)
update_ops.append(tf.assign_sub(param, mv))
return update_ops, mom_ops
评论列表
文章目录