-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Description
Feature request
Across the codebase there are places where tensor ops like .item(), .nonzero() are used. The PyTorch docs state the these operations cause host↔device synchronization if the tensor is on GPU. This can significantly hurt performance and block the CPU while long GPU kernels run.
For reference one such instance which has been fixed in #42433 is given below:
| cache_position: torch.Tensor = torch.arange( |
past_key_values is an instance of CacheLayerMixin. Now, if someone is using StaticLayer then the get_seq_len() method returns value after doing tensor operations so it will be a 0-d tensor. Hence, past_seen_tokens will be a 0-d tensor. But since torch.arange() expects numbers as for its start and end it expects Number. So it will get the value to CPU to make the arange tensor. This will cause an implicit call to .item().
If the past_seen_tokens is on gpu then it calls cudaStreamSynchronize which can block the CPU if a large kernel is running on GPU. The profile image above demonstrates the sync. Additionally, this also seemed to cause graph-breaks while using with torch.compile().
Fix is to make change given below:
cache_position: torch.Tensor = torch.arange(
inputs_embeds.shape[1], device=inputs_embeds.device
) + past_seen_tokens
Now the arange call gets a proper int to work with and the past_seen_tokens tokens is added to the arange tensor assuming both will be on GPU without any sync. This is also torch.compile() friendly as needed for StaticLayer.
Another reference PR #40060.
What to look for to identify similar instances:
- For torch ops that require Number/float/int inputs check if the input can be a tensor.
- Explicit calls to
.item(),.nonzero()or similar operations that can cause a sync. Check PyTorch docs for more such ops. - Looking at profiles can help identify such instances.
- torch.cuda.set_sync_debug_mode("warn") might also help find more such instances
- If the issue is present in modeling files be sure to run
make fix-copiesto apply changes to other similar files after confirming the changes with the maintainers.
NOTE: While the above the reference instances could be fixed with minor refactoring others might need some larger-scale changes so better to consult with maintainers before putting effort into such changes.