def get_train_op(self):
"""
define optimization operation
"""
if self.args.optimizer == "SGD":
optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.args.lr)
elif self.args.optimizer == "ADAM":
optimizer = tf.train.AdamOptimizer(learning_rate=self.args.lr)
else:
raise NotImplementedError("Other Optimizer Not Implemented.-_-||")
# gradient clip
grad_vars = optimizer.compute_gradients(self.loss)
grad_vars = [
(tf.clip_by_norm(grad, self.args.grad_clipping), var)
if grad is not None else (grad, var)
for grad, var in grad_vars]
self.train_op = optimizer.apply_gradients(grad_vars, self.step)
return
评论列表
文章目录