nutszebra_ilsvrc_object_localization_with_multi_gpus.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:trainer 作者: nutszebra 项目源码 文件源码
def update_core(self, x, t, p, args_da):
        self._send_message(('update', None))
        with cuda.Device(self.gpus[0]):
            self.model.cleargrads()
            args = list(zip(x, t, args_da))
            processed = p.starmap(process_train, args)
            tmp_x, tmp_t = list(zip(*processed))
            data_length = len(tmp_x)
            train = True
            x = self.model.prepare_input(tmp_x, dtype=np.float32, volatile=not train, gpu=self.gpus[0])
            t = self.model.prepare_input(tmp_t, dtype=np.int32, volatile=not train, gpu=self.gpus[0])
            y = self.model(x, train=train)
            loss = self.model.calc_loss(y, t) / len(self.gpus)
            loss.backward()
            loss.to_cpu()
            loss = float(loss.data) * data_length

            del x
            del t
            del y

            # NCCL: reduce grads
            null_stream = cuda.Stream.null
            if self.communication is not None:
                # send grads
                gg = gather_grads(self.model)
                self.communication.reduce(gg.data.ptr,
                                          gg.data.ptr,
                                          gg.size,
                                          nccl.NCCL_FLOAT,
                                          nccl.NCCL_SUM,
                                          0,
                                          null_stream.ptr)
                # copy grads, gg, to  self.model
                scatter_grads(self.model, gg)
                del gg
            self.optimizer.update()
            if self.communication is not None:
                gp = gather_params(self.model)
                self.communication.bcast(gp.data.ptr,
                                         gp.size,
                                         nccl.NCCL_FLOAT,
                                         0,
                                         null_stream.ptr)
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号