Source code for database

"""
Abstraction for calibration databases.


This module the objects that are used in Tunax to describe a :class:`Database` of observations used
for a calibration. By *datas* (:class:`Data`), we refer ton the union of a trajectory and a physical
case which represent a measurment or a reference as a Large Eddy Simulation (LES) for example. And
by *observations* (:class:`Obs`), we refer to the union of a trajectory and the Tunax model which
corresponds to it. These classes can be obtained by the prefix :code:`tunax.database.` or directly
by :code:`tunax.`.

"""

from __future__ import annotations
import warnings
from typing import Union, Optional, Tuple, List, Dict, TypeAlias, cast

import yaml
import xarray as xr
import equinox as eqx
import numpy as np
import jax.numpy as jnp
from h5py import File as H5pyFile
from h5py import Dataset as H5pyDataset
from h5py import Group as H5pyGroup
from jaxtyping import Array, Float

from tunax.space import Grid, Trajectory, TRACERS_NAMES, VARIABLE_NAMES
from tunax.case import Case
from tunax.functions import _format_to_single_line
from tunax.model import SingleColumnModel

DimsType: TypeAlias = Tuple[Optional[int]]
"""Type that represent the dimensions on which load the datas from a file."""


def _get_var_jl(
        jl_file: H5pyFile,
        var_names: Dict[str, str],
        var: str,
        n: int,
        dims: Union[DimsType, Dict[str, DimsType]] = (None,),
        suffix: str = ''
    ) -> Float[Array, 'n']:
    """
    This function retrieves the right value of a variable in a .jld2 file.

    This function retrives first the raw data that corresponding to the variable var (name from
    Tunax) in accord to the names registry var_names, with an eventual suffix at the end. In a
    second time it will project in the good dimensions the raw data with the indications in dims to
    get a one-dimension array. And finally it will remove the eventual borders by taking the middle
    part of the array of lenght n.
    
    Parameters
    ----------
    jl_file : H5pyFile
        A .jld2 file loaded with the package h5py.
    var_names : Dict[str, str]
        The reference on which search the variable on the file, the keys are the Tunax names of
        variables and the values are the names in the file (which are actually paths).
    var : str
        The name of the variable to search in terms of Tunax.
    n : int
        The excpected lenght of the variable. The function will keep the middle part of lenght n
        on the array that it will directly extract from the file.
    dims : DimsType or Dict[str, DimsType], default=(None,)
        It contains the dimensions on which search the right array. If it's a dictionnary, it's like
        var_names, the keys are the names of the variables in terms of Tunax and the values are the
        dimensions for each variable. Then we have a Tuple of int or Nones which corresponds at
        every axis of the raw data from the file. If an axis is indexed with None, it means that we
        keep this dimension, if an axis is indexed with an integer, it means that we reduce this
        axis to the value of the raw data on this index.
    suffix : str, default=''
        A string suffix to add at the end of the path that is in var_names.

    Returns
    -------
    arr : float :class:`~jax.Array` of shape (n)
        Value of the right array in the :code:`.jld2` file.
    
    Raises
    ------
    ValueError
        If the lenght of the tuple dims doesn't have the same number of elements than the number of
        axis in the raw data from the file.

    Warns
    -----
    Odd shift
        If the number n hasn't the same parity as the lenght of the raw data array. Then the borders
        are not symetric.
    """
    jl_var = cast(H5pyDataset, jl_file[f'{var_names[var]}{suffix}'])
    if isinstance(dims, dict):
        dims = dims[var]
    if len(jl_var.shape) != len(dims):
        raise ValueError(_format_to_single_line(f"""
            The tuple parameter `dims` of value {dims} must have the length of the number of
            dimension of the variable {var} array in the `jl_file`.
        """))
    dims_slice = tuple(slice(None) if x is None else x for x in dims)
    jl_var_1d = jl_var[dims_slice]
    double_shift = jl_var_1d.shape[0] - n
    shift = double_shift//2
    if double_shift%2 == 1:
        warnings.warn(_format_to_single_line(f"""
            The length array from the `jl_file` of the variable {var} minus `n` is an odd number :
            the removed boundaries are taken 1 point thinner on the bottom side than on the surface
            side.
        """))
        return jl_var_1d[shift:-shift-1]
    return jl_var_1d[shift:-shift]


[docs] class Data(eqx.Module): """ Abstraction to represent an element of the database from the point of view of Tunax. This abstraction is the link between the time-series of :class:`Trajectory` and a physical situation described by :class:`Case`. It can eventually contains metadatas. Typically this class appears when one want to import the different element of a database of observations or simulations. The constructor takes all the attributes as parameters. Attributes ---------- trajectory : Trajectory The time-series of the variables that represent this data. case : Case The physical case that represent this data. metadatas : Dict[str, float], default={} Some metadatas that we want to use later. It can be some values of the :attr:`case` that we want to set by hand later. Raises ------ ValueError If the :attr:`~space.Trajectory.time` of :attr:`trajectory` is not build with constant time-steps. """ trajectory: Trajectory case: Case metadatas: Dict[str, float] = eqx.field(static=True) def __init__( self, trajectory: Trajectory, case: Case, metadatas: Optional[Dict[str, float]]=None ) -> None: time = trajectory.time steps = time[1:] - time[:-1] if not jnp.all(steps == steps[0]): raise ValueError('Tunax only handle constant output time-steps') self.trajectory = trajectory self.case = case if metadatas is None: metadatas = {} self.metadatas = metadatas
[docs] @classmethod def from_nc_yaml( cls, nc_path: str, yaml_path: str, var_names: Dict[str, str], eos_tracers: str = 't', do_pt: bool = False ) -> Data: """ Create a :class:`Data` instance from a *netcdf* and a *yaml* files. This class method build a trajectory from the :code:`.nc` file :code:`nc_path`, it build the physical parameters from the configuration file :code:`yaml_path`. :code:`var_names` is used to do the link between Tunax name convention and the one from the used database. Parameters ---------- nc_path : str Path of the *netcdf* file that contains the time-series of the observation trajectory. The file should contains at least the three dimensions :attr:`~space.Grid.zr`, :attr:`~space.Grid.zw` and :attr:`~space.Trajectory.time`. The time-series can be created with default values if they are not present in the file (only for :attr:`space.Trajectory.u` and :attr:`space.Trajectory.v`). Otherwise, they must have the good dimensions described in :class:`~space.Trajectory`. yaml_path : str Path of the *yaml* file that contains the parameters and forcing that describe the observation. The parameters should be float numbers and directly accessible from the root of the file with a key. Only the parameters that are described in :class:`~case.Case` will be takend in account. var_names : Dict[str, str] Link between the convention names in Tunax and the ones in the database. The keys are the Tunax names and the values are the names in the database. It works for variables of the :class:`~space.Trajectory` and fornthe parameters of :class:`~case.Case`. It must at least contains entries for :attr:`~space.Grid.zr` :attr:`~space.Grid.zw` and :attr:`~space.Trajectory.time` eos_tracers : str, default='t' Tracers used for the equation of state, cf. :attr:`~case.Case.eos_tracers`. do_pt : bool, default=False Compute or not a passive tracer, cf. :attr:`~case.Case.do_pt`. Returns ------- data : Data An object that represent these files. """ ds = xr.load_dataset(nc_path) # dimensions zr = jnp.array(ds[var_names['zr']].values) zw = jnp.array(ds[var_names['zw']].values) grid = Grid(zr, zw) time = jnp.array(ds[var_names['time']].values) nt, = time.shape nz = grid.nz # variables u_name = var_names['u'] if u_name == '': u = jnp.full((nt, nz), 0.) else: u = jnp.array(ds[u_name].values) v_name = var_names['v'] if v_name == '': v = jnp.full((nt, nz), 0.) else: v = jnp.array(ds[v_name].values) tracers_dict = {} for tracer in TRACERS_NAMES: if tracer in var_names.keys(): tracers_dict[tracer] = jnp.array(ds[var_names[tracer]].values) # writing trajectory trajectory = Trajectory(grid, time, u, v, **tracers_dict) # writing the case with open(yaml_path, 'r', encoding='utf-8') as f: metadatas = yaml.safe_load(f) case = Case() case_attributes = list(vars(case).keys()) def get_pytree_fun(att: str): return lambda t: getattr(t, att) for att in case_attributes: if att in var_names.keys(): case = eqx.tree_at(get_pytree_fun(att), case, metadatas[var_names[att]]) case = eqx.tree_at(lambda t: getattr(t, 'eos_tracers'), case, eos_tracers) case = eqx.tree_at(lambda t: getattr(t, 'do_pt'), case, do_pt) return cls(trajectory, case, metadatas)
[docs] @classmethod def from_jld2( cls, jld2_path: str, names_mapping: Dict[str, Dict[str, str]], nz: Optional[int] = None, dims: Union[DimsType, Dict[str, DimsType]] = (None,), eos_tracers: str = 't', do_pt: bool = False ) -> Data: """ Creates a :class:`Data` instance from a :code:`.jld2` file. For the scalar parameters, the values must be registered in the file simply with their path in the file, separated with :code:`/` in the same string. For the timeseries and the time, the arrays of each time step must be register with a path that ends with the reference of the time. For the other variables, just register the normal path. The time is appriximated to the order of the second. Parameters ---------- jld2_path : str Path of the *netcdf* file that contains the time-series of the observation trajectory and the physical parameters and forcings. names_mapping: Dict[str, Dict[str, str]] Contains the link between the Tunax names of variables and the path of the variables in the file. There are 3 first entries : - :code:`variables` :for all the variables corresponding to the :class:`space.Grid` and the :class:`space.Trajectory`. For the grid attributes (:attr:`space.Grid.zr` and :attr:`space.Grid.zw`) the path should correspond directly to the array in the file. For the time and the time-series, the given path corresponds to a path with all the reference (with a number in string) of the time, and then in these path with the time we have the array of the variable (or the float corresponding to the value of the time) at the time with this reference. Then the 2D arrays are rebuild by concatenation. The references of the time are get with the path of the time data. - :code:`parameters` : for all the scalar entries corresponding directly to the parameters of :class:`case.Case` - :code:`metadatas` : for the scalar entries that we want to keep in the :attr:`metadatas` for later. nz : int, optionnal, default=None Expected number of steps of the grid of the water column. The method will remove the borders of the raw data from the file to keep only the middle part of this lenght. If nothing is entered for this parameter, all the raw data are kept. dims : DimsType or Dict[str, DimsType], default=(None,) It contains the dimensions on which search the right arrays. If it's a dictionnary, it's like :code:`var_names`, the keys are the names of the variables in terms of Tunax and the values are the dimensions for each variable. Then we have a Tuple of int or Nones which corresponds at every axis of the raw data from the file. If an axis is indexed with None, it means that we keep this dimension, if an axis is indexed with an integer, it means that we reduce this axis to the value of the raw data on this index. eos_tracers : str, default='t' Tracers used for the equation of state, cf. :attr:`~case.Case.eos_tracers`. do_pt : bool, default=False Compute or not a passive tracer, cf. :attr:`~case.Case.do_pt`. Returns ------- data : Data An object that represent these file. """ var_map = names_mapping['variables'] par_map = names_mapping['parameters'] metadata_map = names_mapping['metadatas'] jl = H5pyFile(jld2_path, 'r') # récupération de la bonne valeur de nz if nz is None: ds = cast(H5pyDataset, jl[par_map['nz']]) nz = int(ds[()]) # variables grid et time zr = jnp.array(_get_var_jl(jl, var_map, 'zr', nz, dims)) zw = jnp.array(_get_var_jl(jl, var_map, 'zw', nz+1, dims)) time_group = var_map['time'] gr = cast(H5pyGroup, (jl[time_group])) time_str_list = list(gr.keys()) time_str_list = [int(i) for i in time_str_list] time_str_list.sort() time_str_list = [str(i) for i in time_str_list] time_float_list = [] for time_str in time_str_list: ds = cast(H5pyDataset, jl[f'{time_group}/{time_str}']) time_val = float(int(ds[()])) time_float_list.append(float(time_val)) time = jnp.array(time_float_list) # variables variables_dict = {} for var_name in VARIABLE_NAMES: if var_name not in var_map: continue var_list = [] for time_str in time_str_list: var_time = _get_var_jl(jl, var_map, var_name, nz, dims, f'/{time_str}') var_list.append(var_time) variables_dict[var_name] = jnp.vstack(var_list) # trajectory trajectory = Trajectory(Grid(zr, zw), time, **variables_dict) # parameters params = {} case_params_list = [nom for nom in vars(Case).keys()] for par_name, jl_name in par_map.items(): if par_name in case_params_list: ds = cast(H5pyDataset, jl[jl_name]) params[par_name] = float(ds[()]) case = Case(eos_tracers=eos_tracers, do_pt=do_pt, **params) # metadatas metadatas = {} for metadata_name, jl_name in metadata_map.items(): ds = cast(H5pyDataset, jl[jl_name]) jl_val = ds[()] if isinstance(jl_val, np.floating) or isinstance(jl_val, np.integer): metadatas[metadata_name] = float(jl_val) # return time return cls(trajectory, case, metadatas)
[docs] def cut(self, out_nt_cut: int) -> List[Data]: """ Cuts the :attr:`Trajectory` in sub-trajectories, cf. :meth:`space.Trajectory.cut`. Parameters ---------- out_nt_cut : int Number of output steps of the sub-trajectories. Returns ------- traj_list : List[Data] List of :class:`Data` instances with the sub-trajectories in the chronological order. """ traj_list = self.trajectory.cut(out_nt_cut) return [Data(traj, self.case, self.metadatas) for traj in traj_list]
[docs] class Weights(eqx.Module): """ Representation of the weights to put on every variable for the computing of the lost function. The constructor takes all the attributes as parameters. Attributes ---------- weight_u : float, default=0. Weight on zonal velocity. weight_v : float, default=0. Weight on meridionnal velocity. weight_t : float, default=0. Weight on temperature. weight_s : float, default=0. Weight on salinity. weight_b : float, default=0. Weight on buoyancy. weight_pt : float, default=0. Weight on passive tracer. """ weight_u: float = 0. weight_v: float = 0. weight_t: float = 0. weight_s: float = 0. weight_b: float = 0. weight_pt: float = 0.
[docs] class Obs(eqx.Module): """ This class represents and element of the database from the point of view of the loss function. This class prepares everything to make the loss function able to compute the loss for this element of the database. Indeed this class makes the link between the :class:`Trajectory` corresponding to this element, a model (with a grid a time parameters) corresponding to this trajectory, and the weights that we want to put on each variable. The constructor takes all the attributes as parameters. Attributes ---------- trajectory: Trajectory The time-series of the variables that represent this observation. model: SingleColumnModel A model built on this trajectory and on a physical case with the time and geometrical parameters. weights: Weights The weights to give to the loss function. """ trajectory: Trajectory model: SingleColumnModel weights: Weights
[docs] @classmethod def from_data(cls, data: Data, dt: float, weights: Weights, checkpoint: bool=False): """ Create a Obs instance from a Data one adding :class:`Weights` and a :code:`dt`. This function builds the other time parameters of the model from the trajectory. Parameters ---------- data : Data A data containing the trajectory and the physical case that we want to apply on our model. dt : float The integration time-step that we want for our model. weights : Weights The weights to give to the loss function. checkpoint : bool, default=False Use the :func:`~jax.checkpoint` on the partial run method. Used for economize the memory when computing the gradient, especially on GPUs. """ time = data.trajectory.time nt = int((time[-1]-time[0])/dt) out_dt = float(time[1] - time[0]) p_out = int(out_dt/dt) init_state = data.trajectory.extract_state(0) start_time = float(time[0]) model = SingleColumnModel( nt, dt, p_out, init_state, data.case, 'k-epsilon', start_time, checkpoint ) return Obs(data.trajectory, model, weights)
[docs] class Database(eqx.Module): """ Represent a set of several observations that form a database. The constructor takes all the attributes as parameters. Attributes ---------- observations : List[Obs] A list of several observations with potentially various forcings, geometry and time configuration. """ observations: List[Obs]