def _passes_gradient_check(self, parameter):
iterator = np.nditer(parameter.value, flags=['multi_index'], op_flags=['readwrite'])
while not iterator.finished:
multi_index = iterator.multi_index
numerical_gradient = self._compute_numerical_gradient(parameter=parameter, multi_index=multi_index)
analytical_gradient = parameter.gradient[multi_index]
relative_error = self._compute_relative_error(
numerical_gradient=numerical_gradient,
analytical_gradient=analytical_gradient
)
if (relative_error > self.error_threshold) or np.isnan(relative_error):
return False
iterator.iternext()
return True
评论列表
文章目录