In this episode, I will apply the frequentist and Bayesian approaches to a very basic discrete choice problem. No prices, no endogeneity, no random coefficients. Instead, we explore one of the benefits that come with the Bayesian approach: a lighter distributional assumption on the distribution of the unobserved taste shocks.
Discrete Choice Basics¶
A crowd of individuals can choose between $J$ options or an outside option. The choice set is $\mathcal{J} = \{0, 1, \dots, J\}$. Choosing option $j$ brings utility: \begin{align*} u_{ij} &= v_j + \varepsilon_{ij} \end{align*}
Where $v_j$ can be thought of as the objective/systematic value of alternative $j$, while $\varepsilon_{ij}$ reflect the taste shock encountered by individual $i$ when facing option $j$.
The Frequentist Approach: Maximum Likelihood Estimation¶
The approach typically followed consists in:
- Making an assumption on the distribution of the unobserved component $\varepsilon_{ij}$
- Write the likelihood function, and maximize it to recover the vector of systematic utilities $\widehat{v}$
Assumptions are generally:
- Logit: $\varepsilon_{ij} \sim_{\text{iid}} \text{Gumbel}$
- Probit: $\varepsilon_{ij} \sim_{\text{iid}} \mathcal{N}(0, 1)$
import jax
import jax.numpy as jnp
import blackjax #Bayesian samplers
import optax #optimization library for JAX
import math
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
#Data Generating Process
N = 10_000 #10,000 individuals: the data exercise we're about to simulate has little identifying variation (only the market shares)
#So we need to have a lot of observations so each alternative gets chosen enough.
J = 5 #5 alternatives
key = jax.random.key(123)
eps_logit = jax.random.gumbel(key, shape=(N,J+1))
v = jnp.array([0.0, 1.0, 2.0, 2.2, 2.5, 1.5])[None, :] #we specify this as a our vector of systematic utilities.
#the first entry is for the outside option
U_logit = v + eps_logit
### Our choice data vectors is a 1000 x 1 vector containing the choice recorded for each individual.
### We assume implicitly that all agents know the set of options available to them perfectly.
### The econometrician will only observe the choice data.
choice_logit = jnp.argmax(U_logit, axis=1)
### For computational ease, it is nice to record choice as an index rather than a value: one-hot encoding
choice_logit_oh = jnp.eye(J+1)[choice_logit]
Under the logit shocks, a well-known result allows us to write the probability that any individual $i$ chooses alternative $j$ in closed form: \begin{align*} \text{Pr}(i \text{ chooses } j) \equiv s_j &= \frac{\exp(v_j)}{1 + \sum_{k=1}^J \exp(v_k)} \end{align*}
Where we have normalized the utility of the outside option to $0$: $v_0 = 0$. In fact, $s_j$ represents economically the market share of alternative $j$. This allows us to write the likelihood function. Write $d_i$ for the recorded choice of agent $i$. The MLE solves: \begin{align*} \max_v &\frac{1}{N} \sum_{i=1}^N \log\Big(\widehat{\text{Pr}}(i \text{ chooses } d_i \mid v)\Big)\\ &\frac{1}{N} \sum_{i=1}^N \log\Bigg(\frac{\exp(v_j)}{1 + \sum_{k=1}^J \exp(v_k)}\Bigg) \end{align*}
@jax.jit
def neg_log_likelihood_logit(v, choices_oh):
v = jnp.concatenate([jnp.array([0.0]), v]) #add the outside option
log_probs = jax.nn.log_softmax(v)
return -jnp.mean(jnp.sum(choices_oh * log_probs, axis=-1))
def optimize_adam(
f,
init_params,
args=(),
kwargs=None,
learning_rate=1e-3,
max_steps=1_000,
grad_tol=1e-6,
return_history=True,
):
if kwargs is None:
kwargs = {}
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(init_params)
params = init_params
history = []
loss_and_grad_fn = jax.value_and_grad(lambda p: f(p, *args, **kwargs))
@jax.jit
def update_step(params, opt_state):
loss, grads = loss_and_grad_fn(params)
grad_norm = jnp.linalg.norm(grads)
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss, grad_norm
for step in range(max_steps):
params, opt_state, loss, grad_norm = update_step(params, opt_state)
if return_history:
history.append((loss, grad_norm))
if grad_norm <= grad_tol:
break
if return_history:
return params, jnp.array(history)
return params
init_v = jnp.ones(J)
v_hat, history = optimize_adam(
neg_log_likelihood_logit,
init_v,
args=(choice_logit_oh,),
learning_rate=0.05,
max_steps=10_000,
)
print(f'v MLE obtained by numerical optimization: {v_hat}')
v MLE obtained by numerical optimization: [0.962938 2.0065932 2.1694036 2.479695 1.4999783]
Now we compute the standard errors, where $d$ below stands for choice data. \begin{align*} \text{SE}(\widehat{\theta}) &= \sqrt{\widehat{\text{Var}}(\widehat{v})}\\ &= \sqrt{\Bigg(\frac{d^2}{dv dv'}\mathcal{L}(v; d)\Big |_{v = \widehat{v}}\Bigg)^{-1}} \end{align*}
H = jax.hessian(
lambda theta: neg_log_likelihood_logit(theta, choice_logit_oh)
)(v_hat)
cov_hat = jnp.linalg.inv(H) / N
se_hat = jnp.sqrt(jnp.diag(cov_hat))
print("SE(v_hat):", se_hat)
SE(v_hat): [0.07075374 0.06410967 0.06353633 0.06266137 0.06656851]
The Bayesian Approach: NUTS¶
Since we already explored many different samplers in the first episode of this series, this time I will focus on only one sampler, NUTS with fine-tuned hyperparameters. First, we need to write our log-density of the posterior distribution. Let's specify a wide prior for v values: \begin{align*} v_j \sim \mathcal{N}(0, \sigma^2) \qquad \sigma^2=100 \end{align*}
Using Bayes' law: \begin{align*} \pi(v|d) &\propto L(v; d) \pi(v) \end{align*}
So, the log-density is: \begin{align*} \log(\pi(v|d)) = \mathcal{L}(v; d) + \sum_{j\in \mathcal{J}} \log(\pi(v_j)) \end{align*}
Where: \begin{align*} \pi(v_j) &= \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{v_j^2}{2\sigma^2}}\\ \log(\pi(v_j)) &= -\frac{1}{2}\log(\sqrt{\pi}\sigma) - \frac{v_j^2}{2\sigma^2}\\ \sum_{j \in \mathcal{J}} \log(\pi(v_j)) &= -J\frac{1}{2}\log(\sqrt{\pi}\sigma) - \frac{1}{2} \sum_{j \in \mathcal{J}} \Big(\frac{v_j}{\sigma} \Big)^2 \end{align*}
Note that the same way we drop the denominator in Bayes' law, the constant here is irrelevant, so we drop it from our function.
@jax.jit
def log_posterior_v(v, X, sigma_prior):
log_likelihood = -N*neg_log_likelihood_logit(v, choices_oh=X) #multiplied by N because our original
#likelihood function computed the mean likelihood
log_prior = -0.5 * jnp.sum((v / sigma_prior) ** 2)
return log_likelihood + log_prior
def run_blackjax_chain(rng_key, kernel, initial_state, num_samples):
keys = jax.random.split(rng_key, num_samples)
@jax.jit
def inference_loop(state, keys):
def one_step(state, key):
state, info = kernel(key, state)
return state, (state.position, info)
return jax.lax.scan(one_step, state, keys)
final_state, (positions, infos) = inference_loop(initial_state, keys)
return positions, infos, final_state
@jax.jit
def logdensity_fn(v):
return log_posterior_v(v, choice_logit_oh, 10)
key_warmup, key_sample = jax.random.split(key)
num_warmup = 1_000
num_draws = 30_000
adapt = blackjax.window_adaptation(
blackjax.nuts,
logdensity_fn,
target_acceptance_rate=0.80,
)
adaptation_result, warmup_info = adapt.run(
key_warmup,
init_v,
num_warmup,
)
adapted_state = adaptation_result.state
tuned_parameters = adaptation_result.parameters
print("Tuned parameters:")
print(tuned_parameters)
blackjax_nuts_adapted = blackjax.nuts(
logdensity_fn,
**tuned_parameters,
)
v_samples_blackjax_nuts_adapted, info_blackjax_nuts_adapted, final_state_blackjax_nuts_adapted = run_blackjax_chain(
key_sample,
blackjax_nuts_adapted.step,
adapted_state,
num_draws,
)
Tuned parameters:
{'step_size': Array(0.26106367, dtype=float32, weak_type=True), 'inverse_mass_matrix': Array([0.00464145, 0.00419285, 0.00378266, 0.00384861, 0.00468612], dtype=float32)}
print(jnp.mean(v_samples_blackjax_nuts_adapted, axis=0))
print(jnp.std(v_samples_blackjax_nuts_adapted, axis=0))
[0.96365964 2.0075467 2.170246 2.48079 1.5006465 ] [0.07085792 0.06392232 0.06330264 0.06249685 0.0662698 ]
param_names = [r"$v_1$", r"$v_2$", r"$v_3$", r"$v_4$", r"$v_5$"]
df = pd.DataFrame(v_samples_blackjax_nuts_adapted, columns=param_names)
df_long = df.melt(
var_name="Parameter",
value_name="Posterior draw"
)
# Force the ordering explicitly
df_long["Parameter"] = pd.Categorical(
df_long["Parameter"],
categories=param_names,
ordered=True
)
fig, ax = plt.subplots(figsize=(13, 4))
sns.histplot(
data=df_long,
x="Posterior draw",
hue="Parameter",
hue_order=param_names,
bins=250,
kde=True,
stat="density",
common_norm=False,
#element="step",
fill=True,
ax=ax,
)
ax.set_title("Posterior distributions from BlackJax NUTS")
ax.set_xlabel(r"Utility value relative to $v_0=0$")
ax.set_ylabel("Posterior density")
sns.move_legend(
ax,
"center left",
bbox_to_anchor=(1.01, 0.5),
title="Parameter",
frameon=False,
)
fig.tight_layout()
plt.show()
We recover meaningful parameter distributions, with distributions that match almost exactly the frequentist estimates. Here the notion of a "True" v is a bit meaningless, since Bayesian estimation recovers the posterior distribution of v, not some true, fixed value of the parameter. But the mean and std of the posterior distributions match very closely the one reconstructed by the frequentist estimates.