test_compare.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:mobula 作者: wkcn 项目源码 文件源码
def test_compare():
    N, C, H, W = 2,3,4,5
    a = np.random.random((N, C, H, W))
    b = np.random.random((N, C, H, W))
    a_copy = a.copy()
    [la, lb, lac] = L.Data([a,b,a_copy]) 

    ops = [operator.ge, operator.gt, operator.le, operator.lt]
    for op in ops:
        l = op(la, lb)
        l2 = op(la, lac)
        assert np.allclose(l.eval(), op(a,b))
        assert np.allclose(l2.eval(), op(a,a_copy))
        l.backward()
        l2.backward()
        assert np.allclose(l.dX[0], np.zeros(l.X[0].shape)) 
        assert np.allclose(l.dX[1], np.zeros(l.X[1].shape)) 
        assert np.allclose(l2.dX[0], np.zeros(l2.X[0].shape)) 
        assert np.allclose(l2.dX[1], np.zeros(l2.X[1].shape))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号