块元素逐点积
是否有一种优雅,麻木的方式逐点应用点积?或者如何将以下代码转换为更好的版本?
m0 # shape (5, 3, 2, 2)
m1 # shape (5, 2, 2)
r = np.empty((5, 3, 2, 2))
for i in range(5):
for j in range(3):
r[i, j] = np.dot(m0[i, j], m1[i])
提前致谢!
-
方法1
使用
np.einsum
-np.einsum('ijkl,ilm->ijkm',m0,m1)
涉及的步骤:
-
保持输入的第一个轴对齐。
-
在减少总和中使最后一个轴相
m0
对于第二个轴丢失m1
。 -
让其余的轴以外积方式从元素展开
m0
并m1
展开 /扩展。
方法#2
如果您正在寻找性能并且求和轴的长度较小,那么最好使用单循环并使用
matrix-multiplication
with
np.tensordot
,例如-s0,s1,s2,s3 = m0.shape s4 = m1.shape[-1] r = np.empty((s0,s1,s2,s4)) for i in range(s0): r[i] = np.tensordot(m0[i],m1[i],axes=([2],[0]))
方法#3
现在,
np.dot
可以将其有效地用于2D输入,以进一步提高性能。因此,有了它,修改后的版本虽然更长一些,但希望性能最好的版本是-s0,s1,s2,s3 = m0.shape s4 = m1.shape[-1] m0.shape = s0,s1*s2,s3 # Get m0 as 3D for temporary usage r = np.empty((s0,s1*s2,s4)) for i in range(s0): r[i] = m0[i].dot(m1[i]) r.shape = s0,s1,s2,s4 m0.shape = s0,s1,s2,s3 # Put m0 back to 4D
运行时测试
功能定义-
def original_app(m0, m1): s0,s1,s2,s3 = m0.shape s4 = m1.shape[-1] r = np.empty((s0,s1,s2,s4)) for i in range(s0): for j in range(s1): r[i, j] = np.dot(m0[i, j], m1[i]) return r def einsum_app(m0, m1): return np.einsum('ijkl,ilm->ijkm',m0,m1) def tensordot_app(m0, m1): s0,s1,s2,s3 = m0.shape s4 = m1.shape[-1] r = np.empty((s0,s1,s2,s4)) for i in range(s0): r[i] = np.tensordot(m0[i],m1[i],axes=([2],[0])) return r def dot_app(m0, m1): s0,s1,s2,s3 = m0.shape s4 = m1.shape[-1] m0.shape = s0,s1*s2,s3 # Get m0 as 3D for temporary usage r = np.empty((s0,s1*s2,s4)) for i in range(s0): r[i] = m0[i].dot(m1[i]) r.shape = s0,s1,s2,s4 m0.shape = s0,s1,s2,s3 # Put m0 back to 4D return r
时间和验证-
In [291]: # Inputs ...: m0 = np.random.rand(50,30,20,20) ...: m1 = np.random.rand(50,20,20) ...: In [292]: out1 = original_app(m0, m1) ...: out2 = einsum_app(m0, m1) ...: out3 = tensordot_app(m0, m1) ...: out4 = dot_app(m0, m1) ...: ...: print np.allclose(out1, out2) ...: print np.allclose(out1, out3) ...: print np.allclose(out1, out4) ...: True True True In [293]: %timeit original_app(m0, m1) ...: %timeit einsum_app(m0, m1) ...: %timeit tensordot_app(m0, m1) ...: %timeit dot_app(m0, m1) ...: 100 loops, best of 3: 10.3 ms per loop 10 loops, best of 3: 31.3 ms per loop 100 loops, best of 3: 5.12 ms per loop 100 loops, best of 3: 4.06 ms per loop
-