def compute_gradients(self, loss, *args, **kwargs):
train_vars = None
if self.trainable_names is not None:
log.info('All trainable vars:\n'+str([var.name for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)]))
train_vars = []
for scope_name in self.trainable_names:
new_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope_name)
if len(new_vars) == 0:
raise ValueError('The scope name, {}, you specified does not contain any trainable variables.'.format(scope_name))
train_vars.extend(new_vars)
log.info('Variables to be trained:\n'+str([var.name for var in train_vars]))
if train_vars is not None:
self.var_list = train_vars
gvs = self._optimizer.compute_gradients(loss,
var_list=train_vars,
*args, **kwargs)
if self.clip:
# gradient clipping. Some gradients returned are 'None' because
# no relation between the variable and loss; so we skip those.
gvs = [(tf.clip_by_value(grad, -1., 1.), var)
for grad, var in gvs if grad is not None]
return gvs
评论列表
文章目录