Triton implementation of GPT/LLAMA models. Objective of this project is to understand how much performance can be squeezed out if we implement full-GPT-block in one triton kernel.
Performance
triton implementation is more fast & memory efficient compared to HuggingFace Transformers implementation.
python3 bench.pyLatency
| precision | HuggingFace GPT | Triton GPT |
|---|---|---|
| fp32 | 1800 ms | - |
| tf32 | 631.35 ms | 462.63 ms |
| mixed precision (fp16) | 510.80 ms | 273 ms |
| fp16 | 301.92 ms | - |
time taken to process batch size - 512x300 on 1 A100 40 GB
Max Batch Size
| max batch size | |
|---|---|
| HuggingFace GPT | 1024 |
| Triton GPT | 2048 |
I considered batch sizes with power of 2 only. Both runs had seqlen=300 and mixed precision was enabled.
MFU
from gpt import compute_mfu
# fwd MFU
# HuggingFace GPT (fp16)
compute_mfu(2 * 124 * 10**6 * 512*512 / 0.302, gpu="h100")
# 21.76%
# HuggingFace GPT (mixed precision)
compute_mfu(2 * 124 * 10**6 * 512*512 / 0.510, gpu="h100")
# 12.88%
# triton (mixed precision)
compute_mfu(2 * 124 * 10**6 * 512*512 / 0.273, gpu="h100")
# 24.07%Supported Features
- fused implementation of several components of GPT block (for eg:
dropout(wte(x) + wpe(x)),dropout(wx + b),gelu(wx + b)) - flash attention v1 algorithm
- GPT2 implementation in triton
- support for loading pre-trained weights of huggingface-gpt2
- support KV cache & sampling for inference loop
- implement back-propogation of GPT block in triton (i.e. solving the math problem)
- implement paged-attention from vLLM project in triton
- implement flash attention v2 & v3
- add kernels for LLAMA-3.1
- implement adamw in triton (with FSDP-stage2 support)
Installation
pip3 install -r requirements.txt
# `numpy<2` is hard-requirement for running on CPU
# else triton gives garbage - likely some bug in tritonRunning tests
# you can run following command on CPU
TRITON_INTERPRET=1 pytest -sv test.py
# you can run following command on GPU
pytest -sv test.py