def forward_cpu(self, x):
a, b = x
batch_size = a.shape[0]
shape = self._output_shape(a, b)
ret_dtype = numpy.find_common_type([a.dtype, b.dtype], [])
ret = numpy.empty(shape, dtype=ret_dtype)
for i in six.moves.range(batch_size):
ret[i] = _matmul(
a[i], b[i], transa=self.transa, transb=self.transb)
return ret,
评论列表
文章目录