def batch_dot(x, y, axes=None):
if type(axes) is int:
axes = (axes, axes)
def _dot(X):
x, y = X
x_shape = x.size()
y_shape = y.size()
x_ndim = len(x_shape)
y_ndim = len(y_shape)
if x_ndim <= 3 and y_ndim <= 3:
if x_ndim < 3:
x_diff = 3 - x_ndim
for i in range(diff):
x = torch.unsqueeze(x, x_ndim + i)
else:
x_diff = 0
if y_ndim < 3:
y_diff = 3 - y_ndim
for i in range(diff):
y = torch.unsqueeze(y, y_ndim + i)
else:
y_diff = 0
if axes[0] == 1:
x = torch.transpose(x, 1, 2)
elif axes[0] == 2:
pass
else:
raise Exception('Invalid axis : ' + str(axes[0]))
if axes[1] == 2:
x = torch.transpose(x, 1, 2)
# -------TODO--------------#
评论列表
文章目录