沿给定轴将numpy ndarray与1d数组相乘

发布于 2021-01-29 16:08:53

看来我迷失于潜在的愚蠢之举。我有一个n维的numpy数组,我想将其与沿某个维度(可以改变!)的向量(1d数组)相乘。例如,假设我要沿着第一个数组的轴0将2d数组乘以1d数组,我可以执行以下操作:

a=np.arange(20).reshape((5,4))
b=np.ones(5)
c=a*b[:,np.newaxis]

容易,但我想将此概念扩展到n维(对于a,而b始终为1d)和任何轴。换句话说,我想知道如何在正确的位置使用np.newaxis生成切片。假设a是3d,并且我想沿axis
= 1进行乘法,那么我想生成一个可以正确给出的切片:

c=a*b[np.newaxis,:,np.newaxis]

即给定a的维数(例如3),以及要沿其相乘的轴(例如axis = 1),如何生成和传递切片:

np.newaxis,:,np.newaxis

谢谢。

关注者
0
被浏览
191
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    解决方案代码-

    import numpy as np
    
    # Given axis along which elementwise multiplication with broadcasting 
    # is to be performed
    given_axis = 1
    
    # Create an array which would be used to reshape 1D array, b to have 
    # singleton dimensions except for the given axis where we would put -1 
    # signifying to use the entire length of elements along that axis  
    dim_array = np.ones((1,a.ndim),int).ravel()
    dim_array[given_axis] = -1
    
    # Reshape b with dim_array and perform elementwise multiplication with 
    # broadcasting along the singleton dimensions for the final output
    b_reshaped = b.reshape(dim_array)
    mult_out = a*b_reshaped
    

    运行示例以演示步骤-

    In [149]: import numpy as np
    
    In [150]: a = np.random.randint(0,9,(4,2,3))
    
    In [151]: b = np.random.randint(0,9,(2,1)).ravel()
    
    In [152]: whos
    Variable   Type       Data/Info
    -------------------------------
    a          ndarray    4x2x3: 24 elems, type `int32`, 96 bytes
    b          ndarray    2: 2 elems, type `int32`, 8 bytes
    
    In [153]: given_axis = 1
    

    现在,我们要沿进行元素乘法given axis = 1。让我们来创建dim_array

    In [154]: dim_array = np.ones((1,a.ndim),int).ravel()
         ...: dim_array[given_axis] = -1
         ...:
    
    In [155]: dim_array
    Out[155]: array([ 1, -1,  1])
    

    最后,重塑b形状并执行逐元素乘法:

    In [156]: b_reshaped = b.reshape(dim_array)
         ...: mult_out = a*b_reshaped
         ...:
    

    whos再次查看信息,并特别注意b_reshapedmult_out

    In [157]: whos
    Variable     Type       Data/Info
    ---------------------------------
    a            ndarray    4x2x3: 24 elems, type `int32`, 96 bytes
    b            ndarray    2: 2 elems, type `int32`, 8 bytes
    b_reshaped   ndarray    1x2x1: 2 elems, type `int32`, 8 bytes
    dim_array    ndarray    3: 3 elems, type `int32`, 12 bytes
    given_axis   int        1
    mult_out     ndarray    4x2x3: 24 elems, type `int32`, 96 bytes
    


知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看