interpolate.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytrip 作者: pytrip 项目源码 文件源码
def __get_1d_function(x, y, kind):
        """
        Train 1-D interpolator
        :param x: x-coordinates of data points
        :param y: y-coordinates of data points
        :param kind: 'linear' or 'spline' interpolation type
        :return Interpolator callable object
        """
        def fun_interp(t):
            return np.interp(t, x, y)

        # input consistency check
        if len(x) != len(y):
            logger.error("len(x) = {:d}, len(y) = {:d}. Both should be equal".format(len(x), len(y)))
            raise Exception("1-D interpolation: X and Y should have the same shape")
        # 1-element data set, return fixed value
        if len(y) == 1:
            # define constant
            def fun_const(t):
                """
                Helper function
                :param t: array-like or scalar
                :return: array of constant values if t is an array of constant scalar if t is a scalar
                """
                try:
                    result = np.ones_like(t) * y[0]  # t is an array
                except TypeError:
                    result = y[0]  # t is a scalar
                return result
            result = fun_const
        # 2-element data set, use linear interpolation from numpy
        elif len(y) == 2:
            result = fun_interp
        else:
            if kind == 'spline':
                # 3-rd degree spline interpolation, passing through all points
                try:
                    from scipy.interpolate import InterpolatedUnivariateSpline
                except ImportError as e:
                    logger.error("Please install scipy on your platform to be able to use spline-based interpolation")
                    raise e
                k = 3
                if len(y) == 3:  # fall back to 2-nd degree spline if only 3 points are present
                    k = 2
                result = InterpolatedUnivariateSpline(x, y, k=k)
            elif kind == 'linear':
                result = fun_interp
            else:
                raise ("Unsupported interpolation type {:s}.".format(kind))
        return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号