def test_plugin_interval(self):
for interval in self.intervals:
self.setUp()
simple_plugin = SimplePlugin(interval)
self.trainer.register_plugin(simple_plugin)
self.trainer.run(epochs=self.num_epochs)
units = {
('iteration', self.num_iters),
('epoch', self.num_epochs),
('batch', self.num_iters),
('update', self.num_iters)
}
for unit, num_triggers in units:
call_every = None
for i, i_unit in interval:
if i_unit == unit:
call_every = i
break
if call_every:
expected_num_calls = math.floor(num_triggers / call_every)
else:
expected_num_calls = 0
num_calls = getattr(simple_plugin, 'num_' + unit)
self.assertEqual(num_calls, expected_num_calls, 0)
评论列表
文章目录