utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Master-Thesis 作者: AntoinePassemiers 项目源码 文件源码
def extended_2d_fancy_indexing(arr, sl1, sl2, value_of_nan):
    new_shape = tuple([sl1.stop - sl1.start, sl2.stop - sl2.start] + list(arr.shape[2:]))
    result = np.full(new_shape, value_of_nan, dtype = arr.dtype)
    x_lower = 0 if sl1.start < 0 else sl1.start
    x_upper = arr.shape[0] if sl1.stop > arr.shape[0] else sl1.stop
    y_lower = 0 if sl2.start < 0 else sl2.start
    y_upper = arr.shape[1] if sl2.stop > arr.shape[1] else sl2.stop

    new_x_lower = max(0, - sl1.stop + (sl1.stop - sl1.start))
    new_x_upper = new_x_lower + (x_upper - x_lower)
    new_y_lower = max(0, - sl2.stop + (sl2.stop - sl2.start))
    new_y_upper = new_y_lower + (y_upper - y_lower)

    if len(result.shape) == 2:
        result[new_x_lower:new_x_upper, new_y_lower:new_y_upper] = arr[x_lower:x_upper, y_lower:y_upper]
    elif len(result.shape) == 3:
        result[new_x_lower:new_x_upper, new_y_lower:new_y_upper, :] = arr[x_lower:x_upper, y_lower:y_upper, :]
    else:
        raise WrongTensorShapeError()
    return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号