{"slug": "poverty-bayes-fitting-million-parameter-models-for-pennies-with-serverless-mcmc", "title": "Poverty Bayes: fitting million-parameter models for pennies with serverless MCMC", "summary": "A team of Bayesian statisticians has developed a method to fit million-parameter models using serverless GPU computing for pennies, dramatically reducing the cost and complexity of Markov chain Monte Carlo (MCMC) inference. The approach, demonstrated on a hierarchical logistic regression with 100,000 groups and 20 covariates, runs on Modal's rented datacenter GPUs instead of requiring dedicated hardware. This technique makes large-scale Bayesian inference accessible to researchers who previously had to manage their own GPU infrastructure.", "body_md": "# Poverty Bayes: fitting million-parameter models for pennies with serverless MCMC\n\nIt’s a good time to be an applied probabilist. The deep learning revolution has led to tremendous improvements in the $ / flops department, and we Bayesians can easily hop on this train! During grad school, I used to spend nights and weekends babysitting MCMC runs on my GeForce Titan XP running in my bedroom (by the way, thank you [NVIDIA Academic Grant Program](https://www.nvidia.com/en-us/industries/higher-education-research/academic-grant-program/)) while simultaneously trying to keep the waste heat from cooking me as I slept. If you are a newcomer to this field, rejoice in the knowledge that all this suffering is a thing of the past. A slew of companies are rushing to the fore with user-friendly platforms for renting GPUs. For prototyping, I really enjoy working with Modal since I’m cheap and I’m too lazy to keep managing my own fleet.\n\nIn this post, I’ll show a workflow for using GPU-based inference on Modal for a model which is very large by the standards of Bayesian statisticians \\((\\vert \\theta\\vert \\gt 10^6)\\), by deploying to a datacenter GPU and renting it only for a short time.\n\n## Model & data\n\nWe’ll use synthetic data for this example.\n\nI’ve chosen a hierarchical logistic regression for this post since it has a non-conjugate likelihood, appears commonly in practice, and can easily be assigned more parameters by increasing the number of covariates and/or the number of groups.\n\nLet \\(i\\), \\(g\\), and \\(k\\) denote the indices over observations, groups, and covariates. Furthermore, let \\(x_i \\in \\mathbb{R}^{20}\\) be the covariate vector for observation \\(i\\), and let \\(g_i \\in \\{1,\\ldots,100000\\}\\) identify its group. The data-generating process uses population-level slopes \\(\\beta_k\\), group intercept deviations \\(\\alpha_g\\), and group slope deviations \\(\\gamma_{gk}\\). The binary outcome is generated from\n\n\\[Y_i \\sim \\operatorname{Bernoulli}(p_i), \\qquad \\operatorname{logit}(p_i) = \\alpha + \\alpha_{g_i} + \\sum_{k=1}^{K} x_{ik}(\\beta_k + \\gamma_{g_i k}).\\]Essentially, this is a logistic regression with random slopes for 20 covariates and a random intercept for each of 100,000 groups in the data.\n\nWe’ll use a non-centered parameterization for the group effects. The prior specification is\n\n\\[\\begin{aligned} \\alpha &\\sim \\operatorname{Normal}(0, 1.5), \\\\ \\beta_k &\\sim \\operatorname{Normal}(0, 1), \\\\ \\sigma_\\alpha &\\sim \\operatorname{HalfNormal}(1), \\\\ \\sigma_{\\gamma,k} &\\sim \\operatorname{HalfNormal}(0.5), \\\\ z_{\\alpha,g} &\\sim \\operatorname{Normal}(0, 1), \\\\ z_{\\gamma,gk} &\\sim \\operatorname{Normal}(0, 1), \\\\ \\alpha_g &= \\sigma_\\alpha z_{\\alpha,g}, \\\\ \\gamma_{gk} &= \\sigma_{\\gamma,k} z_{\\gamma,gk}. \\end{aligned}\\]The code below produces a synthetic dataset; we can control the overall sparsity of the response with the value of `α_true`\n\n.\n\n``` python\nimport numpy as np\n\nRANDOM_SEED = 827\nrng = np.random.default_rng(RANDOM_SEED)\n\nN = 1_000_000 # Number of data points\nG = 100_000   # Number of groups\nK = 20      # Number of covariates / features\n\ngroup_idx = rng.integers(0, G, size=N, dtype=np.int64)\nX = rng.normal(size=(N, K)).astype(np.float32)\n\nα_true = np.float32(-1.0)\nβ_true = rng.normal(0.0, 0.45, size=K).astype(np.float32)\nσ_α_true = np.float32(0.80)\nσ_γ_true = rng.uniform(0.15, 0.35, size=K).astype(np.float32)\nα_group_true = rng.normal(0.0, σ_α_true, size=G).astype(np.float32)\nγ_group_true = rng.normal(0.0, σ_γ_true, size=(G, K)).astype(np.float32)\n\nη = α_true + α_group_true[group_idx] + np.sum(X * (β_true + γ_group_true[group_idx]), axis=1)\np = 1.0 / (1.0 + np.exp(-η))\ny = rng.binomial(1, p).astype(np.int64)\n\nprint(f\"Average of y: {np.mean(y):.2f}\")\nAverage of y: 0.36\n```\n\nHere, we define our model in PyMC. This is a fairly standard model definition without much nuance or many tricks.\n\n``` python\nimport pymc as pm\nimport pytensor\nimport pytensor.tensor as pt\n\npytensor.config.floatX = \"float32\"\n\ndef build_model(X, y, group_idx):\n\n    X = np.asarray(X, dtype=np.float32)\n    y = np.asarray(y, dtype=np.int64)\n    group_idx = np.asarray(group_idx, dtype=np.int64)\n\n    N, K = X.shape\n    G = int(group_idx.max()) + 1\n    coords = {\n        \"obs\": np.arange(N),\n        \"group\": np.arange(G),\n        \"covariate\": np.arange(K),\n    }\n\n    with pm.Model(coords=coords) as model:\n        X_data = pm.Data(\"X\", X, dims=(\"obs\", \"covariate\"))\n        group_lookup = pt.as_tensor_variable(group_idx, name=\"group_lookup\")\n\n        α = pm.Normal(\"α\", np.float32(0.0), np.float32(1.5))\n        β = pm.Normal(\"β\", np.float32(0.0), np.float32(1.0), dims=\"covariate\")\n        σ_α = pm.HalfNormal(\"σ_α\", np.float32(1.0))\n        σ_γ = pm.HalfNormal(\"σ_γ\", np.float32(0.5), dims=\"covariate\")\n\n        z_α = pm.Normal(\"z_α\", np.float32(0.0), np.float32(1.0), dims=\"group\")\n        z_γ = pm.Normal(\"z_γ\", np.float32(0.0), np.float32(1.0), dims=(\"group\", \"covariate\"))\n        α_group = σ_α * z_α \n        γ_group = σ_γ * z_γ\n\n        η = α + α_group[group_lookup] + pm.math.sum(X_data * (β + γ_group[group_lookup]), axis=1)\n        pm.Bernoulli(\"Y\", logit_p=η, observed=y, dims=\"obs\")\n\n    return model\n```\n\nTo illustrate why this model could be challenging to fit using only a CPU, we profile the gradient evaluation time. The `dlogp`\n\nfunction represents \\(\\frac{\\partial}{\\partial \\theta}[\\log \\tilde{p}]\\) with \\(\\tilde{p}\\) denoting the unnormalized posterior density; it may be calculated up to 1000 times per sample with the evaluation count increasing with the difficulty of the problem in terms of posterior geometry and/or nonlinearity. Hamiltonian Monte Carlo, the workhorse of most modern Bayesian programming frameworks, requires these gradient evaluations to draw Monte Carlo samples.\n\nThe code cell below runs the gradient a few times and records the time taken.\n\n``` python\nimport time\n\ndef median_runtime(fn, n=5):\n    fn()\n    times = []\n    for _ in range(n):\n        start = time.perf_counter()\n        fn()\n        times.append(time.perf_counter() - start)\n    return float(np.median(times)), times\n\nlocal_model = build_model(X, y, group_idx)\nlocal_point = local_model.initial_point()\n\nwith local_model:\n    dlogp_fn = local_model.compile_dlogp()\n\nlocal_grad_median, local_grad_times = median_runtime(lambda: dlogp_fn(local_point))\nprint(f\"Median grad eval time is {local_grad_median:.2f} seconds\")\nMedian grad eval time is 0.20 seconds\n```\n\nIf we try to extrapolate to the time required for running a full chain of 1000 samples, we find that assuming ~250 grad evals per sample, we would need 50 seconds per sample and 50,000 seconds (~14 hours) to run a full chain! This simply won’t do.\n\nLet’s deploy this remotely and see what we get!\n\n## Running remotely on Modal\n\nFor those of you unfamiliar with the magical world of serverless GPU, please take a look at Modal. I will probably never be wealthy enough to outright own a beautiful server rack of datacenter GPUs, so this is the best I can get. Basically, Modal provides APIs for spinning up jobs on GPUs with very little friction and little downtime.\n\nTo make this example run nicely on GPU, we’ll make a few adjustments. First, we’ll toggle it to use the Jax + NumPyro backend to work out-of-the-box with an NVIDIA GPU.\n\n``` python\nimport modal\n\nmodal_image = (\n    modal.Image.debian_slim(python_version=\"3.13\")\n    .uv_pip_install(\n        \"arviz==1.1.0\",\n        \"jax[cuda12]==0.7.2\",\n        \"numpy==2.3.5\",\n        \"numpyro==0.19.0\",\n        \"pandas==2.3.3\",\n        \"pymc==6.0.1\",\n    )\n)\n\napp = modal.App(\"i-should-be-working-right-now\", image=modal_image)\n```\n\nWe’ll set a few environment flags for the float precision and the devices, define a helper to benchmark the execution time, and apply some Jax-isms to prep the compute graph and evaluate it.\n\n``` python\ndef remote_model(X, y, group_idx):\n    import os\n\n    os.environ[\"JAX_PLATFORMS\"] = \"cuda\"\n\n    import jax\n    import jax.numpy as jnp\n    import numpy as np\n    import pytensor\n\n    pytensor.config.floatX = \"float32\"\n    X = np.asarray(X, dtype=np.float32)\n    y = np.asarray(y, dtype=np.int64)\n    group_idx = np.asarray(group_idx, dtype=np.int64)\n\n    gpu_devices = jax.devices(\"gpu\")\n    gpu_check = jax.device_put(jnp.ones((512, 512), dtype=jnp.float32), gpu_devices[0]).sum().block_until_ready()\n    model = build_model(X, y, group_idx)\n\n    info = {\n        \"N\": X.shape[0],\n        \"G\": int(group_idx.max()) + 1,\n        \"K\": X.shape[1],\n        \"jax_backend\": jax.default_backend(),\n        \"jax_gpu_devices\": [str(device) for device in gpu_devices],\n        \"gpu_check_sum\": float(gpu_check),\n        \"x_dtype\": str(X.dtype),\n        \"value_var_dtypes\": {var.name: var.dtype for var in model.value_vars},\n    }\n    print(f\"JAX backend: {info['jax_backend']}; GPU devices: {info['jax_gpu_devices']}\")\n    return model, info\n\ndef median_runtime(fn, n=5, synchronize=None):\n    import time\n\n    import numpy as np\n\n    if synchronize is None:\n        synchronize = lambda result: None\n\n    def timed_run():\n        start = time.perf_counter()\n        result = fn()\n        synchronize(result)\n        return time.perf_counter() - start\n\n    synchronize(fn())\n    times = [timed_run() for _ in range(n)]\n    return float(np.median(times)), times\n\ndef wait_jax(result):\n    import jax\n\n    jax.tree_util.tree_map(lambda x: x.block_until_ready(), result)\n\n@app.function(gpu=\"A100\", timeout=2 * 60 * 60)\ndef profile_dlogp(X, y, group_idx, n_evals=5):\n\n    import jax\n    import jax.numpy as jnp\n    from pymc.sampling.jax import get_jaxified_logp\n\n    model, info = remote_model(X, y, group_idx)\n    point = model.initial_point()\n\n    with model:\n        pytensor_dlogp = model.compile_dlogp()\n\n    pytensor_median, pytensor_times = median_runtime(lambda: pytensor_dlogp(point), n_evals)\n\n    values = [jnp.asarray(point[var.name]) for var in model.value_vars]\n    jax_loss = get_jaxified_logp(model)\n    jax_dlogp = jax.jit(jax.value_and_grad(jax_loss))\n\n    wait_jax(jax_dlogp(values))\n    jax_median, jax_times = median_runtime(lambda: jax_dlogp(values), n_evals, wait_jax)\n\n    profile = {\n        \"median_pytensor_dlogp_eval_seconds\": pytensor_median,\n        \"median_jax_dlogp_eval_seconds\": jax_median,\n    }\n    print(f\"Median dlogp eval: PyTensor={pytensor_median:.3f}s, JAX={jax_median:.3f}s\")\n    return {**info, **profile}\n\n@app.function(gpu=\"A100\", timeout=2 * 60 * 60)\ndef fit_hierarchical_logistic(X, y, group_idx, draws=100, tune=100, random_seed=RANDOM_SEED):\n    import time\n    import arviz as az\n    import pymc as pm\n\n    model, info = remote_model(X, y, group_idx)\n\n    with model:\n        start = time.perf_counter()\n        idata = pm.sample(\n            draws,\n            tune=tune,\n            chains=1,\n            nuts_sampler=\"numpyro\",\n            random_seed=random_seed,\n            var_names=[\"α\", \"β\", \"σ_α\", \"σ_γ\"],\n            progressbar=False,\n        )\n        elapsed_seconds = time.perf_counter() - start\n\n    β_mean = idata.posterior[\"β\"].mean(dim=(\"chain\", \"draw\")).to_numpy()\n\n    return {\n        **info,\n        \"elapsed_seconds\": elapsed_seconds,\n        \"summary\": az.summary(idata, var_names=[\"α\", \"β\", \"σ_α\", \"σ_γ\"]),\n        \"β_mean\": β_mean.tolist(),\n    }\n```\n\nWith all of this helper logic written, we can finally deploy to the cloud! We will start by running a short job to just profile the logp gradient function.\n\n```\nprint(\"Launching Modal GPU dlogp profile\")\nwith app.run():\n    profile_result = profile_dlogp.remote(X, y, group_idx)\nprint(f\"Finished running; profile results: {profile_result}\")\nLaunching Modal GPU dlogp profile\nFinished running; profile results: {'N': 1000000, 'G': 100000, 'K': 20, 'jax_backend': 'gpu', 'jax_gpu_devices': ['cuda:0'], 'gpu_check_sum': 262144.0, 'x_dtype': 'float32', 'value_var_dtypes': {'α': 'float32', 'β': 'float32', 'σ_α_log__': 'float32', 'σ_γ_log__': 'float32', 'z_α': 'float32', 'z_γ': 'float32'}, 'median_pytensor_dlogp_eval_seconds': 0.2271286229999987, 'median_jax_dlogp_eval_seconds': 0.0013876599999989025}\n```\n\nInteresting - the gradient takes 0.001 seconds on GPU and 0.22 seconds on CPU. That is around a 200x speedup!\n\nNext, we run the Markov chain to completion and retrieve the results.\n\n```\nprint(\"Launching Modal GPU run\")\nwith app.run():\n    result = fit_hierarchical_logistic.remote(X, y, group_idx)\n\nprint(f\"MCMC run finished in {result[\"elapsed_seconds\"]:.2f} seconds\")\nMCMC run finished in 272.16 seconds\n```\n\nModal helpfully lists [their prices per second](https://modal.com/pricing); an A100 runs for $0.000583 / second, meaning this run cost me around 15 cents from start to finish.\n\n## Parameter Recovery\n\nWe can see all the sampler diagnostics in the posterior summary. We’d need to run more chains to get the \\(\\hat{R}\\) value for this model.\n\n``` python\nimport pandas as pd\n\nposterior_summary = pd.DataFrame(result[\"summary\"])\nposterior_summary.iloc[0:5]\nparameter   mean     sd  eti89_lb  eti89_ub  ess_bulk  ess_tail  r_hat  \\\n0         α -0.980  0.015    -0.997    -0.952     1.510     5.445    NaN   \n1      β[0] -0.275  0.005    -0.282    -0.266     2.042     5.659    NaN   \n2      β[1]  0.630  0.010     0.612     0.641     1.627     5.374    NaN   \n3      β[2] -0.358  0.007    -0.366    -0.345     2.390     5.593    NaN   \n4      β[3]  0.359  0.006     0.348     0.366     1.834     5.607    NaN   \n\n   mcse_mean  mcse_sd  \n0      0.012    0.008  \n1      0.004    0.003  \n2      0.008    0.005  \n3      0.005    0.003  \n4      0.005    0.003\n```\n\nA quick diagnostic is to compare the posterior mean of the fixed effects with the values used to simulate the data. With \\(10^6\\) observations, this model has more than enough data to estimate the fixed-effects.\n\n``` python\nimport matplotlib.pyplot as plt\n\nblog_colors = {\n    \"background\": \"#1c1c1d\",\n    \"text\": \"#e8e8e8\",\n    \"accent\": \"#2698ba\",\n    \"point\": \"#ffffff\",\n}\nβ_posterior_mean = np.asarray(result[\"β_mean\"])\nlimits = [\n    min(β_true.min(), β_posterior_mean.min()) - 0.05,\n    max(β_true.max(), β_posterior_mean.max()) + 0.05,\n]\n\nfig, ax = plt.subplots(figsize=(4.8, 3.6))\nfig.patch.set_facecolor(blog_colors[\"background\"])\nax.set_facecolor(blog_colors[\"background\"])\nax.scatter(β_true, β_posterior_mean, s=36, color=blog_colors[\"point\"], alpha=0.82)\nax.plot(limits, limits, color=blog_colors[\"accent\"], linewidth=1.5)\nax.set_xlim(limits); ax.set_ylim(limits); ax.grid(False)\nax.set_xlabel(\"True fixed effect\", color=blog_colors[\"text\"]); ax.set_ylabel(\"Posterior mean fixed effect\", color=blog_colors[\"text\"])\nax.tick_params(colors=blog_colors[\"text\"])\nfor spine in ax.spines.values():\n    spine.set_color(blog_colors[\"text\"])\n```\n\nNice! The true value and posterior mean estimates line up perfectly. A job well done for MCMC.\n\nI think that Bayes on GPU is tremendously undervalued. If this interests you or you have ideas to chat about, drop me a line.\n\n## Enjoy Reading This Article?\n\nHere are some more articles you might like to read next:\n\n[To my junior collaborators, this is how I want you to write your research code](/blog/2025/preparing-a-dataset/)\n\n[Surrogate modeling for SEIR dynamics](/blog/2021/creating-an-emulator-for-an-agent-based-model/)\n\n[Modeling data with correlated errors across a directed graph](/blog/2025/modeling-data-with-correlated-errors-across-a-directed-graph/)\n\n[Rolling your own serverless OCR in 40 lines of code](/blog/2026/ocr-textbooks-modal-deepseek/)\n\n[Don't know where your data is from? Bayesian modeling for unknown coordinates](/blog/2026/dont-know-where-your-data-is-from/)", "url": "https://wpnews.pro/news/poverty-bayes-fitting-million-parameter-models-for-pennies-with-serverless-mcmc", "canonical_source": "https://christopherkrapu.com/blog/2026/poverty-bayes-serverless-mcmc/", "published_at": "2026-05-27 05:16:41+00:00", "updated_at": "2026-05-27 05:57:42.940118+00:00", "lang": "en", "topics": ["machine-learning", "ai-infrastructure", "neural-networks"], "entities": ["NVIDIA", "Modal", "GeForce Titan XP"], "alternates": {"html": "https://wpnews.pro/news/poverty-bayes-fitting-million-parameter-models-for-pennies-with-serverless-mcmc", "markdown": "https://wpnews.pro/news/poverty-bayes-fitting-million-parameter-models-for-pennies-with-serverless-mcmc.md", "text": "https://wpnews.pro/news/poverty-bayes-fitting-million-parameter-models-for-pennies-with-serverless-mcmc.txt", "jsonld": "https://wpnews.pro/news/poverty-bayes-fitting-million-parameter-models-for-pennies-with-serverless-mcmc.jsonld"}}