model_base.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:nmt_v2 作者: rpryzant 项目源码 文件源码
def _make_training_op(self):
        if self.config.optimizer == 'sgd':
            self.learning_rate = tf.cond(
                self.global_step < self.config.start_decay_step,
                lambda: tf.constant(self.config.learning_rate),
                lambda: tf.train.exponential_decay(
                    self.config.learning_rate,
                    (self.global_step - self.config.start_decay_step),
                    self.config.decay_steps,
                    self.config.decay_factor,
                    staircase=True),
                name='learning_rate')
            optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        elif self.config.optimizer == 'adam':
            assert self.config.learning_rate < 0.007
            self.learning_rate = tf.constant(self.config.learning_rate)
            optimizer = tf.train.AdamOptimizer(self.learning_rate)

        params = tf.trainable_variables()
        gradients = tf.gradients(self.loss, params)
        clipped_gradients, gradient_norm = tf.clip_by_global_norm(
            gradients, self.config.max_gradient_norm)

        tf.summary.scalar("grad_norm", gradient_norm)
        tf.summary.scalar("clipped_norm", tf.global_norm(clipped_gradients))
        tf.summary.scalar("learning_rate", self.learning_rate)

        train_op = optimizer.apply_gradients(
            zip(clipped_gradients, params), global_step=self.global_step)

        return train_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号