cdk.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:xRBM 作者: omimo 项目源码 文件源码
def apply_updates(self, model, grads):
        """
        Updates the model parameters based on the given gradients, using momentum
        """
        update_ops = []
        mom_ops = []

        if isinstance(self._learning_rate, list):
            lrs = self._learning_rate
            print('d')
        else:
            lrs = [self._learning_rate for p in model.model_params]

        with tf.name_scope('CDLearning/updates'):
            for param, grad, mv, lr in zip(model.model_params, grads, self._momentum_vector, lrs):
                mv = tf.assign(mv, self._momentum * mv + grad * lr)
                update_ops.append(tf.assign_sub(param, mv))

        return update_ops, mom_ops
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号