Introduction
In recommendation systems, sparse features such as user-item interaction ids are often used to model user preferences and item characteristics. These sparse features are then mapped to dense representations through large embedding tables.
However, there are a few challenges when working with sparse features in recommendation systems:
- Different samples may have different numbers of interactions, leading to variable-length input data.
- There are often lots of sparse features being used in recommendation systems.
In a batch of requests, if all sparse features are padded to the same length, the embedding tables will produce many useless embedding vectors. That wastes memory and downstream compute resources. If each sparse feature accesses embedding tables independently, the overhead becomes large when the number of sparse features is large.
TorchRec KeyedJaggedTensor was designed to address these challenges by combining sparse features across samples and across features into one large sparse feature without padding. This eliminates the memory and compute inefficiencies.
Despite its efficiency, KeyedJaggedTensor
has several caveats and can be used incorrectly, resulting in worse system performance. One key issue in GPU systems is synchronization. In this blog post, I would like to discuss the main pitfalls of KeyedJaggedTensor
and how to use it efficiently on GPU.
TorchRec Data Types
TorchRec has specific input/output data types of its modules to efficiently represent sparse features, including:
JaggedTensor
: a wrapper around the lengths or offsets tensor and the values tensor for a single sparse feature.KeyedJaggedTensor
: a wrapper that represents multiple sparse features and can be thought of as multipleJaggedTensor
s.KeyedTensor
: a wrapper aroundtorch.Tensor
that allows access to tensor values through keys.
KeyedJaggedTensor
can be constructed from a dictionary of JaggedTensor
s, where the keys are the feature names. The output of KeyedJaggedTensor
and EmbeddingBagCollection
is KeyedTensor
, whose embeddings can be accessed through keys.
Synchronizations With TorchRec KeyedJaggedTensor
When using KeyedJaggedTensor
, the critical question is what the output symbolic shape will be for a given input KeyedJaggedTensor
. Such an operation could be using KeyedJaggedTensor
to access EmbeddingBagCollection
, or getting the value tensor corresponding to a specific key in KeyedJaggedTensor
. This symbolic shape cannot be derived from the symbolic shapes of the value, lengths, or offsets tensors in KeyedJaggedTensor
. In other words, any operation that uses KeyedJaggedTensor
as input is data dependent, and the output shape can only be determined from the actual lengths data at runtime. If the value, lengths, and offsets tensors are on GPU, TorchRec has to copy the lengths from GPU to CPU to infer the output shape, which introduces synchronization and can hurt performance.
To mitigate this problem, KeyedJaggedTensor
saves key metadata in lists when it is constructed. For some metadata, such as lengths per key, the values can be determined directly from the corresponding JaggedTensor
without looking at the flattened lengths tensor. In this way, even after the original JaggedTensor
s are no longer available, the output shape can still be determined without reading the actual lengths data, which avoids GPU-CPU synchronization when KeyedJaggedTensor
is on GPU.
However, this is not how the key metadata is derived in the current implementation of KeyedJaggedTensor
. In eager mode, the metadata is derived from the actual data in the lengths tensors, which causes GPU-CPU synchronization if KeyedJaggedTensor
is on GPU. In TorchDynamo compile mode, the metadata is not computed and saved during construction to avoid that synchronization. This only defers the synchronization to the point when KeyedJaggedTensor
is used in an operation, so it does not really solve the problem.
One might ask why the current KeyedJaggedTensor
implementation does not just derive the key metadata from the JaggedTensor
metadata, since that would avoid GPU-CPU synchronization altogether. The reason is that constructing this metadata from JaggedTensor
objects cannot be traced into computation graph in TorchDynamo compile mode, at least for now, because it involves Python list appending and other bookkeeping that the computation graph cannot support. KeyedJaggedTensor
is not a torch.Tensor
after all; only the metadata in a torch.Tensor
can be symbolically traced in TorchDynamo compile mode. In non-TorchDynamo compile mode, the current implementation may record the construction in the computation graph, but GPU-CPU synchronization is still unavoidable if the JaggedTensor
s are on GPU.
Consequently, if KeyedJaggedTensor
is constructed from GPU JaggedTensor
s and used in compile mode, GPU-CPU synchronization is inevitable.
An Illustration of KeyedJaggedTensor
Metadata Problems
Suppose a batch has two sparse features, user_clicked_item_ids
and user_viewed_item_ids
, and two samples:
- Sample 0:
user_clicked_item_ids = [10, 11]
,user_viewed_item_ids = [20]
- Sample 1:
user_clicked_item_ids = [12]
,user_viewed_item_ids = [21, 22, 23]
These can be flattened into one KeyedJaggedTensor
as:
1234567
|
from torchrec.sparse.jagged_tensor import KeyedJaggedTensorkjt = KeyedJaggedTensor( keys=["user_clicked_item_ids", "user_viewed_item_ids"], values=[10, 11, 20, 12, 21, 22, 23], lengths=[2, 1, 1, 3],)
|
Here, lengths
stores the number of values per feature per sample, in key order:
user_clicked_item_ids
for sample 0 has length2
user_viewed_item_ids
for sample 0 has length1
user_clicked_item_ids
for sample 1 has length1
user_viewed_item_ids
for sample 1 has length3
This small example is enough to see why KeyedJaggedTensor
is data dependent. If a downstream operator needs the output shape for user_viewed_item_ids
, it has to know that the lengths for that key are sum([1, 3]) = 4
. When those lengths live on GPU, TorchRec has to move them back to CPU to determine the shape, which creates the synchronization.
If we only want to derive metadata from the JaggedTensor
s themselves, the idea is much simpler. Conceptually, TorchRec could build the key metadata directly from the per-key JaggedTensor
objects:
1234567891011
|
from torchrec.sparse.jagged_tensor import JaggedTensorjts = { "user_clicked_item_ids": JaggedTensor(values=[10, 11], lengths=[2, 1]), "user_viewed_item_ids": JaggedTensor(values=[20, 21, 22, 23], lengths=[1, 3]),}lengths_per_key = []for key in ["user_clicked_item_ids", "user_viewed_item_ids"]: jt = jts[key] lengths_per_key.append(len(jt.lengths()))
|
In this illustration, the metadata is derived from the JaggedTensor
structure itself, not from the flattened KeyedJaggedTensor
values. That is the shape information we would want to preserve, because it can be known before any downstream operation touches the actual sparse values. However, this pattern is not suitable for TorchDynamo compile mode today, because the Python-side bookkeeping needed to collect the metadata is not traceable into the computation graph.
Conclusions
To use KeyedJaggedTensor
efficiently in a GPU system, it should be constructed from CPU JaggedTensor
s and then moved to GPU in eager mode. This usually means KeyedJaggedTensor
should be constructed in the data preprocessing stage. In a model running on GPU, KeyedJaggedTensor
should only be used as model input. One should avoid constructing KeyedJaggedTensor
from GPU JaggedTensor
s, especially inside the model. In this way, the GPU operations in the model can run asynchronously without GPU-CPU synchronization, which results in the best performance.
References
Synchronizations With TorchRec KeyedJaggedTensor
https://leimao.github.io/blog/TorchRec-KeyedJaggedTensor-Synchronizations/