def getdiag(x):
shp = x.shape
ndim = len(shp)
diag = np.zeros(shp)
if ndim == 1:
diag = 0
elif ndim == 2:
i = np.arange(min(shp))
ii = (i, i)
diag = np.ravel_multi_index(ii, shp)
elif ndim == 3:
i = np.arange(min(shp))
iii = (i, i, i)
diag = np.ravel_multi_index(iii, shp)
return diag
评论列表
文章目录