hetr_server.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def FeedInput(self, request, context):
        logger.debug("server: feed_input")
        if request.comp_id not in self.computations:
            message = 'unknown computation id {}'.format(request.comp_id)
            return hetr_pb2.FeedInputReply(status=False, message=message)

        try:
            values = []
            for v in request.values:
                if v.HasField('scalar'):
                    values.append(protobuf_scalar_to_python(v.scalar))
                else:
                    values.append(pb_to_tensor(v.tensor))
            computation = self.computations[request.comp_id]
            if self.transformer.transformer_name == "gpu":
                import pycuda.driver as drv
                if self.transformer.runtime and \
                   not self.transformer.runtime.ctx == drv.Context.get_current():
                    self.transformer.runtime.ctx.push()
                # TODO figure out doc for rpdb to pass in port
                # give unique port per device (4444 + device_id)
                outputs = computation(*values)
                self.transformer.runtime.ctx.pop()
            else:
                outputs = computation(*values)

            self.results[request.comp_id] = outputs

            return hetr_pb2.FeedInputReply(status=True)
        except Exception:
            return hetr_pb2.FeedInputReply(status=False, message=traceback.format_exc())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号