Skip to content
93 changes: 67 additions & 26 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def __init__(
self,
model: nn.Module,
continuous_batching: bool = False,
qaic_config: Optional[dict] = None,
ccl_enabled: bool = False,
**kwargs,
):
"""
Expand All @@ -932,11 +932,11 @@ def __init__(
self.model = model
self.config = model.config

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)

self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
self.continuous_batching = continuous_batching
self.ccl_enabled = ccl_enabled
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
self.input_shapes, self.output_names = None, None

@property
Expand All @@ -955,7 +955,7 @@ def model_name(self) -> str:
return mname

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
"""
Load a QEfficient multimodal model for dual QPC from a pretrained HuggingFace model or local path.

Expand All @@ -980,11 +980,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
ccl_enabled = kwargs.pop("ccl_enabled", None)

model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(
model,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
ccl_enabled=ccl_enabled,
**kwargs,
)

Expand Down Expand Up @@ -1090,6 +1092,8 @@ def compile(
compile_dir: Optional[str] = None,
*,
prefill_seq_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
ctx_len: Optional[int] = None,
batch_size: int = 1,
full_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -1174,10 +1178,20 @@ def compile(

output_names = self.model.get_output_names(kv_offload=True)

# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
logger.warning(
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)

# For supporting VLLM and Disaggregated with CCL
if "comp_ctx_lengths_prefill" in compiler_options:
self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode

specializations, compiler_options = self.model.get_specializations(
batch_size=batch_size,
Expand Down Expand Up @@ -1600,7 +1614,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
def __init__(
self,
model: nn.Module,
qaic_config: Optional[dict] = None,
ccl_enabled: bool = False,
**kwargs,
):
"""
Expand All @@ -1622,8 +1636,6 @@ def __init__(
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
super().__init__(model, **kwargs)

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)

# to handle internvl models
if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"):
self.model.config.llm_config.use_cache = True
Expand All @@ -1635,12 +1647,13 @@ def __init__(
else:
self.model.config.use_cache = True
self.hash_params["qeff_auto_class"] = self.__class__.__name__
self.ccl_enabled = ccl_enabled
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
qaic_config: Optional[dict] = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -1671,6 +1684,8 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
ccl_enabled = kwargs.pop("ccl_enabled", None)

from transformers import AutoConfig

config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
Expand All @@ -1681,7 +1696,7 @@ def from_pretrained(
return cls(
model,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
ccl_enabled=ccl_enabled,
**kwargs,
)

Expand Down Expand Up @@ -1725,6 +1740,8 @@ def compile(
*,
prefill_seq_len: Optional[int] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
batch_size: int = 1,
full_batch_size: Optional[int] = None,
kv_cache_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -1794,10 +1811,20 @@ def compile(
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
output_names = self.model.get_output_names()

# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
logger.warning(
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)

# For supporting VLLM and Disaggregated with CCL
if "comp_ctx_lengths_prefill" in compiler_options:
self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode

# Get specializations from modelling file
# TODO: expose this via the auto class as well
Expand Down Expand Up @@ -2180,7 +2207,7 @@ def __new__(
model: nn.Module,
kv_offload: Optional[bool] = True,
continuous_batching: bool = False,
qaic_config: Optional[dict] = None,
ccl_enabled: bool = False,
**kwargs,
):
"""
Expand All @@ -2204,10 +2231,10 @@ def __new__(
"""
if kv_offload:
return _QEffAutoModelForImageTextToTextDualQPC(
model, continuous_batching, qaic_config=qaic_config, **kwargs
model, continuous_batching, ccl_enabled=ccl_enabled, **kwargs
)
else:
return _QEFFAutoModelForImageTextToTextSingleQPC(model, qaic_config=qaic_config, **kwargs)
return _QEFFAutoModelForImageTextToTextSingleQPC(model, ccl_enabled=ccl_enabled, **kwargs)

@classmethod
@with_replaced_quantizers
Expand Down Expand Up @@ -2257,14 +2284,15 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
ccl_enabled = kwargs.pop("ccl_enabled", None)

model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(
model,
kv_offload=kv_offload,
continuous_batching=continuous_batching,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
ccl_enabled=ccl_enabled,
**kwargs,
)

Expand Down Expand Up @@ -2317,6 +2345,7 @@ def __init__(
model: nn.Module,
continuous_batching: bool = False,
qaic_config: Optional[dict] = None,
ccl_enabled: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -2363,8 +2392,6 @@ def __init__(
# Set use_cache=True to get KV values as output during ONNX export
model.config.use_cache = True

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)

super().__init__(model, qaic_config=qaic_config, **kwargs)
self.num_layers = model.config.num_hidden_layers
self.continuous_batching = continuous_batching
Expand All @@ -2373,6 +2400,8 @@ def __init__(
self.is_tlm = transformed

self.hash_params["qeff_auto_class"] = self.__class__.__name__
self.ccl_enabled = ccl_enabled
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None

# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
Expand Down Expand Up @@ -2465,6 +2494,7 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")

kv_offload = kwargs.pop("kv_offload", None)
ccl_enabled = kwargs.pop("ccl_enabled", None)

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
Expand All @@ -2478,14 +2508,15 @@ def from_pretrained(
model,
kv_offload=kv_offload,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
ccl_enabled=ccl_enabled,
**kwargs,
)
return cls(
model,
continuous_batching=continuous_batching,
qaic_config=qaic_config,
pretrained_model_name_or_path=pretrained_model_name_or_path,
ccl_enabled=ccl_enabled,
**kwargs,
)

Expand Down Expand Up @@ -2814,6 +2845,8 @@ def compile(
*,
prefill_seq_len: int = 32,
ctx_len: int = 128,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
batch_size: int = 1,
full_batch_size: Optional[int] = None,
kv_cache_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -2905,10 +2938,18 @@ def compile(

"""

# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
logger.warning(
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)

# For supporting VLLM and Disaggregated with CCL
if "comp_ctx_lengths_prefill" in compiler_options and "comp_ctx_lengths_decode" in compiler_options:
comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
if isinstance(comp_ctx_lengths_prefill, str):
import ast

Expand Down
9 changes: 1 addition & 8 deletions QEfficient/utils/check_ccl_specializations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,7 @@
# -----------------------------------------------------------------------------


def process_ccl_specializations(qaic_config):
if qaic_config is None:
return None, None
ccl_prefill = qaic_config.pop("comp_ctx_lengths_prefill", None)
ccl_decode = qaic_config.pop("comp_ctx_lengths_decode", None)
ctx_len = qaic_config.pop("ctx_len", None)
prefill_seq_len = qaic_config.pop("prefill_seq_len", 128)

def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
if ccl_prefill is None or ccl_decode is None:
return None, None

Expand Down
50 changes: 50 additions & 0 deletions examples/performance/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,56 @@ python on_device_sampling.py \
--top-p 0.89
```

### Compute-Context-Length

Calculating Context-Length dynamically during inference for getting the best related performance within each window of context-length

#### compute_context_length/basic_inference.py
Configure CCL parameters: 1) ccl-enabled: to activate CCL feature, 2) comp-ctx-lengths-prefill: list of context length to be used during prefilling, and 3) comp-ctx-lengths-decode: list of context lengths to be used during decoding.

**Usage for Text-only models:**
```bash
python compute_context_length/basic_inference.py \
--model-name meta-llama/Llama-3.1-8B \
--num-cores 16 \
--prefill-seq-len 32 \
--ctx-len 1024 \
--ccl-enabled \
--comp-ctx-lengths-prefill 500,1000 \
--comp-ctx-lengths-decode 512,1024
```

**Usage for VLM models such as mllama and llava:**
```bash
python compute_context_length/vlm_inference.py \
--model-name meta-llama/Llama-3.2-11B-Vision-Instruct \
--hf-token "" \
--num-cores 16 \
--prefill-seq-len 32 \
--ctx-len 8192 \
--img-size 560 \
--ccl-enabled \
--comp-ctx-lengths-prefill 4096 \
--comp-ctx-lengths-decode 6144,8192
```

**Usage with other MoE and Multimodal models:**
For various models available in compute_context_length directory such as gemma3, gpt_oss, granite_vision, internvl, llama4_cb, llama4_multi_image, llama4, mistral3, molmo, qwen2_5_vl, qwen2_5_vl_cb, and qwen3moe, use the related inference script and only change the model-name and ccl configuration in the related script. The following is an example of each model:
```bash
python compute_context_length/gemma3.py
python compute_context_length/gpt_oss.py
python compute_context_length/granite_vision.py
python compute_context_length/internvl.py
python compute_context_length/llama4_cb.py
python compute_context_length/llama4_multi_image.py
python compute_context_length/llama4.py
python compute_context_length/mistral3.py
python compute_context_length/molmo.py
python compute_context_length/qwen2_5_vl.py
python compute_context_length/qwen2_5_vl_cb.py
python compute_context_length/qwen3moe.py
```

## Performance Tips

1. **Speculative Decoding**: Best for long-form generation where draft model is much faster than target
Expand Down
5 changes: 4 additions & 1 deletion examples/performance/compute_context_length/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ python vlm_inference.py \
Basic CCL usage with text-only language models.

**Supported Models:**
- Llama (3.2, 3.3)
- Llama (3.2, 3.3, swiftkv)
- Gemma/Gemma-2
- Mistral
- Phi/Phi-3
Expand All @@ -77,6 +77,9 @@ Basic CCL usage with text-only language models.
- GPT-2, GPT-J
- CodeGen
- OLMo-2
- Mistral/Mixtral
- Qwen2
- Falcon

**Command-Line Arguments:**
- `--model-name`: HuggingFace model ID (default: meta-llama/Llama-3.2-1B)
Expand Down
Loading
Loading