Poverty Bayes: fitting million-parameter models for pennies with serverless MCMC A team of Bayesian statisticians has developed a method to fit million-parameter models using serverless GPU computing for pennies, dramatically reducing the cost and complexity of Markov chain Monte Carlo (MCMC) inference. The approach, demonstrated on a hierarchical logistic regression with 100,000 groups and 20 covariates, runs on Modal's rented datacenter GPUs instead of requiring dedicated hardware. This technique makes large-scale Bayesian inference accessible to researchers who previously had to manage their own GPU infrastructure. Poverty Bayes: fitting million-parameter models for pennies with serverless MCMC It’s a good time to be an applied probabilist. The deep learning revolution has led to tremendous improvements in the $ / flops department, and we Bayesians can easily hop on this train During grad school, I used to spend nights and weekends babysitting MCMC runs on my GeForce Titan XP running in my bedroom by the way, thank you NVIDIA Academic Grant Program https://www.nvidia.com/en-us/industries/higher-education-research/academic-grant-program/ while simultaneously trying to keep the waste heat from cooking me as I slept. If you are a newcomer to this field, rejoice in the knowledge that all this suffering is a thing of the past. A slew of companies are rushing to the fore with user-friendly platforms for renting GPUs. For prototyping, I really enjoy working with Modal since I’m cheap and I’m too lazy to keep managing my own fleet. In this post, I’ll show a workflow for using GPU-based inference on Modal for a model which is very large by the standards of Bayesian statisticians \ \vert \theta\vert \gt 10^6 \ , by deploying to a datacenter GPU and renting it only for a short time. Model & data We’ll use synthetic data for this example. I’ve chosen a hierarchical logistic regression for this post since it has a non-conjugate likelihood, appears commonly in practice, and can easily be assigned more parameters by increasing the number of covariates and/or the number of groups. Let \ i\ , \ g\ , and \ k\ denote the indices over observations, groups, and covariates. Furthermore, let \ x i \in \mathbb{R}^{20}\ be the covariate vector for observation \ i\ , and let \ g i \in \{1,\ldots,100000\}\ identify its group. The data-generating process uses population-level slopes \ \beta k\ , group intercept deviations \ \alpha g\ , and group slope deviations \ \gamma {gk}\ . The binary outcome is generated from \ Y i \sim \operatorname{Bernoulli} p i , \qquad \operatorname{logit} p i = \alpha + \alpha {g i} + \sum {k=1}^{K} x {ik} \beta k + \gamma {g i k} .\ Essentially, this is a logistic regression with random slopes for 20 covariates and a random intercept for each of 100,000 groups in the data. We’ll use a non-centered parameterization for the group effects. The prior specification is \ \begin{aligned} \alpha &\sim \operatorname{Normal} 0, 1.5 , \\ \beta k &\sim \operatorname{Normal} 0, 1 , \\ \sigma \alpha &\sim \operatorname{HalfNormal} 1 , \\ \sigma {\gamma,k} &\sim \operatorname{HalfNormal} 0.5 , \\ z {\alpha,g} &\sim \operatorname{Normal} 0, 1 , \\ z {\gamma,gk} &\sim \operatorname{Normal} 0, 1 , \\ \alpha g &= \sigma \alpha z {\alpha,g}, \\ \gamma {gk} &= \sigma {\gamma,k} z {\gamma,gk}. \end{aligned}\ The code below produces a synthetic dataset; we can control the overall sparsity of the response with the value of α true . python import numpy as np RANDOM SEED = 827 rng = np.random.default rng RANDOM SEED N = 1 000 000 Number of data points G = 100 000 Number of groups K = 20 Number of covariates / features group idx = rng.integers 0, G, size=N, dtype=np.int64 X = rng.normal size= N, K .astype np.float32 α true = np.float32 -1.0 β true = rng.normal 0.0, 0.45, size=K .astype np.float32 σ α true = np.float32 0.80 σ γ true = rng.uniform 0.15, 0.35, size=K .astype np.float32 α group true = rng.normal 0.0, σ α true, size=G .astype np.float32 γ group true = rng.normal 0.0, σ γ true, size= G, K .astype np.float32 η = α true + α group true group idx + np.sum X β true + γ group true group idx , axis=1 p = 1.0 / 1.0 + np.exp -η y = rng.binomial 1, p .astype np.int64 print f"Average of y: {np.mean y :.2f}" Average of y: 0.36 Here, we define our model in PyMC. This is a fairly standard model definition without much nuance or many tricks. python import pymc as pm import pytensor import pytensor.tensor as pt pytensor.config.floatX = "float32" def build model X, y, group idx : X = np.asarray X, dtype=np.float32 y = np.asarray y, dtype=np.int64 group idx = np.asarray group idx, dtype=np.int64 N, K = X.shape G = int group idx.max + 1 coords = { "obs": np.arange N , "group": np.arange G , "covariate": np.arange K , } with pm.Model coords=coords as model: X data = pm.Data "X", X, dims= "obs", "covariate" group lookup = pt.as tensor variable group idx, name="group lookup" α = pm.Normal "α", np.float32 0.0 , np.float32 1.5 β = pm.Normal "β", np.float32 0.0 , np.float32 1.0 , dims="covariate" σ α = pm.HalfNormal "σ α", np.float32 1.0 σ γ = pm.HalfNormal "σ γ", np.float32 0.5 , dims="covariate" z α = pm.Normal "z α", np.float32 0.0 , np.float32 1.0 , dims="group" z γ = pm.Normal "z γ", np.float32 0.0 , np.float32 1.0 , dims= "group", "covariate" α group = σ α z α γ group = σ γ z γ η = α + α group group lookup + pm.math.sum X data β + γ group group lookup , axis=1 pm.Bernoulli "Y", logit p=η, observed=y, dims="obs" return model To illustrate why this model could be challenging to fit using only a CPU, we profile the gradient evaluation time. The dlogp function represents \ \frac{\partial}{\partial \theta} \log \tilde{p} \ with \ \tilde{p}\ denoting the unnormalized posterior density; it may be calculated up to 1000 times per sample with the evaluation count increasing with the difficulty of the problem in terms of posterior geometry and/or nonlinearity. Hamiltonian Monte Carlo, the workhorse of most modern Bayesian programming frameworks, requires these gradient evaluations to draw Monte Carlo samples. The code cell below runs the gradient a few times and records the time taken. python import time def median runtime fn, n=5 : fn times = for in range n : start = time.perf counter fn times.append time.perf counter - start return float np.median times , times local model = build model X, y, group idx local point = local model.initial point with local model: dlogp fn = local model.compile dlogp local grad median, local grad times = median runtime lambda: dlogp fn local point print f"Median grad eval time is {local grad median:.2f} seconds" Median grad eval time is 0.20 seconds If we try to extrapolate to the time required for running a full chain of 1000 samples, we find that assuming ~250 grad evals per sample, we would need 50 seconds per sample and 50,000 seconds ~14 hours to run a full chain This simply won’t do. Let’s deploy this remotely and see what we get Running remotely on Modal For those of you unfamiliar with the magical world of serverless GPU, please take a look at Modal. I will probably never be wealthy enough to outright own a beautiful server rack of datacenter GPUs, so this is the best I can get. Basically, Modal provides APIs for spinning up jobs on GPUs with very little friction and little downtime. To make this example run nicely on GPU, we’ll make a few adjustments. First, we’ll toggle it to use the Jax + NumPyro backend to work out-of-the-box with an NVIDIA GPU. python import modal modal image = modal.Image.debian slim python version="3.13" .uv pip install "arviz==1.1.0", "jax cuda12 ==0.7.2", "numpy==2.3.5", "numpyro==0.19.0", "pandas==2.3.3", "pymc==6.0.1", app = modal.App "i-should-be-working-right-now", image=modal image We’ll set a few environment flags for the float precision and the devices, define a helper to benchmark the execution time, and apply some Jax-isms to prep the compute graph and evaluate it. python def remote model X, y, group idx : import os os.environ "JAX PLATFORMS" = "cuda" import jax import jax.numpy as jnp import numpy as np import pytensor pytensor.config.floatX = "float32" X = np.asarray X, dtype=np.float32 y = np.asarray y, dtype=np.int64 group idx = np.asarray group idx, dtype=np.int64 gpu devices = jax.devices "gpu" gpu check = jax.device put jnp.ones 512, 512 , dtype=jnp.float32 , gpu devices 0 .sum .block until ready model = build model X, y, group idx info = { "N": X.shape 0 , "G": int group idx.max + 1, "K": X.shape 1 , "jax backend": jax.default backend , "jax gpu devices": str device for device in gpu devices , "gpu check sum": float gpu check , "x dtype": str X.dtype , "value var dtypes": {var.name: var.dtype for var in model.value vars}, } print f"JAX backend: {info 'jax backend' }; GPU devices: {info 'jax gpu devices' }" return model, info def median runtime fn, n=5, synchronize=None : import time import numpy as np if synchronize is None: synchronize = lambda result: None def timed run : start = time.perf counter result = fn synchronize result return time.perf counter - start synchronize fn times = timed run for in range n return float np.median times , times def wait jax result : import jax jax.tree util.tree map lambda x: x.block until ready , result @app.function gpu="A100", timeout=2 60 60 def profile dlogp X, y, group idx, n evals=5 : import jax import jax.numpy as jnp from pymc.sampling.jax import get jaxified logp model, info = remote model X, y, group idx point = model.initial point with model: pytensor dlogp = model.compile dlogp pytensor median, pytensor times = median runtime lambda: pytensor dlogp point , n evals values = jnp.asarray point var.name for var in model.value vars jax loss = get jaxified logp model jax dlogp = jax.jit jax.value and grad jax loss wait jax jax dlogp values jax median, jax times = median runtime lambda: jax dlogp values , n evals, wait jax profile = { "median pytensor dlogp eval seconds": pytensor median, "median jax dlogp eval seconds": jax median, } print f"Median dlogp eval: PyTensor={pytensor median:.3f}s, JAX={jax median:.3f}s" return { info, profile} @app.function gpu="A100", timeout=2 60 60 def fit hierarchical logistic X, y, group idx, draws=100, tune=100, random seed=RANDOM SEED : import time import arviz as az import pymc as pm model, info = remote model X, y, group idx with model: start = time.perf counter idata = pm.sample draws, tune=tune, chains=1, nuts sampler="numpyro", random seed=random seed, var names= "α", "β", "σ α", "σ γ" , progressbar=False, elapsed seconds = time.perf counter - start β mean = idata.posterior "β" .mean dim= "chain", "draw" .to numpy return { info, "elapsed seconds": elapsed seconds, "summary": az.summary idata, var names= "α", "β", "σ α", "σ γ" , "β mean": β mean.tolist , } With all of this helper logic written, we can finally deploy to the cloud We will start by running a short job to just profile the logp gradient function. print "Launching Modal GPU dlogp profile" with app.run : profile result = profile dlogp.remote X, y, group idx print f"Finished running; profile results: {profile result}" Launching Modal GPU dlogp profile Finished running; profile results: {'N': 1000000, 'G': 100000, 'K': 20, 'jax backend': 'gpu', 'jax gpu devices': 'cuda:0' , 'gpu check sum': 262144.0, 'x dtype': 'float32', 'value var dtypes': {'α': 'float32', 'β': 'float32', 'σ α log ': 'float32', 'σ γ log ': 'float32', 'z α': 'float32', 'z γ': 'float32'}, 'median pytensor dlogp eval seconds': 0.2271286229999987, 'median jax dlogp eval seconds': 0.0013876599999989025} Interesting - the gradient takes 0.001 seconds on GPU and 0.22 seconds on CPU. That is around a 200x speedup Next, we run the Markov chain to completion and retrieve the results. print "Launching Modal GPU run" with app.run : result = fit hierarchical logistic.remote X, y, group idx print f"MCMC run finished in {result "elapsed seconds" :.2f} seconds" MCMC run finished in 272.16 seconds Modal helpfully lists their prices per second https://modal.com/pricing ; an A100 runs for $0.000583 / second, meaning this run cost me around 15 cents from start to finish. Parameter Recovery We can see all the sampler diagnostics in the posterior summary. We’d need to run more chains to get the \ \hat{R}\ value for this model. python import pandas as pd posterior summary = pd.DataFrame result "summary" posterior summary.iloc 0:5 parameter mean sd eti89 lb eti89 ub ess bulk ess tail r hat \ 0 α -0.980 0.015 -0.997 -0.952 1.510 5.445 NaN 1 β 0 -0.275 0.005 -0.282 -0.266 2.042 5.659 NaN 2 β 1 0.630 0.010 0.612 0.641 1.627 5.374 NaN 3 β 2 -0.358 0.007 -0.366 -0.345 2.390 5.593 NaN 4 β 3 0.359 0.006 0.348 0.366 1.834 5.607 NaN mcse mean mcse sd 0 0.012 0.008 1 0.004 0.003 2 0.008 0.005 3 0.005 0.003 4 0.005 0.003 A quick diagnostic is to compare the posterior mean of the fixed effects with the values used to simulate the data. With \ 10^6\ observations, this model has more than enough data to estimate the fixed-effects. python import matplotlib.pyplot as plt blog colors = { "background": " 1c1c1d", "text": " e8e8e8", "accent": " 2698ba", "point": " ffffff", } β posterior mean = np.asarray result "β mean" limits = min β true.min , β posterior mean.min - 0.05, max β true.max , β posterior mean.max + 0.05, fig, ax = plt.subplots figsize= 4.8, 3.6 fig.patch.set facecolor blog colors "background" ax.set facecolor blog colors "background" ax.scatter β true, β posterior mean, s=36, color=blog colors "point" , alpha=0.82 ax.plot limits, limits, color=blog colors "accent" , linewidth=1.5 ax.set xlim limits ; ax.set ylim limits ; ax.grid False ax.set xlabel "True fixed effect", color=blog colors "text" ; ax.set ylabel "Posterior mean fixed effect", color=blog colors "text" ax.tick params colors=blog colors "text" for spine in ax.spines.values : spine.set color blog colors "text" Nice The true value and posterior mean estimates line up perfectly. A job well done for MCMC. I think that Bayes on GPU is tremendously undervalued. If this interests you or you have ideas to chat about, drop me a line. Enjoy Reading This Article? Here are some more articles you might like to read next: To my junior collaborators, this is how I want you to write your research code /blog/2025/preparing-a-dataset/ Surrogate modeling for SEIR dynamics /blog/2021/creating-an-emulator-for-an-agent-based-model/ Modeling data with correlated errors across a directed graph /blog/2025/modeling-data-with-correlated-errors-across-a-directed-graph/ Rolling your own serverless OCR in 40 lines of code /blog/2026/ocr-textbooks-modal-deepseek/ Don't know where your data is from? Bayesian modeling for unknown coordinates /blog/2026/dont-know-where-your-data-is-from/