def get_final_shape(data_array, out_dims, direction_to_names):
"""
Determine the final shape that data_array must be reshaped to in order to
have one axis for each of the out_dims (for instance, combining all
axes collected by the '*' direction).
"""
final_shape = []
for direction in out_dims:
if len(direction_to_names[direction]) == 0:
final_shape.append(1)
else:
# determine shape once dimensions for direction (usually '*') are combined
final_shape.append(
np.product([len(data_array.coords[name])
for name in direction_to_names[direction]]))
return final_shape
评论列表
文章目录