def matmul_v3(a, b, **kwargs):
if (a.ndim, b.ndim) == (3, 3):
return F.batch_matmul(a, b, **kwargs)
elif (a.ndim, b.ndim) == (2, 2):
return F.matmul(a, b, **kwargs)
else:
raise Exception("unsupported shapes: {}, {}".format(
a.shape, b.shape))
评论列表
文章目录