rc_base.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:RC-experiments 作者: cairoHy 项目源码 文件源码
def get_train_op(self):
        """
        define optimization operation 
        """
        if self.args.optimizer == "SGD":
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.args.lr)
        elif self.args.optimizer == "ADAM":
            optimizer = tf.train.AdamOptimizer(learning_rate=self.args.lr)
        else:
            raise NotImplementedError("Other Optimizer Not Implemented.-_-||")

        # gradient clip
        grad_vars = optimizer.compute_gradients(self.loss)
        grad_vars = [
            (tf.clip_by_norm(grad, self.args.grad_clipping), var)
            if grad is not None else (grad, var)
            for grad, var in grad_vars]
        self.train_op = optimizer.apply_gradients(grad_vars, self.step)
        return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号