Skip to content

API reference

Generated from the docstrings. The everyday surface is small: define a copula family, compile a model, then evaluate or sample.

Defining a copula

A family subclasses Copula — usually via the @copula decorator for dataclass-style parameters — and implements generator. Everything else (density, inverse, sampling) is derived automatically. Call a copula instance on its children to build the model tree.

copula

copula(cls)

Class decorator for defining a copula family with dataclass-style parameters.

Generates an __init__ from the type-annotated class attributes, makes the class inherit from Copula, and registers the parameters — while preserving every user-defined method. Instances accept positional or keyword arguments.

Examples:

@copula
class Clayton:
    theta: float

    def generator(self, u):
        return (1.0 + u) ** (-1.0 / self.theta)

c1 = Clayton(2.0)        # positional
c2 = Clayton(theta=2.0)  # keyword

Copula

Base class for Archimedean copulas in \(\varphi\)-notation.

Subclasses must implement generator(u), which returns \(\varphi(u)\), where \(u \in (0, 1]\) and \(\varphi:(0,1] \to [0,\infty)\) is strictly decreasing and convex with \(\varphi(1)=0\).

Methods:

Name Description
generator

Return the Archimedean generator \(\varphi(u)\).

generator

generator(u: Array) -> Array

Return the Archimedean generator \(\varphi(u)\).

Subclasses must implement this. \(\varphi\) maps \((0,1] \to [0,\infty)\), is strictly decreasing and convex, with \(\varphi(1)=0\). Everything else (density, inverse, sampling) is derived from it automatically.

marginal

marginal(dist: Any, *, obs: Any, censored: bool = False) -> Leaf

DSL primitive for leaf marginals.

Parameters:

Name Type Description Default
dist Any

A (tfp) distribution object with a quantile method.

required
obs Any

An indexing expression into the obs placeholder, e.g. obs[i, j, k].

required
censored bool

If True, the variable is censored (contributes to the copula argument but the density does not differentiate it).

False

Returns:

Type Description
Leaf

A Leaf annotated with (dist, obs_index, censored).

Compiling and evaluating

compile_model

compile_model(model_fn: Callable, *, template: Any, method: str = 'auto', survival: bool = False, with_censored_mask: bool = False, ils_method: str = 'cohen', ils_params: Optional[dict] = None, log_jet: bool = False) -> CompiledModel

Compile a copula composition function into a reusable evaluator.

The compilation runs the user function once under oryx harvest to discover the structural graph (with sentinel template values), builds a flattener that maps parameter pytrees to the canonical flat array, constructs the appropriate log-likelihood evaluator (bell / single_layer / nested), and JIT-compiles it.

Reuse the returned :class:CompiledModel across many parameter sweeps — JAX caches the compiled program by function identity, so repeated calls with different values hit cache. The historical pitfall (set_params rebuilds the graph object every call, invalidating cache) is avoided by construction.

Parameters:

Name Type Description Default
model_fn Callable

A function (params, u) -> Node returning a Copula call that produces the structural graph.

required
template Any

A parameter pytree whose structure matches what model_fn expects. Values are used only to drive oryx's tracer; pass any sensible defaults (e.g. {"theta": 1.0}).

required
method str

Density-evaluation backend:

  • "auto" (default): "bell" for depth > 1, else "single_layer".
  • "bell": polynomial-powering scan (recommended for nested copulas; the only backend that supports with_censored_mask).
  • "single_layer": closed-form unnested formula (flat copulas only).
  • "integral": numerical inverse-Laplace integration — the only backend that uses ils_method / ils_params.
'auto'
survival bool

Treat marginals as survival functions S(t) = 1 - F(t) rather than CDFs. Required for the HACSurv-style joint survival likelihood.

False
with_censored_mask bool

If True, the returned ll_fn takes a third censored_mask argument so callers can vmap per-observation masks. Forces method="bell" since the other backends do not support dynamic censoring.

False
ils_method str

Inverse-Laplace solver used only by method="integral" — e.g. "cohen" (default), "euler", "post_widder". Ignored by "bell" and "single_layer", which assemble the density analytically and never invert a Laplace transform.

'cohen'
ils_params Optional[dict]

Extra params for that solver; likewise only used by method="integral" and ignored by the other backends.

None

Returns:

Name Type Description
A CompiledModel

class:CompiledModel holding the frozen graph, flattener,

CompiledModel

and jit'd log-likelihood.

CompiledModel dataclass

Immutable compiled representation of a nested-copula model.

Built by :func:compile_model. Holds the structural graph, the flattening function for parameter pytrees, and a jit'd log-likelihood evaluator. Reuse a single CompiledModel across parameter sweeps to avoid repeated XLA compilation; only the input values change, so the cache hits.

Attributes:

Name Type Description
graph Any

The frozen :class:Node tree describing the copula composition. Read-only handle for downstream tooling (visualisation, structural diagnostics).

param_shape Tuple[int, ...]

Expected flat parameter shape, e.g. (2,) for a two-parameter model. Useful for sanity checks before calling ll_fn with raw arrays.

with_censored_mask bool

True iff this compilation supports a per-observation censoring mask as a third argument to ll_fn.

survival bool

True iff this compilation interprets observations as survival data — the copula models the joint survival S(T) = 1 - F(T). Stored here so sample() can apply the matching F^{-1}(1 - U) inversion at the leaves and return data drawn from the same joint distribution that the likelihood scores.

Methods:

Name Description
eval

Evaluate the log-likelihood at params for a single

sample

Generate n samples from the copula at params.

flatten

Convert a parameter pytree into the canonical flat array.

ll_fn

Direct handle on the jit'd log-likelihood.

visualize
as_networkx

eval

eval(obs: Array, params: Any, *, censored_mask=None) -> Array

Evaluate the log-likelihood at params for a single observation (or a leading-batched array).

Convenience wrapper that flattens params and dispatches to the underlying jit'd ll_fn.

sample

sample(key: Array, n: int, params: Any, *, post_widder_k: int = 8, max_cdf_x: float = 1000000.0, method: str = 'marshall_olkin') -> Array

Generate n samples from the copula at params.

Respects the survival flag set at compile time — when survival=True the leaf inversion uses F^{-1}(1 - U) so the samples are drawn from the same joint distribution that the likelihood scores.

flatten

flatten(params: Any) -> Array

Convert a parameter pytree into the canonical flat array.

ll_fn

ll_fn(*args, **kwargs)

Direct handle on the jit'd log-likelihood.

Signature matches the compilation

with_censored_mask=False: ll_fn(obs, flat_params) with_censored_mask=True: ll_fn(obs, flat_params, censored_mask)

Use this when you want raw control for jax.value_and_grad, jax.hessian, jax.vmap, etc. Prefer :meth:eval for one-shot calls with a parameter pytree.

visualize

visualize(*, include_leaves: bool = True, layout: str = 'hierarchical', with_labels: bool = True, graphviz_prog: str = 'dot', annotations=None, ax=None)

as_networkx

as_networkx(include_leaves: bool = True, annotations=None)

Composition registry

For nested copulas that mix families, register a closed-form composition \(\psi_{\text{outer}}^{-1} \circ \psi_{\text{inner}}\), or choose the fallback strategy used when none is registered.

register_composition

register_composition(outer_cls: type, inner_cls: type, fn: Callable)

Register a stable closed-form composition h(t) = psi_outer_inv(psi_inner(t)).

Parameters:

Name Type Description Default
outer_cls type

The outer (parent) copula class.

required
inner_cls type

The inner (child) copula class.

required
fn Callable

A JAX-traceable function fn(t, parent_cop, child_cop) -> scalar.

required

set_composition_fallback

set_composition_fallback(mode: str)

Set the fallback method when no registered composition exists.

Parameters:

Name Type Description Default
mode str

"implicit" (O(d_c³), stable, default) or "direct_jet" (O(d_c²), may overflow at extreme values)

required

Configuration

set_stable_log

set_stable_log(enabled: bool) -> None

Enable higher-order-stable log for the Bell polynomial density.

When enabled=True, :func:_poly_power_alpha routes its three jnp.log calls through :func:acopula._stable_log.safe_log, which uses a custom JVP rule to keep jax.hessian / jacfwd(jacrev) finite for Taylor coefficients that reach ~1e-200 (e.g. nested Clayton at Kaplan-Meier pseudo-observations near the marginal floor). Forward value and first derivative are unchanged.

Default is False. Call once at startup, before constructing the model.

set_compile_cache_dir

set_compile_cache_dir(path: str | PathLike, *, min_compile_time_seconds: float = 1.0) -> str

Enable JAX persistent compile cache at path.

Creates the directory if it does not exist. Sets the three JAX config knobs required for the cache to actually capture the slow compiles:

jax_compilation_cache_dir                    (the location)
jax_persistent_cache_min_entry_size_bytes    (0: cache everything)
jax_persistent_cache_min_compile_time_secs   (only compiles
    slower than this threshold get written)

The default threshold of 1 second keeps the cache small while still capturing the Frank/MIMIC compiles we care about.

Must be called before the first jax.jit / jax.pmap / jax.vmap traces through your code; call it at program start-up, right after importing jax.

Returns:

Type Description
str

The absolute cache directory path (useful for logging).

Advanced: custom generator precision

Most families never touch these — they are optional override hooks on Copula for families whose high-order derivatives suffer float64 cancellation (e.g. AMH past d≈30). Each returns None by default, which selects Taylor-mode AD (jet). generator_taylor_coefficients (and its log-space variant log_generator_taylor_coefficients) is the recommended, general hook, used by the Bell density and the nesting composition. log_generator_kth_derivative is a legacy hook consulted only by the flat single-layer likelihood path.

generator_inv

generator_inv(t: Array) -> Array

generator_taylor_coefficients

generator_taylor_coefficients(t: Array, k_max: int) -> Array

Return Taylor coefficients \([\psi(t),\, \psi'(t)/1!,\, \ldots,\, \psi^{(k_{\max})}(t)/k_{\max}!]\).

Override this when the closed-form \(\psi\) has a numerically stable series representation that avoids the cancellations Taylor-mode AD on the closed form would produce at high derivative order. For example, AMH's \(\psi\) is the Laplace transform of \(\mathrm{Geometric}(1-\theta)\), so all Taylor coefficients are sums of same-sign terms (no cancellation):

\[\psi^{(k)}(t) = (-1)^k \sum_{x=1}^{\infty} x^k\,(1-\theta)\,\theta^{x-1}\,e^{-tx}.\]

The framework consults this hook everywhere it would otherwise call jet_array.jet on self.generator to get a Taylor expansion: the root density (bell._root_assembly) and the nesting composition (compose.compute_composition_taylor). Returning None (the default) tells those call sites to fall back to the jet path, which is correct but loses precision when \(\psi\)'s derivatives have large alternating-sign cancellations (e.g. AMH past \(d \approx 30\) in float64).

Parameters:

Name Type Description Default
t Array

scalar input point.

required
k_max int

highest derivative order to compute (inclusive); the returned array has shape (k_max + 1,).

required

Returns:

Type Description
Array

JAX array of shape (k_max + 1,) containing the Taylor

Array

coefficients, or None to fall back to jet.

log_generator_taylor_coefficients

log_generator_taylor_coefficients(t: Array, k_max: int) -> Optional[Tuple[Array, Array]]

Log-space variant of :meth:generator_taylor_coefficients.

Return (sign, log_abs), two arrays of shape (k_max + 1,) with

\[\psi^{(k)}(t) / k! \;=\; \mathrm{sign}[k]\,\exp(\mathrm{log\_abs}[k]).\]

The Bell density consumes the Taylor coefficients as (sign, log|c|) pairs internally, so a family whose normalized coefficients are computed in the log domain (e.g. via logsumexp over a frailty series) should return them here directly rather than through the linear :meth:generator_taylor_coefficients — that avoids exponentiating to a possibly-overflowing float64 only for the framework to take its log again, extending the usable derivative order at high d.

The Bell path prefers this hook when it returns non-None, then falls back to the linear hook, then to jet. Returning None (the default) selects that fallback chain.

Parameters:

Name Type Description Default
t Array

scalar input point.

required
k_max int

highest derivative order (inclusive).

required

Returns:

Type Description
Optional[Tuple[Array, Array]]

(sign, log_abs) arrays of shape (k_max + 1,), or None.

log_generator_kth_derivative

log_generator_kth_derivative(t: Array, k: int) -> Array

Return \(\log|\psi^{(k)}(t)|\), the log absolute \(k\)-th derivative.

Override this with a closed-form expression to bypass jet for specific families. For example, Clayton can use

\[\log\Gamma(k + 1/\theta) - \log\Gamma(1/\theta) - (k + 1/\theta)\,\log(1+t).\]

The default returns None, which tells the framework to use Taylor-mode AD via jet.