nutszebra_ilsvrc_object_localization_with_multi_gpus.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:trainer 作者: nutszebra 项目源码 文件源码
def run(self):
        dev = cuda.Device(self.device)
        dev.use()
        # build communication via nccl
        self.setup()
        gp = None
        p = multiprocessing.Pool(self.parallel_train)
        args_da = [self.da() for _ in six.moves.range(self.batch)]
        while True:
            job, data = self.pipe.recv()
            if job == 'finalize':
                dev.synchronize()
                break
            if job == 'update':
                # for reducing memory
                self.model.cleargrads()
                indices = list(self.sampling.yield_random_batch_from_category(1, self.picture_number_at_each_categories, self.batch, shuffle=True))[0]
                x = self.train_x[indices]
                t = self.train_y[indices]
                args = list(zip(x, t, args_da))
                processed = p.starmap(process_train, args)
                tmp_x, tmp_t = list(zip(*processed))
                train = True
                x = self.model.prepare_input(tmp_x, dtype=np.float32, volatile=not train, gpu=self.device)
                t = self.model.prepare_input(tmp_t, dtype=np.int32, volatile=not train, gpu=self.device)
                y = self.model(x, train=train)
                loss = self.model.calc_loss(y, t) / self.number_of_devices / self.train_batch_divide
                loss.backward()

                del x
                del t
                del y
                del loss

                # send gradients of self.model
                gg = gather_grads(self.model)
                null_stream = cuda.Stream.null
                self.communication.reduce(gg.data.ptr,
                                          gg.data.ptr,
                                          gg.size,
                                          nccl.NCCL_FLOAT,
                                          nccl.NCCL_SUM,
                                          0,
                                          null_stream.ptr)
                del gg
                self.model.cleargrads()
                # send parameters of self.model
                gp = gather_params(self.model)
                self.communication.bcast(gp.data.ptr,
                                         gp.size,
                                         nccl.NCCL_FLOAT,
                                         0,
                                         null_stream.ptr)
                scatter_params(self.model, gp)
                gp = None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号