def FeedInput(self, request, context):
logger.debug("server: feed_input")
if request.comp_id not in self.computations:
message = 'unknown computation id {}'.format(request.comp_id)
return hetr_pb2.FeedInputReply(status=False, message=message)
try:
values = []
for v in request.values:
if v.HasField('scalar'):
values.append(protobuf_scalar_to_python(v.scalar))
else:
values.append(pb_to_tensor(v.tensor))
computation = self.computations[request.comp_id]
if self.transformer.transformer_name == "gpu":
import pycuda.driver as drv
if self.transformer.runtime and \
not self.transformer.runtime.ctx == drv.Context.get_current():
self.transformer.runtime.ctx.push()
# TODO figure out doc for rpdb to pass in port
# give unique port per device (4444 + device_id)
outputs = computation(*values)
self.transformer.runtime.ctx.pop()
else:
outputs = computation(*values)
self.results[request.comp_id] = outputs
return hetr_pb2.FeedInputReply(status=True)
except Exception:
return hetr_pb2.FeedInputReply(status=False, message=traceback.format_exc())
评论列表
文章目录