def dot(a1, a2):
# internally: for matrix-matrix multiplies only; vectors are treated like special cases.
a1 = as_garray(a1); a2 = as_garray(a2)
if a1.ndim==0 or a2.ndim==0: return a1*a2
if a1.ndim==a2.ndim==1:
if a1 is a2: return sum(a1**2)
else: return dot(a1.reshape(1, a1.size), a2.reshape(a2.size, 1)).item()
if a1.ndim==2 and a2.ndim==1: return dot(a1, a2.reshape(a2.size, 1)).ravel() # treat a2 like a column vector
if a1.ndim==1 and a2.ndim==2: return dot(a1._add_axes(2), a2)[0] # treat a1 like a row vector
if a1.shape[-1] != a2.shape[-2]: raise ValueError('arrays not aligned for dot product. a dot product was requested of arrays with shapes %s and %s' % (a1.shape, a2.shape))
if a1.ndim==a2.ndim==2:
retShape = (a1.shape[0], a2.shape[1])
if a1.shape[1]==0: return zeros(retShape) # cudamat bug workaround
ret = empty(retShape)
if ret.size!=0: _cudamat.dot(a2._base_as_2d(), a1._base_as_2d(), ret._base_as_2d())
return ret
if a1.ndim >= 2 and a2.ndim >= 2:
# this is not necessarily fast, because if a2.ndim>=3 then it involves a transpose
a12 = ( a1.reshape_2d(-1) if a1.ndim!=2 else a1)
a22 = ( a2.transpose((a2.ndim-2,) + tuple(xrange(a2.ndim-2)) + (a2.ndim-1,)).reshape_2d(1)
if a2.ndim!=2 else
a2)
retShape = _deleteT2(a1.shape, -1) + _deleteT2(a2.shape, -2)
return dot(a12, a22).reshape(retShape)
raise NotImplementedError('dot with arguments of shapes %s and %s' % (a1.shape, a2.shape))
评论列表
文章目录