def _eq(x, y):
"""
Equality comparison for nested data structures with tensors.
"""
if type(x) is not type(y):
return False
elif isinstance(x, dict):
if set(x.keys()) != set(y.keys()):
return False
return all(_eq(x_val, y[key]) for key, x_val in x.items())
elif isinstance(x, (np.ndarray, torch.Tensor)):
return (x == y).all()
elif isinstance(x, torch.autograd.Variable):
return (x.data == y.data).all()
else:
return x == y
评论列表
文章目录