API reference¶
Generated from the docstrings. The everyday surface is small: define a copula family, compile a model, then evaluate or sample.
Defining a copula¶
A family subclasses Copula — usually via the @copula decorator for
dataclass-style parameters — and implements generator. Everything else
(density, inverse, sampling) is derived automatically. Call a copula instance
on its children to build the model tree.
copula
¶
copula(cls)
Class decorator for defining a copula family with dataclass-style parameters.
Generates an __init__ from the type-annotated class attributes,
makes the class inherit from Copula, and registers the parameters —
while preserving every user-defined method. Instances accept positional
or keyword arguments.
Examples:
@copula
class Clayton:
theta: float
def generator(self, u):
return (1.0 + u) ** (-1.0 / self.theta)
c1 = Clayton(2.0) # positional
c2 = Clayton(theta=2.0) # keyword
Copula
¶
Base class for Archimedean copulas in \(\varphi\)-notation.
Subclasses must implement generator(u), which returns \(\varphi(u)\),
where \(u \in (0, 1]\) and \(\varphi:(0,1] \to [0,\infty)\) is strictly
decreasing and convex with \(\varphi(1)=0\).
Methods:
| Name | Description |
|---|---|
generator |
Return the Archimedean generator \(\varphi(u)\). |
generator
¶
generator(u: Array) -> Array
Return the Archimedean generator \(\varphi(u)\).
Subclasses must implement this. \(\varphi\) maps \((0,1] \to [0,\infty)\), is strictly decreasing and convex, with \(\varphi(1)=0\). Everything else (density, inverse, sampling) is derived from it automatically.
marginal
¶
marginal(dist: Any, *, obs: Any, censored: bool = False) -> Leaf
DSL primitive for leaf marginals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist
|
Any
|
A (tfp) distribution object with a |
required |
obs
|
Any
|
An indexing expression into the obs placeholder, e.g. |
required |
censored
|
bool
|
If True, the variable is censored (contributes to the copula argument but the density does not differentiate it). |
False
|
Returns:
| Type | Description |
|---|---|
Leaf
|
A Leaf annotated with (dist, obs_index, censored). |
Compiling and evaluating¶
compile_model
¶
compile_model(model_fn: Callable, *, template: Any, method: str = 'auto', survival: bool = False, with_censored_mask: bool = False, ils_method: str = 'cohen', ils_params: Optional[dict] = None, log_jet: bool = False) -> CompiledModel
Compile a copula composition function into a reusable evaluator.
The compilation runs the user function once under oryx harvest to
discover the structural graph (with sentinel template values),
builds a flattener that maps parameter pytrees to the canonical
flat array, constructs the appropriate log-likelihood evaluator
(bell / single_layer / nested), and JIT-compiles it.
Reuse the returned :class:CompiledModel across many parameter
sweeps — JAX caches the compiled program by function identity, so
repeated calls with different values hit cache. The historical
pitfall (set_params rebuilds the graph object every call,
invalidating cache) is avoided by construction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_fn
|
Callable
|
A function |
required |
template
|
Any
|
A parameter pytree whose structure matches what
|
required |
method
|
str
|
Density-evaluation backend:
|
'auto'
|
survival
|
bool
|
Treat marginals as survival functions
|
False
|
with_censored_mask
|
bool
|
If True, the returned |
False
|
ils_method
|
str
|
Inverse-Laplace solver used only by
|
'cohen'
|
ils_params
|
Optional[dict]
|
Extra params for that solver; likewise only used by
|
None
|
Returns:
| Name | Type | Description |
|---|---|---|
A |
CompiledModel
|
class: |
CompiledModel
|
and jit'd log-likelihood. |
CompiledModel
dataclass
¶
Immutable compiled representation of a nested-copula model.
Built by :func:compile_model. Holds the structural graph, the
flattening function for parameter pytrees, and a jit'd
log-likelihood evaluator. Reuse a single CompiledModel across
parameter sweeps to avoid repeated XLA compilation; only the input
values change, so the cache hits.
Attributes:
| Name | Type | Description |
|---|---|---|
graph |
Any
|
The frozen :class: |
param_shape |
Tuple[int, ...]
|
Expected flat parameter shape, e.g. |
with_censored_mask |
bool
|
True iff this compilation supports a
per-observation censoring mask as a third argument to
|
survival |
bool
|
True iff this compilation interprets observations as
survival data — the copula models the joint survival
|
Methods:
| Name | Description |
|---|---|
eval |
Evaluate the log-likelihood at |
sample |
Generate |
flatten |
Convert a parameter pytree into the canonical flat array. |
ll_fn |
Direct handle on the jit'd log-likelihood. |
visualize |
|
as_networkx |
|
eval
¶
eval(obs: Array, params: Any, *, censored_mask=None) -> Array
Evaluate the log-likelihood at params for a single
observation (or a leading-batched array).
Convenience wrapper that flattens params and dispatches to
the underlying jit'd ll_fn.
sample
¶
sample(key: Array, n: int, params: Any, *, post_widder_k: int = 8, max_cdf_x: float = 1000000.0, method: str = 'marshall_olkin') -> Array
Generate n samples from the copula at params.
Respects the survival flag set at compile time — when
survival=True the leaf inversion uses F^{-1}(1 - U) so
the samples are drawn from the same joint distribution that
the likelihood scores.
ll_fn
¶
ll_fn(*args, **kwargs)
Direct handle on the jit'd log-likelihood.
Signature matches the compilation
with_censored_mask=False: ll_fn(obs, flat_params)
with_censored_mask=True: ll_fn(obs, flat_params, censored_mask)
Use this when you want raw control for jax.value_and_grad,
jax.hessian, jax.vmap, etc. Prefer :meth:eval for
one-shot calls with a parameter pytree.
visualize
¶
visualize(*, include_leaves: bool = True, layout: str = 'hierarchical', with_labels: bool = True, graphviz_prog: str = 'dot', annotations=None, ax=None)
Composition registry¶
For nested copulas that mix families, register a closed-form composition \(\psi_{\text{outer}}^{-1} \circ \psi_{\text{inner}}\), or choose the fallback strategy used when none is registered.
register_composition
¶
register_composition(outer_cls: type, inner_cls: type, fn: Callable)
Register a stable closed-form composition h(t) = psi_outer_inv(psi_inner(t)).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
outer_cls
|
type
|
The outer (parent) copula class. |
required |
inner_cls
|
type
|
The inner (child) copula class. |
required |
fn
|
Callable
|
A JAX-traceable function fn(t, parent_cop, child_cop) -> scalar. |
required |
set_composition_fallback
¶
set_composition_fallback(mode: str)
Set the fallback method when no registered composition exists.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mode
|
str
|
"implicit" (O(d_c³), stable, default) or "direct_jet" (O(d_c²), may overflow at extreme values) |
required |
Configuration¶
set_stable_log
¶
set_stable_log(enabled: bool) -> None
Enable higher-order-stable log for the Bell polynomial density.
When enabled=True, :func:_poly_power_alpha routes its three jnp.log
calls through :func:acopula._stable_log.safe_log, which uses a custom JVP
rule to keep jax.hessian / jacfwd(jacrev) finite for Taylor
coefficients that reach ~1e-200 (e.g. nested Clayton at Kaplan-Meier
pseudo-observations near the marginal floor). Forward value and first
derivative are unchanged.
Default is False. Call once at startup, before constructing the model.
set_compile_cache_dir
¶
set_compile_cache_dir(path: str | PathLike, *, min_compile_time_seconds: float = 1.0) -> str
Enable JAX persistent compile cache at path.
Creates the directory if it does not exist. Sets the three JAX config knobs required for the cache to actually capture the slow compiles:
jax_compilation_cache_dir (the location)
jax_persistent_cache_min_entry_size_bytes (0: cache everything)
jax_persistent_cache_min_compile_time_secs (only compiles
slower than this threshold get written)
The default threshold of 1 second keeps the cache small while still capturing the Frank/MIMIC compiles we care about.
Must be called before the first jax.jit / jax.pmap /
jax.vmap traces through your code; call it at program
start-up, right after importing jax.
Returns:
| Type | Description |
|---|---|
str
|
The absolute cache directory path (useful for logging). |
Advanced: custom generator precision¶
Most families never touch these — they are optional override hooks on Copula
for families whose high-order derivatives suffer float64 cancellation (e.g. AMH
past d≈30). Each returns None by default, which selects Taylor-mode AD
(jet). generator_taylor_coefficients (and its log-space variant
log_generator_taylor_coefficients) is the recommended, general hook, used by
the Bell density and the nesting composition. log_generator_kth_derivative is
a legacy hook consulted only by the flat single-layer likelihood path.
generator_taylor_coefficients
¶
generator_taylor_coefficients(t: Array, k_max: int) -> Array
Return Taylor coefficients \([\psi(t),\, \psi'(t)/1!,\, \ldots,\, \psi^{(k_{\max})}(t)/k_{\max}!]\).
Override this when the closed-form \(\psi\) has a numerically stable series representation that avoids the cancellations Taylor-mode AD on the closed form would produce at high derivative order. For example, AMH's \(\psi\) is the Laplace transform of \(\mathrm{Geometric}(1-\theta)\), so all Taylor coefficients are sums of same-sign terms (no cancellation):
The framework consults this hook everywhere it would otherwise
call jet_array.jet on self.generator to get a Taylor
expansion: the root density (bell._root_assembly) and the
nesting composition (compose.compute_composition_taylor).
Returning None (the default) tells those call sites to
fall back to the jet path, which is correct but loses
precision when \(\psi\)'s derivatives have large alternating-sign
cancellations (e.g. AMH past \(d \approx 30\) in float64).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
Array
|
scalar input point. |
required |
k_max
|
int
|
highest derivative order to compute (inclusive); the
returned array has shape |
required |
Returns:
| Type | Description |
|---|---|
Array
|
JAX array of shape |
Array
|
coefficients, or |
log_generator_taylor_coefficients
¶
log_generator_taylor_coefficients(t: Array, k_max: int) -> Optional[Tuple[Array, Array]]
Log-space variant of :meth:generator_taylor_coefficients.
Return (sign, log_abs), two arrays of shape (k_max + 1,) with
The Bell density consumes the Taylor coefficients as (sign, log|c|)
pairs internally, so a family whose normalized coefficients are
computed in the log domain (e.g. via logsumexp over a frailty
series) should return them here directly rather than through the linear
:meth:generator_taylor_coefficients — that avoids exponentiating to a
possibly-overflowing float64 only for the framework to take its log
again, extending the usable derivative order at high d.
The Bell path prefers this hook when it returns non-None, then falls
back to the linear hook, then to jet. Returning None (the default)
selects that fallback chain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
Array
|
scalar input point. |
required |
k_max
|
int
|
highest derivative order (inclusive). |
required |
Returns:
| Type | Description |
|---|---|
Optional[Tuple[Array, Array]]
|
|
log_generator_kth_derivative
¶
log_generator_kth_derivative(t: Array, k: int) -> Array
Return \(\log|\psi^{(k)}(t)|\), the log absolute \(k\)-th derivative.
Override this with a closed-form expression to bypass jet for specific families. For example, Clayton can use
The default returns None, which tells the framework to use
Taylor-mode AD via jet.