def test_reinforce_check(self):
x = Variable(torch.randn(5, 5), requires_grad=True)
# these should be ok
y = torch.normal(x)
y.reinforce(torch.randn(5, 5))
y = torch.normal(x)
y.reinforce(2)
# can't call reinforce on non-stochastic variables
self.assertRaises(RuntimeError, lambda: x.reinforce(2))
# can't call reinforce twice
y = torch.normal(x)
y.reinforce(2)
self.assertRaises(RuntimeError, lambda: y.reinforce(2))
# check type of reward
y = torch.normal(x)
self.assertRaises(TypeError, lambda: y.reinforce(torch.randn(5, 5).long()))
# check size of reward
y = torch.normal(x)
self.assertRaises(ValueError, lambda: y.reinforce(torch.randn(4, 5)))
评论列表
文章目录