Fitting a copula by maximum likelihood¶
Because the copula log-likelihood is a differentiable JAX function, fitting it is just gradient descent — no bespoke EM or numerical quadrature. Here we simulate data from a known two-level Clayton copula ($d=10$, two sectors of five) and recover its parameters with Adam.
1. A Clayton family + a copula-scale marginal¶
In [1]:
Copied!
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
from acopula import compile_model, copula, marginal
class Uniform:
"""Trivial marginal: data already lives on the copula (unit-cube) scale."""
def quantile(self, u): return u
def cdf(self, x): return x
def log_prob(self, x): return 0.0
@copula
class Clayton:
theta: float
def generator(self, t):
return jnp.power(1.0 + t, -1.0 / self.theta)
def generator_inv(self, u):
# Closed form u^{-theta} - 1, smooth on (0, 1); avoids the oryx-derived
# inverse that can NaN under higher-order AD.
return jnp.expm1(-self.theta * jnp.log(u))
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
from acopula import compile_model, copula, marginal
class Uniform:
"""Trivial marginal: data already lives on the copula (unit-cube) scale."""
def quantile(self, u): return u
def cdf(self, x): return x
def log_prob(self, x): return 0.0
@copula
class Clayton:
theta: float
def generator(self, t):
return jnp.power(1.0 + t, -1.0 / self.theta)
def generator_inv(self, u):
# Closed form u^{-theta} - 1, smooth on (0, 1); avoids the oryx-derived
# inverse that can NaN under higher-order AD.
return jnp.expm1(-self.theta * jnp.log(u))
2. A two-level model factory¶
Root Clayton over d / group_size Clayton sectors of uniform leaves. The
parameters are a dict {theta_root, theta_sector}.
In [2]:
Copied!
def make_model(d=10, group_size=5):
n_sectors = d // group_size
def model(params, u):
root = Clayton(params["theta_root"])
sector = Clayton(params["theta_sector"])
sectors = []
for s in range(n_sectors):
leaves = [marginal(Uniform(), obs=u[s * group_size + j])
for j in range(group_size)]
sectors.append(sector(leaves))
return root(sectors)
return model
d, group_size = 10, 5
true_params = {"theta_root": 2.0, "theta_sector": 5.0}
model = make_model(d, group_size)
def make_model(d=10, group_size=5):
n_sectors = d // group_size
def model(params, u):
root = Clayton(params["theta_root"])
sector = Clayton(params["theta_sector"])
sectors = []
for s in range(n_sectors):
leaves = [marginal(Uniform(), obs=u[s * group_size + j])
for j in range(group_size)]
sectors.append(sector(leaves))
return root(sectors)
return model
d, group_size = 10, 5
true_params = {"theta_root": 2.0, "theta_sector": 5.0}
model = make_model(d, group_size)
3. Simulate data¶
compile_model(..., method="bell") builds the Bell-polynomial likelihood
evaluator; cm.sample draws from the copula (here via the Rosenblatt
transform).
In [3]:
Copied!
cm = compile_model(model, template=true_params, method="bell")
data = cm.sample(jrandom.PRNGKey(0), 1000, true_params, method="rosenblatt")
data = jnp.clip(data, 1e-3, 1 - 1e-3)
print("simulated data shape:", data.shape)
cm = compile_model(model, template=true_params, method="bell")
data = cm.sample(jrandom.PRNGKey(0), 1000, true_params, method="rosenblatt")
data = jnp.clip(data, 1e-3, 1 - 1e-3)
print("simulated data shape:", data.shape)
simulated data shape: (1000, 10)
4. Fit by gradient descent¶
We maximise the mean log-likelihood with Adam, optimising in log-space so
the Clayton $\theta$ stays positive. cm.ll_fn(x, flat_params) scores one
observation; jax.vmap batches it over the dataset.
In [4]:
Copied!
log_theta = {k: jnp.log(jnp.array(1.0)) for k in true_params} # init theta = 1
def neg_mean_ll(log_theta):
params = {k: jnp.exp(v) for k, v in log_theta.items()}
flat = cm.flatten(params)
lls = jax.vmap(lambda x: cm.ll_fn(x, flat))(data)
return -jnp.mean(lls)
opt = optax.adam(0.05)
state = opt.init(log_theta)
loss_and_grad = jax.jit(jax.value_and_grad(neg_mean_ll))
for step in range(300):
loss, grads = loss_and_grad(log_theta)
updates, state = opt.update(grads, state)
log_theta = optax.apply_updates(log_theta, updates)
if step % 60 == 0:
print(f"step {step:3d} NLL = {float(loss):.4f}")
log_theta = {k: jnp.log(jnp.array(1.0)) for k in true_params} # init theta = 1
def neg_mean_ll(log_theta):
params = {k: jnp.exp(v) for k, v in log_theta.items()}
flat = cm.flatten(params)
lls = jax.vmap(lambda x: cm.ll_fn(x, flat))(data)
return -jnp.mean(lls)
opt = optax.adam(0.05)
state = opt.init(log_theta)
loss_and_grad = jax.jit(jax.value_and_grad(neg_mean_ll))
for step in range(300):
loss, grads = loss_and_grad(log_theta)
updates, state = opt.update(grads, state)
log_theta = optax.apply_updates(log_theta, updates)
if step % 60 == 0:
print(f"step {step:3d} NLL = {float(loss):.4f}")
step 0 NLL = -5.2357
step 60 NLL = -9.1465
step 120 NLL = -9.1787
step 180 NLL = -9.1787
step 240 NLL = -9.1787
5. Recovered parameters¶
In [5]:
Copied!
fitted = {k: float(jnp.exp(v)) for k, v in log_theta.items()}
print(f"{'parameter':<14}{'true':>8}{'fitted':>10}")
for k in true_params:
print(f"{k:<14}{true_params[k]:>8.2f}{fitted[k]:>10.3f}")
fitted = {k: float(jnp.exp(v)) for k, v in log_theta.items()}
print(f"{'parameter':<14}{'true':>8}{'fitted':>10}")
for k in true_params:
print(f"{k:<14}{true_params[k]:>8.2f}{fitted[k]:>10.3f}")
parameter true fitted theta_root 2.00 1.958 theta_sector 5.00 4.840