def __matmul__(self, other):
dim_self = self.dim()
dim_other = other.dim()
# TODO: should this really be dot product?
# if dim_self == 1 and dim_other == 1:
# return self.dot(other)
if dim_self == 2 and dim_other == 1:
return torch.mv(self, other)
elif dim_self == 2 and dim_other == 2:
return torch.mm(self, other)
评论列表
文章目录