def mlsl_gather_send(self, gather_send_id, x_nparr):
gather_send_op = self.gather_send_nodes[gather_send_id]
# todo: get real root_idx
root_idx = 0
# np.atleast_1d is used in cases when we need to reduce to a scalar value
x_nparr = np.atleast_1d(x_nparr)
if self.process_idx == root_idx:
# todo: remove that workaround for non-symmetric case
gather_send_op.arr = x_nparr
else:
send_buf = self.as_buffer(x_nparr)
send_count = x_nparr.size
recv_buf = None
if gather_send_op.use_reduce:
req = self.distribution.reduce(send_buf, send_buf, send_count,
mlsl.DataType.FLOAT, mlsl.ReductionType.SUM,
root_idx, mlsl.GroupType.DATA)
else:
req = self.distribution.gather(send_buf, send_count, recv_buf,
mlsl.DataType.FLOAT, root_idx,
mlsl.GroupType.DATA)
self.mlsl_obj.wait(req)
评论列表
文章目录