def test_dim_reduction(self):
dim_red_fns = [
"mean", "median", "mode", "norm", "prod",
"std", "sum", "var", "max", "min"]
def normfn_attr(t, dim, keepdim=True):
attr = getattr(torch, "norm")
return attr(t, 2, dim, keepdim)
for fn_name in dim_red_fns:
x = torch.randn(3, 4, 5)
fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr
def fn(t, dim, keepdim=True):
ans = fn_attr(x, dim, keepdim)
return ans if not isinstance(ans, tuple) else ans[0]
dim = random.randint(0, 2)
self.assertEqual(fn(x, dim, False).unsqueeze(dim), fn(x, dim))
self.assertEqual(x.ndimension() - 1, fn(x, dim, False).ndimension())
self.assertEqual(x.ndimension(), fn(x, dim, True).ndimension())
# check 1-d behavior
x = torch.randn(1)
dim = 0
self.assertEqual(fn(x, dim), fn(x, dim, True))
self.assertEqual(x.ndimension(), fn(x, dim).ndimension())
self.assertEqual(x.ndimension(), fn(x, dim, True).ndimension())
评论列表
文章目录