cd /news/machine-learning/synchronizations-with-torchrec-keyed… · home topics machine-learning article
[ARTICLE · art-23621] src=leimao.github.io pub= topic=machine-learning verified=true sentiment=· neutral

Synchronizations With TorchRec KeyedJaggedTensor

TorchRec's KeyedJaggedTensor, designed to efficiently combine sparse features in recommendation systems without padding, introduces GPU-CPU synchronization that degrades system performance. The data type's output shape cannot be derived from symbolic tensor shapes, forcing TorchRec to copy lengths data from GPU to CPU at runtime to determine the shape. To avoid this synchronization cost, KeyedJaggedTensor should save key metadata in lists during construction rather than deriving it from actual lengths data, but the current implementation fails to do so in both eager and TorchDynamo compile modes.

read6 min publishedJun 5, 2026

Introduction

In recommendation systems, sparse features such as user-item interaction ids are often used to model user preferences and item characteristics. These sparse features are then mapped to dense representations through large embedding tables.

However, there are a few challenges when working with sparse features in recommendation systems:

  • Different samples may have different numbers of interactions, leading to variable-length input data.
  • There are often lots of sparse features being used in recommendation systems.

In a batch of requests, if all sparse features are padded to the same length, the embedding tables will produce many useless embedding vectors. That wastes memory and downstream compute resources. If each sparse feature accesses embedding tables independently, the overhead becomes large when the number of sparse features is large.

TorchRec KeyedJaggedTensor was designed to address these challenges by combining sparse features across samples and across features into one large sparse feature without padding. This eliminates the memory and compute inefficiencies.

Despite its efficiency, KeyedJaggedTensor

has several caveats and can be used incorrectly, resulting in worse system performance. One key issue in GPU systems is synchronization. In this blog post, I would like to discuss the main pitfalls of KeyedJaggedTensor

and how to use it efficiently on GPU.

TorchRec Data Types

TorchRec has specific input/output data types of its modules to efficiently represent sparse features, including:

JaggedTensor

: a wrapper around the lengths or offsets tensor and the values tensor for a single sparse feature.KeyedJaggedTensor

: a wrapper that represents multiple sparse features and can be thought of as multipleJaggedTensor

s.KeyedTensor

: a wrapper aroundtorch.Tensor

that allows access to tensor values through keys.

KeyedJaggedTensor

can be constructed from a dictionary of JaggedTensor

s, where the keys are the feature names. The output of KeyedJaggedTensor

and EmbeddingBagCollection

is KeyedTensor

, whose embeddings can be accessed through keys.

Synchronizations With TorchRec KeyedJaggedTensor

When using KeyedJaggedTensor

, the critical question is what the output symbolic shape will be for a given input KeyedJaggedTensor

. Such an operation could be using KeyedJaggedTensor

to access EmbeddingBagCollection

, or getting the value tensor corresponding to a specific key in KeyedJaggedTensor

. This symbolic shape cannot be derived from the symbolic shapes of the value, lengths, or offsets tensors in KeyedJaggedTensor

. In other words, any operation that uses KeyedJaggedTensor

as input is data dependent, and the output shape can only be determined from the actual lengths data at runtime. If the value, lengths, and offsets tensors are on GPU, TorchRec has to copy the lengths from GPU to CPU to infer the output shape, which introduces synchronization and can hurt performance.

To mitigate this problem, KeyedJaggedTensor

saves key metadata in lists when it is constructed. For some metadata, such as lengths per key, the values can be determined directly from the corresponding JaggedTensor

without looking at the flattened lengths tensor. In this way, even after the original JaggedTensor

s are no longer available, the output shape can still be determined without reading the actual lengths data, which avoids GPU-CPU synchronization when KeyedJaggedTensor

is on GPU.

However, this is not how the key metadata is derived in the current implementation of KeyedJaggedTensor

. In eager mode, the metadata is derived from the actual data in the lengths tensors, which causes GPU-CPU synchronization if KeyedJaggedTensor

is on GPU. In TorchDynamo compile mode, the metadata is not computed and saved during construction to avoid that synchronization. This only defers the synchronization to the point when KeyedJaggedTensor

is used in an operation, so it does not really solve the problem.

One might ask why the current KeyedJaggedTensor

implementation does not just derive the key metadata from the JaggedTensor

metadata, since that would avoid GPU-CPU synchronization altogether. The reason is that constructing this metadata from JaggedTensor

objects cannot be traced into computation graph in TorchDynamo compile mode, at least for now, because it involves Python list appending and other bookkeeping that the computation graph cannot support. KeyedJaggedTensor

is not a torch.Tensor

after all; only the metadata in a torch.Tensor

can be symbolically traced in TorchDynamo compile mode. In non-TorchDynamo compile mode, the current implementation may record the construction in the computation graph, but GPU-CPU synchronization is still unavoidable if the JaggedTensor

s are on GPU.

Consequently, if KeyedJaggedTensor

is constructed from GPU JaggedTensor

s and used in compile mode, GPU-CPU synchronization is inevitable.

An Illustration of KeyedJaggedTensor

Metadata Problems

Suppose a batch has two sparse features, user_clicked_item_ids

and user_viewed_item_ids

, and two samples:

  • Sample 0: user_clicked_item_ids = [10, 11]

,user_viewed_item_ids = [20]

  • Sample 1: user_clicked_item_ids = [12]

,user_viewed_item_ids = [21, 22, 23]

These can be flattened into one KeyedJaggedTensor

as:

1234567

|

from torchrec.sparse.jagged_tensor import KeyedJaggedTensorkjt = KeyedJaggedTensor(	keys=["user_clicked_item_ids", "user_viewed_item_ids"],	values=[10, 11, 20, 12, 21, 22, 23],	lengths=[2, 1, 1, 3],)

|

Here, lengths

stores the number of values per feature per sample, in key order:

user_clicked_item_ids

for sample 0 has length2

user_viewed_item_ids

for sample 0 has length1

user_clicked_item_ids

for sample 1 has length1

user_viewed_item_ids

for sample 1 has length3

This small example is enough to see why KeyedJaggedTensor

is data dependent. If a downstream operator needs the output shape for user_viewed_item_ids

, it has to know that the lengths for that key are sum([1, 3]) = 4

. When those lengths live on GPU, TorchRec has to move them back to CPU to determine the shape, which creates the synchronization.

If we only want to derive metadata from the JaggedTensor

s themselves, the idea is much simpler. Conceptually, TorchRec could build the key metadata directly from the per-key JaggedTensor

objects:

1234567891011

|

from torchrec.sparse.jagged_tensor import JaggedTensorjts = {	"user_clicked_item_ids": JaggedTensor(values=[10, 11], lengths=[2, 1]),	"user_viewed_item_ids": JaggedTensor(values=[20, 21, 22, 23], lengths=[1, 3]),}lengths_per_key = []for key in ["user_clicked_item_ids", "user_viewed_item_ids"]:	jt = jts[key]	lengths_per_key.append(len(jt.lengths()))

|

In this illustration, the metadata is derived from the JaggedTensor

structure itself, not from the flattened KeyedJaggedTensor

values. That is the shape information we would want to preserve, because it can be known before any downstream operation touches the actual sparse values. However, this pattern is not suitable for TorchDynamo compile mode today, because the Python-side bookkeeping needed to collect the metadata is not traceable into the computation graph.

Conclusions

To use KeyedJaggedTensor

efficiently in a GPU system, it should be constructed from CPU JaggedTensor

s and then moved to GPU in eager mode. This usually means KeyedJaggedTensor

should be constructed in the data preprocessing stage. In a model running on GPU, KeyedJaggedTensor

should only be used as model input. One should avoid constructing KeyedJaggedTensor

from GPU JaggedTensor

s, especially inside the model. In this way, the GPU operations in the model can run asynchronously without GPU-CPU synchronization, which results in the best performance.

References

Synchronizations With TorchRec KeyedJaggedTensor

https://leimao.github.io/blog/TorchRec-KeyedJaggedTensor-Synchronizations/

── more in #machine-learning 4 stories · sorted by recency
sponsored brought to you by zahid.host 4,200+ EU-deployed projects
reading about agents? ship yours in a single git push.

Run your AI side-project on zahid.host

EU-based hosting, git-push deploys, automatic HTTPS, no cold starts. Free tier with a custom domain — perfect for shipping the agent you just read about.

$git push zahid main
Live at https://your-agent.zahid.host
Get free account → Pricing
from €0/mo · no card required
LIVE [news/synchronizations-wit…] indexed:0 read:6min 2026-06-05 ·