def test_RNGState(self):
state = torch.get_rng_state()
stateCloned = state.clone()
before = torch.rand(1000)
self.assertEqual(state.ne(stateCloned).long().sum(), 0, 0)
torch.set_rng_state(state)
after = torch.rand(1000)
self.assertEqual(before, after, 0)
评论列表
文章目录