custom_train.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def apply_update(self, optimizer, grads_and_vars):
        (grads, vars) = zip(*grads_and_vars)

        # Gradient clipping
        if CustomTrainer.GRADIENT_CLIP in self.train_hypers:
            grads, global_norm = clip_ops.clip_by_global_norm(grads,
                                    self.train_hypers[CustomTrainer.GRADIENT_CLIP])
        # Gradient noise
        if CustomTrainer.GRADIENT_NOISE in self.train_hypers:
            sigma_sqr = self.train_hypers[CustomTrainer.GRADIENT_NOISE]
            if CustomTrainer.GRADIENT_NOISE_DECAY in self.train_hypers:
                sigma_sqr /= tf.pow(1.0 + tf.to_float(self.global_step),
                                    self.train_hypers[CustomTrainer.GRADIENT_NOISE_DECAY])
            grads_tmp = []
            for g in grads:
                if g is not None:
                    noisy_grad = g + tf.sqrt(sigma_sqr)*tf.random_normal(tf.shape(g))
                    grads_tmp.append(noisy_grad)
                else:
                    grads_tmp.append(g)
            grads = grads_tmp

        train_op = optimizer.apply_gradients(zip(grads, vars), global_step=self.global_step)
        return train_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号