def test_find_common_type_boolean(self):
# Ticket #1695
assert_(np.find_common_type([], ['?', '?']) == '?')
python类find_common_type()的实例源码
def test_scalar_loses1(self):
res = np.find_common_type(['f4', 'f4', 'i2'], ['f8'])
assert_(res == 'f4')
def test_scalar_loses2(self):
res = np.find_common_type(['f4', 'f4'], ['i8'])
assert_(res == 'f4')
def test_scalar_wins(self):
res = np.find_common_type(['f4', 'f4', 'i2'], ['c8'])
assert_(res == 'c8')
def test_scalar_wins2(self):
res = np.find_common_type(['u4', 'i4', 'i4'], ['f4'])
assert_(res == 'f8')
def test_where_type(self):
# Test the type conservation with where
x = np.arange(4, dtype=np.int32)
y = np.arange(4, dtype=np.float32) * 2.2
test = where(x > 1.5, y, x).dtype
control = np.find_common_type([np.int32, np.float32], [])
assert_equal(test, control)
def test_find_common_type_boolean(self):
# Ticket #1695
assert_(np.find_common_type([], ['?', '?']) == '?')
def test_scalar_loses1(self):
res = np.find_common_type(['f4', 'f4', 'i2'], ['f8'])
assert_(res == 'f4')
def test_scalar_loses2(self):
res = np.find_common_type(['f4', 'f4'], ['i8'])
assert_(res == 'f4')
def test_scalar_wins(self):
res = np.find_common_type(['f4', 'f4', 'i2'], ['c8'])
assert_(res == 'c8')
def test_scalar_wins2(self):
res = np.find_common_type(['u4', 'i4', 'i4'], ['f4'])
assert_(res == 'f8')
def test_where_type(self):
# Test the type conservation with where
x = np.arange(4, dtype=np.int32)
y = np.arange(4, dtype=np.float32) * 2.2
test = where(x > 1.5, y, x).dtype
control = np.find_common_type([np.int32, np.float32], [])
assert_equal(test, control)
def test_find_common_type_boolean(self):
# Ticket #1695
assert_(np.find_common_type([], ['?', '?']) == '?')
def test_scalar_loses1(self):
res = np.find_common_type(['f4', 'f4', 'i2'], ['f8'])
assert_(res == 'f4')
def test_scalar_loses2(self):
res = np.find_common_type(['f4', 'f4'], ['i8'])
assert_(res == 'f4')
def test_scalar_wins(self):
res = np.find_common_type(['f4', 'f4', 'i2'], ['c8'])
assert_(res == 'c8')
def test_scalar_wins2(self):
res = np.find_common_type(['u4', 'i4', 'i4'], ['f4'])
assert_(res == 'f8')
def test_where_type(self):
# Test the type conservation with where
x = np.arange(4, dtype=np.int32)
y = np.arange(4, dtype=np.float32) * 2.2
test = where(x > 1.5, y, x).dtype
control = np.find_common_type([np.int32, np.float32], [])
assert_equal(test, control)
def directsum(tensors,labels,axes=()):
'''
The directsum of a couple of tensors.
Parameters
----------
tensors : list of Tensor
The tensors to be directsummed.
labels : list of Label
The labels of the directsum.
axes : list of integer, optional
The axes along which the directsum is block diagonal.
Returns
-------
Tensor
The directsum of the tensors.
'''
for i,tensor in enumerate(tensors):
if i==0:
assert tensor.ndim>len(axes)
ndim,qnon,shps=tensor.ndim,tensor.qnon,[tensor.shape[axis] for axis in axes]
alters,shape,dtypes=set(xrange(ndim))-set(axes),list(tensor.shape),[tensor.dtype]
else:
assert tensor.ndim==ndim and tensor.qnon==qnon and [tensor.shape[axis] for axis in axes]==shps
for alter in alters: shape[alter]+=tensor.shape[alter]
dtypes.append(tensor.dtype)
data=np.zeros(tuple(shape),dtype=np.find_common_type([],dtypes))
for i,tensor in enumerate(tensors):
if i==0:
slices=[slice(0,tensor.shape[axis]) if axis in alters else slice(None,None,None) for axis in xrange(ndim)]
else:
for alter in alters:
slices[alter]=slice(slices[alter].stop,slices[alter].stop+tensor.shape[alter])
data[tuple(slices)]=tensor[...]
if qnon:
for alter in alters:
labels[alter].qns=QuantumNumbers.union([tensor.labels[alter].qns for tensor in tensors])
for axis in axes:
labels[axis].qns=next(iter(tensor)).labels[axis].qns
return Tensor(data,labels=labels)
def solve(A,b,rtol=10**-8):
'''
Solve the matrix equation A*x=b by QR decomposition.
Parameters
----------
A : 2d ndarray
The coefficient matrix.
b : 1d ndarray
The ordinate values.
rtol : np.float64
The relative tolerance of the solution.
Returns
-------
1d ndarray
The solution.
Raises
------
LinAlgError
When no solution exists.
'''
assert A.ndim==2
nrow,ncol=A.shape
if nrow>=ncol:
result=np.zeros(ncol,dtype=np.find_common_type([],[A.dtype,b.dtype]))
q,r=sl.qr(A,mode='economic',check_finite=False)
temp=q.T.dot(b)
for i,ri in enumerate(r[::-1]):
result[-1-i]=(temp[-1-i]-ri[ncol-i:].dot(result[ncol-i:]))/ri[-1-i]
else:
temp=np.zeros(nrow,dtype=np.find_common_type([],[A.dtype,b.dtype]))
q,r=sl.qr(dagger(A),mode='economic',check_finite=False)
for i,ri in enumerate(dagger(r)):
temp[i]=(b[i]-ri[:i].dot(temp[:i]))/ri[i]
result=q.dot(temp)
if not np.allclose(A.dot(result),b,rtol=rtol):
raise sl.LinAlgError('solve error: no solution.')
return result