Quickstart: a two-level nested copula¶
This notebook builds a nested Archimedean copula, evaluates its
log-likelihood, and takes an exact gradient with jax.grad — the whole
pipeline in a few cells.
The model is a Frank copula over two Frank "sectors", each with two uniform leaves (4 dimensions total):
Frank (outer)
/ \
Frank Frank (sectors)
/ \ / \
u00 u01 u10 u11 (uniform leaves)
1. Define a copula family¶
A family subclasses Copula via the @copula decorator and implements
generator — the Archimedean generator $\varphi$. Everything else (density,
inverse, sampling) is derived automatically.
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
from acopula import compile_model, copula, marginal
@copula
class Frank:
theta: float
def generator(self, t):
return -jnp.log1p(jnp.expm1(-self.theta) * jnp.exp(-t)) / self.theta
2. Describe the model¶
A model is a plain function (params, obs) -> Node. Calling a copula
instance on its children builds the tree; marginal attaches a distribution
(here a Uniform, since the data is already on the copula scale) and the
observation index it reads.
def model(params, obs):
outer = Frank(params[0])
inner = Frank(params[1])
return outer(
inner(marginal(tfp.distributions.Uniform(0.0, 1.0), obs=obs[i, j])
for j in range(2))
for i in range(2)
)
3. Compile it¶
compile_model traces the function once into a copula tree, builds a
parameter flattener, and returns a jit-compiled, grad/vmap-friendly
evaluator. The template only conveys the parameter structure — its values
are placeholders.
cm = compile_model(model, template=jnp.array([1.0, 1.0]))
4. Evaluate the log-likelihood¶
obs = jnp.array([[0.3, 0.7],
[0.4, 0.8]])
params = jnp.array([2.0, 5.0]) # theta_outer, theta_inner
ll = cm.eval(obs, params)
print(f"log-likelihood: {ll:.6f}")
log-likelihood: -0.958893
5. Differentiate it¶
The log-likelihood is an ordinary JAX function, so its exact gradient w.r.t.
the parameters is one jax.grad away — this is what makes gradient-based MLE
(the next notebook) trivial.
grad = jax.grad(cm.eval, argnums=1)(obs, params)
print(f"d(ll)/d(theta_outer): {grad[0]:.6f}")
print(f"d(ll)/d(theta_inner): {grad[1]:.6f}")
d(ll)/d(theta_outer): 0.109353 d(ll)/d(theta_inner): -0.356627