def verify_solve_grad(self, m, n, A_structure, lower, rng):
# ensure diagonal elements of A relatively large to avoid numerical
# precision issues
A_val = (rng.normal(size=(m, m)) * 0.5 +
numpy.eye(m)).astype(config.floatX)
if A_structure == 'lower_triangular':
A_val = numpy.tril(A_val)
elif A_structure == 'upper_triangular':
A_val = numpy.triu(A_val)
if n is None:
b_val = rng.normal(size=m).astype(config.floatX)
else:
b_val = rng.normal(size=(m, n)).astype(config.floatX)
eps = None
if config.floatX == "float64":
eps = 2e-8
solve_op = Solve(A_structure=A_structure, lower=lower)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
评论列表
文章目录