Source code for mooonpy.fitting.fitting

# -*- coding: utf-8 -*-
from typing import Optional, List, Tuple
import warnings
import numpy as np
from lmfit import Model, Parameters


[docs] class CurveFit: """ Base class for fitting curves, subclasses have specific tools, plot and attributes """ default_fit_kws = { 'method': 'lstsq', 'max_nfev': 5000 } def_limit = (-np.inf, np.inf) # limits = defaultdict(default_factory=def_limit) def __init__(self, x, y, name=None, function=None, ic=None, limits=None, ): """ :param x: 1D array :param y: 1D array :param name: str :param model: lmfit.models.Model (or None) """ self.x = x self.y = y self.name = name self.model = Model(function) self.function = function self.params = Parameters() self.fitted_params = None if ic is None: self.ic = {} else: self.ic = ic if limits is None: self.limits = {} else: self.limits = limits self.result = None
[docs] def guess_ic(self,guess=None): pass # do nothing, can't guess. Subclass override
[docs] def guess_limit(self): pass # do nothing, can't guess. Subclass override
[docs] def make_params(self): params = Parameters() for param, value in self.ic.items(): limit = self.limits.get(param, self.def_limit) params.add(param, value=float(value), min=limit[0], max=limit[1], vary=not isinstance(value, str)) # no vary if value is string self.params = params
[docs] def fit_model(self): with warnings.catch_warnings(): warnings.simplefilter("ignore") # Ignore UserWarning: Using UFloat objects with std_dev==0 may give unexpected results. result = self.model.fit(self.y, self.params, x=self.x, **self.default_fit_kws) self.result = result fitted_params = self.get_fit_params() return fitted_params, result
[docs] def get_fit_params(self): fitted_params = {} for key, param in self.result.params.items(): fitted_params[key] = param.value self.fitted_params = fitted_params return self.fitted_params
[docs] def run_default(self): if not self.ic: self.guess_ic() if not self.ic: # if above cannot make IC ... raise Exception('ic is required, no parameters to fit') if not self.limits: self.guess_limit() self.make_params() fitted_params, result = self.fit_model() return fitted_params, result
[docs] def best_guess(self, guess_list=None, mode='r2'): if guess_list is None: return self.run_default() residuals = [] fits = [] results = [] for guess in guess_list: self.guess_ic(guess) self.guess_limit() self.make_params() fitted_params, result = self.run_default() fits.append(fitted_params) results.append(result) if mode == 'r': # closest area residuals.append(np.sum(result.residual)) elif mode == 'r2': # smallest error residuals.append(np.sum(result.residual * result.residual)) ## something based on number of args? min_ind = np.argmin(residuals) self.fitted_params = fits[min_ind] self.result = results[min_ind] return self.fitted_params, self.result
def __call__(self, x): if self.result is None: return None # default param call? else: return self.result.eval(x=x)