def extended_2d_fancy_indexing(arr, sl1, sl2, value_of_nan):
new_shape = tuple([sl1.stop - sl1.start, sl2.stop - sl2.start] + list(arr.shape[2:]))
result = np.full(new_shape, value_of_nan, dtype = arr.dtype)
x_lower = 0 if sl1.start < 0 else sl1.start
x_upper = arr.shape[0] if sl1.stop > arr.shape[0] else sl1.stop
y_lower = 0 if sl2.start < 0 else sl2.start
y_upper = arr.shape[1] if sl2.stop > arr.shape[1] else sl2.stop
new_x_lower = max(0, - sl1.stop + (sl1.stop - sl1.start))
new_x_upper = new_x_lower + (x_upper - x_lower)
new_y_lower = max(0, - sl2.stop + (sl2.stop - sl2.start))
new_y_upper = new_y_lower + (y_upper - y_lower)
if len(result.shape) == 2:
result[new_x_lower:new_x_upper, new_y_lower:new_y_upper] = arr[x_lower:x_upper, y_lower:y_upper]
elif len(result.shape) == 3:
result[new_x_lower:new_x_upper, new_y_lower:new_y_upper, :] = arr[x_lower:x_upper, y_lower:y_upper, :]
else:
raise WrongTensorShapeError()
return result
评论列表
文章目录