# -*- coding: utf-8 -*-
from typing import Optional, List, Tuple
import warnings
import numpy as np
from lmfit import Model, Parameters
from lmfit.models import LinearModel
[docs]
class CurveFit:
    """
    Base class for fitting curves, subclasses have specific tools, plot and attributes
    """
    default_fit_kws = {
        'method': 'lstsq',
        'max_nfev': 5000
    }
    def_limit = (-np.inf, np.inf)
    # limits = defaultdict(default_factory=def_limit)
    def __init__(self, x, y, name=None, function=None, ic=None, limits=None, ):
        """
        :param x: 1D array
        :param y: 1D array
        :param name: str
        :param model: lmfit.models.Model (or None)
        """
        self.x = x
        self.y = y
        self.name = name
        self.model = Model(function)
        self.function = function
        self.params = Parameters()
        self.fitted_params = None
        if ic is None:
            self.ic = {}
        else:
            self.ic = ic
        if limits is None:
            self.limits = {}
        else:
            self.limits = limits
        self.result = None
[docs]
    def guess_ic(self,guess=None):
        pass  # do nothing, can't guess. Subclass override 
[docs]
    def guess_limit(self):
        pass  # do nothing, can't guess. Subclass override 
[docs]
    def make_params(self):
        params = Parameters()
        for param, value in self.ic.items():
            limit = self.limits.get(param, self.def_limit)
            params.add(param, value=float(value), min=limit[0], max=limit[1],
                       vary=not isinstance(value, str)) # no vary if value is string
        self.params = params 
[docs]
    def fit_model(self):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            # Ignore UserWarning: Using UFloat objects with std_dev==0 may give unexpected results.
            result = self.model.fit(self.y, self.params, x=self.x, **self.default_fit_kws)
        self.result = result
        fitted_params = self.get_fit_params()
        return fitted_params, result 
[docs]
    def get_fit_params(self):
        fitted_params = {}
        for key, param in self.result.params.items():
            fitted_params[key] = param.value
        self.fitted_params = fitted_params
        return self.fitted_params 
[docs]
    def run_default(self):
        if not self.ic:
            self.guess_ic()
        if not self.ic:  # if above cannot make IC ...
            raise Exception('ic is required, no parameters to fit')
        if not self.limits:
            self.guess_limit()
        self.make_params()
        fitted_params, result = self.fit_model()
        return fitted_params, result 
[docs]
    def best_guess(self, guess_list=None, mode='r2'):
        if guess_list is None:
            return self.run_default()
        residuals = []
        fits = []
        results = []
        for guess in guess_list:
            self.guess_ic(guess)
            self.guess_limit()
            self.make_params()
            fitted_params, result = self.run_default()
            fits.append(fitted_params)
            results.append(result)
            if mode == 'r':  # closest area
                residuals.append(np.sum(result.residual))
            elif mode == 'r2':  # smallest error
                residuals.append(np.sum(result.residual * result.residual))
            ## something based on number of args?
        min_ind = np.argmin(residuals)
        self.fitted_params = fits[min_ind]
        self.result = results[min_ind]
        return self.fitted_params, self.result 
    def __call__(self, x):
        if self.result is None:
            return None  # default param call?
        else:
            return self.result.eval(x=x) 
[docs]
class Linear(CurveFit):
    """
    Linear Function Fit
    """
    def __init__(self, x, y, name=None, function=None, ic=None, limits=None, ):
        self.x = x
        self.y = y
        self.name = name
        self.model = LinearModel()
        self.function = function
        self.params = Parameters()
        self.fitted_params = None
        result = self.model.fit(y, x=x)
        self.result = result
        fitted_params = self.get_fit_params()
        self.fitted_params = fitted_params