def __init__(self, data, categorical_parameters, continuous_parameters, func=None, order=1):
self.key_columns = categorical_parameters
self.parameter_columns = continuous_parameters
self.func = func
if len(self.parameter_columns) not in [1, 2]:
raise ValueError("Only interpolation over 1 or 2 variables is supported")
if len(self.parameter_columns) == 1 and order == 0:
raise ValueError("Order 0 only supported for 2d interpolation")
# These are the columns which the interpolation function will approximate
value_columns = sorted(data.columns.difference(set(self.key_columns)|set(self.parameter_columns)))
if self.key_columns:
# Since there are key_columns we need to group the table by those
# columns to get the sub-tables to fit
sub_tables = data.groupby(self.key_columns)
else:
# There are no key columns so we will fit the whole table
sub_tables = {None: data}.items()
self.interpolations = {}
for key, base_table in sub_tables:
if base_table.empty:
continue
# For each permutation of the key columns build interpolations
self.interpolations[key] = {}
for value_column in value_columns:
# For each value in the table build an interpolation function
if len(self.parameter_columns) == 2:
# 2 variable interpolation
if order == 0:
x = base_table[list(self.parameter_columns)]
y = base_table[value_column]
func = interpolate.NearestNDInterpolator(x=x.values, y=y.values)
else:
index, column = self.parameter_columns
table = base_table.pivot(index=index, columns=column, values=value_column)
x = table.index.values
y = table.columns.values
z = table.values
func = interpolate.RectBivariateSpline(x=x, y=y, z=z, ky=order, kx=order).ev
else:
# 1 variable interpolation
base_table = base_table.sort_values(by=self.parameter_columns[0])
x = base_table[self.parameter_columns[0]]
y = base_table[value_column]
func = interpolate.InterpolatedUnivariateSpline(x, y, k=order)
self.interpolations[key][value_column] = func
评论列表
文章目录