{"slug": "speeding-up-jumprelu-sae-inference-with-custom-triton-kernels-2-14x-on-real-saes", "title": "Speeding Up JumpReLU SAE Inference with Custom Triton Kernels (2–14× on Real SAEs)", "summary": "Researchers developed custom Triton kernels that accelerate JumpReLU Sparse Autoencoder inference by 2–14× on real SAEs, exploiting activation sparsity to skip zero entries during matrix multiplication. The kernels use either exact or fixed memory allocation per token, with a configurable max_l0 parameter to balance speed and safety.", "body_md": "Sparse Autoencoders (SAEs) have become a central tool in mechanistic interpretability research, providing a way to decompose a model's internal activations into sparse, interpretable features. However, extracting these features often requires running the SAE over large volumes of activations across many layers and tokens. This makes SAE inference efficiency a practical bottleneck for interpretability research at scale.\n\nThis post focuses on improving the inference efficiency of JumpReLU Sparse Autoencoders, which were introduced by DeepMind in [ Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders ](https://arxiv.org/abs/2407.14435)(Rajamanoharan et al). Instead of using a traditional ReLU activation function, these SAEs use JumpReLU, which zeros out activations that fall below a learned per-feature threshold . This gives JumpReLU SAEs a variable number of active features per token (commonly written as , the count of nonzero activations), unlike TopK SAEs which fire exactly features per token.\n\n*I use the terms \"fire\" and \"fired\" to describe features with non-zero activations.*\n\nTraditional JumpReLU SAE implementations compute the decoder step as a dense matrix multiplication (`feature_acts @ W_dec`\n\n), but this is wasteful because of the sparsity of `feature_acts`\n\n. Instead, you can exploit this sparsity property and skip the zero entries entirely during matrix multiplication with a custom Triton kernel.\n\nWhen a single token passes through a JumpReLU SAE with 65,536 features, the encoder produces a feature activation vector of length 65,536, but only some entries are nonzero.\n\nTo be more concrete, consider a toy SAE with feature activations . Now suppose that we only have 2 active features, where represents the weight matrix of the decoder layer:\n\nWe then compute the output with:\n\nNotice how only two of the rows of the decoder matrix were actually used in the computation. The rest were multiplied by 0 and contributed nothing to the output. We could instead just compute:\n\nNow imagine this same example but increase the hidden dimension from 8 to something much greater. For instance with 72 active features. That would mean you're multiplying *~99.89%* of rows by zero.\n\nIf we knew in advance which features are nonzero and their corresponding values, we could skip these zero multiplications and simplify the computation.\n\nFor a single token, this can be divided into two parts:\n\nWhen implementing this kernel, my first thought was to begin with a preliminary step that figures out exactly how many features fired for each token so the system could then allocate exactly the memory needed to store the CSR representation (more on that later). However, this process involves a GPU->CPU sync, which causes some slowdown.\n\nAs an alternative, you can instead allocate some predetermined/fixed amount of memory for each token using a configurable `max_l0`\n\nparameter. This speeds up computation but overallocates memory and introduces an important caveat that `max_l0`\n\nmust be large enough to avoid errors. For example if you set `max_l0=10`\n\n, but one of the tokens in the batch has >10 nonzero features, those extra features will be dropped, resulting in information loss.\n\nBoth approaches are covered below. For convenience, let's refer to the kernel that allocates exactly the memory needed for the CSR representation as the **Exact Allocation **kernel and the kernel that allocates a predetermined amount of memory per token as the **Fixed Allocation** kernel. The Fixed Allocation kernel can also be configured with either `validate=True`\n\nor `validate=False`\n\n. The `validate=True`\n\nversion is slightly slower than `validate=False`\n\n, but it raises an error if any token fires more features than `max_l0`\n\n. This is clearly safer, but if you are 100% sure that no token will exceed `max_l0`\n\n, then you can use `validate=False`\n\nfor some speedup.\n\nTo skip zero entries during matrix multiplication, we need to first represent the feature activations in **Compressed Sparse Row (CSR) format**, which is a standard way of representing sparse matrices that stores only the nonzero values and their indices. For the example above, instead of storing all 8 entries of , CSR stores just:\n\nTo allocate enough memory for building a CSR representation, we need to know how much memory each token requires (how many features fired per token). A `count_nonzero`\n\nkernel handles this:\n\n``` python\nimport tritonimport triton.language as tl@triton.jitdef count_nonzero(feature_acts_ptr, counts_ptr, n_features, BLOCK_F: tl.constexpr):    pid_token = tl.program_id(0) # Which token am I working on? (row index)    pid_d = tl.program_id(1) # Which chunk of features am I working on? (column index)    # Compute the feature indices this block is responsible for    feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)     mask = feat_offsets < n_features # Guard against reading past the end of the feature dimension    # Navigate to this token's features in memory, then to this block's chunk    feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets     vals = tl.load(feat_ptrs, mask=mask, other=0.0) # Load the feature values     fired = vals != 0.0 # Which features in this chunk are active (nonzero)?     fired_count = tl.sum(fired.to(tl.int32)) # How many active (nonzero) features in this chunk?    # Accumulate into this token's count (atomic since multiple blocks write to the same token)    tl.atomic_add(counts_ptr + pid_token, fired_count)\n```\n\nIf you're unfamiliar with Triton, the key mental model is that rather than writing a loop that runs sequentially, you write a kernel that describes what *one block* does and Triton launches many of these blocks in parallel across the GPU. In this kernel, each block is responsible for a chunk of one token's features. The two `program_id`\n\ncalls tell each block where it is: `pid_token`\n\nidentifies which token (which row of the input matrix), and `pid_d`\n\nidentifies which chunk of that token's features to process.\n\nAlso note that pointers in GPU kernels point to the *star*t of a flat block of memory. To reach a particular token's features, we offset into that memory by `pid_token * n_features`\n\n. Within that token, we offset further by `pid_d * BLOCK_F`\n\nto reach the right chunk. The `mask`\n\nguards against reading past the end when `n_features`\n\nisn't a clean multiple of `BLOCK_F`\n\n.\n\nFinally, since multiple blocks may be counting features for the same token simultaneously, `tl.atomic_add`\n\nensures their partial counts are combined safely.\n\nThis `count_nonzero`\n\nkernel produces an array `counts`\n\nof length where is the number of tokens in the batch. The number of active (nonzero) features for the token is stored in `counts[i]`\n\n.\n\nWe can then use this information to allocate two flat arrays, `flat_idx`\n\nand `flat_val`\n\n, which hold the active feature indices and their values across the entire batch. For example, this might look like:\n\nYou may have noticed that it's not clear which entries belong to which token. For example, `flat_idx[2]`\n\ntells us that the feature at index fired, but it doesn't tell us if this was for the first token in the batch or the second token or the third, etc.\n\nWe can solve this problem by introducing a new array `row_offsets`\n\nof length , where `row_offsets[b]`\n\nstores the starting index in `flat_idx`\n\n/`flat_val`\n\nwhere token 's entries begin. It's computed by taking a cumulative sum of `counts`\n\n, so each token's region starts exactly where the previous one ends. For example, if three tokens have 2, 5, and 3 active features:\n\nNow token 0's entries live at indices 0–1, token 1's at 2–6, token 2's at 7–9, and the final entry (10) tells us the total number of nonzero features across all tokens in the batch.\n\nWe can construct `row_offsets`\n\ninside a wrapper function `build_csr`\n\nthat also handles memory and orchestration. It calls `compute_csr_kernel`\n\n, which is the kernel responsible for actually filling `flat_idx`\n\nand `flat_val`\n\nwith the correct values. Note that `flat_idx`\n\nand `flat_val`\n\nare initialized as empty arrays as pre-allocated storage that `compute_csr_kernel`\n\nwill write into.\n\n``` python\ndef build_csr(feature_acts: torch.Tensor, BLOCK_F: int = 1024):    B, n_features = feature_acts.shape    device = feature_acts.device    # Count how many features fired per token    counts = torch.zeros(B, dtype=torch.int32, device=device)    grid = (B, triton.cdiv(n_features, BLOCK_F))    count_nonzero[grid](feature_acts, counts, n_features, BLOCK_F=BLOCK_F)    # Cumsum over counts gives each token a contiguous region in the flat arrays    # row_offsets[b] = start index of token b's entries in flat_idx/flat_val    row_offsets = torch.zeros(B + 1, dtype=torch.int32, device=device)    row_offsets[1:] = counts.cumsum(0).to(torch.int32)    # The last entry is the total number of nonzeros. This is used to size the flat arrays    total_nnz = int(row_offsets[-1].item())  # GPU->CPU sync point    flat_idx = torch.empty(total_nnz, dtype=torch.int32, device=device)    flat_val = torch.empty(total_nnz, dtype=feature_acts.dtype, device=device)    # write_pos is a per-token cursor that coordinates concurrent writes within    # a token's region. Each block atomically claims the next available slots by    # bumping write_pos by its count, getting back its starting offset (base).    write_pos = torch.zeros(B, dtype=torch.int32, device=device)    compute_csr_kernel[grid](        feature_acts,        row_offsets,        write_pos,        flat_idx,        flat_val,        n_features,        BLOCK_F=BLOCK_F,    )    return flat_idx, flat_val, row_offsets, B\n@triton.jitdef compute_csr_kernel(    feature_acts_ptr,    row_offsets_ptr,    write_pos_ptr,    flat_idx_ptr,    flat_val_ptr,    n_features,    BLOCK_F: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)    # Same pointer arithmetic as count_nonzero, navigate to this block's chunk    feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)    mask = feat_offsets < n_features    feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets    vals = tl.load(feat_ptrs, mask=mask, other=0.0)    fired = vals != 0.0    fired_int = fired.to(tl.int32)    # Where does this token's region start in flat_idx/flat_val?    region_start = tl.load(row_offsets_ptr + pid_token)    # Atomically claim the next block_count slots within this token's region    block_count = tl.sum(fired_int)    base = tl.atomic_add(write_pos_ptr + pid_token, block_count)    # Assign each active feature a unique slot within the claimed range    local_rank = tl.cumsum(fired_int) - fired_int    slots = region_start + base + local_rank    # Write the feature index and value into the claimed slots    tl.store(flat_idx_ptr + slots, feat_offsets.to(tl.int32), mask=fired & mask)    tl.store(flat_val_ptr + slots, vals, mask=fired & mask)\n```\n\nNext, `sparse_decode_kernel`\n\nuses this CSR structure to carry out the matrix multiplication step. For each token, it looks up where that token's active features live in `flat_idx`\n\n/`flat_val`\n\nusing `row_offsets`\n\n, then loops over them, accumulating the weighted sum of the corresponding decoder rows into a tile of the output.\n\n```\n@triton.jitdef sparse_decode_kernel(    flat_idx_ptr, flat_val_ptr, row_offsets_ptr,    W_dec_ptr, out_ptr, d_model,    BLOCK_D: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)     # Find the slice of flat_idx/flat_val belonging to this token    start = tl.load(row_offsets_ptr + pid_token)    end = tl.load(row_offsets_ptr + pid_token + 1)    n = end - start  # Number of active features for this token     # This block owns a BLOCK_D-wide slice of the output row    offsets = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)    mask = offsets < d_model    acc = tl.zeros([BLOCK_D], dtype=tl.float32)     # Loop over this token's active features, accumulating their contribution    for i in range(n):        j = start + i        feat_idx = tl.load(flat_idx_ptr + j)   # Which decoder row?        feat_val = tl.load(flat_val_ptr + j)   # Scale factor         # Load the corresponding decoder row (just this block's slice)        row_ptrs = W_dec_ptr + (feat_idx * d_model) + offsets        row = tl.load(row_ptrs, mask=mask, other=0.0)        acc += feat_val.to(tl.float32) * row.to(tl.float32)     # Write this block's output slice    tl.store(out_ptr + pid_token * d_model + offsets, acc, mask=mask)\n```\n\nFinally, we put all of these kernels together by wrapping them in a single `sparse_decode()`\n\nfunction that acts as a drop-in replacement for `@`\n\n:\n\n``` python\ndef _sparse_decode(flat_idx, flat_val, row_offsets, W_dec, B, BLOCK_D=256):    d_model = W_dec.shape[1]    out = torch.zeros((B, d_model), device=W_dec.device, dtype=torch.float32)     # parallelize over batch rows and d_model tiles    grid = (B, triton.cdiv(d_model, BLOCK_D))      sparse_decode_kernel[grid](        flat_idx, flat_val, row_offsets, W_dec, out, d_model, BLOCK_D=BLOCK_D    )     return out  def sparse_decode(feature_acts, W_dec):    # Triton requires contiguous memory for correct stride arithmetic    W_dec = W_dec.contiguous()      flat_idx, flat_val, row_offsets, B = build_csr(feature_acts)    return _sparse_decode(flat_idx, flat_val, row_offsets, W_dec, B)\n```\n\nRecall how in the Exact Allocation Kernel, inside `build_csr`\n\nwe extracted the total number of nonzero entries across all tokens by retrieving the last entry of `row_offsets`\n\n:\n\n```\ntotal_nnz = int(row_offsets[-1].item())\n```\n\nWhen we call `.item()`\n\n, we are forcing the CPU to wait for the GPU to finish the counting pass before it can read `total_nnz`\n\nand allocate `flat_idx`\n\n/`flat_val`\n\n.\n\nNormally the CPU queues up GPU work asynchronously and moves on without waiting, but `.item()`\n\nbreaks that pipeline by requiring the CPU to stall until the GPU result is ready.* *\n\nThis turns out to be a significant source of slowdown.\n\nThe Fixed Allocation kernel works around this by not even allocating *exactly *the memory needed in the first place (meaning we don't even need `total_nnz`\n\n). Instead, we allocate `max_l0`\n\nslots per token, where `max_l0`\n\nis a user-specified upper bound on how many features can fire for any single token. This also means we no longer need to count the number of nonzero tokens before computing the CSR structure.\n\nWith these changes, the new `build_csr`\n\nwrapper function looks like:\n\n``` python\ndef build_csr(feature_acts: torch.Tensor, BLOCK_F: int = 1024, max_l0: int = 512, validate: bool = True):    B, n_features = feature_acts.shape    device = feature_acts.device     # Fixed memory allocation     capacity = B * max_l0    flat_idx = torch.empty(capacity, dtype=torch.int32, device=device)    flat_val = torch.empty(capacity, dtype=feature_acts.dtype, device=device)     # write_pos serves as both the per-token write cursor during the kernel    # and the per-token count afterward    write_pos = torch.zeros(B, dtype=torch.int32, device=device)     grid = (B, triton.cdiv(n_features, BLOCK_F))    compute_csr_kernel[grid](        feature_acts, write_pos, flat_idx, flat_val,        n_features, max_l0, BLOCK_F=BLOCK_F,    )     counts = write_pos  # final cursor value = number of features written per token     # Optional safety check. This reintroduces a GPU→CPU sync but catches silent truncation    if validate and counts.max().item() > max_l0:        raise ValueError(            f\"A token fired more than max_l0={max_l0} features \"            f\"(max was {counts.max().item()}). Increase max_l0.\"        )     return flat_idx, flat_val, counts, B, max_l0\n```\n\nAs mentioned briefly earlier, if a token fires more features than `max_l0`\n\n, those extra features are silently dropped by the overflow guard in the kernel. This can be dangerous because the result is wrong but there's no crash. The `validate=True`\n\ndefault catches this by checking `counts.max()`\n\nafter the kernel, at the cost of reintroducing a GPU→CPU sync. (*However this is still faster than Exact Allocation in practice.*) If you're very confident that your `max_l0`\n\nis a safe upper bound for your SAE then you can pass `validate=False`\n\nto skip the check, but this is not recommended.\n\nThe kernel to compute CSR changes minimally. We no longer need `row_offsets`\n\nsince we know that each token takes up `max_l0`\n\nentries in memory, so the lookup for the start of a token's region is replaced by `region_start = pid_token * max_l0`\n\n.\n\n```\n@triton.jitdef compute_csr_kernel(    feature_acts_ptr,    write_pos_ptr,    flat_idx_ptr,    flat_val_ptr,    n_features,    max_l0,    BLOCK_F: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)     # Navigate to this block's chunk    feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)    mask = feat_offsets < n_features    feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets    vals = tl.load(feat_ptrs, mask=mask, other=0.0)     fired = vals != 0.0    fired_int = fired.to(tl.int32)     # Each token owns a fixed region of max_l0 slots    region_start = pid_token * max_l0     # Atomically claim the next available slots within this token's region    block_count = tl.sum(fired_int)    base = tl.atomic_add(write_pos_ptr + pid_token, block_count)     # Assign each active feature a unique slot within the claimed range    local_rank = tl.cumsum(fired_int) - fired_int    local_slot = base + local_rank     # Guard against writing past this token's region if L0 exceeds max_l0    in_region = local_slot < max_l0    write_mask = fired & mask & in_region     slots = region_start + local_slot    tl.store(flat_idx_ptr + slots, feat_offsets.to(tl.int32), mask=write_mask)    tl.store(flat_val_ptr + slots, vals, mask=write_mask)\n```\n\nThe decoder kernel then changes in the same way. `row_offsets`\n\nis no longer needed, and `counts`\n\nreplaces the start/end bracket:\n\n```\n@triton.jitdef sparse_decode_kernel(    flat_idx_ptr,    flat_val_ptr,    counts_ptr,    W_dec_ptr,    out_ptr,    d_model,    max_l0,    BLOCK_D: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)     start = pid_token * max_l0    n = tl.load(counts_ptr + pid_token)  # Actual number of active features for this token     offsets = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)    mask = offsets < d_model    acc = tl.zeros([BLOCK_D], dtype=tl.float32)     # Same loop as before     for i in range(n):        j = start + i        feat_idx = tl.load(flat_idx_ptr + j)        feat_val = tl.load(flat_val_ptr + j)        row_ptrs = W_dec_ptr + feat_idx * d_model + offsets        row = tl.load(row_ptrs, mask=mask, other=0.0)        acc += feat_val.to(tl.float32) * row.to(tl.float32)     tl.store(out_ptr + pid_token * d_model + offsets, acc, mask=mask)\n```\n\nWriting custom GPU kernels is great, but it's important to make sure that they're actually making the computation faster. I used `triton.testing.do_bench`\n\n(`warmup=25, rep=100`\n\n, reporting the median) to time these kernels and compared them against dense matrix multiplication (`feature_acts @ W_dec`\n\n). All tests were run on a NVIDIA GeForce RTX 4090 GPU.\n\nAs a quick summary, the table below shows the relative speedups for an example input configuration (`B = 32`\n\n, `n_features = 65536`\n\n, `d_model = 768`\n\n, `L0 = 64`\n\n):\n\n|\n|\n|\nDense cuBLAS | 0.288 | 1.0× |\n| 0.288 | 1.0× |\n| 0.210 | 1.4× |\nCustom — exact allocation | 0.151 | 1.9× |\nCustom — fixed allocation ( | 0.041 | 7.0× |\nCustom — fixed allocation ( | 0.115 | 2.5× |\n\nFirst, I verified that the custom kernels actually perform matrix multiplication correctly (a custom kernel that is faster but gives the wrong answer doesn't help anyone). In other words, we verify that `sparse_decode(feature_acts, W_dec) == feature_acts @ W_dec`\n\nacross 486 different inputs using combinations of the parameters below. Note that `sparse_decode()`\n\nhere is just a wrapper matmul function that uses our custom Triton kernels under the hood.\n\n|\n|\n|\n|\n| kernel implementation |\n| 2 |\n| input dtype of |\n| 3 |\n| batch size (tokens) | 1, 4, 32 | 3 |\n| SAE dictionary width | 256, 1024, 16384 | 3 |\n| output width | 128, 512, 768 | 3 |\n| features fired per token | 1, 8, 100 | 3 |\n\n**Total: 2 × 3 × 3 × 3 × 3 × 3 = 486 configurations. **Each asserts output is fp32 and matches the dense fp32 reference within `atol=1e-4, rtol=1e-3`\n\n.\n\nThe preprocessing step of computing a CSR representation adds some computational overhead. It would be interesting to see a direct comparison between `sparse_decode_kernel`\n\nand dense matrix multiplication if you didn't have to pay for that overhead (assume that you somehow already have access to a CSR representation).\n\nIf you hold some parameters of the input constant (`B=32`\n\n, `n_features=65536`\n\n, `d_model=768`\n\n) while varying `L0`\n\n(the number of fired features) as shown in the table below, then how much faster is `sparse_decode_kernel`\n\n?\n\nNote that this is **EXCLUDING **the overhead of the CSR preprocessing step (i.e., `compute_csr_kernel`\n\n). Also note that `sparse_decode_kernel`\n\nis essentially the same between Exact Allocation and Fixed Allocation so there is no need to differentiate, but for completeness the graph below plots both (they overlap).\n\n|\n| |\n16 | 0.02% | 25.5× |\n32 | 0.05% | 18.7× |\n64 | 0.10% | 12.8× |\n128 | 0.20% | 8.0× |\n256 | 0.39% | 5.0× |\n512 | 0.78% | 3.0× |\n1024 | 1.56% | 1.7× |\n4096 | 6.25% | 0.6× |\n\nWe can also vary `n_features`\n\nwhile keeping constant `B=32`\n\n, `L0=64`\n\n, `d_model=768`\n\n:\n\n|\n|\n4,096 | 1.5× |\n16,384 | 4.1× |\n32,768 | 7.3× |\n65,536 | 12.8× |\n131,072 | 22.5× |\n\nSo clearly `sparse_decode_kernel`\n\nalone is faster than dense matrix multiplication at high sparsity. But of course in practice we probably need to compute CSR as well, which will slow things down somewhat.\n\nThe table below shows the relative speedups (relative to dense matmul) for three different input configurations. Here \"Kernel only\" refers to only `sparse_decode_kernel`\n\n(CSR is precomputed), while \"Full\" refers to the whole pipeline (i.e., `build_csr`\n\n).\n\n|\n|\n|\n|\n|\nB=32, F=65536, D=768, L0=64 | 12.8× | 1.9× | 7.0× | 2.5× |\nB=256, F=65536, D=768, L0=64 | 7.7× | 1.7× | 3.1× | 2.2× |\nB=32, F=131072, D=512, L0=128 | 22.5× | 2.2× | 6.1× | 2.3× |\n\nThe graph below shows the speed of the full pipeline (Exact Allocation) and decode-only as you vary sparsity. Here, `L0`\n\nsweeps over [16, 32, 64, 128, 256, 512, 1024, 4096, 16384] while holding `B=32`\n\n, `n_features=65536`\n\n, and `d_model=768`\n\nconstant.\n\nTo be comprehensive, we can also compare our custom kernels to `torch.sparse.mm`\n\n(using PyTorch's `to_sparse_csr()`\n\n), which uses cuSPARSE internally, and `torch.compile`\n\n. This focuses on the same three input configurations as above.\n\n**Note: I found it a little suspicious that this custom kernel would \"beat\" **`torch.sparse.mm`\n\n**. It turns out this is mostly because of beating **`to_sparse_csr()`\n\n** when building the CSR. There doesn't seem to be much of a difference in speed between the custom kernel and cuSPARSE on the matrix multiplication step alone.**\n\nAs expected, `torch.compile`\n\ndoesn't provide a noticeable speedup, but I wanted to include it anyway for completeness.\n\nUp until now we have been focusing entirely on the speed of the matrix multiplication operation, but at the end of the day we care about SAE inference speed as a whole. This is benchmarked by replacing only the decoder matmul step in a `SAELens`\n\nJumpReLU SAE forward pass. The table below focuses on five SAEs across two model families and three dictionary sizes.\n\n|\n|\n|\n|\n|\n|\n|\n|\nGemma Scope 2B, L20, 65k | 65,536 | 2,304 | 72 | 3.8e-6 | 4.27× | 5.57× | 11.41× |\nGemma Scope 9B, L20, 65k | 65,536 | 3,584 | 72 | 3.8e-6 | 5.66× | 7.34× | 13.27× |\nGemma Scope 2B, L12, 65k | 65,536 | 2,304 | 72 | 9.5e-7 | 3.91× | 5.48× | 11.33× |\nGemma Scope 2B, L12, 262k | 262,144 | 2,304 | 100 | 1.9e-6 | 12.08× | 14.49× | 22.59× |\nQwen Scope 3.5 2B, L12 | 32,768 | 2,048 | 100 | 4.8e-7 | 1.98× | 2.54× | 5.74× |\n\nThe purpose of the Fixed Allocation kernel was to overallocate memory in exchange for speed, so it would be helpful to see exactly how much more memory it uses compared to the Exact Allocation kernel. Surprisingly, it turns out that in practice this overhead is small:\n\n|\n|\n|\n|\n|\n|\n32 | 512 | 218.3 | 218.4 | 218.5 | +0.1 MB |\n256 | 512 | 277.7 | 277.9 | 278.8 | +0.9 MB |\n1024 | 512 | 482.3 | 482.9 | 485.6 | +2.7 MB |\n1024 | 1024 | 482.3 | 482.9 | 490.7 | +7.8 MB |\n\nWhile these results are encouraging, there are a few important limitations to be aware of and gaps that I plan to address as I continue working on this project.\n\nFirst, the above benchmark numbers are not absolute, as these tests were run in a specific environment (WSL2 with GPU clocks not pinned). The primary goal of these benchmarks was to gauge the relative performance of the custom kernels compared to baseline implementations. The actual absolute speed likely differs depending on the hardware and benchmarking setup.\n\nA second limitation, which was discussed earlier but is worth reiterating, is that although the Fixed Allocation kernel with `validate=False`\n\nachieves the highest performance, it can silently produce incorrect results if the `max_l0`\n\nparameter is set too low. For this reason using either the Exact Allocation kernel or Fixed Allocation with `validate=True`\n\nis likely better for most cases.\n\nThirdly, these kernels were designed specifically for sparse matrix multiplication, meaning that beyond a certain sparsity threshold, dense matrix multiplication is actually faster.\n\nFourth, this implementation focuses exclusively on the decoder inference step of JumpReLU Sparse Autoencoders, but there are likely other sources of inefficiency that could be addressed. For example, future projects could focus on the encoder pass or support for training through custom backward kernels. Additionally the current implementation only supports `float32`\n\noutputs.\n\nFinally, all experiments were run on an RTX 4090, and performance may differ on other GPU architectures such as the A100 or H100.\n\nIn conclusion, this project implements custom Triton kernels for the decoder inference step of JumpReLU SAEs by exploiting the inherent sparsity of the hidden representation. On a sample of real SAEs, this achieves 2.5–14× speedup with the Fixed Allocation (validate=True) kernel, with larger gains at higher dictionary sizes.\n\nThe full implementation is available on [GitHub](https://github.com/dtiourine/jumprelu-sae-kernels/tree/main).\n\n*I welcome feedback! If you have thoughts, questions, or find any issues, feel free to leave a comment or reach out directly. This is also my first GPU kernel project, so if you're experienced with Triton or GPU kernel optimization and see things I could have done better, I would appreciate any suggestions.*", "url": "https://wpnews.pro/news/speeding-up-jumprelu-sae-inference-with-custom-triton-kernels-2-14x-on-real-saes", "canonical_source": "https://www.lesswrong.com/posts/8gZspSs4WFtpfki9i/speeding-up-jumprelu-sae-inference-with-custom-triton", "published_at": "2026-06-14 04:00:04+00:00", "updated_at": "2026-06-14 04:30:22.276755+00:00", "lang": "en", "topics": ["machine-learning", "ai-research", "ai-infrastructure", "large-language-models", "neural-networks"], "entities": ["DeepMind", "JumpReLU", "Triton", "Sparse Autoencoders", "Rajamanoharan et al"], "alternates": {"html": "https://wpnews.pro/news/speeding-up-jumprelu-sae-inference-with-custom-triton-kernels-2-14x-on-real-saes", "markdown": "https://wpnews.pro/news/speeding-up-jumprelu-sae-inference-with-custom-triton-kernels-2-14x-on-real-saes.md", "text": "https://wpnews.pro/news/speeding-up-jumprelu-sae-inference-with-custom-triton-kernels-2-14x-on-real-saes.txt", "jsonld": "https://wpnews.pro/news/speeding-up-jumprelu-sae-inference-with-custom-triton-kernels-2-14x-on-real-saes.jsonld"}}