test_multiarray.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:radar 作者: amoose136 项目源码 文件源码
def test_swapaxes(self):
        a = np.arange(1*2*3*4).reshape(1, 2, 3, 4).copy()
        idx = np.indices(a.shape)
        assert_(a.flags['OWNDATA'])
        b = a.copy()
        # check exceptions
        assert_raises(ValueError, a.swapaxes, -5, 0)
        assert_raises(ValueError, a.swapaxes, 4, 0)
        assert_raises(ValueError, a.swapaxes, 0, -5)
        assert_raises(ValueError, a.swapaxes, 0, 4)

        for i in range(-4, 4):
            for j in range(-4, 4):
                for k, src in enumerate((a, b)):
                    c = src.swapaxes(i, j)
                    # check shape
                    shape = list(src.shape)
                    shape[i] = src.shape[j]
                    shape[j] = src.shape[i]
                    assert_equal(c.shape, shape, str((i, j, k)))
                    # check array contents
                    i0, i1, i2, i3 = [dim-1 for dim in c.shape]
                    j0, j1, j2, j3 = [dim-1 for dim in src.shape]
                    assert_equal(src[idx[j0], idx[j1], idx[j2], idx[j3]],
                                 c[idx[i0], idx[i1], idx[i2], idx[i3]],
                                 str((i, j, k)))
                    # check a view is always returned, gh-5260
                    assert_(not c.flags['OWNDATA'], str((i, j, k)))
                    # check on non-contiguous input array
                    if k == 1:
                        b = c
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号