nutszebra_cifar10_with_multi_gpus.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:trainer 作者: nutszebra 项目源码 文件源码
def run(self):
        dev = cuda.Device(self.device)
        dev.use()
        # build communication via nccl
        self.setup()
        gp = None
        da_args = [self.da() for _ in six.moves.range(self.batch)]
        p = multiprocessing.Pool(self.parallel)
        batch_of_batch = int(float(self.batch) / self.train_batch_divide)
        while True:
            job, data = self.pipe.recv()
            if job == 'finalize':
                dev.synchronize()
                break
            if job == 'update':
                # for reducing memory
                self.model.zerograds()
                indices = list(self.sampling.yield_random_batch_samples(1, self.batch, len(self.train_x), sort=False))[0]
                for ii in six.moves.range(0, len(indices), batch_of_batch):
                    x = self.train_x[indices[ii:ii + batch_of_batch]]
                    t = self.train_y[indices[ii:ii + batch_of_batch]]
                    args = list(six.moves.zip(x, t, da_args))
                    processed = p.starmap(process_train, args)
                    tmp_x, tmp_t = list(zip(*processed))
                    train = True
                    x = self.model.prepare_input(tmp_x, dtype=np.float32, volatile=not train, gpu=self.device)
                    t = self.model.prepare_input(tmp_t, dtype=np.int32, volatile=not train, gpu=self.device)
                    y = self.model(x, train=train)
                    loss = self.model.calc_loss(y, t) / self.number_of_devices / self.train_batch_divide
                    loss.backward()
                    del x
                    del t
                    del y
                    del loss

                # send gradients of self.model
                gg = gather_grads(self.model)
                null_stream = cuda.Stream.null
                self.communication.reduce(gg.data.ptr,
                                          gg.data.ptr,
                                          gg.size,
                                          nccl.NCCL_FLOAT,
                                          nccl.NCCL_SUM,
                                          0,
                                          null_stream.ptr)
                del gg
                self.model.zerograds()
                # send parameters of self.model
                gp = gather_params(self.model)
                self.communication.bcast(gp.data.ptr,
                                         gp.size,
                                         nccl.NCCL_FLOAT,
                                         0,
                                         null_stream.ptr)
                scatter_params(self.model, gp)
                gp = None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号