def run_grad_and_gradgrad_checks(test_case, test_name, apply_method, output_variable, input_variables):
test_case.assertTrue(gradcheck(apply_method, input_variables, eps=1e-6, atol=PRECISION))
grad_y = generate_gradoutput(output_variable, non_contiguous=True)
gradgradcheck_precision_override = gradgradcheck_method_precision_override(test_name)
if gradgradcheck_precision_override is not None:
atol = gradgradcheck_precision_override['atol']
rtol = gradgradcheck_precision_override['rtol']
test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y, atol=atol, rtol=rtol))
else:
test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y,))
评论列表
文章目录