def _backward(self, loss, summaries=False):
hps = self.hps
loss = loss * hps.num_steps
emb_vars = find_trainable_variables("emb")
lstm_vars = find_trainable_variables("LSTM")
softmax_vars = find_trainable_variables("softmax")
all_vars = emb_vars + lstm_vars + softmax_vars
grads = tf.gradients(loss, all_vars)
orig_grads = grads[:]
emb_grads = grads[:len(emb_vars)]
grads = grads[len(emb_vars):]
for i in range(len(emb_grads)):
assert isinstance(emb_grads[i], tf.IndexedSlices)
emb_grads[i] = tf.IndexedSlices(emb_grads[i].values * hps.batch_size, emb_grads[i].indices,
emb_grads[i].dense_shape)
lstm_grads = grads[:len(lstm_vars)]
softmax_grads = grads[len(lstm_vars):]
lstm_grads, lstm_norm = tf.clip_by_global_norm(lstm_grads, hps.max_grad_norm)
clipped_grads = emb_grads + lstm_grads + softmax_grads
assert len(clipped_grads) == len(orig_grads)
if summaries:
tf.scalar_summary("model/lstm_grad_norm", lstm_norm)
tf.scalar_summary("model/lstm_grad_scale", tf.minimum(hps.max_grad_norm / lstm_norm, 1.0))
tf.scalar_summary("model/lstm_weight_norm", tf.global_norm(lstm_vars))
# for v, g, cg in zip(all_vars, orig_grads, clipped_grads):
# name = v.name.lstrip("model/")
# tf.histogram_summary(name + "/var", v)
# tf.histogram_summary(name + "/grad", g)
# tf.histogram_summary(name + "/clipped_grad", cg)
return list(zip(clipped_grads, all_vars))
评论列表
文章目录