{"slug": "synchronizations-with-torchrec-keyedjaggedtensor", "title": "Synchronizations With TorchRec KeyedJaggedTensor", "summary": "TorchRec's KeyedJaggedTensor, designed to efficiently combine sparse features in recommendation systems without padding, introduces GPU-CPU synchronization that degrades system performance. The data type's output shape cannot be derived from symbolic tensor shapes, forcing TorchRec to copy lengths data from GPU to CPU at runtime to determine the shape. To avoid this synchronization cost, KeyedJaggedTensor should save key metadata in lists during construction rather than deriving it from actual lengths data, but the current implementation fails to do so in both eager and TorchDynamo compile modes.", "body_md": "# Synchronizations With TorchRec KeyedJaggedTensor\n\nIntroduction\n\nIn recommendation systems, sparse features such as user-item interaction ids are often used to model user preferences and item characteristics. These sparse features are then mapped to dense representations through large embedding tables.\n\nHowever, there are a few challenges when working with sparse features in recommendation systems:\n\n- Different samples may have different numbers of interactions, leading to variable-length input data.\n- There are often lots of sparse features being used in recommendation systems.\n\nIn a batch of requests, if all sparse features are padded to the same length, the embedding tables will produce many useless embedding vectors. That wastes memory and downstream compute resources. If each sparse feature accesses embedding tables independently, the overhead becomes large when the number of sparse features is large.\n\nTorchRec [ KeyedJaggedTensor](https://github.com/meta-pytorch/torchrec/blob/release/v1.6.0/torchrec/sparse/jagged_tensor.py#L1823) was designed to address these challenges by combining sparse features across samples and across features into one large sparse feature without padding. This eliminates the memory and compute inefficiencies.\n\nDespite its efficiency, `KeyedJaggedTensor`\n\nhas several caveats and can be used incorrectly, resulting in worse system performance. One key issue in GPU systems is synchronization. In this blog post, I would like to discuss the main pitfalls of `KeyedJaggedTensor`\n\nand how to use it efficiently on GPU.\n\nTorchRec Data Types\n\nTorchRec has specific input/output data types of its modules to efficiently represent sparse features, including:\n\n`JaggedTensor`\n\n: a wrapper around the lengths or offsets tensor and the values tensor for a single sparse feature.`KeyedJaggedTensor`\n\n: a wrapper that represents multiple sparse features and can be thought of as multiple`JaggedTensor`\n\ns.`KeyedTensor`\n\n: a wrapper around`torch.Tensor`\n\nthat allows access to tensor values through keys.\n\n`KeyedJaggedTensor`\n\ncan be constructed from a dictionary of `JaggedTensor`\n\ns, where the keys are the feature names. The output of `KeyedJaggedTensor`\n\nand `EmbeddingBagCollection`\n\nis `KeyedTensor`\n\n, whose embeddings can be accessed through keys.\n\nSynchronizations With TorchRec `KeyedJaggedTensor`\n\nWhen using `KeyedJaggedTensor`\n\n, the critical question is what the output symbolic shape will be for a given input `KeyedJaggedTensor`\n\n. Such an operation could be using `KeyedJaggedTensor`\n\nto access `EmbeddingBagCollection`\n\n, or getting the value tensor corresponding to a specific key in `KeyedJaggedTensor`\n\n. This symbolic shape cannot be derived from the symbolic shapes of the value, lengths, or offsets tensors in `KeyedJaggedTensor`\n\n. In other words, any operation that uses `KeyedJaggedTensor`\n\nas input is data dependent, and the output shape can only be determined from the actual lengths data at runtime. If the value, lengths, and offsets tensors are on GPU, TorchRec has to copy the lengths from GPU to CPU to infer the output shape, which introduces synchronization and can hurt performance.\n\nTo mitigate this problem, `KeyedJaggedTensor`\n\nsaves key metadata in lists when it is constructed. For some metadata, such as lengths per key, the values can be determined directly from the corresponding `JaggedTensor`\n\nwithout looking at the flattened lengths tensor. In this way, even after the original `JaggedTensor`\n\ns are no longer available, the output shape can still be determined without reading the actual lengths data, which avoids GPU-CPU synchronization when `KeyedJaggedTensor`\n\nis on GPU.\n\nHowever, this is not how the key metadata is derived in the current implementation of `KeyedJaggedTensor`\n\n. In eager mode, the metadata is derived from the actual data in the lengths tensors, which causes GPU-CPU synchronization if `KeyedJaggedTensor`\n\nis on GPU. In TorchDynamo compile mode, the metadata is not computed and saved during construction to avoid that synchronization. This only defers the synchronization to the point when `KeyedJaggedTensor`\n\nis used in an operation, so it does not really solve the problem.\n\nOne might ask why the current `KeyedJaggedTensor`\n\nimplementation does not just derive the key metadata from the `JaggedTensor`\n\nmetadata, since that would avoid GPU-CPU synchronization altogether. The reason is that constructing this metadata from `JaggedTensor`\n\nobjects cannot be traced into computation graph in TorchDynamo compile mode, at least for now, because it involves Python list appending and other bookkeeping that the computation graph cannot support. `KeyedJaggedTensor`\n\nis not a `torch.Tensor`\n\nafter all; only the metadata in a `torch.Tensor`\n\ncan be symbolically traced in TorchDynamo compile mode. In non-TorchDynamo compile mode, the current implementation may record the construction in the computation graph, but GPU-CPU synchronization is still unavoidable if the `JaggedTensor`\n\ns are on GPU.\n\nConsequently, if `KeyedJaggedTensor`\n\nis constructed from GPU `JaggedTensor`\n\ns and used in compile mode, GPU-CPU synchronization is inevitable.\n\nAn Illustration of `KeyedJaggedTensor`\n\nMetadata Problems\n\nSuppose a batch has two sparse features, `user_clicked_item_ids`\n\nand `user_viewed_item_ids`\n\n, and two samples:\n\n- Sample 0:\n`user_clicked_item_ids = [10, 11]`\n\n,`user_viewed_item_ids = [20]`\n\n- Sample 1:\n`user_clicked_item_ids = [12]`\n\n,`user_viewed_item_ids = [21, 22, 23]`\n\nThese can be flattened into one `KeyedJaggedTensor`\n\nas:\n\n```\n1234567\n```\n\n | \n\n``` python\nfrom torchrec.sparse.jagged_tensor import KeyedJaggedTensorkjt = KeyedJaggedTensor(\tkeys=[\"user_clicked_item_ids\", \"user_viewed_item_ids\"],\tvalues=[10, 11, 20, 12, 21, 22, 23],\tlengths=[2, 1, 1, 3],)\n```\n\n |\n\nHere, `lengths`\n\nstores the number of values per feature per sample, in key order:\n\n`user_clicked_item_ids`\n\nfor sample 0 has length`2`\n\n`user_viewed_item_ids`\n\nfor sample 0 has length`1`\n\n`user_clicked_item_ids`\n\nfor sample 1 has length`1`\n\n`user_viewed_item_ids`\n\nfor sample 1 has length`3`\n\nThis small example is enough to see why `KeyedJaggedTensor`\n\nis data dependent. If a downstream operator needs the output shape for `user_viewed_item_ids`\n\n, it has to know that the lengths for that key are `sum([1, 3]) = 4`\n\n. When those lengths live on GPU, TorchRec has to move them back to CPU to determine the shape, which creates the synchronization.\n\nIf we only want to derive metadata from the `JaggedTensor`\n\ns themselves, the idea is much simpler. Conceptually, TorchRec could build the key metadata directly from the per-key `JaggedTensor`\n\nobjects:\n\n```\n1234567891011\n```\n\n | \n\n``` python\nfrom torchrec.sparse.jagged_tensor import JaggedTensorjts = {\t\"user_clicked_item_ids\": JaggedTensor(values=[10, 11], lengths=[2, 1]),\t\"user_viewed_item_ids\": JaggedTensor(values=[20, 21, 22, 23], lengths=[1, 3]),}lengths_per_key = []for key in [\"user_clicked_item_ids\", \"user_viewed_item_ids\"]:\tjt = jts[key]\tlengths_per_key.append(len(jt.lengths()))\n```\n\n |\n\nIn this illustration, the metadata is derived from the `JaggedTensor`\n\nstructure itself, not from the flattened `KeyedJaggedTensor`\n\nvalues. That is the shape information we would want to preserve, because it can be known before any downstream operation touches the actual sparse values. However, this pattern is not suitable for TorchDynamo compile mode today, because the Python-side bookkeeping needed to collect the metadata is not traceable into the computation graph.\n\nConclusions\n\nTo use `KeyedJaggedTensor`\n\nefficiently in a GPU system, it should be constructed from CPU `JaggedTensor`\n\ns and then moved to GPU in eager mode. This usually means `KeyedJaggedTensor`\n\nshould be constructed in the data preprocessing stage. In a model running on GPU, `KeyedJaggedTensor`\n\nshould only be used as model input. One should avoid constructing `KeyedJaggedTensor`\n\nfrom GPU `JaggedTensor`\n\ns, especially inside the model. In this way, the GPU operations in the model can run asynchronously without GPU-CPU synchronization, which results in the best performance.\n\nReferences\n\nSynchronizations With TorchRec KeyedJaggedTensor\n\n[https://leimao.github.io/blog/TorchRec-KeyedJaggedTensor-Synchronizations/](https://leimao.github.io/blog/TorchRec-KeyedJaggedTensor-Synchronizations/)", "url": "https://wpnews.pro/news/synchronizations-with-torchrec-keyedjaggedtensor", "canonical_source": "https://leimao.github.io/blog/TorchRec-KeyedJaggedTensor-Synchronizations/", "published_at": "2026-06-05 07:00:00+00:00", "updated_at": "2026-06-06 22:37:29.681347+00:00", "lang": "en", "topics": ["machine-learning", "neural-networks", "ai-infrastructure", "ai-tools"], "entities": ["TorchRec", "KeyedJaggedTensor", "Meta"], "alternates": {"html": "https://wpnews.pro/news/synchronizations-with-torchrec-keyedjaggedtensor", "markdown": "https://wpnews.pro/news/synchronizations-with-torchrec-keyedjaggedtensor.md", "text": "https://wpnews.pro/news/synchronizations-with-torchrec-keyedjaggedtensor.txt", "jsonld": "https://wpnews.pro/news/synchronizations-with-torchrec-keyedjaggedtensor.jsonld"}}