solver.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:deepwater-nae 作者: h2oai 项目源码 文件源码
def start(self, rank):
        self.rank = rank

        if len(self.gpus) > 0:
            self.device = self.gpus[rank]
            if debug:
                s = 'solver gpu %d' % self.gpus[self.rank] + \
                    ' pid %d' % os.getpid() + ' size %d' % self.size + \
                    ' rank %d' % self.rank
                print(s, file = sys.stderr)
            caffe.set_mode_gpu()
            caffe.set_device(self.device)
            caffe.set_solver_count(self.size)
            caffe.set_solver_rank(self.rank)
            caffe.set_multiprocess(True)
        else:
            print('solver cpu', file = sys.stderr)
            caffe.set_mode_cpu()

        if self.cmd.graph.endswith('.json'):
            with open(self.cmd.graph, mode = 'r') as f:
                graph = caffe_pb2.SolverParameter()
                text_format.Merge(f.read(), graph)
                self.graph = graph
        else:
            self.graph = self.solver_graph()

        import tempfile
        with tempfile.NamedTemporaryFile(mode = 'w+', delete = False) as f:
            text_format.PrintMessage(self.graph, f)
            tmp = f.name
        self.caffe = caffe.AdamSolver(tmp)

        if self.uid:
            self.nccl = caffe.NCCL(self.caffe, self.uid)
            self.nccl.bcast()
            self.caffe.add_callback(self.nccl)
            if self.caffe.param.layer_wise_reduce:
                self.caffe.net.after_backward(self.nccl)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号