Censored survival likelihood¶
A core feature for survival analysis: leaves can be right-censored. A
censored dimension still enters the copula argument (through its survival
function) but is not differentiated in the density — acopula assembles
the correct mixed partial over only the observed dimensions.
Here: a bivariate Frank copula over two Weibull event times, where the second time is censored.
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
from acopula import compile_model, copula, marginal
@copula
class Frank:
theta: float
def generator(self, t):
return -jnp.log1p(jnp.expm1(-self.theta) * jnp.exp(-t)) / self.theta
Build the model two ways¶
Same Frank-over-Weibull model, but the second leaf is marked censored or
not. Pass survival=True so leaf inversion uses the survival convention
$F^{-1}(1-u)$. acopula enables jax_enable_x64 at import, so we use
float64 distribution parameters.
def make_model(censored_second: bool):
weib = tfp.distributions.Weibull(
concentration=jnp.float64(1.5), scale=jnp.float64(1.0))
def model(theta, t):
c = Frank(theta[0])
return c([
marginal(weib, obs=t[0]),
marginal(weib, obs=t[1], censored=censored_second),
])
return model
theta = jnp.array([4.0])
times = jnp.array([0.6, 1.2]) # the second time is the (possibly) censored one
cm_obs = compile_model(make_model(False), template=theta, survival=True)
cm_cens = compile_model(make_model(True), template=theta, survival=True)
Score the same data both ways¶
With the second event observed we get the full joint density; with it right-censored we get the (lower) partial likelihood over the first dimension only.
print(f"log-likelihood, both observed: {cm_obs.eval(times, theta):.6f}")
print(f"log-likelihood, 2nd right-censored: {cm_cens.eval(times, theta):.6f}")
log-likelihood, both observed: -1.398859
log-likelihood, 2nd right-censored: -2.268904
The dependence parameter is still learnable through the censored contribution — the gradient flows:
g = jax.grad(cm_cens.eval, argnums=1)(times, theta)
print(f"d(ll_censored)/d(theta): {g[0]:.6f}")
d(ll_censored)/d(theta): -0.224742