def krylovMethod(self,tol=1e-8):
"""
We obtain ``pi`` by using the :func:``gmres`` solver for the system of linear equations.
It searches in Krylov subspace for a vector with minimal residual. The result is stored in the class attribute ``pi``.
Example
-------
>>> P = np.array([[0.5,0.5],[0.6,0.4]])
>>> mc = markovChain(P)
>>> mc.krylovMethod()
>>> print(mc.pi)
[ 0.54545455 0.45454545]
Parameters
----------
tol : float, optional(default=1e-8)
Tolerance level for the precision of the end result. A lower tolerance leads to more accurate estimate of ``pi``.
Remarks
-------
For large state spaces, this method may not always give a solution.
Code due to http://stackoverflow.com/questions/21308848/
"""
P = self.getIrreducibleTransitionMatrix()
#if P consists of one element, then set self.pi = 1.0
if P.shape == (1, 1):
self.pi = np.array([1.0])
return
size = P.shape[0]
dP = P - eye(size)
#Replace the first equation by the normalizing condition.
A = vstack([np.ones(size), dP.T[1:,:]]).tocsr()
rhs = np.zeros((size,))
rhs[0] = 1
pi, info = gmres(A, rhs, tol=tol)
if info != 0:
raise RuntimeError("gmres did not converge")
self.pi = pi
python类gmres()的实例源码
def _declare_options(self):
"""
Declare options before kwargs are processed in the init method.
"""
self.options.declare('solver', default='gmres', values=tuple(_SOLVER_TYPES.keys()),
desc='function handle for actual solver')
self.options.declare('restart', default=20, types=int,
desc='Number of iterations between restarts. Larger values increase '
'iteration cost, but may be necessary for convergence. This '
'option applies only to gmres.')
# changing the default maxiter from the base class
self.options['maxiter'] = 1000
self.options['atol'] = 1.0e-12
def __init__(self, M, ifunc=gmres, tol=0):
if tol <= 0:
# when tol=0, ARPACK uses machine tolerance as calculated
# by LAPACK's _LAMCH function. We should match this
tol = 2 * np.finfo(M.dtype).eps
self.M = M
self.ifunc = ifunc
self.tol = tol
if hasattr(M, 'dtype'):
self.dtype = M.dtype
else:
x = np.zeros(M.shape[1])
self.dtype = (M * x).dtype
self.shape = M.shape
def __init__(self, A, M, sigma, ifunc=gmres, tol=0):
if tol <= 0:
# when tol=0, ARPACK uses machine tolerance as calculated
# by LAPACK's _LAMCH function. We should match this
tol = 2 * np.finfo(A.dtype).eps
self.A = A
self.M = M
self.sigma = sigma
self.ifunc = ifunc
self.tol = tol
def mult_func(x):
return A.matvec(x) - sigma * M.matvec(x)
def mult_func_M_None(x):
return A.matvec(x) - sigma * x
x = np.zeros(A.shape[1])
if M is None:
dtype = mult_func_M_None(x).dtype
self.OP = LinearOperator(self.A.shape,
mult_func_M_None,
dtype=dtype)
else:
dtype = mult_func(x).dtype
self.OP = LinearOperator(self.A.shape,
mult_func,
dtype=dtype)
self.shape = A.shape
def solve(A, b, method, tol=1e-3):
""" General sparse solver interface.
method can be one of
- spsolve_umfpack_mmd_ata
- spsolve_umfpack_colamd
- spsolve_superlu_mmd_ata
- spsolve_superlu_colamd
- bicg
- bicgstab
- cg
- cgs
- gmres
- lgmres
- minres
- qmr
- lsqr
- lsmr
"""
if method == 'spsolve_umfpack_mmd_ata':
return spla.spsolve(A,b,use_umfpack=True, permc_spec='MMD_ATA')
elif method == 'spsolve_umfpack_colamd':
return spla.spsolve(A,b,use_umfpack=True, permc_spec='COLAMD')
elif method == 'spsolve_superlu_mmd_ata':
return spla.spsolve(A,b,use_umfpack=False, permc_spec='MMD_ATA')
elif method == 'spsolve_superlu_colamd':
return spla.spsolve(A,b,use_umfpack=False, permc_spec='COLAMD')
elif method == 'bicg':
res = spla.bicg(A,b,tol=tol)
return res[0]
elif method == 'bicgstab':
res = spla.bicgstab(A,b,tol=tol)
return res[0]
elif method == 'cg':
res = spla.cg(A,b,tol=tol)
return res[0]
elif method == 'cgs':
res = spla.cgs(A,b,tol=tol)
return res[0]
elif method == 'gmres':
res = spla.gmres(A,b,tol=tol)
return res[0]
elif method == 'lgmres':
res = spla.lgmres(A,b,tol=tol)
return res[0]
elif method == 'minres':
res = spla.minres(A,b,tol=tol)
return res[0]
elif method == 'qmr':
res = spla.qmr(A,b,tol=tol)
return res[0]
elif method == 'lsqr':
res = spla.lsqr(A,b,atol=tol,btol=tol)
return res[0]
elif method == 'lsmr':
res = spla.lsmr(A,b,atol=tol,btol=tol)
return res[0]
else:
raise Exception('UnknownSolverType')