def get_update_op(self, loss, opts, global_step=None, max_gradient_norm=None, freeze_variables=None):
if loss is None:
return None
freeze_variables = freeze_variables or []
# compute gradient only for variables that are not frozen
frozen_parameters = [var.name for var in tf.trainable_variables()
if any(re.match(var_, var.name) for var_ in freeze_variables)]
params = [var for var in tf.trainable_variables() if var.name not in frozen_parameters]
self.params = params
gradients = tf.gradients(loss, params)
if max_gradient_norm:
gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm)
update_ops = []
for opt in opts:
with tf.variable_scope('gradients' if self.name is None else 'gradients_{}'.format(self.name)):
update_op = opt.apply_gradients(list(zip(gradients, params)), global_step=global_step)
update_ops.append(update_op)
return update_ops
评论列表
文章目录