def test_log_epoch_empty_log(self):
l = logger.Logger(agent_name='test')
l.log_epoch(epoch=0)
log_dir = l.log_dir
self.assertTrue(os.path.isfile(os.path.join(log_dir, 'actions.npz')))
self.assertTrue(os.path.isfile(os.path.join(log_dir, 'rewards.npz')))
self.assertTrue(os.path.isfile(os.path.join(log_dir, 'losses.npz')))
shutil.rmtree(log_dir)
# class TestMovingAverage(unittest.TestCase):
# def test_moving_average_single_item_window(self):
# arr = [1,2,3]
# actual = logger.moving_average(arr, 1)
# self.assertSequenceEqual(actual, arr)
# def test_moving_average_small_window(self):
# arr = [1,2,3,4,5,6,7]
# actual = logger.moving_average(arr, 2)
# expected = [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5]
# self.assertSequenceEqual(actual, expected)
# def test_moving_average_small_window_large_variance(self):
# arr = [0,9,0,9,0]
# actual = logger.moving_average(arr, 3)
# expected = [3, 3, 6, 3, 3]
# self.assertSequenceEqual(actual, expected)
# def test_moving_average_large_window_large_variance(self):
# arr = [0,9,0,9,0]
# actual = logger.moving_average(arr, 4)
# expected = [2.25, 2.25, 4.5, 4.5, 2.25]
# self.assertSequenceEqual(actual, expected)
评论列表
文章目录