# Synchronizations With TorchRec KeyedJaggedTensor

> Source: <https://leimao.github.io/blog/TorchRec-KeyedJaggedTensor-Synchronizations/>
> Published: 2026-06-05 07:00:00+00:00

# Synchronizations With TorchRec KeyedJaggedTensor

Introduction

In recommendation systems, sparse features such as user-item interaction ids are often used to model user preferences and item characteristics. These sparse features are then mapped to dense representations through large embedding tables.

However, there are a few challenges when working with sparse features in recommendation systems:

- Different samples may have different numbers of interactions, leading to variable-length input data.
- There are often lots of sparse features being used in recommendation systems.

In a batch of requests, if all sparse features are padded to the same length, the embedding tables will produce many useless embedding vectors. That wastes memory and downstream compute resources. If each sparse feature accesses embedding tables independently, the overhead becomes large when the number of sparse features is large.

TorchRec [ KeyedJaggedTensor](https://github.com/meta-pytorch/torchrec/blob/release/v1.6.0/torchrec/sparse/jagged_tensor.py#L1823) was designed to address these challenges by combining sparse features across samples and across features into one large sparse feature without padding. This eliminates the memory and compute inefficiencies.

Despite its efficiency, `KeyedJaggedTensor`

has several caveats and can be used incorrectly, resulting in worse system performance. One key issue in GPU systems is synchronization. In this blog post, I would like to discuss the main pitfalls of `KeyedJaggedTensor`

and how to use it efficiently on GPU.

TorchRec Data Types

TorchRec has specific input/output data types of its modules to efficiently represent sparse features, including:

`JaggedTensor`

: a wrapper around the lengths or offsets tensor and the values tensor for a single sparse feature.`KeyedJaggedTensor`

: a wrapper that represents multiple sparse features and can be thought of as multiple`JaggedTensor`

s.`KeyedTensor`

: a wrapper around`torch.Tensor`

that allows access to tensor values through keys.

`KeyedJaggedTensor`

can be constructed from a dictionary of `JaggedTensor`

s, where the keys are the feature names. The output of `KeyedJaggedTensor`

and `EmbeddingBagCollection`

is `KeyedTensor`

, whose embeddings can be accessed through keys.

Synchronizations With TorchRec `KeyedJaggedTensor`

When using `KeyedJaggedTensor`

, the critical question is what the output symbolic shape will be for a given input `KeyedJaggedTensor`

. Such an operation could be using `KeyedJaggedTensor`

to access `EmbeddingBagCollection`

, or getting the value tensor corresponding to a specific key in `KeyedJaggedTensor`

. This symbolic shape cannot be derived from the symbolic shapes of the value, lengths, or offsets tensors in `KeyedJaggedTensor`

. In other words, any operation that uses `KeyedJaggedTensor`

as input is data dependent, and the output shape can only be determined from the actual lengths data at runtime. If the value, lengths, and offsets tensors are on GPU, TorchRec has to copy the lengths from GPU to CPU to infer the output shape, which introduces synchronization and can hurt performance.

To mitigate this problem, `KeyedJaggedTensor`

saves key metadata in lists when it is constructed. For some metadata, such as lengths per key, the values can be determined directly from the corresponding `JaggedTensor`

without looking at the flattened lengths tensor. In this way, even after the original `JaggedTensor`

s are no longer available, the output shape can still be determined without reading the actual lengths data, which avoids GPU-CPU synchronization when `KeyedJaggedTensor`

is on GPU.

However, this is not how the key metadata is derived in the current implementation of `KeyedJaggedTensor`

. In eager mode, the metadata is derived from the actual data in the lengths tensors, which causes GPU-CPU synchronization if `KeyedJaggedTensor`

is on GPU. In TorchDynamo compile mode, the metadata is not computed and saved during construction to avoid that synchronization. This only defers the synchronization to the point when `KeyedJaggedTensor`

is used in an operation, so it does not really solve the problem.

One might ask why the current `KeyedJaggedTensor`

implementation does not just derive the key metadata from the `JaggedTensor`

metadata, since that would avoid GPU-CPU synchronization altogether. The reason is that constructing this metadata from `JaggedTensor`

objects cannot be traced into computation graph in TorchDynamo compile mode, at least for now, because it involves Python list appending and other bookkeeping that the computation graph cannot support. `KeyedJaggedTensor`

is not a `torch.Tensor`

after all; only the metadata in a `torch.Tensor`

can be symbolically traced in TorchDynamo compile mode. In non-TorchDynamo compile mode, the current implementation may record the construction in the computation graph, but GPU-CPU synchronization is still unavoidable if the `JaggedTensor`

s are on GPU.

Consequently, if `KeyedJaggedTensor`

is constructed from GPU `JaggedTensor`

s and used in compile mode, GPU-CPU synchronization is inevitable.

An Illustration of `KeyedJaggedTensor`

Metadata Problems

Suppose a batch has two sparse features, `user_clicked_item_ids`

and `user_viewed_item_ids`

, and two samples:

- Sample 0:
`user_clicked_item_ids = [10, 11]`

,`user_viewed_item_ids = [20]`

- Sample 1:
`user_clicked_item_ids = [12]`

,`user_viewed_item_ids = [21, 22, 23]`

These can be flattened into one `KeyedJaggedTensor`

as:

```
1234567
```

 | 

``` python
from torchrec.sparse.jagged_tensor import KeyedJaggedTensorkjt = KeyedJaggedTensor(	keys=["user_clicked_item_ids", "user_viewed_item_ids"],	values=[10, 11, 20, 12, 21, 22, 23],	lengths=[2, 1, 1, 3],)
```

 |

Here, `lengths`

stores the number of values per feature per sample, in key order:

`user_clicked_item_ids`

for sample 0 has length`2`

`user_viewed_item_ids`

for sample 0 has length`1`

`user_clicked_item_ids`

for sample 1 has length`1`

`user_viewed_item_ids`

for sample 1 has length`3`

This small example is enough to see why `KeyedJaggedTensor`

is data dependent. If a downstream operator needs the output shape for `user_viewed_item_ids`

, it has to know that the lengths for that key are `sum([1, 3]) = 4`

. When those lengths live on GPU, TorchRec has to move them back to CPU to determine the shape, which creates the synchronization.

If we only want to derive metadata from the `JaggedTensor`

s themselves, the idea is much simpler. Conceptually, TorchRec could build the key metadata directly from the per-key `JaggedTensor`

objects:

```
1234567891011
```

 | 

``` python
from torchrec.sparse.jagged_tensor import JaggedTensorjts = {	"user_clicked_item_ids": JaggedTensor(values=[10, 11], lengths=[2, 1]),	"user_viewed_item_ids": JaggedTensor(values=[20, 21, 22, 23], lengths=[1, 3]),}lengths_per_key = []for key in ["user_clicked_item_ids", "user_viewed_item_ids"]:	jt = jts[key]	lengths_per_key.append(len(jt.lengths()))
```

 |

In this illustration, the metadata is derived from the `JaggedTensor`

structure itself, not from the flattened `KeyedJaggedTensor`

values. That is the shape information we would want to preserve, because it can be known before any downstream operation touches the actual sparse values. However, this pattern is not suitable for TorchDynamo compile mode today, because the Python-side bookkeeping needed to collect the metadata is not traceable into the computation graph.

Conclusions

To use `KeyedJaggedTensor`

efficiently in a GPU system, it should be constructed from CPU `JaggedTensor`

s and then moved to GPU in eager mode. This usually means `KeyedJaggedTensor`

should be constructed in the data preprocessing stage. In a model running on GPU, `KeyedJaggedTensor`

should only be used as model input. One should avoid constructing `KeyedJaggedTensor`

from GPU `JaggedTensor`

s, especially inside the model. In this way, the GPU operations in the model can run asynchronously without GPU-CPU synchronization, which results in the best performance.

References

Synchronizations With TorchRec KeyedJaggedTensor

[https://leimao.github.io/blog/TorchRec-KeyedJaggedTensor-Synchronizations/](https://leimao.github.io/blog/TorchRec-KeyedJaggedTensor-Synchronizations/)
