def test_soft_copy_param(self):
a = L.Linear(1, 5)
b = L.Linear(1, 5)
a.W.data[:] = 0.5
b.W.data[:] = 1
# a = (1 - tau) * a + tau * b
copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)
np.testing.assert_almost_equal(a.W.data, np.full(a.W.data.shape, 0.55))
np.testing.assert_almost_equal(b.W.data, np.full(b.W.data.shape, 1.0))
copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)
np.testing.assert_almost_equal(
a.W.data, np.full(a.W.data.shape, 0.595))
np.testing.assert_almost_equal(b.W.data, np.full(b.W.data.shape, 1.0))
评论列表
文章目录