def op_matmul(s_x_, s_y_, axes_=(-2, -1)):
'''
limited implementation of np.matmul, does not support broadcasting
Args:
s_x_: (batch of) matrix(matrices)
s_y_: (batch of) matrix(matrices)
axes_: tuple of int, the axes for the matrix
'''
assert s_x_.ndim == s_y_.ndim
ndim = s_x_.ndim
assert -ndim <= axes_[0] < ndim
assert -ndim <= axes_[1] < ndim
assert ndim >= 2
axes = axes_[0]%ndim, axes_[1]%ndim
if ndim == 2:
if axes == (0,1):
return T.dot(s_x_, s_y_)
else:
return T.dot(s_y_, s_x_)
s_shp = T.shape(s_x_)
s_size = reduce(T.mul, [s_shp[i] for i in range(s_x_.ndim) if i not in axes])
s_szu = s_shp[axes[0]]
s_szv = s_shp[axes[1]]
s_szw = T.shape(s_y_)[axes[1]]
transpp = list(range(ndim))
transpp[axes[0]], transpp[ndim-2] = transpp[ndim-2], transpp[axes[0]]
transpp[axes[1]], transpp[ndim-1] = transpp[ndim-1], transpp[axes[1]]
s_shp2 = [s_shp[a] for a in transpp]
s_shp2[axes[1]] = s_szw
s_x = s_x_.transpose(*transpp).reshape((s_size, s_szu, s_szv))
s_y = s_y_.transpose(*transpp).reshape((s_size, s_szv, s_szw))
return T.batched_dot(s_x, s_y).reshape(s_shp2).transpose(transpp)
评论列表
文章目录