gnumpy.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:CNNbasedMedicalSegmentation 作者: BRML 项目源码 文件源码
def dot(a1, a2):
 # internally: for matrix-matrix multiplies only; vectors are treated like special cases.
 a1 = as_garray(a1); a2 = as_garray(a2)
 if a1.ndim==0 or a2.ndim==0: return a1*a2
 if a1.ndim==a2.ndim==1:
  if a1 is a2: return sum(a1**2)
  else: return dot(a1.reshape(1, a1.size), a2.reshape(a2.size, 1)).item()
 if a1.ndim==2 and a2.ndim==1: return dot(a1, a2.reshape(a2.size, 1)).ravel() # treat a2 like a column vector
 if a1.ndim==1 and a2.ndim==2: return dot(a1._add_axes(2), a2)[0]   # treat a1 like a row vector
 if a1.shape[-1] != a2.shape[-2]: raise ValueError('arrays not aligned for dot product. a dot product was requested of arrays with shapes %s and %s' % (a1.shape, a2.shape))
 if a1.ndim==a2.ndim==2:
  retShape = (a1.shape[0], a2.shape[1])
  if a1.shape[1]==0: return zeros(retShape) # cudamat bug workaround
  ret = empty(retShape)
  if ret.size!=0: _cudamat.dot(a2._base_as_2d(), a1._base_as_2d(), ret._base_as_2d())
  return ret
 if a1.ndim >= 2 and a2.ndim >= 2:
  # this is not necessarily fast, because if a2.ndim>=3 then it involves a transpose
  a12 = ( a1.reshape_2d(-1) if a1.ndim!=2 else a1)
  a22 = ( a2.transpose((a2.ndim-2,) + tuple(xrange(a2.ndim-2)) + (a2.ndim-1,)).reshape_2d(1)
          if a2.ndim!=2 else
          a2)
  retShape = _deleteT2(a1.shape, -1) + _deleteT2(a2.shape, -2)
  return dot(a12, a22).reshape(retShape)
 raise NotImplementedError('dot with arguments of shapes %s and %s' % (a1.shape, a2.shape))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号