hetr.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def mlsl_gather_send(self, gather_send_id, x_nparr):
        gather_send_op = self.gather_send_nodes[gather_send_id]

        # todo: get real root_idx
        root_idx = 0

        # np.atleast_1d is used in cases when we need to reduce to a scalar value
        x_nparr = np.atleast_1d(x_nparr)
        if self.process_idx == root_idx:
            # todo: remove that workaround for non-symmetric case
            gather_send_op.arr = x_nparr
        else:
            send_buf = self.as_buffer(x_nparr)
            send_count = x_nparr.size
            recv_buf = None
            if gather_send_op.use_reduce:
                req = self.distribution.reduce(send_buf, send_buf, send_count,
                                               mlsl.DataType.FLOAT, mlsl.ReductionType.SUM,
                                               root_idx, mlsl.GroupType.DATA)
            else:
                req = self.distribution.gather(send_buf, send_count, recv_buf,
                                               mlsl.DataType.FLOAT, root_idx,
                                               mlsl.GroupType.DATA)
            self.mlsl_obj.wait(req)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号