def _get_expanded_coords_data(coords, data, params, broadcast_shape):
"""
Expand coordinates/data to broadcast_shape. Does most of the heavy lifting for broadcast_to.
Produces sorted output for sorted inputs.
Parameters
----------
coords : np.ndarray
The coordinates to expand.
data : np.ndarray
The data corresponding to the coordinates.
params : list
The broadcast parameters.
broadcast_shape : tuple[int]
The shape to broadcast to.
Returns
-------
expanded_coords : np.ndarray
List of 1-D arrays. Each item in the list has one dimension of coordinates.
expanded_data : np.ndarray
The data corresponding to expanded_coords.
"""
first_dim = -1
expand_shapes = []
for d, p, l in zip(range(len(broadcast_shape)), params, broadcast_shape):
if p and first_dim == -1:
expand_shapes.append(coords.shape[1])
first_dim = d
if not p:
expand_shapes.append(l)
all_idx = COO._cartesian_product(*(np.arange(d, dtype=np.min_scalar_type(d - 1)) for d in expand_shapes))
dt = np.result_type(*(np.min_scalar_type(l - 1) for l in broadcast_shape))
false_dim = 0
dim = 0
expanded_coords = np.empty((len(broadcast_shape), all_idx.shape[1]), dtype=dt)
expanded_data = data[all_idx[first_dim]]
for d, p, l in zip(range(len(broadcast_shape)), params, broadcast_shape):
if p:
expanded_coords[d] = coords[dim, all_idx[first_dim]]
else:
expanded_coords[d] = all_idx[false_dim + (d > first_dim)]
false_dim += 1
if p is not None:
dim += 1
return np.asarray(expanded_coords), np.asarray(expanded_data)
# (c) senderle
# Taken from https://stackoverflow.com/a/11146645/774273
# License: https://creativecommons.org/licenses/by-sa/3.0/
评论列表
文章目录