cd /news/machine-learning/speeding-up-jumprelu-sae-inference-w… · home topics machine-learning article
[ARTICLE · art-26728] src=lesswrong.com ↗ pub= topic=machine-learning verified=true sentiment=↑ positive

Speeding Up JumpReLU SAE Inference with Custom Triton Kernels (2–14× on Real SAEs)

Researchers developed custom Triton kernels that accelerate JumpReLU Sparse Autoencoder inference by 2–14× on real SAEs, exploiting activation sparsity to skip zero entries during matrix multiplication. The kernels use either exact or fixed memory allocation per token, with a configurable max_l0 parameter to balance speed and safety.

read20 min publishedJun 14, 2026

Sparse Autoencoders (SAEs) have become a central tool in mechanistic interpretability research, providing a way to decompose a model's internal activations into sparse, interpretable features. However, extracting these features often requires running the SAE over large volumes of activations across many layers and tokens. This makes SAE inference efficiency a practical bottleneck for interpretability research at scale.

This post focuses on improving the inference efficiency of JumpReLU Sparse Autoencoders, which were introduced by DeepMind in Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders (Rajamanoharan et al). Instead of using a traditional ReLU activation function, these SAEs use JumpReLU, which zeros out activations that fall below a learned per-feature threshold . This gives JumpReLU SAEs a variable number of active features per token (commonly written as , the count of nonzero activations), unlike TopK SAEs which fire exactly features per token.

I use the terms "fire" and "fired" to describe features with non-zero activations.

Traditional JumpReLU SAE implementations compute the decoder step as a dense matrix multiplication (feature_acts @ W_dec

), but this is wasteful because of the sparsity of feature_acts

. Instead, you can exploit this sparsity property and skip the zero entries entirely during matrix multiplication with a custom Triton kernel.

When a single token passes through a JumpReLU SAE with 65,536 features, the encoder produces a feature activation vector of length 65,536, but only some entries are nonzero.

To be more concrete, consider a toy SAE with feature activations . Now suppose that we only have 2 active features, where represents the weight matrix of the decoder layer:

We then compute the output with:

Notice how only two of the rows of the decoder matrix were actually used in the computation. The rest were multiplied by 0 and contributed nothing to the output. We could instead just compute:

Now imagine this same example but increase the hidden dimension from 8 to something much greater. For instance with 72 active features. That would mean you're multiplying ~99.89% of rows by zero.

If we knew in advance which features are nonzero and their corresponding values, we could skip these zero multiplications and simplify the computation.

For a single token, this can be divided into two parts:

When implementing this kernel, my first thought was to begin with a preliminary step that figures out exactly how many features fired for each token so the system could then allocate exactly the memory needed to store the CSR representation (more on that later). However, this process involves a GPU->CPU sync, which causes some slowdown.

As an alternative, you can instead allocate some predetermined/fixed amount of memory for each token using a configurable max_l0

parameter. This speeds up computation but overallocates memory and introduces an important caveat that max_l0

must be large enough to avoid errors. For example if you set max_l0=10

, but one of the tokens in the batch has >10 nonzero features, those extra features will be dropped, resulting in information loss.

Both approaches are covered below. For convenience, let's refer to the kernel that allocates exactly the memory needed for the CSR representation as the **Exact Allocation **kernel and the kernel that allocates a predetermined amount of memory per token as the Fixed Allocation kernel. The Fixed Allocation kernel can also be configured with either validate=True

or validate=False

. The validate=True

version is slightly slower than validate=False

, but it raises an error if any token fires more features than max_l0

. This is clearly safer, but if you are 100% sure that no token will exceed max_l0

, then you can use validate=False

for some speedup.

To skip zero entries during matrix multiplication, we need to first represent the feature activations in Compressed Sparse Row (CSR) format, which is a standard way of representing sparse matrices that stores only the nonzero values and their indices. For the example above, instead of storing all 8 entries of , CSR stores just:

To allocate enough memory for building a CSR representation, we need to know how much memory each token requires (how many features fired per token). A count_nonzero

kernel handles this:

import tritonimport triton.language as tl@triton.jitdef count_nonzero(feature_acts_ptr, counts_ptr, n_features, BLOCK_F: tl.constexpr):    pid_token = tl.program_id(0) # Which token am I working on? (row index)    pid_d = tl.program_id(1) # Which chunk of features am I working on? (column index)    # Compute the feature indices this block is responsible for    feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)     mask = feat_offsets < n_features # Guard against reading past the end of the feature dimension    # Navigate to this token's features in memory, then to this block's chunk    feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets     vals = tl.load(feat_ptrs, mask=mask, other=0.0) # Load the feature values     fired = vals != 0.0 # Which features in this chunk are active (nonzero)?     fired_count = tl.sum(fired.to(tl.int32)) # How many active (nonzero) features in this chunk?    # Accumulate into this token's count (atomic since multiple blocks write to the same token)    tl.atomic_add(counts_ptr + pid_token, fired_count)

If you're unfamiliar with Triton, the key mental model is that rather than writing a loop that runs sequentially, you write a kernel that describes what one block does and Triton launches many of these blocks in parallel across the GPU. In this kernel, each block is responsible for a chunk of one token's features. The two program_id

calls tell each block where it is: pid_token

identifies which token (which row of the input matrix), and pid_d

identifies which chunk of that token's features to process.

Also note that pointers in GPU kernels point to the start of a flat block of memory. To reach a particular token's features, we offset into that memory by pid_token * n_features

. Within that token, we offset further by pid_d * BLOCK_F

to reach the right chunk. The mask

guards against reading past the end when n_features

isn't a clean multiple of BLOCK_F

.

Finally, since multiple blocks may be counting features for the same token simultaneously, tl.atomic_add

ensures their partial counts are combined safely.

This count_nonzero

kernel produces an array counts

of length where is the number of tokens in the batch. The number of active (nonzero) features for the token is stored in counts[i]

.

We can then use this information to allocate two flat arrays, flat_idx

and flat_val

, which hold the active feature indices and their values across the entire batch. For example, this might look like:

You may have noticed that it's not clear which entries belong to which token. For example, flat_idx[2]

tells us that the feature at index fired, but it doesn't tell us if this was for the first token in the batch or the second token or the third, etc.

We can solve this problem by introducing a new array row_offsets

of length , where row_offsets[b]

stores the starting index in flat_idx

/flat_val

where token 's entries begin. It's computed by taking a cumulative sum of counts

, so each token's region starts exactly where the previous one ends. For example, if three tokens have 2, 5, and 3 active features:

Now token 0's entries live at indices 0–1, token 1's at 2–6, token 2's at 7–9, and the final entry (10) tells us the total number of nonzero features across all tokens in the batch.

We can construct row_offsets

inside a wrapper function build_csr

that also handles memory and orchestration. It calls compute_csr_kernel

, which is the kernel responsible for actually filling flat_idx

and flat_val

with the correct values. Note that flat_idx

and flat_val

are initialized as empty arrays as pre-allocated storage that compute_csr_kernel

will write into.

def build_csr(feature_acts: torch.Tensor, BLOCK_F: int = 1024):    B, n_features = feature_acts.shape    device = feature_acts.device    # Count how many features fired per token    counts = torch.zeros(B, dtype=torch.int32, device=device)    grid = (B, triton.cdiv(n_features, BLOCK_F))    count_nonzero[grid](feature_acts, counts, n_features, BLOCK_F=BLOCK_F)    # Cumsum over counts gives each token a contiguous region in the flat arrays    # row_offsets[b] = start index of token b's entries in flat_idx/flat_val    row_offsets = torch.zeros(B + 1, dtype=torch.int32, device=device)    row_offsets[1:] = counts.cumsum(0).to(torch.int32)    # The last entry is the total number of nonzeros. This is used to size the flat arrays    total_nnz = int(row_offsets[-1].item())  # GPU->CPU sync point    flat_idx = torch.empty(total_nnz, dtype=torch.int32, device=device)    flat_val = torch.empty(total_nnz, dtype=feature_acts.dtype, device=device)    # write_pos is a per-token cursor that coordinates concurrent writes within    # a token's region. Each block atomically claims the next available slots by    # bumping write_pos by its count, getting back its starting offset (base).    write_pos = torch.zeros(B, dtype=torch.int32, device=device)    compute_csr_kernel[grid](        feature_acts,        row_offsets,        write_pos,        flat_idx,        flat_val,        n_features,        BLOCK_F=BLOCK_F,    )    return flat_idx, flat_val, row_offsets, B
@triton.jitdef compute_csr_kernel(    feature_acts_ptr,    row_offsets_ptr,    write_pos_ptr,    flat_idx_ptr,    flat_val_ptr,    n_features,    BLOCK_F: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)    # Same pointer arithmetic as count_nonzero, navigate to this block's chunk    feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)    mask = feat_offsets < n_features    feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets    vals = tl.load(feat_ptrs, mask=mask, other=0.0)    fired = vals != 0.0    fired_int = fired.to(tl.int32)    # Where does this token's region start in flat_idx/flat_val?    region_start = tl.load(row_offsets_ptr + pid_token)    # Atomically claim the next block_count slots within this token's region    block_count = tl.sum(fired_int)    base = tl.atomic_add(write_pos_ptr + pid_token, block_count)    # Assign each active feature a unique slot within the claimed range    local_rank = tl.cumsum(fired_int) - fired_int    slots = region_start + base + local_rank    # Write the feature index and value into the claimed slots    tl.store(flat_idx_ptr + slots, feat_offsets.to(tl.int32), mask=fired & mask)    tl.store(flat_val_ptr + slots, vals, mask=fired & mask)

Next, sparse_decode_kernel

uses this CSR structure to carry out the matrix multiplication step. For each token, it looks up where that token's active features live in flat_idx

/flat_val

using row_offsets

, then loops over them, accumulating the weighted sum of the corresponding decoder rows into a tile of the output.

@triton.jitdef sparse_decode_kernel(    flat_idx_ptr, flat_val_ptr, row_offsets_ptr,    W_dec_ptr, out_ptr, d_model,    BLOCK_D: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)     # Find the slice of flat_idx/flat_val belonging to this token    start = tl.load(row_offsets_ptr + pid_token)    end = tl.load(row_offsets_ptr + pid_token + 1)    n = end - start  # Number of active features for this token     # This block owns a BLOCK_D-wide slice of the output row    offsets = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)    mask = offsets < d_model    acc = tl.zeros([BLOCK_D], dtype=tl.float32)     # Loop over this token's active features, accumulating their contribution    for i in range(n):        j = start + i        feat_idx = tl.load(flat_idx_ptr + j)   # Which decoder row?        feat_val = tl.load(flat_val_ptr + j)   # Scale factor         # Load the corresponding decoder row (just this block's slice)        row_ptrs = W_dec_ptr + (feat_idx * d_model) + offsets        row = tl.load(row_ptrs, mask=mask, other=0.0)        acc += feat_val.to(tl.float32) * row.to(tl.float32)     # Write this block's output slice    tl.store(out_ptr + pid_token * d_model + offsets, acc, mask=mask)

Finally, we put all of these kernels together by wrapping them in a single sparse_decode()

function that acts as a drop-in replacement for @

:

def _sparse_decode(flat_idx, flat_val, row_offsets, W_dec, B, BLOCK_D=256):    d_model = W_dec.shape[1]    out = torch.zeros((B, d_model), device=W_dec.device, dtype=torch.float32)     # parallelize over batch rows and d_model tiles    grid = (B, triton.cdiv(d_model, BLOCK_D))      sparse_decode_kernel[grid](        flat_idx, flat_val, row_offsets, W_dec, out, d_model, BLOCK_D=BLOCK_D    )     return out  def sparse_decode(feature_acts, W_dec):    # Triton requires contiguous memory for correct stride arithmetic    W_dec = W_dec.contiguous()      flat_idx, flat_val, row_offsets, B = build_csr(feature_acts)    return _sparse_decode(flat_idx, flat_val, row_offsets, W_dec, B)

Recall how in the Exact Allocation Kernel, inside build_csr

we extracted the total number of nonzero entries across all tokens by retrieving the last entry of row_offsets

:

total_nnz = int(row_offsets[-1].item())

When we call .item()

, we are forcing the CPU to wait for the GPU to finish the counting pass before it can read total_nnz

and allocate flat_idx

/flat_val

.

Normally the CPU queues up GPU work asynchronously and moves on without waiting, but .item()

breaks that pipeline by requiring the CPU to stall until the GPU result is ready.* *

This turns out to be a significant source of slowdown.

The Fixed Allocation kernel works around this by not even allocating *exactly *the memory needed in the first place (meaning we don't even need total_nnz

). Instead, we allocate max_l0

slots per token, where max_l0

is a user-specified upper bound on how many features can fire for any single token. This also means we no longer need to count the number of nonzero tokens before computing the CSR structure.

With these changes, the new build_csr

wrapper function looks like:

def build_csr(feature_acts: torch.Tensor, BLOCK_F: int = 1024, max_l0: int = 512, validate: bool = True):    B, n_features = feature_acts.shape    device = feature_acts.device     # Fixed memory allocation     capacity = B * max_l0    flat_idx = torch.empty(capacity, dtype=torch.int32, device=device)    flat_val = torch.empty(capacity, dtype=feature_acts.dtype, device=device)     # write_pos serves as both the per-token write cursor during the kernel    # and the per-token count afterward    write_pos = torch.zeros(B, dtype=torch.int32, device=device)     grid = (B, triton.cdiv(n_features, BLOCK_F))    compute_csr_kernel[grid](        feature_acts, write_pos, flat_idx, flat_val,        n_features, max_l0, BLOCK_F=BLOCK_F,    )     counts = write_pos  # final cursor value = number of features written per token     # Optional safety check. This reintroduces a GPU→CPU sync but catches silent truncation    if validate and counts.max().item() > max_l0:        raise ValueError(            f"A token fired more than max_l0={max_l0} features "            f"(max was {counts.max().item()}). Increase max_l0."        )     return flat_idx, flat_val, counts, B, max_l0

As mentioned briefly earlier, if a token fires more features than max_l0

, those extra features are silently dropped by the overflow guard in the kernel. This can be dangerous because the result is wrong but there's no crash. The validate=True

default catches this by checking counts.max()

after the kernel, at the cost of reintroducing a GPU→CPU sync. (However this is still faster than Exact Allocation in practice.) If you're very confident that your max_l0

is a safe upper bound for your SAE then you can pass validate=False

to skip the check, but this is not recommended.

The kernel to compute CSR changes minimally. We no longer need row_offsets

since we know that each token takes up max_l0

entries in memory, so the lookup for the start of a token's region is replaced by region_start = pid_token * max_l0

.

@triton.jitdef compute_csr_kernel(    feature_acts_ptr,    write_pos_ptr,    flat_idx_ptr,    flat_val_ptr,    n_features,    max_l0,    BLOCK_F: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)     # Navigate to this block's chunk    feat_offsets = pid_d * BLOCK_F + tl.arange(0, BLOCK_F)    mask = feat_offsets < n_features    feat_ptrs = feature_acts_ptr + pid_token * n_features + feat_offsets    vals = tl.load(feat_ptrs, mask=mask, other=0.0)     fired = vals != 0.0    fired_int = fired.to(tl.int32)     # Each token owns a fixed region of max_l0 slots    region_start = pid_token * max_l0     # Atomically claim the next available slots within this token's region    block_count = tl.sum(fired_int)    base = tl.atomic_add(write_pos_ptr + pid_token, block_count)     # Assign each active feature a unique slot within the claimed range    local_rank = tl.cumsum(fired_int) - fired_int    local_slot = base + local_rank     # Guard against writing past this token's region if L0 exceeds max_l0    in_region = local_slot < max_l0    write_mask = fired & mask & in_region     slots = region_start + local_slot    tl.store(flat_idx_ptr + slots, feat_offsets.to(tl.int32), mask=write_mask)    tl.store(flat_val_ptr + slots, vals, mask=write_mask)

The decoder kernel then changes in the same way. row_offsets

is no longer needed, and counts

replaces the start/end bracket:

@triton.jitdef sparse_decode_kernel(    flat_idx_ptr,    flat_val_ptr,    counts_ptr,    W_dec_ptr,    out_ptr,    d_model,    max_l0,    BLOCK_D: tl.constexpr,):    pid_token = tl.program_id(0)    pid_d = tl.program_id(1)     start = pid_token * max_l0    n = tl.load(counts_ptr + pid_token)  # Actual number of active features for this token     offsets = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)    mask = offsets < d_model    acc = tl.zeros([BLOCK_D], dtype=tl.float32)     # Same loop as before     for i in range(n):        j = start + i        feat_idx = tl.load(flat_idx_ptr + j)        feat_val = tl.load(flat_val_ptr + j)        row_ptrs = W_dec_ptr + feat_idx * d_model + offsets        row = tl.load(row_ptrs, mask=mask, other=0.0)        acc += feat_val.to(tl.float32) * row.to(tl.float32)     tl.store(out_ptr + pid_token * d_model + offsets, acc, mask=mask)

Writing custom GPU kernels is great, but it's important to make sure that they're actually making the computation faster. I used triton.testing.do_bench

(warmup=25, rep=100

, reporting the median) to time these kernels and compared them against dense matrix multiplication (feature_acts @ W_dec

). All tests were run on a NVIDIA GeForce RTX 4090 GPU.

As a quick summary, the table below shows the relative speedups for an example input configuration (B = 32

, n_features = 65536

, d_model = 768

, L0 = 64

):

| | | Dense cuBLAS | 0.288 | 1.0× | | 0.288 | 1.0× | | 0.210 | 1.4× | Custom — exact allocation | 0.151 | 1.9× | Custom — fixed allocation ( | 0.041 | 7.0× | Custom — fixed allocation ( | 0.115 | 2.5× |

First, I verified that the custom kernels actually perform matrix multiplication correctly (a custom kernel that is faster but gives the wrong answer doesn't help anyone). In other words, we verify that sparse_decode(feature_acts, W_dec) == feature_acts @ W_dec

across 486 different inputs using combinations of the parameters below. Note that sparse_decode()

here is just a wrapper matmul function that uses our custom Triton kernels under the hood.

| | | | | kernel implementation | | 2 | | input dtype of | | 3 | | batch size (tokens) | 1, 4, 32 | 3 | | SAE dictionary width | 256, 1024, 16384 | 3 | | output width | 128, 512, 768 | 3 | | features fired per token | 1, 8, 100 | 3 |

**Total: 2 × 3 × 3 × 3 × 3 × 3 = 486 configurations. **Each asserts output is fp32 and matches the dense fp32 reference within atol=1e-4, rtol=1e-3

.

The preprocessing step of computing a CSR representation adds some computational overhead. It would be interesting to see a direct comparison between sparse_decode_kernel

and dense matrix multiplication if you didn't have to pay for that overhead (assume that you somehow already have access to a CSR representation).

If you hold some parameters of the input constant (B=32

, n_features=65536

, d_model=768

) while varying L0

(the number of fired features) as shown in the table below, then how much faster is sparse_decode_kernel

?

Note that this is **EXCLUDING **the overhead of the CSR preprocessing step (i.e., compute_csr_kernel

). Also note that sparse_decode_kernel

is essentially the same between Exact Allocation and Fixed Allocation so there is no need to differentiate, but for completeness the graph below plots both (they overlap).

| | | 16 | 0.02% | 25.5× | 32 | 0.05% | 18.7× | 64 | 0.10% | 12.8× | 128 | 0.20% | 8.0× | 256 | 0.39% | 5.0× | 512 | 0.78% | 3.0× | 1024 | 1.56% | 1.7× | 4096 | 6.25% | 0.6× |

We can also vary n_features

while keeping constant B=32

, L0=64

, d_model=768

:

| | 4,096 | 1.5× | 16,384 | 4.1× | 32,768 | 7.3× | 65,536 | 12.8× | 131,072 | 22.5× |

So clearly sparse_decode_kernel

alone is faster than dense matrix multiplication at high sparsity. But of course in practice we probably need to compute CSR as well, which will slow things down somewhat.

The table below shows the relative speedups (relative to dense matmul) for three different input configurations. Here "Kernel only" refers to only sparse_decode_kernel

(CSR is precomputed), while "Full" refers to the whole pipeline (i.e., build_csr

).

| | | | | B=32, F=65536, D=768, L0=64 | 12.8× | 1.9× | 7.0× | 2.5× | B=256, F=65536, D=768, L0=64 | 7.7× | 1.7× | 3.1× | 2.2× | B=32, F=131072, D=512, L0=128 | 22.5× | 2.2× | 6.1× | 2.3× |

The graph below shows the speed of the full pipeline (Exact Allocation) and decode-only as you vary sparsity. Here, L0

sweeps over [16, 32, 64, 128, 256, 512, 1024, 4096, 16384] while holding B=32

, n_features=65536

, and d_model=768

constant.

To be comprehensive, we can also compare our custom kernels to torch.sparse.mm

(using PyTorch's to_sparse_csr()

), which uses cuSPARSE internally, and torch.compile

. This focuses on the same three input configurations as above.

**Note: I found it a little suspicious that this custom kernel would "beat" **torch.sparse.mm

**. It turns out this is mostly because of beating **to_sparse_csr()

** when building the CSR. There doesn't seem to be much of a difference in speed between the custom kernel and cuSPARSE on the matrix multiplication step alone.**

As expected, torch.compile

doesn't provide a noticeable speedup, but I wanted to include it anyway for completeness.

Up until now we have been focusing entirely on the speed of the matrix multiplication operation, but at the end of the day we care about SAE inference speed as a whole. This is benchmarked by replacing only the decoder matmul step in a SAELens

JumpReLU SAE forward pass. The table below focuses on five SAEs across two model families and three dictionary sizes.

| | | | | | | | Gemma Scope 2B, L20, 65k | 65,536 | 2,304 | 72 | 3.8e-6 | 4.27× | 5.57× | 11.41× | Gemma Scope 9B, L20, 65k | 65,536 | 3,584 | 72 | 3.8e-6 | 5.66× | 7.34× | 13.27× | Gemma Scope 2B, L12, 65k | 65,536 | 2,304 | 72 | 9.5e-7 | 3.91× | 5.48× | 11.33× | Gemma Scope 2B, L12, 262k | 262,144 | 2,304 | 100 | 1.9e-6 | 12.08× | 14.49× | 22.59× | Qwen Scope 3.5 2B, L12 | 32,768 | 2,048 | 100 | 4.8e-7 | 1.98× | 2.54× | 5.74× |

The purpose of the Fixed Allocation kernel was to overallocate memory in exchange for speed, so it would be helpful to see exactly how much more memory it uses compared to the Exact Allocation kernel. Surprisingly, it turns out that in practice this overhead is small:

| | | | | | 32 | 512 | 218.3 | 218.4 | 218.5 | +0.1 MB | 256 | 512 | 277.7 | 277.9 | 278.8 | +0.9 MB | 1024 | 512 | 482.3 | 482.9 | 485.6 | +2.7 MB | 1024 | 1024 | 482.3 | 482.9 | 490.7 | +7.8 MB |

While these results are encouraging, there are a few important limitations to be aware of and gaps that I plan to address as I continue working on this project.

First, the above benchmark numbers are not absolute, as these tests were run in a specific environment (WSL2 with GPU clocks not pinned). The primary goal of these benchmarks was to gauge the relative performance of the custom kernels compared to baseline implementations. The actual absolute speed likely differs depending on the hardware and benchmarking setup.

A second limitation, which was discussed earlier but is worth reiterating, is that although the Fixed Allocation kernel with validate=False

achieves the highest performance, it can silently produce incorrect results if the max_l0

parameter is set too low. For this reason using either the Exact Allocation kernel or Fixed Allocation with validate=True

is likely better for most cases.

Thirdly, these kernels were designed specifically for sparse matrix multiplication, meaning that beyond a certain sparsity threshold, dense matrix multiplication is actually faster.

Fourth, this implementation focuses exclusively on the decoder inference step of JumpReLU Sparse Autoencoders, but there are likely other sources of inefficiency that could be addressed. For example, future projects could focus on the encoder pass or support for training through custom backward kernels. Additionally the current implementation only supports float32

outputs.

Finally, all experiments were run on an RTX 4090, and performance may differ on other GPU architectures such as the A100 or H100.

In conclusion, this project implements custom Triton kernels for the decoder inference step of JumpReLU SAEs by exploiting the inherent sparsity of the hidden representation. On a sample of real SAEs, this achieves 2.5–14× speedup with the Fixed Allocation (validate=True) kernel, with larger gains at higher dictionary sizes.

The full implementation is available on GitHub.

I welcome feedback! If you have thoughts, questions, or find any issues, feel free to leave a comment or reach out directly. This is also my first GPU kernel project, so if you're experienced with Triton or GPU kernel optimization and see things I could have done better, I would appreciate any suggestions.

── more in #machine-learning 4 stories · sorted by recency
sponsored brought to you by zahid.host 4,200+ EU-deployed projects
reading about agents? ship yours in a single git push.

Run your AI side-project on zahid.host

EU-based hosting, git-push deploys, automatic HTTPS, no cold starts. Free tier with a custom domain — perfect for shipping the agent you just read about.

$git push zahid main
Live at https://your-agent.zahid.host
Get free account → Pricing
from €0/mo · no card required
LIVE [news/speeding-up-jumprelu…] indexed:0 read:20min 2026-06-14 ·