def test_MeanSquaredError():
from npdl.objectives import MeanSquaredError
obj = MeanSquaredError()
outputs = np.random.rand(10, 20)
targets = np.random.rand(10, 20)
f_res = obj.forward(outputs, targets)
b_res = obj.backward(outputs, targets)
assert np.ndim(f_res) == 0
assert np.ndim(b_res) == 2
评论列表
文章目录