def dot(a, b):
if a.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but "
"a.ndim = {}.".format(a.ndim))
if b.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but "
"b.ndim = {}.".format(b.ndim))
if a.shape[1] != b.shape[0]:
raise Exception("dot expects a.shape[1] to equal b.shape[0], but "
"a.shape = {} and b.shape = {}.".format(a.shape,
b.shape))
shape = [a.shape[0], b.shape[1]]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
args = list(a.objectids[i, :]) + list(b.objectids[:, j])
result.objectids[i, j] = blockwise_dot.remote(*args)
return result
评论列表
文章目录