def _check_input_not_mutated(self, func):
@wraps(func)
def check_not_mutated(*args, **kwargs):
# Copy inputs to compare them to originals later.
arg_copies = [(i, arg.copy()) for i, arg in enumerate(args)
if isinstance(arg, (NDFrame, np.ndarray))]
kwarg_copies = {
k: v.copy() for k, v in iteritems(kwargs)
if isinstance(v, (NDFrame, np.ndarray))
}
result = func(*args, **kwargs)
# Check that inputs weren't mutated by func.
for i, arg_copy in arg_copies:
assert_allclose(
args[i],
arg_copy,
atol=0.5 * 10 ** (-DECIMAL_PLACES),
err_msg="Input 'arg %s' mutated by %s"
% (i, func.__name__),
)
for kwarg_name, kwarg_copy in iteritems(kwarg_copies):
assert_allclose(
kwargs[kwarg_name],
kwarg_copy,
atol=0.5 * 10 ** (-DECIMAL_PLACES),
err_msg="Input '%s' mutated by %s"
% (kwarg_name, func.__name__),
)
return result
return check_not_mutated
评论列表
文章目录