cd /news/machine-learning/poverty-bayes-fitting-million-parame… · home topics machine-learning article
[ARTICLE · art-14982] src=christopherkrapu.com pub= topic=machine-learning verified=true sentiment=↑ positive

Poverty Bayes: fitting million-parameter models for pennies with serverless MCMC

A team of Bayesian statisticians has developed a method to fit million-parameter models using serverless GPU computing for pennies, dramatically reducing the cost and complexity of Markov chain Monte Carlo (MCMC) inference. The approach, demonstrated on a hierarchical logistic regression with 100,000 groups and 20 covariates, runs on Modal's rented datacenter GPUs instead of requiring dedicated hardware. This technique makes large-scale Bayesian inference accessible to researchers who previously had to manage their own GPU infrastructure.

read9 min publishedMay 27, 2026

It’s a good time to be an applied probabilist. The deep learning revolution has led to tremendous improvements in the $ / flops department, and we Bayesians can easily hop on this train! During grad school, I used to spend nights and weekends babysitting MCMC runs on my GeForce Titan XP running in my bedroom (by the way, thank you NVIDIA Academic Grant Program) while simultaneously trying to keep the waste heat from cooking me as I slept. If you are a newcomer to this field, rejoice in the knowledge that all this suffering is a thing of the past. A slew of companies are rushing to the fore with user-friendly platforms for renting GPUs. For prototyping, I really enjoy working with Modal since I’m cheap and I’m too lazy to keep managing my own fleet.

In this post, I’ll show a workflow for using GPU-based inference on Modal for a model which is very large by the standards of Bayesian statisticians ((\vert \theta\vert \gt 10^6)), by deploying to a datacenter GPU and renting it only for a short time.

Model & data #

We’ll use synthetic data for this example.

I’ve chosen a hierarchical logistic regression for this post since it has a non-conjugate likelihood, appears commonly in practice, and can easily be assigned more parameters by increasing the number of covariates and/or the number of groups.

Let (i), (g), and (k) denote the indices over observations, groups, and covariates. Furthermore, let (x_i \in \mathbb{R}^{20}) be the covariate vector for observation (i), and let (g_i \in {1,\ldots,100000}) identify its group. The data-generating process uses population-level slopes (\beta_k), group intercept deviations (\alpha_g), and group slope deviations (\gamma_{gk}). The binary outcome is generated from

[Y_i \sim \operatorname{Bernoulli}(p_i), \qquad \operatorname{logit}(p_i) = \alpha + \alpha_{g_i} + \sum_{k=1}^{K} x_{ik}(\beta_k + \gamma_{g_i k}).]Essentially, this is a logistic regression with random slopes for 20 covariates and a random intercept for each of 100,000 groups in the data.

We’ll use a non-centered parameterization for the group effects. The prior specification is

[\begin{aligned} \alpha &\sim \operatorname{Normal}(0, 1.5), \ \beta_k &\sim \operatorname{Normal}(0, 1), \ \sigma_\alpha &\sim \operatorname{HalfNormal}(1), \ \sigma_{\gamma,k} &\sim \operatorname{HalfNormal}(0.5), \ z_{\alpha,g} &\sim \operatorname{Normal}(0, 1), \ z_{\gamma,gk} &\sim \operatorname{Normal}(0, 1), \ \alpha_g &= \sigma_\alpha z_{\alpha,g}, \ \gamma_{gk} &= \sigma_{\gamma,k} z_{\gamma,gk}. \end{aligned}]The code below produces a synthetic dataset; we can control the overall sparsity of the response with the value of α_true

.

import numpy as np

RANDOM_SEED = 827
rng = np.random.default_rng(RANDOM_SEED)

N = 1_000_000 # Number of data points
G = 100_000   # Number of groups
K = 20      # Number of covariates / features

group_idx = rng.integers(0, G, size=N, dtype=np.int64)
X = rng.normal(size=(N, K)).astype(np.float32)

α_true = np.float32(-1.0)
β_true = rng.normal(0.0, 0.45, size=K).astype(np.float32)
σ_α_true = np.float32(0.80)
σ_γ_true = rng.uniform(0.15, 0.35, size=K).astype(np.float32)
α_group_true = rng.normal(0.0, σ_α_true, size=G).astype(np.float32)
γ_group_true = rng.normal(0.0, σ_γ_true, size=(G, K)).astype(np.float32)

η = α_true + α_group_true[group_idx] + np.sum(X * (β_true + γ_group_true[group_idx]), axis=1)
p = 1.0 / (1.0 + np.exp(-η))
y = rng.binomial(1, p).astype(np.int64)

print(f"Average of y: {np.mean(y):.2f}")
Average of y: 0.36

Here, we define our model in PyMC. This is a fairly standard model definition without much nuance or many tricks.

import pymc as pm
import pytensor
import pytensor.tensor as pt

pytensor.config.floatX = "float32"

def build_model(X, y, group_idx):

    X = np.asarray(X, dtype=np.float32)
    y = np.asarray(y, dtype=np.int64)
    group_idx = np.asarray(group_idx, dtype=np.int64)

    N, K = X.shape
    G = int(group_idx.max()) + 1
    coords = {
        "obs": np.arange(N),
        "group": np.arange(G),
        "covariate": np.arange(K),
    }

    with pm.Model(coords=coords) as model:
        X_data = pm.Data("X", X, dims=("obs", "covariate"))
        group_lookup = pt.as_tensor_variable(group_idx, name="group_lookup")

        α = pm.Normal("α", np.float32(0.0), np.float32(1.5))
        β = pm.Normal("β", np.float32(0.0), np.float32(1.0), dims="covariate")
        σ_α = pm.HalfNormal("σ_α", np.float32(1.0))
        σ_γ = pm.HalfNormal("σ_γ", np.float32(0.5), dims="covariate")

        z_α = pm.Normal("z_α", np.float32(0.0), np.float32(1.0), dims="group")
        z_γ = pm.Normal("z_γ", np.float32(0.0), np.float32(1.0), dims=("group", "covariate"))
        α_group = σ_α * z_α 
        γ_group = σ_γ * z_γ

        η = α + α_group[group_lookup] + pm.math.sum(X_data * (β + γ_group[group_lookup]), axis=1)
        pm.Bernoulli("Y", logit_p=η, observed=y, dims="obs")

    return model

To illustrate why this model could be challenging to fit using only a CPU, we profile the gradient evaluation time. The dlogp

function represents (\frac{\partial}{\partial \theta}[\log \tilde{p}]) with (\tilde{p}) denoting the unnormalized posterior density; it may be calculated up to 1000 times per sample with the evaluation count increasing with the difficulty of the problem in terms of posterior geometry and/or nonlinearity. Hamiltonian Monte Carlo, the workhorse of most modern Bayesian programming frameworks, requires these gradient evaluations to draw Monte Carlo samples.

The code cell below runs the gradient a few times and records the time taken.

import time

def median_runtime(fn, n=5):
    fn()
    times = []
    for _ in range(n):
        start = time.perf_counter()
        fn()
        times.append(time.perf_counter() - start)
    return float(np.median(times)), times

local_model = build_model(X, y, group_idx)
local_point = local_model.initial_point()

with local_model:
    dlogp_fn = local_model.compile_dlogp()

local_grad_median, local_grad_times = median_runtime(lambda: dlogp_fn(local_point))
print(f"Median grad eval time is {local_grad_median:.2f} seconds")
Median grad eval time is 0.20 seconds

If we try to extrapolate to the time required for running a full chain of 1000 samples, we find that assuming ~250 grad evals per sample, we would need 50 seconds per sample and 50,000 seconds (~14 hours) to run a full chain! This simply won’t do.

Let’s deploy this remotely and see what we get!

Running remotely on Modal #

For those of you unfamiliar with the magical world of serverless GPU, please take a look at Modal. I will probably never be wealthy enough to outright own a beautiful server rack of datacenter GPUs, so this is the best I can get. Basically, Modal provides APIs for spinning up jobs on GPUs with very little friction and little downtime.

To make this example run nicely on GPU, we’ll make a few adjustments. First, we’ll toggle it to use the Jax + NumPyro backend to work out-of-the-box with an NVIDIA GPU.

import modal

modal_image = (
    modal.Image.debian_slim(python_version="3.13")
    .uv_pip_install(
        "arviz==1.1.0",
        "jax[cuda12]==0.7.2",
        "numpy==2.3.5",
        "numpyro==0.19.0",
        "pandas==2.3.3",
        "pymc==6.0.1",
    )
)

app = modal.App("i-should-be-working-right-now", image=modal_image)

We’ll set a few environment flags for the float precision and the devices, define a helper to benchmark the execution time, and apply some Jax-isms to prep the compute graph and evaluate it.

def remote_model(X, y, group_idx):
    import os

    os.environ["JAX_PLATFORMS"] = "cuda"

    import jax
    import jax.numpy as jnp
    import numpy as np
    import pytensor

    pytensor.config.floatX = "float32"
    X = np.asarray(X, dtype=np.float32)
    y = np.asarray(y, dtype=np.int64)
    group_idx = np.asarray(group_idx, dtype=np.int64)

    gpu_devices = jax.devices("gpu")
    gpu_check = jax.device_put(jnp.ones((512, 512), dtype=jnp.float32), gpu_devices[0]).sum().block_until_ready()
    model = build_model(X, y, group_idx)

    info = {
        "N": X.shape[0],
        "G": int(group_idx.max()) + 1,
        "K": X.shape[1],
        "jax_backend": jax.default_backend(),
        "jax_gpu_devices": [str(device) for device in gpu_devices],
        "gpu_check_sum": float(gpu_check),
        "x_dtype": str(X.dtype),
        "value_var_dtypes": {var.name: var.dtype for var in model.value_vars},
    }
    print(f"JAX backend: {info['jax_backend']}; GPU devices: {info['jax_gpu_devices']}")
    return model, info

def median_runtime(fn, n=5, synchronize=None):
    import time

    import numpy as np

    if synchronize is None:
        synchronize = lambda result: None

    def timed_run():
        start = time.perf_counter()
        result = fn()
        synchronize(result)
        return time.perf_counter() - start

    synchronize(fn())
    times = [timed_run() for _ in range(n)]
    return float(np.median(times)), times

def wait_jax(result):
    import jax

    jax.tree_util.tree_map(lambda x: x.block_until_ready(), result)

@app.function(gpu="A100", timeout=2 * 60 * 60)
def profile_dlogp(X, y, group_idx, n_evals=5):

    import jax
    import jax.numpy as jnp
    from pymc.sampling.jax import get_jaxified_logp

    model, info = remote_model(X, y, group_idx)
    point = model.initial_point()

    with model:
        pytensor_dlogp = model.compile_dlogp()

    pytensor_median, pytensor_times = median_runtime(lambda: pytensor_dlogp(point), n_evals)

    values = [jnp.asarray(point[var.name]) for var in model.value_vars]
    jax_loss = get_jaxified_logp(model)
    jax_dlogp = jax.jit(jax.value_and_grad(jax_loss))

    wait_jax(jax_dlogp(values))
    jax_median, jax_times = median_runtime(lambda: jax_dlogp(values), n_evals, wait_jax)

    profile = {
        "median_pytensor_dlogp_eval_seconds": pytensor_median,
        "median_jax_dlogp_eval_seconds": jax_median,
    }
    print(f"Median dlogp eval: PyTensor={pytensor_median:.3f}s, JAX={jax_median:.3f}s")
    return {**info, **profile}

@app.function(gpu="A100", timeout=2 * 60 * 60)
def fit_hierarchical_logistic(X, y, group_idx, draws=100, tune=100, random_seed=RANDOM_SEED):
    import time
    import arviz as az
    import pymc as pm

    model, info = remote_model(X, y, group_idx)

    with model:
        start = time.perf_counter()
        idata = pm.sample(
            draws,
            tune=tune,
            chains=1,
            nuts_sampler="numpyro",
            random_seed=random_seed,
            var_names=["α", "β", "σ_α", "σ_γ"],
            progressbar=False,
        )
        elapsed_seconds = time.perf_counter() - start

    β_mean = idata.posterior["β"].mean(dim=("chain", "draw")).to_numpy()

    return {
        **info,
        "elapsed_seconds": elapsed_seconds,
        "summary": az.summary(idata, var_names=["α", "β", "σ_α", "σ_γ"]),
        "β_mean": β_mean.tolist(),
    }

With all of this helper logic written, we can finally deploy to the cloud! We will start by running a short job to just profile the logp gradient function.

print("Launching Modal GPU dlogp profile")
with app.run():
    profile_result = profile_dlogp.remote(X, y, group_idx)
print(f"Finished running; profile results: {profile_result}")
Launching Modal GPU dlogp profile
Finished running; profile results: {'N': 1000000, 'G': 100000, 'K': 20, 'jax_backend': 'gpu', 'jax_gpu_devices': ['cuda:0'], 'gpu_check_sum': 262144.0, 'x_dtype': 'float32', 'value_var_dtypes': {'α': 'float32', 'β': 'float32', 'σ_α_log__': 'float32', 'σ_γ_log__': 'float32', 'z_α': 'float32', 'z_γ': 'float32'}, 'median_pytensor_dlogp_eval_seconds': 0.2271286229999987, 'median_jax_dlogp_eval_seconds': 0.0013876599999989025}

Interesting - the gradient takes 0.001 seconds on GPU and 0.22 seconds on CPU. That is around a 200x speedup!

Next, we run the Markov chain to completion and retrieve the results.

print("Launching Modal GPU run")
with app.run():
    result = fit_hierarchical_logistic.remote(X, y, group_idx)

print(f"MCMC run finished in {result["elapsed_seconds"]:.2f} seconds")
MCMC run finished in 272.16 seconds

Modal helpfully lists their prices per second; an A100 runs for $0.000583 / second, meaning this run cost me around 15 cents from start to finish.

Parameter Recovery #

We can see all the sampler diagnostics in the posterior summary. We’d need to run more chains to get the (\hat{R}) value for this model.

import pandas as pd

posterior_summary = pd.DataFrame(result["summary"])
posterior_summary.iloc[0:5]
parameter   mean     sd  eti89_lb  eti89_ub  ess_bulk  ess_tail  r_hat  \
0         α -0.980  0.015    -0.997    -0.952     1.510     5.445    NaN   
1      β[0] -0.275  0.005    -0.282    -0.266     2.042     5.659    NaN   
2      β[1]  0.630  0.010     0.612     0.641     1.627     5.374    NaN   
3      β[2] -0.358  0.007    -0.366    -0.345     2.390     5.593    NaN   
4      β[3]  0.359  0.006     0.348     0.366     1.834     5.607    NaN   

   mcse_mean  mcse_sd  
0      0.012    0.008  
1      0.004    0.003  
2      0.008    0.005  
3      0.005    0.003  
4      0.005    0.003

A quick diagnostic is to compare the posterior mean of the fixed effects with the values used to simulate the data. With (10^6) observations, this model has more than enough data to estimate the fixed-effects.

import matplotlib.pyplot as plt

blog_colors = {
    "background": "#1c1c1d",
    "text": "#e8e8e8",
    "accent": "#2698ba",
    "point": "#ffffff",
}
β_posterior_mean = np.asarray(result["β_mean"])
limits = [
    min(β_true.min(), β_posterior_mean.min()) - 0.05,
    max(β_true.max(), β_posterior_mean.max()) + 0.05,
]

fig, ax = plt.subplots(figsize=(4.8, 3.6))
fig.patch.set_facecolor(blog_colors["background"])
ax.set_facecolor(blog_colors["background"])
ax.scatter(β_true, β_posterior_mean, s=36, color=blog_colors["point"], alpha=0.82)
ax.plot(limits, limits, color=blog_colors["accent"], linewidth=1.5)
ax.set_xlim(limits); ax.set_ylim(limits); ax.grid(False)
ax.set_xlabel("True fixed effect", color=blog_colors["text"]); ax.set_ylabel("Posterior mean fixed effect", color=blog_colors["text"])
ax.tick_params(colors=blog_colors["text"])
for spine in ax.spines.values():
    spine.set_color(blog_colors["text"])

Nice! The true value and posterior mean estimates line up perfectly. A job well done for MCMC.

I think that Bayes on GPU is tremendously undervalued. If this interests you or you have ideas to chat about, drop me a line.

Enjoy Reading This Article? #

Here are some more articles you might like to read next:

To my junior collaborators, this is how I want you to write your research code

Surrogate modeling for SEIR dynamics

Modeling data with correlated errors across a directed graph

Rolling your own serverless OCR in 40 lines of code

Don't know where your data is from? Bayesian modeling for unknown coordinates

── more in #machine-learning 4 stories · sorted by recency
sponsored brought to you by zahid.host 4,200+ EU-deployed projects
reading about agents? ship yours in a single git push.

Run your AI side-project on zahid.host

EU-based hosting, git-push deploys, automatic HTTPS, no cold starts. Free tier with a custom domain — perfect for shipping the agent you just read about.

$git push zahid main
Live at https://your-agent.zahid.host
Get free account → Pricing
from €0/mo · no card required
LIVE [news/poverty-bayes-fittin…] indexed:0 read:9min 2026-05-27 ·