"""
Single column forward model core.
This module contains the main class :class:`SingleColumnModel` which is the core of Tunax for
implementing computing the evolution of a water column of the ocean. It also contains functions
that are used in this computation. The model was traduced from Frotran to JAX with the work of
Florian LemariƩ and Manolis Perrot [1]_, the translation was done in part using the work of Anthony
Zhou, Linnia Hawkins and Pierre Gentine [2]_. this class and these functions can be obtained by the
prefix :code:`tunax.model.` or directly by :code:`tunax.`.
References
----------
.. [1] M. Perrot and F. LemariƩ. Energetically consistent Eddy-Diffusivity Mass-Flux convective
schemes. Part I: Theory and Models (2024). url :
`hal.science/hal-04439113 <https://hal.science/hal-04439113>`_.
.. [2] A. Zhou, L. Hawkins and P. Gentine. Proof-of-concept: Using ChatGPT to Translate and
Modernize an Earth System Model from Fortran to Python/JAX (2024). url :
`arxiv.org/abs/2405.00018 <https://arxiv.org/abs/2405.00018>`_.
"""
from __future__ import annotations
import inspect
from typing import Tuple, Dict, TypeAlias, cast
import equinox as eqx
import jax
import jax.numpy as jnp
from jax import lax, jit, vmap
from tunax.case import Case, CaseTracable
from tunax.space import Grid, State, Trajectory, ArrNz, ArrNzp1, ArrNt, VARIABLE_NAMES
from tunax.functions import FloatJax, tridiag_solve, add_boundaries
from tunax.closure import ClosureParametersAbstract, ClosureStateAbstract, Closure
from tunax.closures_registry import CLOSURES_REGISTRY
StatesTime: TypeAlias = Tuple[State, ClosureStateAbstract, float]
"""Type that represent the values that are transformed in an integration step of the model."""
[docs]
class SingleColumnModel(eqx.Module):
r"""
Single column forward model core.
This forward model of tunax is the combination of 4 things : the physical case
:attr:`case_tracable`, an initial state of the water column :attr:`init_state`, the time
information with :attr:`nt`, :attr:`dt` and :attr:`p_out` and the abstraction of the chosen
closure for eddy-diffusivity :attr:`closure`. Adding a set of parameters for the closure one can
run the model with the method :meth:`run`. The builder of this class takes the case as
arguments all the attributes, except for the case : the parameter is an instance of
:class:`~case.Case` which is transformed in a :class:`~case.CaseTracable` instance for JAX
purpose and for the closure : the parameter is only the name of the closure.
Parameters
----------
nt : int
cf. :attr:`nt`.
dt : float
cf. :attr:`dt`.
p_out : int
cf. :attr:`p_out`.
init_state : State
cf. :attr:`init_state`.
case : Case
Physical case and forcings of the experiment.
closure_name : str
Name of the chosen closure, must be a key of
:attr:`~closures_registry.CLOSURES_REGISTRY`, see its documentation
for the available closures.
start_time : float, default=0.
cf. :attr:`start_time`.
Attributes
----------
nt : int
Number of integration interations.
dt : float
Time-step of integration for every iteration :math:`[\text s]`.
p_out : int
Number of of time-steps between every output step.
init_state : State
Initial physical state of the water column.
case_tracable : CaseTracable
Physical case and forcings of the experiment, made tracable for JAX purposes.
closure : Closure
Abstraction representing the chosen closure.
start_time : float, default=0.
Value of the starting time in :math:`[\text s]`.
checkpoint : bool, default=False
Use the :func:`~jax.checkpoint` on the partial run method. Used for economize the memory
when computing the gradient, especially on GPUs.
Note
----
The closure parameters are given only at the end when we compute the model so that the
calibrations of the parameters are easier.
"""
nt: int = eqx.field(static=True)
dt: float
p_out: int = eqx.field(static=True)
init_state: State
case_tracable: CaseTracable
closure: Closure = eqx.field(static=True)
start_time: float
checkpoint: bool = eqx.field(static=True)
def __init__(
self,
nt: int,
dt: float,
p_out: int,
init_state: State,
case: Case,
closure_name: str,
start_time: float=0.,
checkpoint: bool=False
) -> None:
self.nt = nt
self.dt = dt
self.p_out = p_out
self.init_state = init_state
self.closure = CLOSURES_REGISTRY[closure_name]
self.start_time = start_time
self.checkpoint = checkpoint
# creation of the CaseTracable class
grid = self.init_state.grid
case_attributes = {k: v for k, v in vars(case).items() if not k.startswith('__')}
for tra in ['t', 's', 'b', 'pt']:
tra_attr = f'{tra}_forcing'
tra_type_attr = f'{tra}_forcing_type'
forcing = getattr(case, tra_attr)
if forcing is not None:
if isinstance(forcing, tuple):
case_attributes[tra_type_attr] = 'borders'
case_attributes[tra_attr] = forcing
elif callable(forcing) and len(inspect.signature(forcing).parameters) == 1:
case_attributes[tra_type_attr] = 'constant'
vec_fun = vmap(forcing)
case_attributes[tra_attr] = grid.hz*vec_fun(grid.zr)
elif callable(forcing) and len(inspect.signature(forcing).parameters) == 2:
case_attributes[tra_type_attr] = 'variable'
time = jnp.linspace(0, (self.nt-1)*self.dt, self.nt)
zr_grid, time_grid = jnp.meshgrid(grid.zr, time)
case_attributes[tra_attr] = grid.hz*forcing(zr_grid, time_grid)
else:
case_attributes[tra_type_attr] = None
self.case_tracable = CaseTracable(**case_attributes)
[docs]
def step(
self,
state: State,
closure_state: ClosureStateAbstract,
time: float,
closure_parameters: ClosureParametersAbstract
) -> StatesTime:
r"""
Runs one time-step of the model.
This functions first call the closure to compute the eddy-diffusivity and viscosity, and
then integrate the equations of tracers and momentum. It modifies the :code:`state` with
these new values and then returns the new :code:`state` and :code:`closure_state`.
Parameters
----------
state : State
State of the water column at the current :code:`time`.
closure_state : ClosureStateAbstract
State of the water column for the closure variables at the current :code:`time`.
time : float
Time of the current iteration (the mehtod integrates from this time to the next one).
closure_parameters : ClosureParametersAbstract
A set of parameters of the used closure.
Returns
-------
state : State
State of the water column after the integration.
closure_state : ClosureStateAbstract
State of the water column for the closure variables after the integration.
time : float
Value of the time after the integration.
"""
dt = self.dt
case_tracable = self.case_tracable
# advance closure state (compute eddy-diffusivity and viscosity)
closure_state = self.closure.step_fun(
state, closure_state, dt, closure_parameters, case_tracable
)
# advance tracers
i_time = cast(int, time/self.dt)
state = advance_tra_ed(state, closure_state.akt, dt, case_tracable, i_time)
# advance velocities
state = advance_dyn_cor_ed(state, closure_state.akv, dt, case_tracable)
time += self.dt
return state, closure_state, time
[docs]
@jax.checkpoint
def step_check(
self,
state: State,
closure_state: ClosureStateAbstract,
time: float,
closure_parameters: ClosureParametersAbstract
) -> StatesTime:
r"""
Checkpointed version of :meth:`run` to save memory during the gradient computation.
Parameters
----------
state : State
State of the water column at the current :code:`time`.
closure_state : ClosureStateAbstract
State of the water column for the closure variables at the current :code:`time`.
time : float
Time of the current iteration (the mehtod integrates from this time to the next one).
closure_parameters : ClosureParametersAbstract
A set of parameters of the used closure.
Returns
-------
state : State
State of the water column after the integration.
closure_state : ClosureStateAbstract
State of the water column for the closure variables after the integration.
time : float
Value of the time after the integration.
"""
return self.step(state, closure_state, time, closure_parameters)
[docs]
def run_partial(
self,
state0: State,
closure_state0: ClosureStateAbstract,
time0: float,
n_steps: int,
closure_parameters: ClosureParametersAbstract
) -> StatesTime:
r"""
Runs a certain number of time steps.
Computes a loop of integration for a number of time steps of :code:`n_steps`, and return
the last states of the loop.
Parameters
----------
state0 : State
State of the water column at the beginning of the integration loop.
closure_state : ClosureStateAbstract
State of the water column for the closure variables at the beginning of the integration
loop.
time0 : float
Begining time of the integration loop.
n_steps : int
Number of integration steps.
closure_parameters : ClosureParametersAbstract
A set of parameters of the used closure.
Returns
-------
state : State
State of the water column after a number of :code:`n_steps` integration steps.
closure_state : ClosureStateAbstract
State of the water column for the closure variables after a number of :code:`n_steps`
integration steps.
time : float
Value of the time after a number of :code:`n_steps` integration steps.
"""
if self.checkpoint:
step_fun = self.step_check
else:
step_fun = self.step
def scan_fn(carry: StatesTime, _: FloatJax) -> Tuple[StatesTime, None]:
state, closure_state, time = carry
state, closure_state, time = step_fun(state, closure_state, time, closure_parameters)
return (state, closure_state, time), None
carry, _ = lax.scan(scan_fn, (state0, closure_state0, time0), jnp.arange(n_steps))
(state, closure_state, time) = carry
return state, closure_state, time
def _state_concat_to_traj(self, states: State, times: ArrNt) -> Trajectory:
"""
Convert the concatenations of :class:`~space.State` in a :class:`~space.Trajectory`.
Use to get a trajectory from the output of the :func:`~jax.lax.scan` function which
computes the concatenation of several :class:`~space.State` instances.
Parameters
----------
states : State
A concatenation of :class:`~space.State` instances. More specifically, every leafs of
the pytree is an array on the first axis of all the values of the the
:class:`~space.State` instances.
times : float :class:`~jax.Array` of shape (nt)
Values of the different times corresponding at each :class:`~space.State` instances.
Returns
-------
trajectory : Trajectory
Trajectory corresponding to the concatenation of the states.
"""
var_dict = {}
for var in VARIABLE_NAMES:
if getattr(states, var) is not None:
var0 = jnp.expand_dims(getattr(self.init_state, var), 0)
var_computed = getattr(states, var)
var_dict[var] = jnp.concat([var0, var_computed])
times = jnp.concat([jnp.array([self.start_time]), times])
return Trajectory(self.init_state.grid, times, **var_dict)
[docs]
def run(self, closure_parameters: ClosureParametersAbstract) -> Trajectory:
r"""
Main run the model.
Computes the :attr:`nt` integration steps of lenght :attr:`dt` from the initial state
:attr:`init_state` with and doing an output of the state every :attr:`p_out` steps with the
physical case and forcings corresponding to :attr:`case_tracable`. The closure for eddy-
diffusivity used is :attr:`closure` with the values of the parameters described by the
parameter :code:`closure_parameters`.
Parameters
----------
closure_parameters : ClosureParametersAbstract
The set of parameters to use for the computation of the closure of eddy-diffusivity.
Returns
-------
trajectory : Trajectory
The trajectory with all the output steps of the integration loop.
"""
init_closure_state = self.closure.state_class(self.init_state.grid, closure_parameters)
def scan_fn(carry: StatesTime, _: FloatJax) -> Tuple[StatesTime, StatesTime]:
state, closure_state, time = carry
state, closure_state, time = self.run_partial(
state, closure_state, time, self.p_out, closure_parameters
)
return (state, closure_state, time), (state, closure_state, time)
n_steps_out = self.nt//self.p_out
(_, _, _), (states, _, times) = lax.scan(
scan_fn, (self.init_state, init_closure_state, self.start_time), jnp.arange(n_steps_out)
)
return self._state_concat_to_traj(states, jnp.array(times))
[docs]
@jit
def jit_run(self, closure_parameters: ClosureParametersAbstract) -> Trajectory:
r"""
Jitted version of :meth:`run`.
This method does exacly like :meth:`run` but :func:`~jax.jit` is applied on it, which
means that the first call of this method will be the compilation of the function, and the
next ones will be the compiled execution of the function which are supposed to be faster.
There will be a compilation each time that this method will be call for a new "shape" of the
:class:`SingleColumnModel` instance (which means that all the leafs of the pytree have the
same shape as :class:`~jax.Array`), but even if this method is call on different instances,
if they have the same shape, the compilation will be done only one time. Moreover, this
method should use only for direct methods, if one wants to apply a :func:`~jax.grad` over
it's better to put the :func:`~jax.jit` outside.
Parameters
----------
closure_parameters : ClosureParametersAbstract
cf. :meth:`run`.
Returns
-------
trajectory : Trajectory
cf. :meth:`run`.
"""
return self.run(closure_parameters)
[docs]
def lmd_swfrac(hz: ArrNz) -> ArrNzp1:
r"""
Compute solar forcing.
Compute fraction of solar shortwave flux penetrating to specified depth due to exponential decay
in Jerlov water type. This function is called once before running the model.
Parameters
----------
hz : float :class:`~jax.Array` of shape (nz)
Thickness of cells from deepest to shallowest :math:`[\text m]`.
Returns
-------
swr_frac : float :class:`~jax.Array` of shape (nz+1)
Fraction of solar penetration throught the water column :math:`[\text{dimensionless}]`.
"""
nz, = hz.shape
mu1 = 0.35
mu2 = 23.0
r1 = 0.58
attn1 = -1.0 / mu1
attn2 = -1.0 / mu2
xi1 = attn1 * hz
xi2 = attn2 * hz
def lax_step(sdwk, k):
sdwk1, sdwk2 = sdwk
sdwk1 = lax.cond(xi1[nz-k] > -20, lambda x: x*jnp.exp(xi1[nz-k]), lambda x: 0.*x, sdwk1)
sdwk2 = lax.cond(xi2[nz-k] > -20, lambda x: x*jnp.exp(xi2[nz-k]), lambda x: 0.*x, sdwk2)
return (sdwk1, sdwk2), sdwk1+sdwk2
_, swr_frac = lax.scan(lax_step, (r1, 1.0 - r1), jnp.arange(1, nz+1))
return jnp.concat((swr_frac[::-1], jnp.array([1.])))
[docs]
def tracer_flux(
tracer: str,
case_tracable: CaseTracable,
grid: Grid,
i_time: int
) -> ArrNz:
r"""
Computes flux of the tracer forcing.
This function get the flux of the forcing at a certain time depending on the type of the
forcing, the flux being the derivative of the forcing along the depth.
Parameters
----------
tracer : str
Name of the tracer variable of the concerned forcing. One of {:code:`'t'`, :code:`'s'`,
:code:`'b'`, :code:`'pt`'}.
case_tracable : CaseTracable
Physical case which contains the forcings type and values.
grid : Grid
Vertical grid of the water column.
i_time : int
Index of the time iteration corresponding to the index in the forcing.
Returns
-------
df : float :class:`~jax.Array` of shape (nz)
Flux of the forcing of the tracer. At each cell it represents the difference between the
input and the ouput flux.
Raises
------
ValueError
If the forcing type is not one of {'borders', 'constant', 'variable'}.
"""
forcing = getattr(case_tracable, f'{tracer}_forcing')
forcing_type = getattr(case_tracable, f'{tracer}_forcing_type')
match forcing_type:
case 'borders':
df = add_boundaries(-forcing[0], jnp.zeros(grid.nz-2), forcing[1])
case 'constant':
df = forcing
case 'variable':
df = forcing[:, i_time]
case _:
mess = f'Forcing type of variable {tracer} should be one of' + \
"{'borders', 'constant', 'variable'}."
raise ValueError(mess)
return df
[docs]
def advance_tra_ed(
state: State,
akt: ArrNzp1,
dt: float,
case_tracable: CaseTracable,
i_time: int,
) -> State:
r"""
Integrate vertical diffusion term for tracers.
First the flux divergences are computed taking in account the forcings. Then the diffusion
equation of the tracers system is solved, and the tracers at next time-step are returned. The
solved equation is for :math:`C` a tracer :
:math:`\partial _z ( K_m \partial _z C) + \partial _t C + F = 0`
where :math:`F` is the representation of the forcings. This equation is solved for every tracer
indicated in :attr:`~case.Case.eos_tracers` and the passive tracer if :attr:`~case.Case.do_pt`
is set.
Parameters
----------
state : State
State of the water column at the current iteration.
akt : float :class:`~jax.Array` of shape (nz)+1
Eddy-diffusivity on the interfaces of the cells
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
dt : float
Time-step of the integration step.
case_tracable : CaseTracable
Physical case which contains the forcings type and values.
i_time : int
Index of the time iteration corresponding to the index in the forcing.
Returns
-------
state : State
The state of the water column with the values of the tracers after the integration and the
diffusion by the eddy-diffusivity.
"""
hz = state.grid.hz
tracers = [tra for tra in case_tracable.eos_tracers]
if case_tracable.do_pt:
tracers.append('pt')
def get_pytree_fun(tracer: str):
return lambda t: getattr(t, tracer)
for tracer in tracers:
tra = getattr(state, tracer)
df = tracer_flux(tracer, case_tracable, state.grid, i_time)
df = hz*tra + dt*df
tra = diffusion_solver(akt, hz, df, dt)
state = eqx.tree_at(get_pytree_fun(tracer), state, tra)
return state
[docs]
def advance_dyn_cor_ed(
state: State,
akv: ArrNzp1,
dt: float,
case_tracable: CaseTracable
) -> State:
r"""
Integrate vertical diffusion and Coriolis terms for momentum.
First the Coriolis term is computed, then the momentum forcings are applied and finally, the
diffusion equation is solved. The momentum at next time-step is returned. The equation which is
solved is :
:math:`\partial_z (K_v \partial_z U) + F_{\text{cor}}(U) + F = 0`
where :math:`F_{\text{cor}}` represent the Coriolis effect, and :math:`F` represent the effect
of the forcings on the momentum.
Parameters
----------
state : State
State of the water column at the current iteration.
akv : float :class:`~jax.Array` of shape (nz+1)
Eddy-viscosity on the interfaces of the cells
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
dt : float
Time-step of the integration step.
case_tracable : CaseTracable
Physical case which contains the forcings type and values.
Returns
-------
state : State
The state of the water column with the values of the momentum after the integration and the
diffusion by the eddy-viscosity.
"""
gamma_cor = 0.55
fcor = case_tracable.fcor
u = state.u
v = state.v
hz = state.grid.hz
# 1 - Compute Coriolis term
cff = (dt * fcor) ** 2
cff1 = 1 / (1 + gamma_cor * gamma_cor * cff)
fu = cff1 * hz * ((1-gamma_cor*(1-gamma_cor)*cff)*u + dt*fcor*v)
fv = cff1 * hz * ((1-gamma_cor*(1-gamma_cor)*cff)*v - dt*fcor*u)
# 2 - Apply surface and bottom forcing
fu = fu.at[-1].add(dt * case_tracable.ustr_sfc)
fv = fv.at[-1].add(dt * case_tracable.vstr_sfc)
fu = fu.at[0].add(-dt * case_tracable.ustr_btm)
fv = fv.at[0].add(-dt * case_tracable.vstr_btm)
# 3 - Implicit integration for vertical viscosity
u = diffusion_solver(akv, hz, fu, dt)
v = diffusion_solver(akv, hz, fv, dt)
# 4 - Update the state
state = eqx.tree_at(lambda t: t.u, state, u)
state = eqx.tree_at(lambda t: t.v, state, v)
return state
[docs]
def diffusion_solver(
ak: ArrNzp1,
hz: ArrNz,
f: ArrNz,
dt: float
) -> ArrNz:
r"""
Solve a diffusion problem with finite volumes.
The diffusion problems can be written
:math:`\partial _z (K \partial _z X) + \dfrac f {\Delta t \Delta x} = 0`
where we are searching for :math:`X` and where :math:`f` represents the temporal derivative and
forcings. This function transforms this problem in a tridiagonal system and then solve it.
Parameters
----------
ak : float :class:`~jax.Array` of shape (nz+1)
Diffusion at the cell interfaces :math:`K` in
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
hz : float :class:`~jax.Array` of shape (nz)
Thickness of cells from deepest to shallowest :math:`\left[\text m\right]`.
f : float :class:`~jax.Array` of shape (nz)
Right-hand flux of the equation :math:`f` in :math:`[[X] \cdot \text m ]`.
dt : float
Time-step of discretisation :math:`[\text s]`.
Returns
-------
x : float :class:`~jax.Array` of shape (nz)
Solution of the diffusion problem :math:`X` in :math:`\left[[X]\right]`.
"""
# fill the coefficients for the tridiagonal matrix
a_in = -2.0 * dt * ak[1:-2] / (hz[:-2] + hz[1:-1])
c_in = -2.0 * dt * ak[2:-1] / (hz[2:] + hz[1:-1])
b_in = hz[1:-1] - a_in - c_in
# bottom boundary condition
c_btm = -2.0 * dt * ak[1] / (hz[1] + hz[0])
b_btm = hz[0] - c_btm
# surface boundary condition
a_sfc = -2.0 * dt * ak[-2] / (hz[-2] + hz[-1])
b_sfc = hz[-1] - a_sfc
# concatenations
a = add_boundaries(0., a_in, cast(float, a_sfc))
b = add_boundaries(cast(float, b_btm), b_in, cast(float, b_sfc))
c = add_boundaries(cast(float, c_btm), c_in, 0.)
x = tridiag_solve(a, b, c, f)
return x