gradient_check.py 文件源码

python
阅读 67 收藏 0 点赞 0 评论 0

项目:vanilla-neural-nets 作者: cavaunpeu 项目源码 文件源码
def _passes_gradient_check(self, parameter):
        iterator = np.nditer(parameter.value, flags=['multi_index'], op_flags=['readwrite'])

        while not iterator.finished:
            multi_index = iterator.multi_index
            numerical_gradient = self._compute_numerical_gradient(parameter=parameter, multi_index=multi_index)
            analytical_gradient = parameter.gradient[multi_index]

            relative_error = self._compute_relative_error(
                numerical_gradient=numerical_gradient,
                analytical_gradient=analytical_gradient
            )
            if (relative_error > self.error_threshold) or np.isnan(relative_error):
                return False

            iterator.iternext()

        return True
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号