def assert_allclose(x, y, rtol=1e-10, atol=1e-8):
"""Drop in replacement for `numpy.testing.assert_allclose` that shows the nonmatching elements"""
if np.isscalar(x) and np.isscalar(y) == 1:
return np.testing.assert_allclose(x, y, rtol=rtol, atol=atol)
if x.shape != y.shape:
raise AssertionError("Shape mismatch: %s vs %s" % (str(x.shape), str(y.shape)))
d = ~np.isclose(x, y, rtol, atol)
if np.any(d):
miss = np.where(d)[0]
raise AssertionError("""Mismatch of %d elements (%g %%) at the level of rtol=%g, atol=%g
%s
%s
%s""" % (len(miss), len(miss)/x.size, rtol, atol, repr(miss), str(x[d]), str(y[d])))
评论列表
文章目录