Source code for fitter
"""
Abstraction usefull for the calibration of the closures.
This module contains the classes :class:`FittableParameter` and :class:`FittableParametersSet` which
are used to make a link between the optimization part and the closures structures. These class can
be obtained by the prefix :code:`tunax.fitter.` or directly by :code:`tunax.`.
"""
from __future__ import annotations
from typing import Dict
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Float, Array
from tunax.closure import ClosureParametersAbstract, Closure
from tunax.closures_registry import CLOSURES_REGISTRY
[docs]
class FittableParameter(eqx.Module):
"""
Calibration configuration for one parameter.
An instance of this class must be created for every parameter of the
closure that will be calibrated and for every not default values during
calibration. The constructor takes all the attributes as parameters.
Attributes
----------
do_fit : bool
The parameter will be calibrated.
val : float, default=0.
If :attr:`do_fit` : the initial value for calibration (at the first
step of the calibration) ; if :attr:`do_fit` is false : the constant
value to take for this parameter if it's not the default one in the
closure.
"""
do_fit: bool
val: float = 0.
[docs]
class FittableParametersSet(eqx.Module):
"""
Complete closure calibration parameters.
This class is the set of all the configurations on the closure parameters
for the calibration. It makes the link between the array on which the
optimizer works and the closure parameters class.
Parameters
----------
coef_fit_dict : Dict[str, FittableParameter]
cf. :attr:`coef_fit_dict`.
closure_name : str
Name of the chosen closure, must be a key of
:attr:`~closures_registry.CLOSURES_REGISTRY`, see its documentation
for the available closures.
Attributes
----------
coef_fit_dico : Dict[str, FittableParameter]
The set of all the configurations of all the parameters that will be
calibrated and the one constants but not with the default value of the
closure.
closure : Closure
The abstraction that represent the used closure.
"""
coef_fit_dict: Dict[str, FittableParameter]
closure: Closure
def __init__(
self,
coef_fit_dict: Dict[str, FittableParameter],
closure_name: str
) -> None:
self.coef_fit_dict = coef_fit_dict
self.closure = CLOSURES_REGISTRY[closure_name]
@property
def n_calib(self) -> int:
"""
Number of variables that are calibrated.
Returns
-------
nc : int
Number of variables that are calibrated.
"""
nc = 0
for coef_fit in self.coef_fit_dict.values():
if coef_fit.do_fit:
nc += 1
return nc
[docs]
def fit_to_closure(
self,
x: Float[Array, 'nc']
) -> ClosureParametersAbstract:
"""
Transforms an fitted array in a set of closure parameters.
This method copy the fixed non-default values of :attr:`coef_fit_dict`
and copy the calibrated values from :code:`x`. Which is simply the
parameters values in the order that is indicated by
:attr:`coef_fit_dict`.
Parameters
----------
x : float :class:`~jax.Array` of shape (nc)
The array on which the optimize works to find the best values. It
is the array of the parameters that are calibrated.
Returns
-------
clo_params : ClosureParametersAbstract
The instance of the closure parameters class (child class of
:class:`~closure.ClosureParametersAbstract`) with the modifications
of the calibration step.
"""
clo_coef_dico = {}
i_x = 0
for coef_name, coef_fit in self.coef_fit_dict.items():
if coef_fit.do_fit:
clo_coef_dico[coef_name] = x[i_x]
i_x += 1
else:
clo_coef_dico[coef_name] = coef_fit.val
return self.closure.parameters_class(**clo_coef_dico)
[docs]
def gen_init_val(self) -> Float[Array, 'nc']:
"""
Produce the fitted array for the first calibration step.
This method simply copy the initial values of the calibrated
coefficients in an array :code:`x` which will be used as the first
calibration step for the optimizer.
Returns
-------
x : float :class:`~jax.Array` of shape (nc)
The initial vector for the optimizer at the first step of
calibration.
"""
x = []
for coef_fit in self.coef_fit_dict.values():
if coef_fit.do_fit:
x.append(coef_fit.val)
return jnp.array(x)