test_utils.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_np_layout_shuffle():
    # set up
    bsz = 8
    C, H, W, N = 3, 28, 28, bsz
    C, R, S, K = 3, 5, 5, 32

    # image dim-shuffle
    np_tf_image = np.random.randn(N, H, W, C)
    np_ng_image = np_layout_shuffle(np_tf_image, "NHWC", "CDHWN")
    np_tf_image_reverse = np_layout_shuffle(np_ng_image, "CDHWN", "NHWC")
    assert np.array_equal(np_tf_image, np_tf_image_reverse)

    # filter dim-shuffle
    np_tf_weight = np.random.randn(R, S, C, K)
    np_ng_weight = np_layout_shuffle(np_tf_weight, "RSCK", "CTRSK")
    np_tf_weight_reverse = np_layout_shuffle(np_ng_weight, "CTRSK", "RSCK")
    assert np.array_equal(np_tf_weight, np_tf_weight_reverse)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号