def _update_adabn(self, eval_data):
'''Update moving mean and moving var with eval data'''
from time import time
start = time()
with self._restore_eval_data(eval_data):
for _ in range(self.num_adabn_epoch):
eval_data.reset()
for nbatch, eval_batch in enumerate(eval_data):
self.forward(eval_batch, is_train=True)
for out in self.get_outputs():
# Cause memory leak (though not increase after this _update_adabn) without this wait
# TODO: fixme
out.wait_to_read()
# for name, block in zip(self._exec_group.aux_names, self._exec_group.aux_arrays):
# if 'moving' in name:
# for a in block:
# a.wait_to_read()
logger.debug(
'AdaBN with {} epochs takes {} seconds',
self.num_adabn_epoch,
time() - start
)
评论列表
文章目录