如何很好地使用Cython更快地求解微分方程?

发布于 2021-01-29 16:57:47

我想减少Scipy的odeint解微分方程所花费的时间。

为了练习,我使用了科学计算Python涵盖的示例
作为模板。因为odeint接受一个函数f作为参数,所以我将此函数编写为静态类型的Cython版本,并希望odeint的运行时间会大大减少。

该函数f包含在名为的文件中ode.pyx,如下所示:

import numpy as np
cimport numpy as np
from libc.math cimport sin, cos

def f(y, t, params):
  cdef double theta = y[0], omega = y[1]
  cdef double Q = params[0], d = params[1], Omega = params[2]
  cdef double derivs[2]
  derivs[0] = omega
  derivs[1] = -omega/Q + np.sin(theta) + d*np.cos(Omega*t)
  return derivs

def fCMath(y, double t, params):
  cdef double theta = y[0], omega = y[1]
  cdef double Q = params[0], d = params[1], Omega = params[2]
  cdef double derivs[2]
  derivs[0] = omega
  derivs[1] = -omega/Q + sin(theta) + d*cos(Omega*t)
  return derivs

然后,我创建一个文件setup.py来编写函数:

from distutils.core import setup
from Cython.Build import cythonize

setup(ext_modules=cythonize('ode.pyx'))

求解微分方程的脚本(也包含的Python版本f)称为solveODE.py,其外观为:

import ode
import numpy as np
from scipy.integrate import odeint
import time

def f(y, t, params):
    theta, omega = y
    Q, d, Omega = params
    derivs = [omega,
             -omega/Q + np.sin(theta) + d*np.cos(Omega*t)]
    return derivs

params = np.array([2.0, 1.5, 0.65])
y0 = np.array([0.0, 0.0])
t = np.arange(0., 200., 0.05)

start_time = time.time()
odeint(f, y0, t, args=(params,))
print("The Python Code took: %.6s seconds" % (time.time() - start_time))

start_time = time.time()
odeint(ode.f, y0, t, args=(params,))
print("The Cython Code took: %.6s seconds ---" % (time.time() - start_time))

start_time = time.time()
odeint(ode.fCMath, y0, t, args=(params,))
print("The Cython Code incorpoarting two of DavidW_s suggestions took: %.6s seconds ---" % (time.time() - start_time))

然后,我运行:

python setup.py build_ext --inplace
python solveODE.py

在终端。

python版本的时间约为0.055秒,而Cython版本的时间约为0.04秒。

是否有人建议改进我解决微分方程的尝试,最好不要使用Cython修改odeint例程本身?

编辑

我将DavidW的建议合并到了两个文件中ode.pyxsolveODE.py用这些建议运行代码仅花费了大约0.015秒。

关注者
0
被浏览
51
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    进行最简单的更改(可能会为您带来很多好处)是使用C数学库sin并对cos单个数字而不是数字进行运算。调用它numpy以及确定不是数组所花费的时间相当昂贵。

    from libc.math cimport sin, cos
    
        # later
        -omega/Q + sin(theta) + d*cos(Omega*t)
    

    我很想为输入分配一个类型d(在不更改接口的情况下,其他任何输入都不容易输入):

    def f(y, double t, params):
    

    我想我也会像在Python版本中一样返回一个列表。我认为使用C数组不会带来很多好处。



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看