After discussing this with Yeon-Koo Che and a great presentation from Kohei Onzo, I got curious about Bayesian econometrics. They are not typically taught in an empirical IO class. In this first notebook, I will try to enhance my understanding of Bayesian methods by working the simplest example of inference I know: inferring a mean from a dataset of $0$ and $1$.¶
This notebook has benefited from the editing and coding skills of my AI-OpenClaw assistant, Hope.¶
I am using Google JAX as my preferred workflow, when writing Python code. JAX enables direct parallelization on all the cores available on the machine running the code (including the GPU), and can lead to massive performance gains. The coding style changes a bit relative to standard single-core Python, but this should not disturb the reader too much. Check https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html to learn about Google JAX, or https://jax.quantecon.org/intro.html for the economics crowd.¶
import jax
import jax.numpy as jnp
import blackjax #Bayesian samplers
import optax #optimization library for JAX
import matplotlib.pyplot as plt
import seaborn as sns
N = 1000
key = jax.random.key(123)
X = jax.random.binomial(
key,
n=1,
p=0.65,
shape=(N,)
)
### Take a look at our data sample:
print(X[:10])
[1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
Warm-Up: the frequentist approach¶
For the data-generating process: I have chosen a true parameter $p = \text{Pr}(X_i = 1) = 0.65$.
Of course, the maximum likelihood estimate in this case is available in closed form. \begin{align*} X_i \sim \text{Bernoulli}(p) \end{align*}
\begin{align*} \mathcal{L}(p|X) &= \prod_{i=1}^N \mathbb{1}\{X_i=X\} \cdot \Pr(X = X_i)\\ &= p^{\sum_i X_i} \cdot (1-p)^{N - \sum_i X_i}\\ \log\left(\mathcal{L}(p|X) \right) &= \left(\sum_i X_i\right) \log(p) + \left(N - \sum_i X_i\right) \log(1-p) \end{align*}
The MLE is obtained by solving: \begin{align*} \widehat{p}^{MLE} = \arg\max_{p} \log\left(\mathcal{L}(p|X) \right) \end{align*}
FOC yields, without surprise: \begin{align*} \widehat{p}^{MLE} &= \frac{\sum_i X_i}{N} \end{align*}
Then, the standard error is obtained as follows: \begin{align*} \text{Var}[\widehat{p}^{MLE}] &= \frac{\widehat{p}(1-\widehat{p})}{N}\\ \text{SE}[\widehat{p}^{MLE}] &= \sqrt{\frac{\widehat{p}(1-\widehat{p})}{N}} \end{align*}
p_hat_analytical = jnp.mean(X)
SE_p_hat_analytical = jnp.sqrt((p_hat_analytical*(1-p_hat_analytical))/N)
print(f'p MLE: {p_hat_analytical}')
print(f'p SE: {SE_p_hat_analytical}')
p MLE: 0.6530000567436218 p SE: 0.01505293883383274
| $\widehat{p}$ |
|---|
| 0.653 |
| (0.015) |
But let's pretend that the closed form is not available to us
@jax.jit #this decorator is meant to "just-in-time" (jit-)compile the function below, which accelerates it significantly
def neg_log_likelihood(p, X):
N = X.shape[0]
return -(jnp.sum(X)*jnp.log(p) + (N - jnp.sum(X))*jnp.log(1-p))
neg_log_likelihood(0.65, X)
Array(645.58954, dtype=float32)
A first difficulty comes from the fact that standard optimizers do not work well with domain constraints (like, in our case, $\widehat{p} \in [0, 1]$). So we need to use a transformation, and that will require using the $\delta$ method to recover standard errors.
@jax.jit
def neg_log_likelihood_transform(theta, X):
p = jax.nn.sigmoid(theta) #maps to (0, 1)
return neg_log_likelihood(p, X)
plt.plot(jnp.linspace(-5, 5, 1000), neg_log_likelihood_transform(jnp.linspace(-5.0, 5.0, 1000), X))
plt.show()
def optimize_adam(
f,
init_params,
args=(),
kwargs=None,
learning_rate=1e-3,
max_steps=1_000,
grad_tol=1e-6,
return_history=True,
):
"""A simple Adam optimizer with gradient-norm stopping."""
if kwargs is None:
kwargs = {}
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(init_params)
params = init_params
history = []
loss_and_grad = jax.value_and_grad(lambda p: f(p, *args, **kwargs))
for step in range(max_steps):
loss, grads = loss_and_grad(params)
grad_norm = jnp.linalg.norm(grads)
if return_history:
history.append((loss, grad_norm))
if grad_norm <= grad_tol:
break
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if return_history:
return params, jnp.array(history)
return params
init_theta = jnp.array(0.0)
theta_hat, history = optimize_adam(
neg_log_likelihood_transform,
init_theta,
args=(X,),
learning_rate=0.05,
max_steps=100,
)
p_hat = jax.nn.sigmoid(theta_hat)
print(f'p MLE obtained by numerical optimization: {p_hat}')
p MLE obtained by numerical optimization: 0.6533495187759399
The delta method:¶
Denote $\widehat{\theta}$ for the optimizer of the transformed likelihood. \begin{align*} \widehat{\theta} \sim \mathcal{N}(\theta, \text{Var}[\widehat{\theta}]) \end{align*}
We are interested in the distribution of $\widehat{p} = \frac{1}{1 + e^{-\widehat{\theta}}}$: \begin{align*} g(\widehat{\theta}) &= g(\theta) + g'(\theta) \cdot (\widehat{\theta} - \theta)\\ \Rightarrow g(\widehat{\theta}) &\sim \mathcal{N}\left(g(\theta), [g'(\theta)]^2 \text{Var}[\widehat{\theta}] \right) \end{align*}
So: \begin{align*} \widehat{\text{Var}}\left[ g(\widehat{\theta}) \right] &= [g'(\widehat{\theta})]^2 \widehat{\text{Var}}[\widehat{\theta}] \end{align*}
In the present case: \begin{align*} p = g(\theta) &= \frac{1}{1+e^{-\theta}}\\ g'(\theta) &= p(1-p) \end{align*}
So: \begin{align*} \text{SE}(\widehat{p}) &= \sqrt{\widehat{\text{Var}}\left[ g(\widehat{\theta}) \right]}\\ &= g'(\widehat{\theta}) \sqrt{\widehat{\text{Var}}[\widehat{\theta}]}\\ &= \widehat{p}(1-\widehat{p}) \cdot \text{SE}(\widehat{\theta}) \end{align*}
The only ingredient we still need in this is $\text{SE}(\hat{\theta})$. In our simple scalar case, it is obtained through the following formula: \begin{align*} \text{SE}(\widehat{\theta}) &= \sqrt{\widehat{\text{Var}}(\widehat{\theta})}\\ &= \sqrt{\Bigg(\frac{d^2}{d\theta^2}\mathcal{L}(\theta; X)\Big |_{\theta = \widehat{\theta}}\Bigg)^{-1}} \end{align*}
# Hessian of the negative log-likelihood with respect to theta, evaluated at theta_hat
hess_theta = jax.hessian(
lambda theta: neg_log_likelihood_transform(theta, X)
)(theta_hat)
# Variance of theta_hat is inverse observed information
var_theta = 1.0 / hess_theta
se_theta = jnp.sqrt(var_theta)
# Delta method: p = sigmoid(theta)
se_p = p_hat*(1-p_hat)*se_theta
print(f"SE(p_hat): {se_p}")
SE(p_hat): 0.015049383044242859
Which is very close to our earlier analytical result. This is heavy work to infer a mean, but it's a correct skeleton for any form of numerical optimization in econometrics using the frequentist approach. Now let us turn to a Bayesian approach.
Bayesian inference¶
Again here, this example is so simple that we can work it out analytically. In Bayesian inference, we treat $p$ as a random object. The data is made of $0$ and $1$, so a flexible choice would be to specify the prior to be the $\text{Beta}(\alpha, \beta)$ distribution:
\begin{align*} p \sim \text{Beta}(\alpha, \beta) \end{align*}
This means that the distribution of $p$ has density: \begin{align*} \pi(p) &= \frac{\Gamma(\alpha + \beta)}{\Gamma(\alpha)\Gamma(\beta)} p^{\alpha-1}(1-p)^{\beta-1} \end{align*}
Our goal is to compute the posterior distribution $\pi(p|X)$, which incorporates the information carried by the data. The posterior is obtained through Bayes-updating: \begin{align*} \pi(p|X) &= \frac{L(p; X)\pi(p)}{f(X)}\\ &\propto L(p; X)\pi(p) \end{align*}
Where $L$ is the likelihood function we derived in the frequentist segment of this notebook, $f(X)$ can be ignored, and $\pi(p)$ is the prior. In this super-simple case, we can actually obtain a closed-form for the posterior: \begin{align*} \pi(p|X) &\propto p^{\sum_i X_i}(1-p)^{N - \sum_i X_i} p^{\alpha-1}(1-p)^{\beta-1}\\ &= p^{\sum_i X_i + \alpha-1}(1-p)^{N - \sum_i X_i + \beta-1} \end{align*}
Therefore, the posterior is: \begin{align*} p|X &\sim \text{Beta}\left(\alpha + \sum_i X_i, \beta+N - \sum_i X_i\right) \end{align*}
The posterior mean is given by: \begin{align*} \text{E}[p|X] &= \frac{\alpha+\sum_i X_i}{\alpha+\beta + N} \end{align*}
This gives us some idea of how the prior actually becomes almost irrelevant as the number of observations grows large.
Let's now explore what we would have to do, if we did not have this closed form available to us. We implement three different sampling strategies:
- Random-walk Metropolis-Hastings
- Hamiltonian-MonteCarlo
- No U-Turn Sampler (NUTS)
Random-Walk Metropolis-Hastings¶
We also want to avoid the $p \in [0, 1]$ constraint, so we reparameterize: \begin{align*} p &= \frac{1}{1+e^{-\theta}} = \sigma(\theta)\\ \Leftrightarrow \theta &= \log\left( \frac{p}{1-p}\right) = \sigma^{-1}(p) \end{align*}
A linear approximation yields the transformed log-likelihood: \begin{align*} \frac{dp}{d\theta} &= p(1-p) \end{align*}
Now, we want to compute the density of the transformed variable $\pi(\theta|X)$, in a way that preserves the density of $p$:
\begin{align*} \text{Pr}(\theta \in A|X) &= \text{Pr}(p \in \sigma(A)) \quad \forall A \subseteq \mathbb{R} \end{align*}
In integral form: \begin{align*} \pi(\theta|X) d\theta &= \pi(\sigma(\theta)|X) dp \end{align*}
Using $\frac{dp}{d\theta} = p(1-p)$ and taking logs: \begin{align*} \log \pi(\theta|X) &= \log(\pi(\sigma(\theta))) + \log \Big|\frac{dp}{d\theta} \Big|\\ &= \log(\pi(\sigma(\theta))) + \log(\sigma(\theta)) + \log(1-\sigma(\theta)) \end{align*}
@jax.jit
def log_posterior_theta(theta, X, alpha_prior, beta_prior):
"""
Unnormalized log posterior for theta = logit(p).
Prior is Beta(alpha_prior, beta_prior) on p.
We sample theta, where p = sigmoid(theta).
Therefore:
log posterior(theta)
=
log likelihood(p)
+ log prior(p)
+ log Jacobian
"""
S = jnp.sum(X)
N = X.shape[0]
log_p = jax.nn.log_sigmoid(theta)
log_1mp = jax.nn.log_sigmoid(-theta)
log_likelihood = S * log_p + (N - S) * log_1mp
log_prior = (alpha_prior - 1) * log_p + (beta_prior - 1) * log_1mp
log_jacobian = log_p + log_1mp
return log_likelihood + log_prior + log_jacobian
The MH algorithm works as follows. Start from some initial guess $\theta^{(0)}$. At step $(t)$:
- Draw some random number $\varepsilon_t \sim \mathcal{N}(0, \tau^2)$
- Propose a candidate $\theta^* = \theta^{(t)} + \varepsilon^{(t)}$
- Compare likelihoods: \begin{align*} r &= \frac{\pi(\theta^*|X)}{\pi(\theta^{(t)}|X)} \end{align*}
- If $r \geq 1$, accept $\theta^*$ and set $\theta^{t+1} = \theta^*$
- If $r < 1$, accept $\theta^*$ with probability $r$, and reject otherwise
This process returns a chain of $\theta$ candidates. "Burn" some share of the first draws, which correspond to the time where the chain is converging away from the initial guess.
# ----------------------------
# Random-walk Metropolis-Hastings on theta
# ----------------------------
def random_walk_mh_theta(
log_posterior,
initial_theta,
key,
num_samples,
proposal_sd,
args=(),
):
"""
Random-walk Metropolis-Hastings on unconstrained theta.
Proposal:
theta_star = theta_current + Normal(0, proposal_sd^2)
Since the proposal is symmetric, the MH acceptance ratio is:
posterior(theta_star) / posterior(theta_current)
"""
samples = []
accepts = []
theta_current = initial_theta
log_post_current = log_posterior(theta_current, *args)
for t in range(num_samples):
key, proposal_key, accept_key = jax.random.split(key, 3)
theta_proposal = theta_current + proposal_sd * jax.random.normal(proposal_key)
log_post_proposal = log_posterior(theta_proposal, *args)
log_accept_ratio = log_post_proposal - log_post_current #log(r) = log(\pi(\theta^*|X)) - log(\pi(\theta^t|X))
accept = jnp.log(jax.random.uniform(accept_key)) < log_accept_ratio #accept status
theta_current = jnp.where(accept, theta_proposal, theta_current)
log_post_current = jnp.where(accept, log_post_proposal, log_post_current)
samples.append(theta_current)
accepts.append(accept)
return jnp.array(samples), jnp.array(accepts)
# ----------------------------
# Run MH
# ----------------------------
key, subkey = jax.random.split(key)
#I choose a Uniform distribution on [0, 1] as my prior.
alpha_prior=1
beta_prior=1
num_samples = 10_000
burn_in = 2_000
theta_samples_mh, accepts_theta = random_walk_mh_theta(
log_posterior_theta,
initial_theta=jnp.array(0.0),
key=key,
num_samples=num_samples,
proposal_sd=0.05,
args=(X, alpha_prior, beta_prior),
)
theta_samples_post = theta_samples_mh[burn_in:]
p_samples_mh = jax.nn.sigmoid(theta_samples_post) #transform back to the original p
The procedure above samples $\theta$ from the posterior distribution $\pi(\theta \mid X)$. There is a "burn-in" number of samples, during which the procedure converges; we eliminate those from our final sample.
sns.histplot(
p_samples_mh,
bins=30,
alpha=0.4,
label="Random-walk MH samples",
kde=True
)
plt.show()
This looks nice, and well centered on the "true" value 0.65. But Metropolis-Hastings is a very basic technique for Bayesian inference. The proposal process, as indicated, follows a random walk, which is potentially inefficient.
Hamiltonian MonteCarlo¶
In comparison, HMC proposes "intelligently", by measuring if a direction improves the likelihoods. Conceptually, this is closer to a gradient descent type of algorithm, with a momentum $r$. The Hamiltonian function is: \begin{align*} H(\theta, r) &= -\log(\pi(\theta|X)) + \frac{r^2}{2} \end{align*}
Start from some initial guess $(\theta^{(0)})$. Choose $\epsilon$ as the step size and $L$ as the number of steps per iteration.
At step $(t)$:
- Draw a momentum $r^{(t)} \sim \mathcal{N}(0, 1)$
- Take one half-step in momentum: \begin{align*} r \leftarrow r + \frac{\epsilon}{2} \nabla_\theta \log \pi(\theta|X) \end{align*}
- Alternate position and momentum updates using the leapfrog integrator.
- This generates a proposal $(\theta^*, r^*)$.
- Compare the Hamiltonians: \begin{align*} \Delta H &= H(\theta^{(t)}, r^{(t)}) - H(\theta^*, r^*) \end{align*}
- Accept the proposal with probability \begin{align*} \alpha = \min\left\{1, \exp(\Delta H)\right\}. \end{align*} Equivalently, draw $u \sim \text{Uniform}(0,1)$ and accept when \begin{align*} \log u < H(\theta^{(t)}, r^{(t)}) - H(\theta^*, r^*). \end{align*}
This process returns a chain of $\theta$ candidates. "Burn" some share of the first draws, which correspond to the time where the chain is converging away from the initial guess.
def hmc_sampler(
log_posterior,
initial_theta,
key,
num_samples,
step_size,
num_leapfrog_steps,
args=(),
):
grad_log_posterior = jax.grad(lambda theta: log_posterior(theta, *args))
@jax.jit
def leapfrog_step(carry, _):
theta, r = carry
theta = theta + step_size * r
r = r + step_size * grad_log_posterior(theta)
return (theta, r), None
@jax.jit
def one_hmc_step(theta_current, key):
key_momentum, key_accept = jax.random.split(key)
r_current = jax.random.normal(key_momentum)
theta = theta_current
r = r_current
# Half momentum step
r = r + 0.5 * step_size * grad_log_posterior(theta)
# Full leapfrog steps
(theta, r), _ = jax.lax.scan(
leapfrog_step,
(theta, r),
xs=None,
length=num_leapfrog_steps - 1,
)
# Final position step
theta = theta + step_size * r
# Final half momentum step
r = r + 0.5 * step_size * grad_log_posterior(theta)
theta_proposal = theta
r_proposal = r
current_H = (
-log_posterior(theta_current, *args)
+ 0.5 * r_current**2
)
proposal_H = (
-log_posterior(theta_proposal, *args)
+ 0.5 * r_proposal**2
)
log_accept_ratio = current_H - proposal_H
accept = jnp.log(jax.random.uniform(key_accept)) < log_accept_ratio
theta_next = jnp.where(
accept,
theta_proposal,
theta_current,
)
return theta_next, (theta_next, accept)
keys = jax.random.split(key, num_samples)
final_theta, (theta_samples, accepts) = jax.lax.scan(
one_hmc_step,
initial_theta,
keys,
)
return theta_samples, accepts
# ----------------------------
# Run HMC
# ----------------------------
key, subkey = jax.random.split(key)
num_samples = 10_000
burn_in = 1000
theta_samples_hmc, accepts_hmc = hmc_sampler(
log_posterior_theta,
initial_theta=jnp.array(0.0),
key=key,
num_samples=num_samples,
step_size=0.05,
num_leapfrog_steps=20,
args=(X, alpha_prior, beta_prior),
)
theta_samples_hmc_post = theta_samples_hmc[burn_in:]
p_samples_hmc = jax.nn.sigmoid(theta_samples_hmc_post)
sns.histplot(
p_samples_hmc,
bins=30,
alpha=0.4,
kde=True,
label="HMC samples",
)
plt.show()
Instead of our home-made (but probably under-optimized) samplers, we may use the ones provided by the BlackJax library. This allows us to plug-and-play with different samplers very easily. First, we implement the HMC sampler provided by this library.
def run_blackjax_chain(rng_key, kernel, initial_state, num_samples):
keys = jax.random.split(rng_key, num_samples)
@jax.jit
def inference_loop(state, keys):
def one_step(state, key):
state, info = kernel(key, state)
return state, (state.position, info)
return jax.lax.scan(one_step, state, keys)
final_state, (positions, infos) = inference_loop(initial_state, keys)
return positions, infos, final_state
@jax.jit
def logdensity_fn(theta):
return log_posterior_theta(theta, X, alpha_prior, beta_prior)
key_bj_hmc, key_bj_nuts = jax.random.split(key)
initial_theta = jnp.array(0.0)
inverse_mass_matrix = jnp.array([1.0]) #we will have opportunities to explore this object in future notebooks. Ignore it for now
num_samples = 10_000
burn_in = 1_000
blackjax_hmc = blackjax.hmc(
logdensity_fn,
step_size=0.05,
inverse_mass_matrix=inverse_mass_matrix,
num_integration_steps=30,
)
blackjax_hmc_state = blackjax_hmc.init(initial_theta)
theta_samples_blackjax_hmc, info_blackjax_hmc, final_state_blackjax_hmc = run_blackjax_chain(
key_bj_hmc,
blackjax_hmc.step,
blackjax_hmc_state,
num_samples,
)
theta_samples_blackjax_hmc_post = theta_samples_blackjax_hmc[burn_in:]
p_samples_blackjax_hmc = jax.nn.sigmoid(theta_samples_blackjax_hmc_post)
On my computer, this runs almost instantaneously.
sns.histplot(p_samples_blackjax_hmc,
bins=30,
kde=True)
<Axes: ylabel='Count'>
BlackJAX NUTS without adaptation¶
NUTS chooses the trajectory length dynamically. We still need to give it a step size and an inverse mass matrix, but we no longer specify a fixed number of integration steps.
blackjax_nuts = blackjax.nuts(
logdensity_fn,
step_size=0.05,
inverse_mass_matrix=inverse_mass_matrix,
)
blackjax_nuts_state = blackjax_nuts.init(initial_theta)
theta_samples_blackjax_nuts, info_blackjax_nuts, final_state_blackjax_nuts = run_blackjax_chain(
key_bj_nuts,
blackjax_nuts.step,
blackjax_nuts_state,
num_samples,
)
theta_samples_blackjax_nuts_post = theta_samples_blackjax_nuts[burn_in:]
p_samples_blackjax_nuts = jax.nn.sigmoid(theta_samples_blackjax_nuts_post)
sns.histplot(p_samples_blackjax_nuts,
bins=30,
kde=True)
plt.show()
BlackJAX NUTS with warmup/adaptation¶
The inverse mass matrix and step size are hyper-parameters, which drives the magnitude of the momentum. The warmup is decomposed in two phases. Initially, the sampler just converges. Then, BlackJAX adapts the step size and inverse mass matrix. After warmup is over, we freeze those tuning parameters and draw posterior samples.
key_warmup, key_sample = jax.random.split(key)
num_warmup = 1_000
num_draws = 15_000
adapt = blackjax.window_adaptation(
blackjax.nuts,
logdensity_fn,
target_acceptance_rate=0.80,
)
adaptation_result, warmup_info = adapt.run(
key_warmup,
initial_theta,
num_warmup,
)
adapted_state = adaptation_result.state
tuned_parameters = adaptation_result.parameters
print("Tuned parameters:")
print(tuned_parameters)
blackjax_nuts_adapted = blackjax.nuts(
logdensity_fn,
**tuned_parameters,
)
theta_samples_blackjax_nuts_adapted, info_blackjax_nuts_adapted, final_state_blackjax_nuts_adapted = run_blackjax_chain(
key_sample,
blackjax_nuts_adapted.step,
adapted_state,
num_draws,
)
p_samples_blackjax_nuts_adapted = jax.nn.sigmoid(theta_samples_blackjax_nuts_adapted)
Tuned parameters:
{'step_size': Array(0.96027654, dtype=float32, weak_type=True), 'inverse_mass_matrix': Array([0.00414182], dtype=float32)}
sns.histplot(p_samples_blackjax_nuts_adapted,
bins=30,
kde=True)
plt.show()
Final Comparison¶
alpha_post = alpha_prior + jnp.sum(X)
beta_post = beta_prior + X.shape[0] - jnp.sum(X)
p_grid = jnp.linspace(0.599, 0.699, 1_000)
posterior_pdf = jax.scipy.stats.beta.pdf(
jnp.asarray(p_grid),
float(alpha_post),
float(beta_post),
)
plt.figure(figsize=(9, 5))
sns.kdeplot(p_samples_mh, alpha=0.75, label="Random-walk MH")
sns.kdeplot(p_samples_hmc, alpha=0.75, label="Homemade HMC")
sns.kdeplot(p_samples_blackjax_hmc, alpha=0.75, label="BlackJAX HMC")
sns.kdeplot(p_samples_blackjax_nuts_adapted, alpha=0.75, label="BlackJAX NUTS adapted")
plt.plot(p_grid, posterior_pdf, label="Exact Beta posterior")
plt.axvline(float(p_hat_analytical), linestyle="--", label="MLE")
plt.xlabel("p")
plt.ylabel("density")
plt.title("Posterior draws vs exact posterior")
plt.legend()
plt.show()