def solve(A, b, method, tol=1e-3):
""" General sparse solver interface.
method can be one of
- spsolve_umfpack_mmd_ata
- spsolve_umfpack_colamd
- spsolve_superlu_mmd_ata
- spsolve_superlu_colamd
- bicg
- bicgstab
- cg
- cgs
- gmres
- lgmres
- minres
- qmr
- lsqr
- lsmr
"""
if method == 'spsolve_umfpack_mmd_ata':
return spla.spsolve(A,b,use_umfpack=True, permc_spec='MMD_ATA')
elif method == 'spsolve_umfpack_colamd':
return spla.spsolve(A,b,use_umfpack=True, permc_spec='COLAMD')
elif method == 'spsolve_superlu_mmd_ata':
return spla.spsolve(A,b,use_umfpack=False, permc_spec='MMD_ATA')
elif method == 'spsolve_superlu_colamd':
return spla.spsolve(A,b,use_umfpack=False, permc_spec='COLAMD')
elif method == 'bicg':
res = spla.bicg(A,b,tol=tol)
return res[0]
elif method == 'bicgstab':
res = spla.bicgstab(A,b,tol=tol)
return res[0]
elif method == 'cg':
res = spla.cg(A,b,tol=tol)
return res[0]
elif method == 'cgs':
res = spla.cgs(A,b,tol=tol)
return res[0]
elif method == 'gmres':
res = spla.gmres(A,b,tol=tol)
return res[0]
elif method == 'lgmres':
res = spla.lgmres(A,b,tol=tol)
return res[0]
elif method == 'minres':
res = spla.minres(A,b,tol=tol)
return res[0]
elif method == 'qmr':
res = spla.qmr(A,b,tol=tol)
return res[0]
elif method == 'lsqr':
res = spla.lsqr(A,b,atol=tol,btol=tol)
return res[0]
elif method == 'lsmr':
res = spla.lsmr(A,b,atol=tol,btol=tol)
return res[0]
else:
raise Exception('UnknownSolverType')
评论列表
文章目录