r"""
:math:`k-\varepsilon` closure parameters, states and computation functions.
This module contains the implementation of the :math:`k-\varepsilon` model
described as a GLS case as in [1]_ as a :class:`~closure.Closure` instance. The
model was traduced from Frotran to JAX with the work of Florian Lemarié and
Manolis Perrot [2]_, the translation was done in part using the work of Anthony
Zhou, Linnia Hawkins and Pierre Gentine [3]_. The parameters of the closure are
available in the :class:`KepsParameters` class, the closure state in
:class:`KepsState` class. The function :attr:`keps_step` compute one time-step
of the closure, which means that it computes the eddy-diffusivity and
viscosity. The module contains other functions that are used by this main one.
These classes and the function step can be obtained by the prefix
:code:`tunax.closures.k_epsilon` or directly by :code:`tunax.closures`.
References
----------
.. [1] L. Umlauf and H. Burchard. A generic length-scale equation for
geophysical turbulence models (2003). Journal of Marine Research 61
pp. 235-265. doi : `10.1357/002224003322005087 <https://www.semanticscholar
.org/paper/A-generic-length-scale-equation-for-geophysical-Umlauf-Burchard/
24fd6403615fc7a6c5d9b6156e4f1e8d4d280af2>`_.
.. [2] M. Perrot and F. Lemarié. Energetically consistent Eddy-Diffusivity
Mass-Flux convective schemes. Part I: Theory and Models (2024). url :
`hal.science/hal-04439113 <https://hal.science/hal-04439113>`_.
.. [3] A. Zhou, L. Hawkins and P. Gentine. Proof-of-concept: Using ChatGPT to
Translate and Modernize an Earth System Model from Fortran to Python/JAX
(2024). url : `arxiv.org/abs/2405.00018
<https://arxiv.org/abs/2405.00018>`_.
"""
from __future__ import annotations
from functools import partial
from typing import Tuple
import equinox as eqx
import jax.numpy as jnp
from jax import lax, jit
from jaxtyping import Float, Array
from tunax.case import Case
from tunax.space import Grid, State
from tunax.functions import tridiag_solve, add_boundaries
from tunax.closure import ClosureParametersAbstract, ClosureStateAbstract
[docs]
class KepsParameters(ClosureParametersAbstract):
r"""
Parameters and constants for :math:`k-\varepsilon`.
The first 17 attributes are the parameters of :math:`k-\varepsilon` that
may be calibrated. This class also contains some physical constants used
in the closure computing. The last 19 attributes are the one for the
stability function that are computed from the parameters of
:math:`k-\varepsilon`.
Parameters
----------
c1 : float, default=5.
cf. attribute.
c2 : float, default=0.8
cf. attribute.
c3 : float, default=1.968
cf. attribute.
c4 : float, default=1.136
cf. attribute.
c5 : float, default=0.
cf. attribute.
c6 : float, default=0.4
cf. attribute.
cb1 : float, default=5.95
cf. attribute.
cb2 : float, default=.6
cf. attribute.
cb3 : float, default=1.
cf. attribute.
cb4 : float, default=0.
cf. attribute.
cb5 : float, default=0.3333
cf. attribute.
cbb : float, default=.72
cf. attribute.
c_mu0 : float, default=0.5477
cf. attribute.
sig_k : float, default=1.
cf. attribute.
sig_eps : float, default=1.3
cf. attribute.
c_eps1 : float, default=1.44
cf. attribute.
c_eps2 : float, default=1.92
cf. attribute.
c_eps3m : float, default=-0.4
cf. attribute.
c_eps3p : float, default=1.
cf. attribute.
chk_grav : float, default=1400.
cf. attribute.
galp: float, default=0.53
cf. attribute.
z0s_min : float, default=1e-2
cf. attribute.
z0b_min : float, default=1e-4
cf. attribute.
z0b : float, default=1e-14
cf. attribute.
akt_min : float, default=1e-5
cf. attribute.
akt_min : float, default=1e-4
cf. attribute.
tke_min : float, default=1e-6
cf. attribute.
eps_min : float, default=1e-12
cf. attribute.
c_mu_min : float, default=0.1
cf. attribute.
c_mu_prim_min : float, default=0.1
cf. attribute.
dir_sfc: bool, default=False
cf. attribute.
dir_btm: bool, default=True
cf. attribute.
gls_p : float, default=3
cf. attribute.
gls_m : float, default=1.5
cf. attribute.
gls_n : float, default=-1
cf. attribute.
Attributes
----------
c1 : float, default=5.
:math:`k-\varepsilon` parameter :math:`c_1` for the dissipation of the
corelation tensor pressure/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
c2 : float, default=0.8
:math:`k-\varepsilon` parameter :math:`c_2` for the dissipation of the
corelation tensor pressure/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
c3 : float, default=1.968
:math:`k-\varepsilon` parameter :math:`c_3` for the dissipation of the
corelation tensor pressure/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
c4 : float, default=1.136
:math:`k-\varepsilon` parameter :math:`c_4` for the dissipation of the
corelation tensor pressure/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
c5 : float, default=0.
:math:`k-\varepsilon` parameter :math:`c_5` for the dissipation of the
corelation tensor pressure/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
c6 : float, default=0.4
:math:`k-\varepsilon` parameter :math:`c_6` for the dissipation of the
corelation tensor pressure/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
cb1 : float, default=5.95
:math:`k-\varepsilon` parameter :math:`c_{b1}` for the dissipation of
the corelation tensor buoyancy/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
cb2 : float, default=.6
:math:`k-\varepsilon` parameter :math:`c_{b2}` for the dissipation of
the corelation tensor buoyancy/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
cb3 : float, default=1.
:math:`k-\varepsilon` parameter :math:`c_{b3}` for the dissipation of
the corelation tensor buoyancy/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
cb4 : float, default=0.
:math:`k-\varepsilon` parameter :math:`c_{b4}` for the dissipation of
the corelation tensor buoyancy/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
cb5 : float, default=0.3333
:math:`k-\varepsilon` parameter :math:`c_{b5}` for the dissipation of
the corelation tensor buoyancy/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
cbb : float, default=.72
:math:`k-\varepsilon` parameter :math:`c_{b}` for the dissipation of
the corelation tensor buoyancy/velocity (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
c_mu0 : float, default=0.5477
:math:`k-\varepsilon` parameter :math:`c_\mu^0` which links the mixing
length to the dissipation (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
sig_k : float, default=1.
:math:`k-\varepsilon` parameter :math:`\sigma_k` Schmit number for the
dissipation of TKE (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
sig_eps : float, default=1.3
:math:`k-\varepsilon` parameter :math:`\sigma_\varepsilon` Schmit
number for the dissipation of :math:`\varepsilon` (Umlauf and Burchard
notations) :math:`[\text{dimensionless}]`.
c_eps1 : float, default=1.44
:math:`k-\varepsilon` parameter :math:`c_{\varepsilon 1}` correction of
the :math:`\varepsilon` equation (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
c_eps2 : float, default=1.92
:math:`k-\varepsilon` parameter :math:`c_{\varepsilon 2}` correction of
the :math:`\varepsilon` equation (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
c_eps3m : float, default=-0.4
:math:`k-\varepsilon` parameter :math:`c_{\varepsilon 3}^-` correction
of the :math:`\varepsilon` equation (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
c_eps3p : float, default=1.
:math:`k-\varepsilon` parameter :math:`c_{\varepsilon 3}^+` correction
of the :math:`\varepsilon` equation (Umlauf and Burchard notations)
:math:`[\text{dimensionless}]`.
chk_grav : float, default=1400.
Charnock coefficient times gravity :math:`[\text{dimensionless}]`.
galp: float, default=0.53
Parameter for Galperin mixing length limitation
:math:`[\text{dimensionless}]`.
z0s_min : float, default=1e-2
Minimal surface roughness length :math:`[\text m]`.
z0b_min : float, default=1e-4
Minimal bottom roughness length :math:`[\text m]`.
z0b : float, default=1e-14
Bottom roughness length :math:`[\text m]`.
akt_min : float, default=1e-5
Minimal and initialization value of eddy-diffusivity
:math:`\left[\text m^2 \cdot \text s^{-1} \right]`.
akv_min : float, default=1e-4
Minimal and initialization value of eddy-viscosity
:math:`\left[\text m^2 \cdot \text s^{-1} \right]`.
tke_min : float, default=1e-6
Minimal and initialization value of turbulent kinetic energy (TKE)
:math:`\left[\text m^3 \cdot \text s^{-2} \right]`.
eps_min : float, default=1e-12
Minimal and initialization value of TKE dissipation
:math:`\left[\text m^2 \cdot \text s^{-3} \right]`.
c_mu_min : float, default=0.1
Minimal and initialization value of :math:`c_\mu` in GLS formalisim
:math:`[\text{dimensionless}]`.
c_mu_prim_min : float, default=0.1
Minimal and initialization value of `c_\mu'` in GLS formalisim
:math:`[\text{dimensionless}]`.
dir_sfc: bool, default=False
Apply a Dirichlet boundary condition at the surface for TKE, else
apply a Neumann boundary condition.
dir_btm: bool, default=True
Apply a Dirichlet boundary condition at the bottom for TKE, else
apply a Neumann boundary condition.
gls_p : float, default=3
GLS coefficient :math:`p` to define :math:`k-\varepsilon`
:math:`[\text{dimensionless}]`.
gls_m : float, default=1.5
GLS coefficient :math:`m` to define :math:`k-\varepsilon`
:math:`[\text{dimensionless}]`.
gls_n : float, default=-1
GLS coefficient :math:`n` to define :math:`k-\varepsilon`
:math:`[\text{dimensionless}]`.
sf_d0 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_d1 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_d2 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_d3 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_d4 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_d5 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_n0 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_n1 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_n2 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_nb0 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_nb1 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
sf_nb2 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
lim_am0 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
lim_am1 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
lim_am2 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
lim_am3 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
lim_am4 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
lim_am5 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
lim_am6 : float (not a parameter, computed from the above attributes)
Limitation coefficient for :math:`k-\varepsilon` computed from the
parameters :math:`[\text{dimensionless}]`.
"""
# k-epsilon coefficients (Umlauf and Burchard notations)
c1: float = 5.
c2: float = .8
c3: float = 1.968
c4: float = 1.136
c5: float = 0.
c6: float = .4
cb1: float = 5.95
cb2: float = .6
cb3: float = 1.
cb4: float = 0.
cb5: float = 0.3333
cbb: float = 0.72
c_mu0: float = 0.5477
sig_k: float = 1.
sig_eps: float = 1.3
c_eps1: float = 1.44
c_eps2: float = 1.92
c_eps3m: float = -.4
c_eps3p: float = 1.
# physical constants
chk_grav: float = 1400.
galp: float = .53
z0s_min: float = 1e-2
z0b_min: float = 1e-2
z0b: float = 1e-14
akt_min: float = 1e-5
akv_min: float = 1e-4
tke_min: float = 1e-6
eps_min: float = 1e-12
c_mu_min: float = .1
c_mu_prim_min: float = .1
# physical case
dir_sfc: bool = False
dir_btm: bool = True
# GLS coefficient for k-epsilon
gls_p: float = 3
gls_m: float = 1.5
gls_n: float = -1
# limitation coefficients computed from k-epsilon coefficients
sf_d0: float = eqx.field(init=False)
sf_d1: float = eqx.field(init=False)
sf_d2: float = eqx.field(init=False)
sf_d3: float = eqx.field(init=False)
sf_d4: float = eqx.field(init=False)
sf_d5: float = eqx.field(init=False)
sf_n0: float = eqx.field(init=False)
sf_n1: float = eqx.field(init=False)
sf_n2: float = eqx.field(init=False)
sf_nb0: float = eqx.field(init=False)
sf_nb1: float = eqx.field(init=False)
sf_nb2: float = eqx.field(init=False)
lim_am0: float = eqx.field(init=False)
lim_am1: float = eqx.field(init=False)
lim_am2: float = eqx.field(init=False)
lim_am3: float = eqx.field(init=False)
lim_am4: float = eqx.field(init=False)
lim_am5: float = eqx.field(init=False)
lim_am6: float = eqx.field(init=False)
def __post_init__(self):
# stability function coefficients
a1 = .66666666667 - .5*self.c2
a2 = 1 - .5*self.c3
a3 = 1 - .5*self.c4
a5 = .5 - .5*self.c6
nn = .5*self.c1
nb = self.cb1
ab1 = 1 - self.cb2
ab2 = 1 - self.cb3
ab3 = 2*(1 - self.cb4)
ab5 = 2*self.cbb*(1-self.cb5)
sf_d0 = 36*nn*nn*nn*nb*nb
sf_d1 = 84*a5*ab3*nn*nn*nb + 36*ab5*nn*nn*nn*nb
sf_d2 = 9*(ab2*ab2 - ab1*ab1)*nn*nn*nn - 12*(a2*a2 - 3*a3*a3)*nn*nb*nb
sf_d3 = 12*a5*ab3*(a2*ab1 - 3*a3*ab2)*nn + \
12*a5*ab3*(a3*a3 - a2*a2)*nb + 12*ab5*(3*a3*a3 - a2*a2)*nn*nb
sf_d4 = 48*a5*a5*ab3*ab3*nn + 36*a5*ab3*ab5*nn*nn
sf_d5 = 3*(a2*a2 - 3*a3*a3)*(ab1*ab1 - ab2*ab2)*nn
sf_n0 = 36*a1*nn*nn*nb*nb
sf_n1 = -12*a5*ab3*(ab1 + ab2)*nn*nn + \
8*a5*ab3*(6*a1 - a2 - 3*a3)*nn*nb + 36*a1*ab5*nn*nn*nb
sf_n2 = 9*a1*(ab2*ab2 - ab1*ab1)*nn*nn
sf_nb0 = 12*ab3*nn*nn*nn*nb
sf_nb1 = 12*a5*ab3*ab3*nn*nn
sf_nb2 = 9*a1*ab3*(ab1 - ab2)*nn*nn + \
(6*a1*(a2 - 3*a3) - 4*(a2*a2 - 3*a3*a3))*ab3*nn*nb
lim_am0 = sf_d0*sf_n0
lim_am1 = sf_d0*sf_n1 + sf_d1*sf_n0
lim_am2 = sf_d1*sf_n1 + sf_d4*sf_n0
lim_am3 = sf_d4*sf_n1
lim_am4 = sf_d2*sf_n0
lim_am5 = sf_d2*sf_n1 + sf_d3*sf_n0
lim_am6 = sf_d3*sf_n1
object.__setattr__(self, 'sf_d0', sf_d0)
object.__setattr__(self, 'sf_d1', sf_d1)
object.__setattr__(self, 'sf_d2', sf_d2)
object.__setattr__(self, 'sf_d3', sf_d3)
object.__setattr__(self, 'sf_d4', sf_d4)
object.__setattr__(self, 'sf_d5', sf_d5)
object.__setattr__(self, 'sf_n0', sf_n0)
object.__setattr__(self, 'sf_n1', sf_n1)
object.__setattr__(self, 'sf_n2', sf_n2)
object.__setattr__(self, 'sf_nb0', sf_nb0)
object.__setattr__(self, 'sf_nb1', sf_nb1)
object.__setattr__(self, 'sf_nb2', sf_nb2)
object.__setattr__(self, 'lim_am0', lim_am0)
object.__setattr__(self, 'lim_am1', lim_am1)
object.__setattr__(self, 'lim_am2', lim_am2)
object.__setattr__(self, 'lim_am3', lim_am3)
object.__setattr__(self, 'lim_am4', lim_am4)
object.__setattr__(self, 'lim_am5', lim_am5)
object.__setattr__(self, 'lim_am6', lim_am6)
[docs]
class KepsState(ClosureStateAbstract):
r"""
Define the state of the water column for the :math:`k-\varepsilon` model.
The first initilisation is done from the minimal values of the different
variables given in an instance of :class:'KepsParameters`.
Parameters
----------
grid : Grid
cf. attribute.
keps_params : KepsParameters
Used to define the initialization values of the variables.
Attributes
----------
grid : Grid
Geometry of the water column, should be the same than for the
:class:`~space.State` instance used in the model.
akt : Float[~jax.Array, 'nz+1']
Eddy-diffusivity on the interfaces of the cells
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
akv : Float[~jax.Array, 'nz+1']
Eddy-viscosity on the interfaces of the cells
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
tke : Float[~jax.Array, 'nz+1']
Turbulent kinetic energy (TKE) denoted :math:`k` on the interfaces of
the cells :math:`\left[\text m ^2 \cdot \text s ^{-2}\right]`.
eps : Float[~jax.Array, 'nz+1']
TKE dissipation denoted :math:`\varepsilon` on the interfaces of the
cells :math:`\left[\text m ^2 \cdot \text s ^{-3}\right]`.
c_mu : Float[~jax.Array, 'nz+1']
:math:`c_\mu` in GLS formalisim on the interfaces of the cells
:math:`[\text{dimensionless}]`.
c_mu_prim : Float[~jax.Array, 'nz+1']
:math:`c_\mu'` in GLS formalisim on the interfaces of the cells
:math:`[\text{dimensionless}]`.
"""
grid: Grid
akt: Float[Array, 'nz+1']
akv: Float[Array, 'nz+1']
tke: Float[Array, 'nz+1']
eps: Float[Array, 'nz+1']
c_mu: Float[Array, 'nz+1']
c_mu_prim: Float[Array, 'nz+1']
def __init__(self, grid: Grid, keps_params: KepsParameters):
self.grid = grid
nz = grid.nz
self.akt = jnp.full(nz+1, keps_params.akt_min)
self.akv = jnp.full(nz+1, keps_params.akv_min)
self.tke = jnp.full(nz+1, keps_params.tke_min)
self.eps = jnp.full(nz+1, keps_params.eps_min)
self.c_mu = jnp.full(nz+1, keps_params.c_mu_min)
self.c_mu_prim = jnp.full(nz+1, keps_params.c_mu_prim_min)
[docs]
@partial(jit, static_argnames=('case',))
def keps_step(
state: State,
keps_state: KepsState,
dt: float,
keps_params: KepsParameters,
case: Case
) -> KepsState:
r"""
Run one time-step of the :math:`k-\varepsilon` closure.
The purpose of this function is to get the eddy-diffusivity and
eddy-viscosity at the next time-step. It works in 3 steps
1. The Brunt–Väisälä frequency and the shear is computed from the
:code:`state` and the boundary conditions are computed.
2. The equations on :math:`k` and :math:`\varepsilon` are solved and their
values are computed for the next time step.
3. The eddy-diffusivity and viscosity are computed as diagnostic variables
and the :math:`keps_state` is updated.
Parameters
----------
state : State
Current state of the water column.
keps_state : KepsState
Current state of the water column for the variables used by
:math:`k-\varepsilon`.
dt : float
Time-step of the forward model :math:`[\text s]`.
keps_params: KepsParameters
Values of the parameters used by :math:`k-\varepsilon` (time-
independant).
case : Case
Physical parameters and forcings of the model run.
Returns
-------
keps_state : KepsState
State of the water column for the variables used by
:math:`k-\varepsilon` at the next time-step.
Note
----
This function is jitted with JAX, it should make it faster, but the
:func:`~jax.jit` decorator can be removed.
"""
akt = keps_state.akt
akv = keps_state.akv
tke = keps_state.tke
eps = keps_state.eps
c_mu = keps_state.c_mu
c_mu_prim = keps_state.c_mu_prim
u = state.u
v = state.v
zr = state.grid.zr
hz = state.grid.hz
# prognostic computations
_, bvf = compute_rho_eos(state.t, state.s, zr, case)
shear2 = compute_shear(u, v, u, v, zr)
tke_sfc_bc, tke_btm_bc, eps_sfc_bc, eps_btm_bc = compute_tke_eps_bc(
tke, hz, keps_params, case
)
# integrations
tke_new = advance_turb(
akt, akv, tke, tke, eps, c_mu, c_mu_prim, bvf, shear2, hz, dt,
tke_sfc_bc, tke_btm_bc, eps_sfc_bc, eps_btm_bc, keps_params, True
)
eps_new = advance_turb(
akt, akv, tke, tke_new, eps, c_mu, c_mu_prim, bvf, shear2, hz, dt,
tke_sfc_bc, tke_btm_bc, eps_sfc_bc, eps_btm_bc, keps_params, False
)
# diagnostic variables
akt_new, akv_new, eps_new, c_mu_new, c_mu_prim_new = compute_diag(
tke_new, eps_new, bvf, shear2, keps_params
)
keps_state = eqx.tree_at(lambda t: t.akv, keps_state, akv_new)
keps_state = eqx.tree_at(lambda t: t.akt, keps_state, akt_new)
keps_state = eqx.tree_at(lambda t: t.tke, keps_state, tke_new)
keps_state = eqx.tree_at(lambda t: t.eps, keps_state, eps_new)
keps_state = eqx.tree_at(lambda t: t.c_mu, keps_state, c_mu_new)
keps_state = eqx.tree_at(lambda t: t.c_mu_prim, keps_state, c_mu_prim_new)
return keps_state
[docs]
def compute_rho_eos(
t: Float[Array, 'nz'],
s: Float[Array, 'nz'],
zr: Float[Array, 'nz'],
case: Case
) -> Tuple[Float[Array, 'nz+1'], Float[Array, 'nz']]:
r"""
Compute density anomaly and Brunt–Väisälä frequency.
Prognostic computation via linear Equation Of State (EOS) :
:math:`\rho = \rho_0(1-\alpha (T-T_0) + \beta (S-S_0))`
:math:`N^2 = - \dfrac g {\rho_0} \partial_z \rho`
Parameters
----------
t : Float[~jax.Array, 'nz']
Temperature on the center of the cells :math:`[° \text C]`.
s : Float[~jax.Array, 'nz']
Salinity on the center of the cells :math:`[\text{psu}]`.
zr : Float[~jax.Array, 'nz']
Depths of cell centers from deepest to shallowest :math:`[\text m]`
case : Case
Physical parameters and forcings of the model run.
Returns
-------
bvf : Float[Array, 'nz+1']
Brunt–Väisälä frequency squared :math:`N^2` on cell interfaces
:math:`\left[\text s^{-2}\right]`.
rho : Float[Array, 'nz']
Density anomaly :math:`\rho` on cell interfaces
:math:`\left[\text {kg} \cdot \text m^{-3}\right]`
"""
rho0 = case.rho0
rho = rho0 * (1. - case.alpha*(t-case.t_rho_ref) + \
case.beta*(s-case.s_rho_ref))
cff = 1./(zr[1:]-zr[:-1])
bvf_in = - cff*case.grav/rho0 * (rho[1:]-rho[:-1])
bvf = add_boundaries(0., bvf_in, bvf_in[-1])
return rho, bvf
[docs]
def compute_shear(
u_n: Float[Array, 'nz'],
v_n: Float[Array, 'nz'],
u_np1: Float[Array, 'nz'],
v_np1: Float[Array, 'nz'],
zr: Float[Array, 'nz']
) -> Float[Array, 'nz+1']:
r"""
Compute shear production term for TKE equation.
The prognostic equations are
:math:`S_h^2 = \partial_Z U^n \cdot \partial_z U^{n+1/2}`
where :math:`U^{n+1/2}` is the mean between :math:`U^n` and
:math:`U^{n+1}`.
Parameters
----------
u_n : Float[~jax.Array, 'nz']
Current zonal velocity on the center of the cells
:math:`\left[\text m \cdot \text s^{-1}\right]`.
v_n : Float[~jax.Array, 'nz']
Current meridional velocity on the center of the cells
:math:`\left[\text m \cdot \text s^{-1}\right]`.
u_np1 : Float[~jax.Array, 'nz']
Zonal velocity on the center of the cells at the next time step
:math:`\left[\text m \cdot \text s^{-1}\right]`.
v_np1 : Float[~jax.Array, 'nz']
Meridional velocity on the center of the cells at the next time step
:math:`\left[\text m \cdot \text s^{-1}\right]`.
zr : Float[~jax.Array, 'nz']
Depths of cell centers from deepest to shallowest :math:`[\text m]`
Returns
-------
shear2 : Float[~jax.Array, 'nz+1']
Shear production squared :math:`S_h^2` on cell interfaces
:math:`\left[\text m ^2 \cdot \text s ^{-3}\right]`.
"""
cff = 1.0 / (zr[1:] - zr[:-1])**2
du = 0.5*cff * (u_np1[1:]-u_np1[:-1]) * \
(u_n[1:]+u_np1[1:]-u_n[:-1]-u_np1[:-1])
dv = 0.5*cff * (v_np1[1:]-v_np1[:-1]) * \
(v_n[1:]+v_np1[1:]-v_n[:-1]-v_np1[:-1])
shear2_in = du + dv
return add_boundaries(0., shear2_in, 0.)
[docs]
def compute_tke_eps_bc(
tke: Float[Array, 'nz+1'],
hz: Float[Array, 'nz'],
keps_params: KepsParameters,
case: Case
) -> Tuple[float, float, float, float]:
r"""
Compute top and bottom boundary conditions for TKE and :math:`\varepsilon`.
Parameters
----------
tke : Float[~jax.Array, 'nz+1']
Turbulent kinetic energy (TKE) denoted :math:`k` on the interfaces of
the cells :math:`\left[\text m ^2 \cdot \text s ^{-2}\right]`.
hz : Float[~jax.Array, 'nz']
Thickness of cells from deepest to shallowest :math:`[\text m]`.
keps_params : KepsParameters
Values of the parameters used by :math:`k-\varepsilon`.
case : Case
Physical parameters and forcings of the model run.
Returns
-------
tke_sfc_bc : float
TKE value for surface boundary condition (Dirichlet
:math:`\left[\text m ^2 \cdot \text s ^{-2}\right]` or Neumann
:math:`\left[\text m ^3 \cdot \text s ^{-3}\right]`).
tke_btm_bc : float
TKE value for bottom boundary condition (Dirichlet
:math:`\left[\text m ^2 \cdot \text s ^{-2}\right]` or Neumann
:math:`\left[\text m ^3 \cdot \text s ^{-3}\right]`).
eps_sfc_bc : float
:math:`\varepsilon` value for surface boundary condition (Dirichlet
:math:`\left[\text m ^2 \cdot \text s ^{-3}\right]` or Neumann
:math:`\left[\text m ^3 \cdot \text s ^{-4}\right]`).
eps_btm_bc : float
:math:`\varepsilon` value for surface boundary condition (Dirichlet
:math:`\left[\text m ^2 \cdot \text s ^{-3}\right]` or Neumann
:math:`\left[\text m ^3 \cdot \text s ^{-4}\right]`).
Note
----
The kind of boundary conditions between Neumann and Dirichlet are register
in the parameters :code:`keps_params`.
"""
# constants
rp, rm, rn = keps_params.gls_p, keps_params.gls_m, keps_params.gls_n
c_mu0 = keps_params.c_mu0
cm0inv2 = 1./c_mu0**2
vkarmn = case.vkarmn
chk = keps_params.chk_grav/case.grav
sig_eps = keps_params.sig_eps
# velocity scales
ustar2_sfc = jnp.sqrt(case.ustr_sfc**2 + case.vstr_sfc**2)
ustar2_bot = jnp.sqrt(case.ustr_btm**2 + case.vstr_btm**2)
# TKE Dirichlet boundary condition
tke_sfc_dir = jnp.maximum(keps_params.tke_min, cm0inv2*ustar2_sfc)
tke_btm_dir = jnp.maximum(keps_params.tke_min, cm0inv2*ustar2_bot)
# TKE Neumann boundary condition
tke_sfc_neu = 0.0
tke_btm_neu = 0.0
# epsilon surface conditions
z0_s = jnp.maximum(keps_params.z0s_min, chk*ustar2_sfc)
lgthsc = vkarmn*(0.5*hz[-1] + z0_s)
tke_sfc = 0.5*(tke[-1]+tke[-2])
eps_sfc_dir = jnp.maximum(keps_params.eps_min, \
c_mu0**rp * lgthsc**rn * tke_sfc**rm)
eps_sfc_neu = -rn*vkarmn/sig_eps * c_mu0**(rp+1) * \
tke_sfc**(rm+.5) * lgthsc**rn
# epsilon bottom conditions
z0b = jnp.maximum(keps_params.z0b, keps_params.z0b_min)
lgthsc = vkarmn *(0.5*hz[0] + z0b)
tke_btm = 0.5*(tke[0]+tke[1])
eps_btm_dir = jnp.maximum(keps_params.eps_min, \
c_mu0**rp * lgthsc**rn * tke_btm**rm)
eps_btm_neu = -rn*vkarmn/sig_eps * c_mu0**(rp+1) * \
tke_btm**(rm+.5) * lgthsc**rn
tke_sfc_bc = jnp.where(keps_params.dir_sfc, tke_sfc_dir, tke_sfc_neu)
tke_btm_bc = jnp.where(keps_params.dir_btm, tke_btm_dir, tke_btm_neu)
eps_sfc_bc = jnp.where(keps_params.dir_sfc, eps_sfc_dir, eps_sfc_neu)
eps_btm_bc = jnp.where(keps_params.dir_btm, eps_btm_dir, eps_btm_neu)
return tke_sfc_bc, tke_btm_bc, eps_sfc_bc, eps_btm_bc
[docs]
def advance_turb(
akt: Float[Array, 'nz+1'],
akv: Float[Array, 'nz+1'],
tke: Float[Array, 'nz+1'],
tke_np1: Float[Array, 'nz+1'],
eps: Float[Array, 'nz+1'],
c_mu: Float[Array, 'nz+1'],
c_mu_prim: Float[Array, 'nz+1'],
bvf: Float[Array, 'nz+1'],
shear2: Float[Array, 'nz+1'],
hz: Float[Array, 'nz'],
dt: float,
tke_sfc_bc: float,
tke_btm_bc: float,
eps_sfc_bc: float,
eps_btm_bc: float,
keps_params: KepsParameters,
do_tke: bool
) -> Float[Array, 'nz+1']:
r"""
Integrate TKE or :math:`\varepsilon` quantities.
First the shear and buoyancy production are computed, then they are used in
the building of the tridiagonal problem, the boundary conditions are then
added and finally the tridiagonal problem is solved.
Parameters
----------
akt : Float[~jax.Array, 'nz+1']
Current eddy-diffusivity on the interfaces of the cells
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
akv : Float[~jax.Array, 'nz+1']
Current eddy-viscosity on the interfaces of the cells
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
tke : Float[~jax.Array, 'nz+1']
Current turbulent kinetic energy (TKE) denoted :math:`k` on the
interfaces of the cells
:math:`\left[\text m ^2 \cdot \text s ^{-2}\right]`.
tke_np1 : Float[~jax.Array, 'nz+1']
Turbulent kinetic energy (TKE) denoted :math:`k` on the interfaces of
the cells at next step (usefull only for :math:`\varepsilon`
integration) :math:`\left[\text m ^2 \cdot \text s ^{-2}\right]`.
eps : Float[~jax.Array, 'nz+1']
Current TKE dissipation denoted :math:`\varepsilon` on the interfaces
of the cells :math:`\left[\text m ^2 \cdot \text s ^{-3}\right]`.
c_mu : Float[~jax.Array, 'nz+1']
Current :math:`c_\mu` in GLS formalisim on the interfaces of the cells
:math:`[\text{dimensionless}]`.
c_mu_prim : Float[~jax.Array, 'nz+1']
Current :math:`c_\mu'` in GLS formalisim on the interfaces of the cells
:math:`[\text{dimensionless}]`.
bvf : float(nz+1)
Current Brunt–Väisälä frequency squared :math:`N^2` on cell interfaces
:math:`\left[\text s^{-2}\right]`.
shear2 : Float[~jax.Array, 'nz+1']
Current shear production squared :math:`S_h^2` on cell interfaces
:math:`\left[\text m ^2 \cdot \text s ^{-3}\right]`.
hz : Float[~jax.Array, 'nz']
Thickness of cells from deepest to shallowest :math:`[\text m]`.
dt : float
Time-step of the forward model :math:`[\text s]`.
tke_sfc_bc : float
TKE value for surface boundary condition (Dirichlet
:math:`\left[\text m ^2 \cdot \text s ^{-2}\right]` or Neumann
:math:`\left[\text m ^3 \cdot \text s ^{-3}\right]`).
tke_btm_bc : float
TKE value for bottom boundary condition (Dirichlet
:math:`\left[\text m ^2 \cdot \text s ^{-2}\right]` or Neumann
:math:`\left[\text m ^3 \cdot \text s ^{-3}\right]`).
eps_sfc_bc : float
:math:`\varepsilon` value for surface boundary condition (Dirichlet
:math:`\left[\text m ^2 \cdot \text s ^{-3}\right]` or Neumann
:math:`\left[\text m ^3 \cdot \text s ^{-4}\right]`).
eps_btm_bc : float
:math:`\varepsilon` value for surface boundary condition (Dirichlet
:math:`\left[\text m ^2 \cdot \text s ^{-3}\right]` or Neumann
:math:`\left[\text m ^3 \cdot \text s ^{-4}\right]`).
keps_params : KepsParameters
Values of the parameters used by :math:`k-\varepsilon`.
do_tke : bool
If :code:`True` solve the equation for TKE, else for
:math:`\varepsilon`.
Returns
-------
vec : Float[~jax.Array, 'nz+1']
TKE or :math:`\varepsilon` at next step (depending on :code:`do_tke`).
"""
# fill the matrix off-diagonal terms for the tridiagonal problem
cff = -0.5*dt
ak_vec = jnp.where(do_tke, akv/keps_params.sig_k, akv/keps_params.sig_eps)
a_in = cff*(ak_vec[1:-1]+ak_vec[:-2]) / hz[:-1]
c_in = cff*(ak_vec[1:-1]+ak_vec[2:]) / hz[1:]
# shear and buoyancy production
s_prod_tke = akv[1:-1]*shear2[1:-1]
b_prod_tke = -akt[1:-1]*bvf[1:-1]
s_prod_eps = keps_params.c_eps1*c_mu[1:-1]*tke[1:-1]*shear2[1:-1]
b_prod_eps = -c_mu_prim[1:-1] * tke[1:-1] * \
(keps_params.c_eps3m*jnp.maximum(bvf[1:-1], 0) + \
keps_params.c_eps3p*jnp.minimum(bvf[1:-1], 0))
s_prod = jnp.where(do_tke, s_prod_tke, s_prod_eps)
b_prod = jnp.where(do_tke, b_prod_tke, b_prod_eps)
# diagonal and f term
cff = 0.5*(hz[:-1] + hz[1:])
f_tke_in = lax.select(b_prod+s_prod > 0,
cff*(tke[1:-1]+dt*(b_prod+s_prod)),
cff*(tke[1:-1]+dt*s_prod))
f_eps_in = lax.select(b_prod+s_prod > 0,
cff*(eps[1:-1]+dt*(b_prod+s_prod)),
cff*(eps[1:-1]+dt*s_prod))
f_in = jnp.where(do_tke, f_tke_in, f_eps_in)
b_tke_in = lax.select((b_prod + s_prod) > 0,
cff*(1. + dt*eps[1:-1]/tke[1:-1]) - a_in - c_in,
cff*(1. + dt*(eps[1:-1] - b_prod)/tke[1:-1]) - a_in - c_in)
b_eps_in = lax.select((b_prod + s_prod) > 0,
cff*(1. + dt*keps_params.c_eps2*eps[1:-1]/tke_np1[1:-1]) - a_in - c_in,
cff*(1. + dt*keps_params.c_eps2*eps[1:-1]/tke_np1[1:-1] - \
dt*b_prod/eps[1:-1]) - a_in - c_in)
b_in = jnp.where(do_tke, b_tke_in, b_eps_in)
# surface boundary condition
dir_sfc = keps_params.dir_sfc
a_sfc = jnp.where(dir_sfc, 0., -0.5*(ak_vec[-1] + ak_vec[-2]))
b_sfc = jnp.where(dir_sfc, 1., 0.5*(ak_vec[-1] + ak_vec[-2]))
sfc_bc = jnp.where(do_tke, tke_sfc_bc, eps_sfc_bc)
f_sfc = jnp.where(dir_sfc, sfc_bc, hz[-1]*sfc_bc)
# bottom boundary condition
dir_btm = keps_params.dir_btm
b_btm = jnp.where(dir_sfc, 1., -0.5*(ak_vec[0] + ak_vec[1]))
c_btm = jnp.where(dir_sfc, 0., 0.5*(ak_vec[0] + ak_vec[1]))
btm_bc = jnp.where(do_tke, tke_btm_bc, eps_btm_bc)
f_btm = jnp.where(dir_btm, btm_bc, hz[0]*btm_bc)
# vectors rassembly
a = add_boundaries(0., a_in, a_sfc)
b = add_boundaries(b_btm, b_in, b_sfc)
c = add_boundaries(c_btm, c_in, 0.)
f = add_boundaries(f_btm, f_in, f_sfc)
# solve tridiagonal problem
f = tridiag_solve(a, b, c, f)
vec_min = jnp.where(do_tke, keps_params.tke_min, keps_params.eps_min)
vec = jnp.maximum(f, vec_min)
return vec
[docs]
def compute_diag(
tke: Float[Array, 'nz+1'],
eps: Float[Array, 'nz+1'],
bvf: Float[Array, 'nz+1'],
shear2: Float[Array, 'nz+1'],
keps_params: KepsParameters
) -> Tuple[Float[Array, 'nz+1'], Float[Array, 'nz+1'],
Float[Array, 'nz+1'], Float[Array, 'nz+1'],
Float[Array, 'nz+1']]:
r"""
Computes the diagnostic variables of :math:`k-\varepsilon` closure.
This function first apply the Galperin limitation, then it computes
:math:`c_\mu'` and :math:`c_\mu` with the stability function, and finally
it computes the eddy-diffusivity and viscosity with these variables.
Parameters
----------
tke : Float[~jax.Array, 'nz+1']
Turbulent kinetic energy (TKE) denoted :math:`k` on the interfaces of
the cells at next step
:math:`\left[\text m ^2 \cdot \text s ^{-2}\right]`.
eps : Float[~jax.Array, 'nz+1']
TKE dissipation denoted :math:`\varepsilon` on the interfaces of the
cells at next step :math:`\left[\text m ^2 \cdot \text s ^{-3}\right]`.
bvf : Float[~jax.Array, 'nz+1']
Current Brunt–Väisälä frequency squared :math:`N^2` on cell interfaces
:math:`\left[\text s^{-2}\right]`.
shear2 : Float[~jax.Array, 'nz+1']
Current shear production squared :math:`S_h^2` on cell interfaces
:math:`\left[\text m ^2 \cdot \text s ^{-3}\right]`.
keps_params : KepsParameters
Values of the parameters used by :math:`k-\varepsilon`.
Returns
-------
akt : Float[Array, 'nz+1']
Eddy-diffusivity on the interfaces of the cells at next step
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
akv : Float[Array, 'nz+1']
Eddy-viscosity on the interfaces of the cells at next step
:math:`\left[\text m ^2 \cdot \text s ^{-1}\right]`.
eps : Float[Array, 'nz+1']
TKE dissipation denoted :math:`\varepsilon` on the interfaces
of the cells at next step
:math:`\left[\text m ^2 \cdot \text s ^{-3}\right]`.
c_mu : Float[Array, 'nz+1']
:math:`c_\mu` in GLS formalisim on the interfaces of the cells at next
step :math:`[\text{dimensionless}]`.
c_mu_prim : Float[Array, 'nz+1']
:math:`c_\mu'` in GLS formalisim on the interfaces of the cells at next
step :math:`[\text{dimensionless}]`.
"""
# parameters
akv_min = keps_params.akv_min
akt_min = keps_params.akt_min
eps_min = keps_params.eps_min
sf_d0 = keps_params.sf_d0
sf_d1 = keps_params.sf_d1
sf_d2 = keps_params.sf_d2
sf_d3 = keps_params.sf_d3
sf_d4 = keps_params.sf_d4
sf_d5 = keps_params.sf_d5
sf_n0 = keps_params.sf_n0
sf_n1 = keps_params.sf_n1
sf_n2 = keps_params.sf_n2
sf_nb0 = keps_params.sf_nb0
sf_nb1 = keps_params.sf_nb1
sf_nb2 = keps_params.sf_nb2
lim_am0 = keps_params.lim_am0
lim_am1 = keps_params.lim_am1
lim_am2 = keps_params.lim_am2
lim_am3 = keps_params.lim_am3
lim_am4 = keps_params.lim_am4
lim_am5 = keps_params.lim_am5
lim_am6 = keps_params.lim_am6
c_mu0 = keps_params.c_mu0
rp, rm, rn = keps_params.gls_p, keps_params.gls_m, keps_params.gls_n
e1 = 3 + rp/rn
e2 = 1.5 + rm/rn
e3 = -1/rn
# minimum value of alpha_n to ensure that alpha_m is positive
alpha_n_min = 0.5*(- (sf_d1 + sf_nb0) + jnp.sqrt((sf_d1 + sf_nb0)**2 - \
4.0*sf_d0*(sf_d4 + sf_nb1))) / (sf_d4 + sf_nb1)
# Galperin limitation : l <= l_li
l_lim = keps_params.galp*jnp.sqrt(2.0*tke[1:-1] / \
jnp.maximum(1e-14, bvf[1:-1]))
# limitation (use MAX because rn is negative)
cff = c_mu0**rp * l_lim**rn * tke[1:-1]**rm
eps = eps.at[1:-1].set(jnp.maximum(eps[1:-1], cff))
epsilon = c_mu0**e1 * tke[1:-1]**e2 * eps[1:-1]**e3
epsilon = jnp.maximum(epsilon, eps_min)
# compute alpha_n and alpha_m
cff = (tke[1:-1] / epsilon)**2
alpha_m = cff*shear2[1:-1]
alpha_n = cff*bvf[1:-1]
# limitation of alpha_n and alpha_m
alpha_n = jnp.minimum(jnp.maximum(0.73*alpha_n_min, alpha_n), 1e10)
alpha_m_max = (lim_am0 + lim_am1*alpha_n + lim_am2*alpha_n**2 + \
lim_am3*alpha_n**3) / (lim_am4 + lim_am5*alpha_n + \
lim_am6*alpha_n**2)
alpha_m = jnp.minimum(alpha_m, alpha_m_max)
# compute stability functions
denom = sf_d0 + sf_d1*alpha_n + sf_d2*alpha_m + sf_d3*alpha_n*alpha_m \
+ sf_d4*alpha_n**2 + sf_d5*alpha_m**2
cff = 1./denom
c_mu_in = cff*(sf_n0 + sf_n1*alpha_n + sf_n2*alpha_m)
c_mu = add_boundaries(keps_params.c_mu_min, c_mu_in, keps_params.c_mu_min)
c_mu_prim_in = cff*(sf_nb0 + sf_nb1*alpha_n + sf_nb2*alpha_m)
c_mu_prim = add_boundaries(
keps_params.c_mu_prim_min, c_mu_prim_in, keps_params.c_mu_prim_min
)
epsilon = c_mu0**e1 * tke[1:-1]**e2 * eps[1:-1]**e3
epsilon = jnp.maximum(epsilon, eps_min)
# finalize the computation of akv and akt
cff = tke[1:-1]**2 / epsilon
akt_in = jnp.maximum(cff*c_mu_prim[1:-1], akt_min)
akv_in = jnp.maximum(cff*c_mu[1:-1], akv_min)
akt_btm = jnp.maximum(1.5*akt_in[0] - 0.5*akt_in[1], akt_min)
akt_sfc = jnp.maximum(1.5*akt_in[-1] - 0.5*akt_in[-2], akt_min)
akv_btm = jnp.maximum(1.5*akv_in[0] - 0.5*akv_in[1], akv_min)
akv_sfc = jnp.maximum(1.5*akv_in[-1] - 0.5*akv_in[-2], akv_min)
akt = add_boundaries(akt_btm, akt_in, akt_sfc)
akv = add_boundaries(akv_btm, akv_in, akv_sfc)
return akt, akv, eps, c_mu, c_mu_prim