# Speeding Up JumpReLU SAE Inference with Custom Triton Kernels (2–14× on Real SAEs)

> Source: <https://www.lesswrong.com/posts/8gZspSs4WFtpfki9i/speeding-up-jumprelu-sae-inference-with-custom-triton>
> Published: 2026-06-14 04:00:04+00:00

Sparse Autoencoders (SAEs) have become a central tool in mechanistic interpretability research, providing a way to decompose a model's internal activations into sparse, interpretable features. However, extracting these features often requires running the SAE over large volumes of activations across many layers and tokens. This makes SAE inference efficiency a practical bottleneck for interpretability research at scale.

This post focuses on improving the inference efficiency of JumpReLU Sparse Autoencoders, which were introduced by DeepMind in [ Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders ](https://arxiv.org/abs/2407.14435)(Rajamanoharan et al). Instead of using a traditional ReLU activation function, these SAEs use JumpReLU, which zeros out activations that fall below a learned per-feature threshold . This gives JumpReLU SAEs a variable number of active features per token (commonly written as , the count of nonzero activations), unlike TopK SAEs which fire exactly features per token.

*I use the terms "fire" and "fired" to describe features with non-zero activations.*

Traditional JumpReLU SAE implementations compute the decoder step as a dense matrix multiplication (`feature_acts @ W_dec`

), but this is wasteful because of the sparsity of `feature_acts`

. Instead, you can exploit this sparsity property and skip the zero entries entirely during matrix multiplication with a custom Triton kernel.

When a single token passes through a JumpReLU SAE with 65,536 features, the encoder produces a feature activation vector of length 65,536, but only some entries are nonzero.

To be more concrete, consider a toy SAE with feature activations . Now suppose that we only have 2 active features, where represents the weight matrix of the decoder layer:

We then compute the output with:

Notice how only two of the rows of the decoder matrix were actually used in the computation. The rest were multiplied by 0 and contributed nothing to the output. We could instead just compute:

Now imagine this same example but increase the hidden dimension from 8 to something much greater. For instance with 72 active features. That would mean you're multiplying *~99.89%* of rows by zero.

If we knew in advance which features are nonzero and their corresponding values, we could skip these zero multiplications and simplify the computation.

For a single token, this can be divided into two parts:

When implementing this kernel, my first thought was to begin with a preliminary step that figures out exactly how many features fired for each token so the system could then allocate exactly the memory needed to store the CSR representation (more on that later). However, this process involves a GPU->CPU sync, which causes some slowdown.

As an alternative, you can instead allocate some predetermined/fixed amount of memory for each token using a configurable `max_l0`

parameter. This speeds up computation but overallocates memory and introduces an important caveat that `max_l0`

must be large enough to avoid errors. For example if you set `max_l0=10`

, but one of the tokens in the batch has >10 nonzero features, those extra features will be dropped, resulting in information loss.

Both approaches are covered below. For convenience, let's refer to the kernel that allocates exactly the memory needed for the CSR representation as the **Exact Allocation **kernel and the kernel that allocates a predetermined amount of memory per token as the **Fixed Allocation** kernel. The Fixed Allocation kernel can also be configured with either `validate=True`

or `validate=False`

. The `validate=True`

version is slightly slower than `validate=False`

, but it raises an error if any token fires more features than `max_l0`

. This is clearly safer, but if you are 100% sure that no token will exceed `max_l0`

, then you can use `validate=False`

for some speedup.

To skip zero entries during matrix multiplication, we need to first represent the feature activations in **Compressed Sparse Row (CSR) format**, which is a standard way of representing sparse matrices that stores only the nonzero values and their indices. For the example above, instead of storing all 8 entries of , CSR stores just:

To allocate enough memory for building a CSR representation, we need to know how much memory each token requires (how many features fired per token). A `count_nonzero`

kernel handles this:

``` python
import tritonimport triton.language as tl@triton.jitdef count_nonzero(feature_acts_ptr, counts_ptr, n_features, BLOCK_F: tl.constexpr):    pid_token = tl.program_id(0) # Which token am I working on? (row index)    pid_d = tl.program_id(1) # Which chunk of features am I working on? (column index)    # Compute the feature indices this block is responsible for    feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)     mask = feat_offsets < n_features # Guard against reading past the end of the feature dimension    # Navigate to this token's features in memory, then to this block's chunk    feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets     vals = tl.load(feat_ptrs, mask=mask, other=0.0) # Load the feature values     fired = vals != 0.0 # Which features in this chunk are active (nonzero)?     fired_count = tl.sum(fired.to(tl.int32)) # How many active (nonzero) features in this chunk?    # Accumulate into this token's count (atomic since multiple blocks write to the same token)    tl.atomic_add(counts_ptr + pid_token, fired_count)
```

If you're unfamiliar with Triton, the key mental model is that rather than writing a loop that runs sequentially, you write a kernel that describes what *one block* does and Triton launches many of these blocks in parallel across the GPU. In this kernel, each block is responsible for a chunk of one token's features. The two `program_id`

calls tell each block where it is: `pid_token`

identifies which token (which row of the input matrix), and `pid_d`

identifies which chunk of that token's features to process.

Also note that pointers in GPU kernels point to the *star*t of a flat block of memory. To reach a particular token's features, we offset into that memory by `pid_token * n_features`

. Within that token, we offset further by `pid_d * BLOCK_F`

to reach the right chunk. The `mask`

guards against reading past the end when `n_features`

isn't a clean multiple of `BLOCK_F`

.

Finally, since multiple blocks may be counting features for the same token simultaneously, `tl.atomic_add`

ensures their partial counts are combined safely.

This `count_nonzero`

kernel produces an array `counts`

of length where is the number of tokens in the batch. The number of active (nonzero) features for the token is stored in `counts[i]`

.

We can then use this information to allocate two flat arrays, `flat_idx`

and `flat_val`

, which hold the active feature indices and their values across the entire batch. For example, this might look like:

You may have noticed that it's not clear which entries belong to which token. For example, `flat_idx[2]`

tells us that the feature at index fired, but it doesn't tell us if this was for the first token in the batch or the second token or the third, etc.

We can solve this problem by introducing a new array `row_offsets`

of length , where `row_offsets[b]`

stores the starting index in `flat_idx`

/`flat_val`

where token 's entries begin. It's computed by taking a cumulative sum of `counts`

, so each token's region starts exactly where the previous one ends. For example, if three tokens have 2, 5, and 3 active features:

Now token 0's entries live at indices 0–1, token 1's at 2–6, token 2's at 7–9, and the final entry (10) tells us the total number of nonzero features across all tokens in the batch.

We can construct `row_offsets`

inside a wrapper function `build_csr`

that also handles memory and orchestration. It calls `compute_csr_kernel`

, which is the kernel responsible for actually filling `flat_idx`

and `flat_val`

with the correct values. Note that `flat_idx`

and `flat_val`

are initialized as empty arrays as pre-allocated storage that `compute_csr_kernel`

will write into.

``` python
def build_csr(feature_acts: torch.Tensor, BLOCK_F: int = 1024):    B, n_features = feature_acts.shape    device = feature_acts.device    # Count how many features fired per token    counts = torch.zeros(B, dtype=torch.int32, device=device)    grid = (B, triton.cdiv(n_features, BLOCK_F))    count_nonzero[grid](feature_acts, counts, n_features, BLOCK_F=BLOCK_F)    # Cumsum over counts gives each token a contiguous region in the flat arrays    # row_offsets[b] = start index of token b's entries in flat_idx/flat_val    row_offsets = torch.zeros(B + 1, dtype=torch.int32, device=device)    row_offsets[1:] = counts.cumsum(0).to(torch.int32)    # The last entry is the total number of nonzeros. This is used to size the flat arrays    total_nnz = int(row_offsets[-1].item())  # GPU->CPU sync point    flat_idx = torch.empty(total_nnz, dtype=torch.int32, device=device)    flat_val = torch.empty(total_nnz, dtype=feature_acts.dtype, device=device)    # write_pos is a per-token cursor that coordinates concurrent writes within    # a token's region. Each block atomically claims the next available slots by    # bumping write_pos by its count, getting back its starting offset (base).    write_pos = torch.zeros(B, dtype=torch.int32, device=device)    compute_csr_kernel[grid](        feature_acts,        row_offsets,        write_pos,        flat_idx,        flat_val,        n_features,        BLOCK_F=BLOCK_F,    )    return flat_idx, flat_val, row_offsets, B
@triton.jitdef compute_csr_kernel(    feature_acts_ptr,    row_offsets_ptr,    write_pos_ptr,    flat_idx_ptr,    flat_val_ptr,    n_features,    BLOCK_F: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)    # Same pointer arithmetic as count_nonzero, navigate to this block's chunk    feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)    mask = feat_offsets < n_features    feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets    vals = tl.load(feat_ptrs, mask=mask, other=0.0)    fired = vals != 0.0    fired_int = fired.to(tl.int32)    # Where does this token's region start in flat_idx/flat_val?    region_start = tl.load(row_offsets_ptr + pid_token)    # Atomically claim the next block_count slots within this token's region    block_count = tl.sum(fired_int)    base = tl.atomic_add(write_pos_ptr + pid_token, block_count)    # Assign each active feature a unique slot within the claimed range    local_rank = tl.cumsum(fired_int) - fired_int    slots = region_start + base + local_rank    # Write the feature index and value into the claimed slots    tl.store(flat_idx_ptr + slots, feat_offsets.to(tl.int32), mask=fired & mask)    tl.store(flat_val_ptr + slots, vals, mask=fired & mask)
```

Next, `sparse_decode_kernel`

uses this CSR structure to carry out the matrix multiplication step. For each token, it looks up where that token's active features live in `flat_idx`

/`flat_val`

using `row_offsets`

, then loops over them, accumulating the weighted sum of the corresponding decoder rows into a tile of the output.

```
@triton.jitdef sparse_decode_kernel(    flat_idx_ptr, flat_val_ptr, row_offsets_ptr,    W_dec_ptr, out_ptr, d_model,    BLOCK_D: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)     # Find the slice of flat_idx/flat_val belonging to this token    start = tl.load(row_offsets_ptr + pid_token)    end = tl.load(row_offsets_ptr + pid_token + 1)    n = end - start  # Number of active features for this token     # This block owns a BLOCK_D-wide slice of the output row    offsets = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)    mask = offsets < d_model    acc = tl.zeros([BLOCK_D], dtype=tl.float32)     # Loop over this token's active features, accumulating their contribution    for i in range(n):        j = start + i        feat_idx = tl.load(flat_idx_ptr + j)   # Which decoder row?        feat_val = tl.load(flat_val_ptr + j)   # Scale factor         # Load the corresponding decoder row (just this block's slice)        row_ptrs = W_dec_ptr + (feat_idx * d_model) + offsets        row = tl.load(row_ptrs, mask=mask, other=0.0)        acc += feat_val.to(tl.float32) * row.to(tl.float32)     # Write this block's output slice    tl.store(out_ptr + pid_token * d_model + offsets, acc, mask=mask)
```

Finally, we put all of these kernels together by wrapping them in a single `sparse_decode()`

function that acts as a drop-in replacement for `@`

:

``` python
def _sparse_decode(flat_idx, flat_val, row_offsets, W_dec, B, BLOCK_D=256):    d_model = W_dec.shape[1]    out = torch.zeros((B, d_model), device=W_dec.device, dtype=torch.float32)     # parallelize over batch rows and d_model tiles    grid = (B, triton.cdiv(d_model, BLOCK_D))      sparse_decode_kernel[grid](        flat_idx, flat_val, row_offsets, W_dec, out, d_model, BLOCK_D=BLOCK_D    )     return out  def sparse_decode(feature_acts, W_dec):    # Triton requires contiguous memory for correct stride arithmetic    W_dec = W_dec.contiguous()      flat_idx, flat_val, row_offsets, B = build_csr(feature_acts)    return _sparse_decode(flat_idx, flat_val, row_offsets, W_dec, B)
```

Recall how in the Exact Allocation Kernel, inside `build_csr`

we extracted the total number of nonzero entries across all tokens by retrieving the last entry of `row_offsets`

:

```
total_nnz = int(row_offsets[-1].item())
```

When we call `.item()`

, we are forcing the CPU to wait for the GPU to finish the counting pass before it can read `total_nnz`

and allocate `flat_idx`

/`flat_val`

.

Normally the CPU queues up GPU work asynchronously and moves on without waiting, but `.item()`

breaks that pipeline by requiring the CPU to stall until the GPU result is ready.* *

This turns out to be a significant source of slowdown.

The Fixed Allocation kernel works around this by not even allocating *exactly *the memory needed in the first place (meaning we don't even need `total_nnz`

). Instead, we allocate `max_l0`

slots per token, where `max_l0`

is a user-specified upper bound on how many features can fire for any single token. This also means we no longer need to count the number of nonzero tokens before computing the CSR structure.

With these changes, the new `build_csr`

wrapper function looks like:

``` python
def build_csr(feature_acts: torch.Tensor, BLOCK_F: int = 1024, max_l0: int = 512, validate: bool = True):    B, n_features = feature_acts.shape    device = feature_acts.device     # Fixed memory allocation     capacity = B * max_l0    flat_idx = torch.empty(capacity, dtype=torch.int32, device=device)    flat_val = torch.empty(capacity, dtype=feature_acts.dtype, device=device)     # write_pos serves as both the per-token write cursor during the kernel    # and the per-token count afterward    write_pos = torch.zeros(B, dtype=torch.int32, device=device)     grid = (B, triton.cdiv(n_features, BLOCK_F))    compute_csr_kernel[grid](        feature_acts, write_pos, flat_idx, flat_val,        n_features, max_l0, BLOCK_F=BLOCK_F,    )     counts = write_pos  # final cursor value = number of features written per token     # Optional safety check. This reintroduces a GPU→CPU sync but catches silent truncation    if validate and counts.max().item() > max_l0:        raise ValueError(            f"A token fired more than max_l0={max_l0} features "            f"(max was {counts.max().item()}). Increase max_l0."        )     return flat_idx, flat_val, counts, B, max_l0
```

As mentioned briefly earlier, if a token fires more features than `max_l0`

, those extra features are silently dropped by the overflow guard in the kernel. This can be dangerous because the result is wrong but there's no crash. The `validate=True`

default catches this by checking `counts.max()`

after the kernel, at the cost of reintroducing a GPU→CPU sync. (*However this is still faster than Exact Allocation in practice.*) If you're very confident that your `max_l0`

is a safe upper bound for your SAE then you can pass `validate=False`

to skip the check, but this is not recommended.

The kernel to compute CSR changes minimally. We no longer need `row_offsets`

since we know that each token takes up `max_l0`

entries in memory, so the lookup for the start of a token's region is replaced by `region_start = pid_token * max_l0`

.

```
@triton.jitdef compute_csr_kernel(    feature_acts_ptr,    write_pos_ptr,    flat_idx_ptr,    flat_val_ptr,    n_features,    max_l0,    BLOCK_F: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)     # Navigate to this block's chunk    feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)    mask = feat_offsets < n_features    feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets    vals = tl.load(feat_ptrs, mask=mask, other=0.0)     fired = vals != 0.0    fired_int = fired.to(tl.int32)     # Each token owns a fixed region of max_l0 slots    region_start = pid_token * max_l0     # Atomically claim the next available slots within this token's region    block_count = tl.sum(fired_int)    base = tl.atomic_add(write_pos_ptr + pid_token, block_count)     # Assign each active feature a unique slot within the claimed range    local_rank = tl.cumsum(fired_int) - fired_int    local_slot = base + local_rank     # Guard against writing past this token's region if L0 exceeds max_l0    in_region = local_slot < max_l0    write_mask = fired & mask & in_region     slots = region_start + local_slot    tl.store(flat_idx_ptr + slots, feat_offsets.to(tl.int32), mask=write_mask)    tl.store(flat_val_ptr + slots, vals, mask=write_mask)
```

The decoder kernel then changes in the same way. `row_offsets`

is no longer needed, and `counts`

replaces the start/end bracket:

```
@triton.jitdef sparse_decode_kernel(    flat_idx_ptr,    flat_val_ptr,    counts_ptr,    W_dec_ptr,    out_ptr,    d_model,    max_l0,    BLOCK_D: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)     start = pid_token * max_l0    n = tl.load(counts_ptr + pid_token)  # Actual number of active features for this token     offsets = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)    mask = offsets < d_model    acc = tl.zeros([BLOCK_D], dtype=tl.float32)     # Same loop as before     for i in range(n):        j = start + i        feat_idx = tl.load(flat_idx_ptr + j)        feat_val = tl.load(flat_val_ptr + j)        row_ptrs = W_dec_ptr + feat_idx * d_model + offsets        row = tl.load(row_ptrs, mask=mask, other=0.0)        acc += feat_val.to(tl.float32) * row.to(tl.float32)     tl.store(out_ptr + pid_token * d_model + offsets, acc, mask=mask)
```

Writing custom GPU kernels is great, but it's important to make sure that they're actually making the computation faster. I used `triton.testing.do_bench`

(`warmup=25, rep=100`

, reporting the median) to time these kernels and compared them against dense matrix multiplication (`feature_acts @ W_dec`

). All tests were run on a NVIDIA GeForce RTX 4090 GPU.

As a quick summary, the table below shows the relative speedups for an example input configuration (`B = 32`

, `n_features = 65536`

, `d_model = 768`

, `L0 = 64`

):

|
|
|
Dense cuBLAS | 0.288 | 1.0× |
| 0.288 | 1.0× |
| 0.210 | 1.4× |
Custom — exact allocation | 0.151 | 1.9× |
Custom — fixed allocation ( | 0.041 | 7.0× |
Custom — fixed allocation ( | 0.115 | 2.5× |

First, I verified that the custom kernels actually perform matrix multiplication correctly (a custom kernel that is faster but gives the wrong answer doesn't help anyone). In other words, we verify that `sparse_decode(feature_acts, W_dec) == feature_acts @ W_dec`

across 486 different inputs using combinations of the parameters below. Note that `sparse_decode()`

here is just a wrapper matmul function that uses our custom Triton kernels under the hood.

|
|
|
|
| kernel implementation |
| 2 |
| input dtype of |
| 3 |
| batch size (tokens) | 1, 4, 32 | 3 |
| SAE dictionary width | 256, 1024, 16384 | 3 |
| output width | 128, 512, 768 | 3 |
| features fired per token | 1, 8, 100 | 3 |

**Total: 2 × 3 × 3 × 3 × 3 × 3 = 486 configurations. **Each asserts output is fp32 and matches the dense fp32 reference within `atol=1e-4, rtol=1e-3`

.

The preprocessing step of computing a CSR representation adds some computational overhead. It would be interesting to see a direct comparison between `sparse_decode_kernel`

and dense matrix multiplication if you didn't have to pay for that overhead (assume that you somehow already have access to a CSR representation).

If you hold some parameters of the input constant (`B=32`

, `n_features=65536`

, `d_model=768`

) while varying `L0`

(the number of fired features) as shown in the table below, then how much faster is `sparse_decode_kernel`

?

Note that this is **EXCLUDING **the overhead of the CSR preprocessing step (i.e., `compute_csr_kernel`

). Also note that `sparse_decode_kernel`

is essentially the same between Exact Allocation and Fixed Allocation so there is no need to differentiate, but for completeness the graph below plots both (they overlap).

|
| |
16 | 0.02% | 25.5× |
32 | 0.05% | 18.7× |
64 | 0.10% | 12.8× |
128 | 0.20% | 8.0× |
256 | 0.39% | 5.0× |
512 | 0.78% | 3.0× |
1024 | 1.56% | 1.7× |
4096 | 6.25% | 0.6× |

We can also vary `n_features`

while keeping constant `B=32`

, `L0=64`

, `d_model=768`

:

|
|
4,096 | 1.5× |
16,384 | 4.1× |
32,768 | 7.3× |
65,536 | 12.8× |
131,072 | 22.5× |

So clearly `sparse_decode_kernel`

alone is faster than dense matrix multiplication at high sparsity. But of course in practice we probably need to compute CSR as well, which will slow things down somewhat.

The table below shows the relative speedups (relative to dense matmul) for three different input configurations. Here "Kernel only" refers to only `sparse_decode_kernel`

(CSR is precomputed), while "Full" refers to the whole pipeline (i.e., `build_csr`

).

|
|
|
|
|
B=32, F=65536, D=768, L0=64 | 12.8× | 1.9× | 7.0× | 2.5× |
B=256, F=65536, D=768, L0=64 | 7.7× | 1.7× | 3.1× | 2.2× |
B=32, F=131072, D=512, L0=128 | 22.5× | 2.2× | 6.1× | 2.3× |

The graph below shows the speed of the full pipeline (Exact Allocation) and decode-only as you vary sparsity. Here, `L0`

sweeps over [16, 32, 64, 128, 256, 512, 1024, 4096, 16384] while holding `B=32`

, `n_features=65536`

, and `d_model=768`

constant.

To be comprehensive, we can also compare our custom kernels to `torch.sparse.mm`

(using PyTorch's `to_sparse_csr()`

), which uses cuSPARSE internally, and `torch.compile`

. This focuses on the same three input configurations as above.

**Note: I found it a little suspicious that this custom kernel would "beat" **`torch.sparse.mm`

**. It turns out this is mostly because of beating **`to_sparse_csr()`

** when building the CSR. There doesn't seem to be much of a difference in speed between the custom kernel and cuSPARSE on the matrix multiplication step alone.**

As expected, `torch.compile`

doesn't provide a noticeable speedup, but I wanted to include it anyway for completeness.

Up until now we have been focusing entirely on the speed of the matrix multiplication operation, but at the end of the day we care about SAE inference speed as a whole. This is benchmarked by replacing only the decoder matmul step in a `SAELens`

JumpReLU SAE forward pass. The table below focuses on five SAEs across two model families and three dictionary sizes.

|
|
|
|
|
|
|
|
Gemma Scope 2B, L20, 65k | 65,536 | 2,304 | 72 | 3.8e-6 | 4.27× | 5.57× | 11.41× |
Gemma Scope 9B, L20, 65k | 65,536 | 3,584 | 72 | 3.8e-6 | 5.66× | 7.34× | 13.27× |
Gemma Scope 2B, L12, 65k | 65,536 | 2,304 | 72 | 9.5e-7 | 3.91× | 5.48× | 11.33× |
Gemma Scope 2B, L12, 262k | 262,144 | 2,304 | 100 | 1.9e-6 | 12.08× | 14.49× | 22.59× |
Qwen Scope 3.5 2B, L12 | 32,768 | 2,048 | 100 | 4.8e-7 | 1.98× | 2.54× | 5.74× |

The purpose of the Fixed Allocation kernel was to overallocate memory in exchange for speed, so it would be helpful to see exactly how much more memory it uses compared to the Exact Allocation kernel. Surprisingly, it turns out that in practice this overhead is small:

|
|
|
|
|
|
32 | 512 | 218.3 | 218.4 | 218.5 | +0.1 MB |
256 | 512 | 277.7 | 277.9 | 278.8 | +0.9 MB |
1024 | 512 | 482.3 | 482.9 | 485.6 | +2.7 MB |
1024 | 1024 | 482.3 | 482.9 | 490.7 | +7.8 MB |

While these results are encouraging, there are a few important limitations to be aware of and gaps that I plan to address as I continue working on this project.

First, the above benchmark numbers are not absolute, as these tests were run in a specific environment (WSL2 with GPU clocks not pinned). The primary goal of these benchmarks was to gauge the relative performance of the custom kernels compared to baseline implementations. The actual absolute speed likely differs depending on the hardware and benchmarking setup.

A second limitation, which was discussed earlier but is worth reiterating, is that although the Fixed Allocation kernel with `validate=False`

achieves the highest performance, it can silently produce incorrect results if the `max_l0`

parameter is set too low. For this reason using either the Exact Allocation kernel or Fixed Allocation with `validate=True`

is likely better for most cases.

Thirdly, these kernels were designed specifically for sparse matrix multiplication, meaning that beyond a certain sparsity threshold, dense matrix multiplication is actually faster.

Fourth, this implementation focuses exclusively on the decoder inference step of JumpReLU Sparse Autoencoders, but there are likely other sources of inefficiency that could be addressed. For example, future projects could focus on the encoder pass or support for training through custom backward kernels. Additionally the current implementation only supports `float32`

outputs.

Finally, all experiments were run on an RTX 4090, and performance may differ on other GPU architectures such as the A100 or H100.

In conclusion, this project implements custom Triton kernels for the decoder inference step of JumpReLU SAEs by exploiting the inherent sparsity of the hidden representation. On a sample of real SAEs, this achieves 2.5–14× speedup with the Fixed Allocation (validate=True) kernel, with larger gains at higher dictionary sizes.

The full implementation is available on [GitHub](https://github.com/dtiourine/jumprelu-sae-kernels/tree/main).

*I welcome feedback! If you have thoughts, questions, or find any issues, feel free to leave a comment or reach out directly. This is also my first GPU kernel project, so if you're experienced with Triton or GPU kernel optimization and see things I could have done better, I would appreciate any suggestions.*
