def test_tril_triu_ndim3():
for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']:
a = np.array([
[[1, 1], [1, 1]],
[[1, 1], [1, 0]],
[[1, 1], [0, 0]],
], dtype=dtype)
a_tril_desired = np.array([
[[1, 0], [1, 1]],
[[1, 0], [1, 0]],
[[1, 0], [0, 0]],
], dtype=dtype)
a_triu_desired = np.array([
[[1, 1], [0, 1]],
[[1, 1], [0, 0]],
[[1, 1], [0, 0]],
], dtype=dtype)
a_triu_observed = np.triu(a)
a_tril_observed = np.tril(a)
yield assert_array_equal, a_triu_observed, a_triu_desired
yield assert_array_equal, a_tril_observed, a_tril_desired
yield assert_equal, a_triu_observed.dtype, a.dtype
yield assert_equal, a_tril_observed.dtype, a.dtype
评论列表
文章目录