test_copy_param.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def test_soft_copy_param(self):
        a = L.Linear(1, 5)
        b = L.Linear(1, 5)

        a.W.data[:] = 0.5
        b.W.data[:] = 1

        # a = (1 - tau) * a + tau * b
        copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

        np.testing.assert_almost_equal(a.W.data, np.full(a.W.data.shape, 0.55))
        np.testing.assert_almost_equal(b.W.data, np.full(b.W.data.shape, 1.0))

        copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

        np.testing.assert_almost_equal(
            a.W.data, np.full(a.W.data.shape, 0.595))
        np.testing.assert_almost_equal(b.W.data, np.full(b.W.data.shape, 1.0))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号