def create_updater(train_iter, optimizer, devices):
if HAVE_NCCL and len(devices) > 1:
updater = training.updaters.MultiprocessParallelUpdater(
train_iter, optimizer, devices=devices)
elif len(devices) > 1:
optimizer.lr /= len(devices)
updater = training.ParallelUpdater(
train_iter, optimizer, devices=devices)
else:
updater = training.StandardUpdater(
train_iter, optimizer, device=devices['main'])
return updater
评论列表
文章目录