test_dot.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_flat_tensor_dot_tensor():
    """
    Ensure that a flattened argument axis is not unflattend in the result.

    """
    H = ng.make_axis(2)
    W = ng.make_axis(7)
    C = ng.make_axis(3)
    K = ng.make_axis(11)

    axes_a = ng.make_axes([H, W, C])
    a = ng.constant(np.ones(axes_a.lengths), axes=axes_a)
    flat_a = ng.flatten_at(a, 2)

    axes_b = ng.make_axes([C, K])
    b = ng.constant(np.ones(axes_b.lengths), axes=axes_b)

    result = ng.dot(b, flat_a)

    with ExecutorFactory() as factory:
        result_fun = factory.executor(result)
        result_val = result_fun()

    result_correct = np.ones_like(result_val) * C.length
    ng.testing.assert_allclose(result_val, result_correct)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号