Bayesian Econometrics for Empirical IO -- A Journey¶

Episode 2: Multinomial Discrete Choice¶

In this episode, I will apply the frequentist and Bayesian approaches to a very basic discrete choice problem. No prices, no endogeneity, no random coefficients. Instead, we explore one of the benefits that come with the Bayesian approach: a lighter distributional assumption on the distribution of the unobserved taste shocks.

Discrete Choice Basics¶

A crowd of individuals can choose between $J$ options or an outside option. The choice set is $\mathcal{J} = \{0, 1, \dots, J\}$. Choosing option $j$ brings utility: \begin{align*} u_{ij} &= v_j + \varepsilon_{ij} \end{align*}

Where $v_j$ can be thought of as the objective/systematic value of alternative $j$, while $\varepsilon_{ij}$ reflect the taste shock encountered by individual $i$ when facing option $j$.

The Frequentist Approach: Maximum Likelihood Estimation¶

The approach typically followed consists in:

  • Making an assumption on the distribution of the unobserved component $\varepsilon_{ij}$
  • Write the likelihood function, and maximize it to recover the vector of systematic utilities $\widehat{v}$

Assumptions are generally:

  • Logit: $\varepsilon_{ij} \sim_{\text{iid}} \text{Gumbel}$
  • Probit: $\varepsilon_{ij} \sim_{\text{iid}} \mathcal{N}(0, 1)$
In [1]:
import jax
import jax.numpy as jnp
import blackjax #Bayesian samplers
import optax #optimization library for JAX

import math
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
In [2]:
#Data Generating Process

N = 10_000 #10,000 individuals: the data exercise we're about to simulate has little identifying variation (only the market shares)
            #So we need to have a lot of observations so each alternative gets chosen enough.
J = 5 #5 alternatives

key = jax.random.key(123)
eps_logit = jax.random.gumbel(key, shape=(N,J+1))

v = jnp.array([0.0, 1.0, 2.0, 2.2, 2.5, 1.5])[None, :] #we specify this as a our vector of systematic utilities.
                                                         #the first entry is for the outside option

U_logit = v + eps_logit

### Our choice data vectors is a 1000 x 1 vector containing the choice recorded for each individual.
### We assume implicitly that all agents know the set of options available to them perfectly.
### The econometrician will only observe the choice data.
choice_logit = jnp.argmax(U_logit, axis=1)

### For computational ease, it is nice to record choice as an index rather than a value: one-hot encoding
choice_logit_oh = jnp.eye(J+1)[choice_logit]

Under the logit shocks, a well-known result allows us to write the probability that any individual $i$ chooses alternative $j$ in closed form: \begin{align*} \text{Pr}(i \text{ chooses } j) \equiv s_j &= \frac{\exp(v_j)}{1 + \sum_{k=1}^J \exp(v_k)} \end{align*}

Where we have normalized the utility of the outside option to $0$: $v_0 = 0$. In fact, $s_j$ represents economically the market share of alternative $j$. This allows us to write the likelihood function. Write $d_i$ for the recorded choice of agent $i$. The MLE solves: \begin{align*} \max_v &\frac{1}{N} \sum_{i=1}^N \log\Big(\widehat{\text{Pr}}(i \text{ chooses } d_i \mid v)\Big)\\ &\frac{1}{N} \sum_{i=1}^N \log\Bigg(\frac{\exp(v_j)}{1 + \sum_{k=1}^J \exp(v_k)}\Bigg) \end{align*}

In [3]:
@jax.jit
def neg_log_likelihood_logit(v, choices_oh):
    v = jnp.concatenate([jnp.array([0.0]), v]) #add the outside option
    log_probs = jax.nn.log_softmax(v)
    return -jnp.mean(jnp.sum(choices_oh * log_probs, axis=-1))
In [4]:
def optimize_adam(
    f,
    init_params,
    args=(),
    kwargs=None,
    learning_rate=1e-3,
    max_steps=1_000,
    grad_tol=1e-6,
    return_history=True,
):
    if kwargs is None:
        kwargs = {}

    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(init_params)
    params = init_params
    history = []

    loss_and_grad_fn = jax.value_and_grad(lambda p: f(p, *args, **kwargs))
    
    @jax.jit
    def update_step(params, opt_state):
        loss, grads = loss_and_grad_fn(params)
        grad_norm = jnp.linalg.norm(grads)
        updates, new_opt_state = optimizer.update(grads, opt_state, params)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_opt_state, loss, grad_norm

    for step in range(max_steps):
        params, opt_state, loss, grad_norm = update_step(params, opt_state)

        if return_history:
            history.append((loss, grad_norm))

        if grad_norm <= grad_tol:
            break

    if return_history:
        return params, jnp.array(history)

    return params
In [5]:
init_v = jnp.ones(J)

v_hat, history = optimize_adam(
    neg_log_likelihood_logit,
    init_v,
    args=(choice_logit_oh,),
    learning_rate=0.05,
    max_steps=10_000,
)
print(f'v MLE obtained by numerical optimization: {v_hat}')
v MLE obtained by numerical optimization: [0.962938  2.0065932 2.1694036 2.479695  1.4999783]

Now we compute the standard errors, where $d$ below stands for choice data. \begin{align*} \text{SE}(\widehat{\theta}) &= \sqrt{\widehat{\text{Var}}(\widehat{v})}\\ &= \sqrt{\Bigg(\frac{d^2}{dv dv'}\mathcal{L}(v; d)\Big |_{v = \widehat{v}}\Bigg)^{-1}} \end{align*}

In [6]:
H = jax.hessian(
    lambda theta: neg_log_likelihood_logit(theta, choice_logit_oh)
)(v_hat)

cov_hat = jnp.linalg.inv(H) / N
se_hat = jnp.sqrt(jnp.diag(cov_hat))

print("SE(v_hat):", se_hat)
SE(v_hat): [0.07075374 0.06410967 0.06353633 0.06266137 0.06656851]

The Bayesian Approach: NUTS¶

Since we already explored many different samplers in the first episode of this series, this time I will focus on only one sampler, NUTS with fine-tuned hyperparameters. First, we need to write our log-density of the posterior distribution. Let's specify a wide prior for v values: \begin{align*} v_j \sim \mathcal{N}(0, \sigma^2) \qquad \sigma^2=100 \end{align*}

Using Bayes' law: \begin{align*} \pi(v|d) &\propto L(v; d) \pi(v) \end{align*}

So, the log-density is: \begin{align*} \log(\pi(v|d)) = \mathcal{L}(v; d) + \sum_{j\in \mathcal{J}} \log(\pi(v_j)) \end{align*}

Where: \begin{align*} \pi(v_j) &= \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{v_j^2}{2\sigma^2}}\\ \log(\pi(v_j)) &= -\frac{1}{2}\log(\sqrt{\pi}\sigma) - \frac{v_j^2}{2\sigma^2}\\ \sum_{j \in \mathcal{J}} \log(\pi(v_j)) &= -J\frac{1}{2}\log(\sqrt{\pi}\sigma) - \frac{1}{2} \sum_{j \in \mathcal{J}} \Big(\frac{v_j}{\sigma} \Big)^2 \end{align*}

Note that the same way we drop the denominator in Bayes' law, the constant here is irrelevant, so we drop it from our function.

In [7]:
@jax.jit
def log_posterior_v(v, X, sigma_prior):
    
    log_likelihood = -N*neg_log_likelihood_logit(v, choices_oh=X) #multiplied by N because our original 
                                                                  #likelihood function computed the mean likelihood
    log_prior = -0.5 * jnp.sum((v / sigma_prior) ** 2)

    return log_likelihood + log_prior
In [8]:
def run_blackjax_chain(rng_key, kernel, initial_state, num_samples):
    keys = jax.random.split(rng_key, num_samples)
    @jax.jit
    def inference_loop(state, keys):
        def one_step(state, key):
            state, info = kernel(key, state)
            return state, (state.position, info)

        return jax.lax.scan(one_step, state, keys)

    final_state, (positions, infos) = inference_loop(initial_state, keys)
    return positions, infos, final_state
In [9]:
@jax.jit
def logdensity_fn(v):
    return log_posterior_v(v, choice_logit_oh, 10)
In [10]:
key_warmup, key_sample = jax.random.split(key)

num_warmup = 1_000
num_draws = 30_000

adapt = blackjax.window_adaptation(
    blackjax.nuts,
    logdensity_fn,
    target_acceptance_rate=0.80,
)

adaptation_result, warmup_info = adapt.run(
    key_warmup,
    init_v,
    num_warmup,
)

adapted_state = adaptation_result.state
tuned_parameters = adaptation_result.parameters

print("Tuned parameters:")
print(tuned_parameters)

blackjax_nuts_adapted = blackjax.nuts(
    logdensity_fn,
    **tuned_parameters,
)

v_samples_blackjax_nuts_adapted, info_blackjax_nuts_adapted, final_state_blackjax_nuts_adapted = run_blackjax_chain(
    key_sample,
    blackjax_nuts_adapted.step,
    adapted_state,
    num_draws,
)
Tuned parameters:
{'step_size': Array(0.26106367, dtype=float32, weak_type=True), 'inverse_mass_matrix': Array([0.00464145, 0.00419285, 0.00378266, 0.00384861, 0.00468612],      dtype=float32)}
In [11]:
print(jnp.mean(v_samples_blackjax_nuts_adapted, axis=0))
print(jnp.std(v_samples_blackjax_nuts_adapted, axis=0))
[0.96365964 2.0075467  2.170246   2.48079    1.5006465 ]
[0.07085792 0.06392232 0.06330264 0.06249685 0.0662698 ]
In [12]:
param_names = [r"$v_1$", r"$v_2$", r"$v_3$", r"$v_4$", r"$v_5$"]

df = pd.DataFrame(v_samples_blackjax_nuts_adapted, columns=param_names)

df_long = df.melt(
    var_name="Parameter",
    value_name="Posterior draw"
)

# Force the ordering explicitly
df_long["Parameter"] = pd.Categorical(
    df_long["Parameter"],
    categories=param_names,
    ordered=True
)

fig, ax = plt.subplots(figsize=(13, 4))

sns.histplot(
    data=df_long,
    x="Posterior draw",
    hue="Parameter",
    hue_order=param_names,
    bins=250,
    kde=True,
    stat="density",
    common_norm=False,
    #element="step",
    fill=True,
    ax=ax,
)

ax.set_title("Posterior distributions from BlackJax NUTS")
ax.set_xlabel(r"Utility value relative to $v_0=0$")
ax.set_ylabel("Posterior density")

sns.move_legend(
    ax,
    "center left",
    bbox_to_anchor=(1.01, 0.5),
    title="Parameter",
    frameon=False,
)

fig.tight_layout()
plt.show()
No description has been provided for this image

We recover meaningful parameter distributions, with distributions that match almost exactly the frequentist estimates. Here the notion of a "True" v is a bit meaningless, since Bayesian estimation recovers the posterior distribution of v, not some true, fixed value of the parameter. But the mean and std of the posterior distributions match very closely the one reconstructed by the frequentist estimates.