def run(self):
dev = cuda.Device(self.device)
dev.use()
# build communication via nccl
self.setup()
gp = None
da_args = [self.da() for _ in six.moves.range(self.batch)]
p = multiprocessing.Pool(self.parallel)
batch_of_batch = int(float(self.batch) / self.train_batch_divide)
while True:
job, data = self.pipe.recv()
if job == 'finalize':
dev.synchronize()
break
if job == 'update':
# for reducing memory
self.model.zerograds()
indices = list(self.sampling.yield_random_batch_samples(1, self.batch, len(self.train_x), sort=False))[0]
for ii in six.moves.range(0, len(indices), batch_of_batch):
x = self.train_x[indices[ii:ii + batch_of_batch]]
t = self.train_y[indices[ii:ii + batch_of_batch]]
args = list(six.moves.zip(x, t, da_args))
processed = p.starmap(process_train, args)
tmp_x, tmp_t = list(zip(*processed))
train = True
x = self.model.prepare_input(tmp_x, dtype=np.float32, volatile=not train, gpu=self.device)
t = self.model.prepare_input(tmp_t, dtype=np.int32, volatile=not train, gpu=self.device)
y = self.model(x, train=train)
loss = self.model.calc_loss(y, t) / self.number_of_devices / self.train_batch_divide
loss.backward()
del x
del t
del y
del loss
# send gradients of self.model
gg = gather_grads(self.model)
null_stream = cuda.Stream.null
self.communication.reduce(gg.data.ptr,
gg.data.ptr,
gg.size,
nccl.NCCL_FLOAT,
nccl.NCCL_SUM,
0,
null_stream.ptr)
del gg
self.model.zerograds()
# send parameters of self.model
gp = gather_params(self.model)
self.communication.bcast(gp.data.ptr,
gp.size,
nccl.NCCL_FLOAT,
0,
null_stream.ptr)
scatter_params(self.model, gp)
gp = None
nutszebra_cifar10_with_multi_gpus.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录