def apply_gradients(self, grads):
coldOptim = tf.train.MomentumOptimizer(
self._cold_lr * (1. - self._momentum), self._momentum)
def coldSGDstart():
sgd_step_op = tf.assign_add(self.sgd_step, 1)
coldOptim_op = coldOptim.apply_gradients(grads)
if KFAC_DEBUG:
with tf.control_dependencies([sgd_step_op, coldOptim_op]):
sgd_step_op = tf.Print(
sgd_step_op, [self.sgd_step, tf.convert_to_tensor('doing cold sgd step')])
return tf.group(*[sgd_step_op, coldOptim_op])
kfacOptim_op, qr = self.apply_gradients_kfac(grads)
def warmKFACstart():
return kfacOptim_op
return tf.cond(tf.greater(self.sgd_step, self._cold_iter), warmKFACstart, coldSGDstart), qr
评论列表
文章目录