test_objectives.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:NumpyDL 作者: oujago 项目源码 文件源码
def test_MeanSquaredError():
    from npdl.objectives import MeanSquaredError

    obj = MeanSquaredError()

    outputs = np.random.rand(10, 20)
    targets = np.random.rand(10, 20)

    f_res = obj.forward(outputs, targets)
    b_res = obj.backward(outputs, targets)

    assert np.ndim(f_res) == 0
    assert np.ndim(b_res) == 2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号