"""
Single column forward model core.
This module contains the main class :class:`SingleColumnModel` which is the
core of Tunax for implementing computing the evolution of a water column of the
ocean. It also contains functions that are used in this computation. The model
was traduced from Frotran to JAX with the work of Florian Lemarié and Manolis
Perrot [1]_, the translation was done in part using the work of Anthony Zhou,
Linnia Hawkins and Pierre Gentine [2]_. this class and these functions can be
obtained by the prefix :code:`tunax.model.` or directly by :code:`tunax.`.
References
----------
.. [1] M. Perrot and F. Lemarié. Energetically consistent Eddy-Diffusivity
Mass-Flux convective schemes. Part I: Theory and Models (2024). url :
`hal.science/hal-04439113 <https://hal.science/hal-04439113>`_.
.. [2] A. Zhou, L. Hawkins and P. Gentine. Proof-of-concept: Using ChatGPT to
Translate and Modernize an Earth System Model from Fortran to Python/JAX
(2024). url : `arxiv.org/abs/2405.00018
<https://arxiv.org/abs/2405.00018>`_.
"""
from __future__ import annotations
import warnings
from typing import Tuple, List
from functools import partial
import equinox as eqx
import jax.numpy as jnp
from jax import lax, jit
from jaxtyping import Float, Array
from tunax.case import Case
from tunax.space import State, Trajectory
from tunax.functions import (
tridiag_solve, add_boundaries, _format_to_single_line
)
from tunax.closure import (
ClosureParametersAbstract, ClosureStateAbstract, Closure
)
from tunax.closures_registry import CLOSURES_REGISTRY
[docs]
class SingleColumnModel(eqx.Module):
r"""
Single column forward model core.
This forward model of tunax is the combination of 4 things : the physical
case :attr:`case`, an initial state of the water column
:attr:`init_state`, the time information with :attr:`time_frame`,
:attr:`dt` and :attr:`out_dt` and the abstraction of the chosen closure
for eddy-diffusivity :attr:`closure`. Adding a set of parameters for the
closure one can run the model with the method
:meth:`compute_trajectory_with`.
Parameters
----------
time_frame : float
Total time of the simulation :math:`[\text h]`.
dt : float
Time-step of integration for every iteration :math:`[\text s]`.
out_dt : float
Time-step for the output writing :math:`[\text s]`.
init_state : State
cf. attribute.
case : Case
cf. attribute.
closure_name : str
Name of the chosen closure, must be a key of
:attr:`~closures_registry.CLOSURES_REGISTRY`, see its documentation
for the available closures.
Attributes
----------
nt : int
Number of integration interations.
dt : float
Time-step of integration for every iteration :math:`[\text s]`.
n_out : int
Number of of time-steps between every output.
init_state : State
Initial physical state of the water column.
case : Case
Physical case and forcings of the experiment.
closure : Closure
Abstraction representing the chosen closure.
Warnings
--------
- If :code:`time_frame` is not proportional to the time-step :attr:`dt`.
- If :code:`time_frame` is not proportional to the out time-step
:code:`out_dt`.
Raises
------
ValueError
If :code:`out_dt` is not proportional to the time step :attr:`dt`.
ValueError
If :code:`closure_name` is not registerd in
:attr:`~closures_registry.CLOSURES_REGISTRY`.
Note
----
To make this forward model compatible with the fitter part of Tunax, the
parameters of the closure are only given during of the call of the run
with :meth:`compute_trajectory_with`.
"""
nt: int
dt: float
n_out: int
init_state: State
case: Case
closure: Closure
def __init__(
self,
time_frame: float,
dt: float,
out_dt: float,
init_state: State,
case: Case,
closure_name: str
) -> SingleColumnModel:
# time parameters transformation
n_out = out_dt/dt
nt = time_frame*3600/dt
# warnings and errors on time parameters coherence
if not n_out.is_integer():
raise ValueError('`out_dt` should be a multiple of `dt`.')
if not nt % n_out == 0:
warnings.warn(_format_to_single_line("""
The `time_frame`is not proportional to the out time-step
`out_dt`, the last step will be computed a few before the
`time_frame`.
"""))
if not nt.is_integer():
warnings.warn(_format_to_single_line("""
The `time_frame`is not proportional to the time-step `dt`, the
last step will be computed a few before the time_frame.
"""))
if not closure_name in CLOSURES_REGISTRY:
raise ValueError(_format_to_single_line("""
`closure_name` not registerd in CLOSURES_REGISTRY.
"""))
# write attributes
self.nt = int(nt)
self.dt = dt
self.n_out = int(n_out)
self.init_state = init_state
self.case = case
self.closure = CLOSURES_REGISTRY[closure_name]
[docs]
def compute_trajectory_with(
self,
closure_parameters: ClosureParametersAbstract
) -> Trajectory:
"""
Run the model with a specific set of closure parameters.
This method is the main one for runing the model. It calls :attr:`nt`
times the function :func:`step` and regulary writes the output to build
the :class:`~space.Trajectory` output.
Parameters
----------
closure_parameters : ClosureParametersAbstract
A set of parameters of the used closure.
Returns
-------
trajectory : Trajectory
Timeseries of the evolution of the variables of the model every
:code:`out_dt`.
"""
# initialize the model
states_list: List[State] = []
state = self.init_state
closure_state = self.closure.state_class(
self.init_state.grid, closure_parameters
)
swr_frac = lmd_swfrac(self.init_state.grid.hz)
# loop the model
for i_t in range(self.nt):
if i_t % self.n_out == 0:
states_list.append(state)
state, closure_state = step(
self.dt, self.case, self.closure, state, closure_state,
closure_parameters, swr_frac
)
time = jnp.arange(0, self.nt*self.dt, self.n_out*self.dt)
# generate trajectory
u_list = [s.u for s in states_list]
v_list = [s.v for s in states_list]
t_list = [s.t for s in states_list]
s_list = [state.s for state in states_list]
trajectory = Trajectory(
self.init_state.grid, time, jnp.vstack(t_list), jnp.vstack(s_list),
jnp.vstack(u_list), jnp.vstack(v_list)
)
return trajectory
[docs]
@partial(jit, static_argnames=('dt', 'case', 'closure'))
def step(
dt: float,
case: Case,
closure: Closure,
state: State,
closure_state: ClosureStateAbstract,
closure_parameters: ClosureParametersAbstract,
swr_frac: Float[Array, 'nz+1']
) -> Tuple[State, ClosureStateAbstract]:
r"""
Run one time-step of the model.
This functions first call the closure to compute the eddy-diffusivity and
viscosity, and then integrate the equations of tracers and momentum. It
modifies the :code:`state` with these new values and then returns the new
:code:`state` and :code:`closure_state`.
Parameters
----------
dt : float
Time-step of integration for every iteration :math:`[\text s]`.
case : Case
Physical case and forcings of the experiment.
closure : Closure
Abstraction representing the chosen closure.
state : State
Curent state of the water column.
closure_state : ClosureStateAbstract
Curent state of the water column for the closure variables.
closure_parameters : ClosureParametersAbstract
A set of parameters of the used with the :code:`closure`.
swr_frac : Float[~jax.Array, 'nz+1']
Fraction of solar penetration throught the water column
:math:`[\text{dimensionless}]`.
Returns
-------
state : State
State of the water column at next time-step.
closure_state : ClosureStateAbstract
State of the water column at next time-step for the closure variables.
Note
----
This function is jitted with JAX, it should make it faster, but the
:func:`~jax.jit` decorator can be removed.
"""
grid = state.grid
# advance closure state (compute eddy-diffusivity and viscosity)
closure_state = closure.step_fun(
state, closure_state, dt, closure_parameters, case
)
# advance tracers
t_new, s_new = advance_tra_ed(
state.t, state.s, closure_state.akt, swr_frac, grid.hz, dt, case
)
# advance velocities
u_new, v_new = advance_dyn_cor_ed(
state.u, state.v, grid.hz, closure_state.akv, dt, case
)
# write the new state
state = eqx.tree_at(lambda tree: tree.t, state, t_new)
state = eqx.tree_at(lambda t: t.s, state, s_new)
state = eqx.tree_at(lambda t: t.u, state, u_new)
state = eqx.tree_at(lambda t: t.v, state, v_new)
return state, closure_state
[docs]
def lmd_swfrac(hz: Float[Array, 'nz']) -> Float[Array, 'nz+1']:
r"""
Compute solar forcing.
Compute fraction of solar shortwave flux penetrating to specified depth due
to exponential decay in Jerlov water type. This function is called once
before running the model.
Parameters
----------
hz : Float[~jax.Array, 'nz']
Thickness of cells from deepest to shallowest :math:`[\text m]`.
Returns
-------
swr_frac : Float[~jax.Array, 'nz+1']
Fraction of solar penetration throught the water column
:math:`[\text{dimensionless}]`.
"""
nz, = hz.shape
mu1 = 0.35
mu2 = 23.0
r1 = 0.58
attn1 = -1.0 / mu1
attn2 = -1.0 / mu2
xi1 = attn1 * hz
xi2 = attn2 * hz
def lax_step(sdwk, k):
sdwk1, sdwk2 = sdwk
sdwk1 = lax.cond(xi1[nz-k] > -20, lambda x: x*jnp.exp(xi1[nz-k]),
lambda x: 0.*x, sdwk1)
sdwk2 = lax.cond(xi2[nz-k] > -20, lambda x: x*jnp.exp(xi2[nz-k]),
lambda x: 0.*x, sdwk2)
return (sdwk1, sdwk2), sdwk1+sdwk2
_, swr_frac = lax.scan(lax_step, (r1, 1.0 - r1), jnp.arange(1, nz+1))
return jnp.concat((swr_frac[::-1], jnp.array([1.])))
[docs]
def advance_tra_ed(
t: Float[Array, 'nz'],
s: Float[Array, 'nz'],
akt: Float[Array, 'nz+1'],
swr_frac: Float[Array, 'nz+1'],
hz: Float[Array, 'nz'],
dt: float,
case: Case
) -> Tuple[Float[Array, 'nz'], Float[Array, 'nz']]:
r"""
Integrate vertical diffusion term for tracers.
First the flux divergences are computed taking in account the forcings.
Then the diffusion equation of the tracers system is solved, and the
tracers at next time-step are returned. The solved equation is for
:math:`C` a tracer :
:math:`\partial _z ( K_m \partial _z C) + \partial _t C + F = 0`
where :math:`F` is the representation of the forcings.
Parameters
----------
t : Float[~jax.Array, 'nz']
Current temperature on the center of the cells :math:`[° \text C]`.
s : Float[~jax.Array, 'nz']
Current salinity on the center of the cells :math:`[\text{psu}]`.
akt : Float[~jax.Array, 'nz+1']
Current eddy-diffusivity :math:`K_m` on the interfaces of the cells
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
swr_frac : Float[~jax.Array, 'nz+1']
Fraction of solar penetration throught the water column
:math:`[\text{dimensionless}]`.
hz : Float[~jax.Array, 'nz']
Thickness of cells from deepest to shallowest :math:`[\text m]`.
dt : float
Time-step of the integration step :math:`[\text s]`.
case : Case
Physical case and forcings of the experiment.
Returns
-------
t : Float[Array, 'nz']
Temperature on the center of the cells at next step
:math:`[° \text C]`.
s : Float[Array, 'nz']
Salinity on the center of the cells at next step :math:`[\text{psu}]`.
"""
# 1 - Fluxes
# Temperature
fc_t = case.rflx_sfc_max * swr_frac
fc_t = fc_t.at[-1].add(case.tflx_sfc)
fc_t = fc_t.at[0].set(0.)
# apply flux divergence
dft = hz*t + dt*(fc_t[1:] - fc_t[:-1])
# Salinity
fc_s = jnp.zeros(t.shape[0]+1)
fc_s.at[-1].set(case.sflx_sfc)
# apply flux divergence
dfs = hz*s + dt*(fc_s[1:] - fc_s[:-1])
# 2 - Implicit integration for vertical diffusion
# Temperature
dft = dft.at[0].add(-dt * case.tflx_btm)
t = diffusion_solver(akt, hz, dft, dt)
# Salinity
dfs = dfs.at[0].add(-dt * case.sflx_btm)
s = diffusion_solver(akt, hz, dfs, dt)
return t, s
[docs]
def advance_dyn_cor_ed(
u: Float[Array, 'nz'],
v: Float[Array, 'nz'],
hz: Float[Array, 'nz'],
akv: Float[Array, 'nz+1'],
dt: float,
case: Case
) -> Tuple[Float[Array, 'nz'], Float[Array, 'nz']]:
r"""
Integrate vertical diffusion and Coriolis terms for momentum.
First the Coriolis term is computed, then the momentum forcings are applied
and finally, the diffusion equation is solved. The momentum at next time-
step is returned. The equation which is solved is :
:math:`\partial_z (K_v \partial_z U) + F_{\text{cor}}(U) + F = 0`
where :math:`F_{\text{cor}}` represent the Coriolis effect, and :math:`F`
represent the effect of the forcings.
Parameters
----------
u : Float[~jax.Array, 'nz']
Current zonal velocity on the center of the cells
:math:`\left[\text m \cdot \text s^{-1}\right]`.
v : Float[~jax.Array, 'nz']
Current meridional velocity on the center of the cells
:math:`\left[\text m \cdot \text s^{-1}\right]`.
akv : Float[~jax.Array, 'nz+1']
Current eddy-viscosity :math:`K_v` on the interfaces of the cells
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
hz : Float[~jax.Array, 'nz']
Thickness of cells from deepest to shallowest :math:`[\text m]`.
dt : float
Time-step of the integration step :math:`[\text s]`.
case : Case
Physical case and forcings of the experiment.
Returns
-------
u : Float[Array, 'nz']
Zonal velocity on the center of the cells at the next time step
:math:`\left[\text m \cdot \text s^{-1}\right]`.
v : Float[Array, 'nz']
Meridional velocity on the center of the cells at the next time step
:math:`\left[\text m \cdot \text s^{-1}\right]`.
"""
gamma_cor = 0.55
fcor = case.fcor
# 1 - Compute Coriolis term
cff = (dt * fcor) ** 2
cff1 = 1 / (1 + gamma_cor * gamma_cor * cff)
fu = cff1 * hz * ((1-gamma_cor*(1-gamma_cor)*cff)*u + dt*fcor*v)
fv = cff1 * hz * ((1-gamma_cor*(1-gamma_cor)*cff)*v - dt*fcor*u)
# 2 - Apply surface and bottom forcing
fu = fu.at[-1].add(dt * case.ustr_sfc)
fv = fv.at[-1].add(dt * case.vstr_sfc)
fu = fu.at[0].add(-dt * case.ustr_btm)
fv = fv.at[0].add(-dt * case.vstr_btm)
# 3 - Implicit integration for vertical viscosity
u = diffusion_solver(akv, hz, fu, dt)
v = diffusion_solver(akv, hz, fv, dt)
return u, v
[docs]
def diffusion_solver(
ak: Float[Array, 'nz+1'],
hz: Float[Array, 'nz'],
f: Float[Array, 'nz'],
dt: float
) -> Float[Array, 'nz']:
r"""
Solve a diffusion problem with finite volumes.
The diffusion problems can be written
:math:`\partial _z (K \partial _z X) + \dfrac f {\Delta t \Delta x} = 0`
where we are searching for :math:`X` and where :math:`f` represents the
temporal derivative and forcings. This function transforms this problem in
a tridiagonal system and then solve it.
Parameters
----------
ak : Float[~jax.Array, 'nz+1']
Diffusion at the cell interfaces :math:`K` in
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
hz : Float[~jax.Array, 'nz']
Thickness of cells from deepest to shallowest
:math:`\left[\text m\right]`.
f : Float[~jax.Array, 'nz']
Right-hand flux of the equation :math:`f` in
:math:`[[X] \cdot \text m ]`.
dt : float
Time-step of discretisation :math:`[\text s]`.
Returns
-------
x : Float[~jax.Array, 'nz']
Solution of the diffusion problem
:math:`X` in :math:`\left[[X]\right]`.
"""
# fill the coefficients for the tridiagonal matrix
a_in = -2.0 * dt * ak[1:-2] / (hz[:-2] + hz[1:-1])
c_in = -2.0 * dt * ak[2:-1] / (hz[2:] + hz[1:-1])
b_in = hz[1:-1] - a_in - c_in
# bottom boundary condition
c_btm = -2.0 * dt * ak[1] / (hz[1] + hz[0])
b_btm = hz[0] - c_btm
# surface boundary condition
a_sfc = -2.0 * dt * ak[-2] / (hz[-2] + hz[-1])
b_sfc = hz[-1] - a_sfc
# concatenations
a = add_boundaries(0., a_in, a_sfc)
b = add_boundaries(b_btm, b_in, b_sfc)
c = add_boundaries(c_btm, c_in, 0.)
x = tridiag_solve(a, b, c, f)
return x