55from functools import lru_cache
66from typing import Any
77
8+ from transformers .generation .configuration_utils import CompileConfig
89from transformers .utils .import_utils import is_flash_attn_2_available , is_kernels_available
910
1011
@@ -61,8 +62,7 @@ def __init__(
6162 sequence_length : int = 128 ,
6263 num_tokens_to_generate : int = 128 ,
6364 attn_implementation : str = "eager" ,
64- compile_mode : str | None = None ,
65- compile_options : dict [str , Any ] | None = None ,
65+ compile_kwargs : dict [str , Any ] | None = None ,
6666 kernelize : bool = False ,
6767 name : str | None = None ,
6868 skip_validity_check : bool = False ,
@@ -79,8 +79,11 @@ def __init__(
7979 # Generation parameters
8080 self .attn_implementation = attn_implementation
8181 # Optimization parameters
82- self .compile_mode = compile_mode
83- self .compile_options = compile_options if compile_options is not None else {}
82+ if compile_kwargs is None :
83+ self .compile_config = None
84+ else :
85+ compile_kwargs ["fullgraph" ] = compile_kwargs .get ("fullgraph" , True )
86+ self .compile_config = CompileConfig (** compile_kwargs )
8487 self .kernelize = kernelize
8588 # Constant parameters
8689 self .dtype = "torch.bfloat16"
@@ -92,22 +95,41 @@ def __init__(
9295 def check_validity (self , skip_validity_check : bool = False ) -> None :
9396 if skip_validity_check :
9497 return
95- # Check FA is installed
96- is_fa = self . attn_implementation == "flash_attention_2"
97- if is_fa and not is_fa2_or_kernel_available ():
98- logger .warning ("Flash attention is not available. Defaulting to SDPA." )
98+
99+ # If flash_attention_2 is selected but not available, default to SDPA
100+ if self . attn_implementation == "flash_attention_2" and not is_fa2_or_kernel_available ():
101+ logger .error ("Flash attention is not available. Defaulting to SDPA." )
99102 self .attn_implementation = "sdpa"
100- # Flash attention does not support compile mode, so we turn it off # FIXME: it would be better to support it
101- if is_fa and self .compile_mode is not None :
102- logger .warning ("Flash attention does not support compile mode. Turning off compile mode." )
103- self .compile_mode = None
104- # Handle continuous batching cases
105- if self .continuous_batching :
106- if self .attn_implementation == "flex_attention" :
107- logger .error (
108- "Disabling continuous batching because of invalid configuration: flex attention is not supported."
109- )
110- self .continuous_batching = False
103+
104+ # The combination of flash_attention_2, compile and generate is not supported # FIXME: support it
105+ if (
106+ not self .continuous_batching
107+ and self .attn_implementation == "flash_attention_2"
108+ and self .compile_config is not None
109+ ):
110+ logger .error (
111+ "The combination of flash_attention_2, compile and generate is not supported. Turning off compile."
112+ )
113+ self .compile_config = None
114+
115+ # Continuous batching does not support flex attention as an attention implementation # FIXME: support it
116+ if self .attn_implementation == "flex_attention" and self .continuous_batching :
117+ logger .error (
118+ "Disabling continuous batching because of invalid configuration: flex attention is not supported."
119+ )
120+ self .continuous_batching = False
121+
122+ # Continuous batching supports compile mode "default" or "max-autotune-no-cudagraphs"
123+ if (
124+ self .continuous_batching
125+ and self .compile_config is not None
126+ and self .compile_config .mode not in ["default" , "max-autotune-no-cudagraphs" ]
127+ ):
128+ logger .error (
129+ f"You have continuous batching and compile enabled, but { self .compile_config .mode = } is not supported."
130+ " Supported modes are: default, max-autotune-no-cudagraphs. Changing to default."
131+ )
132+ self .compile_config .mode = "default"
111133
112134 @property
113135 def hash (self ) -> str :
@@ -120,7 +142,7 @@ def infer_name(self, compact: bool = True) -> str:
120142 gpu_monitor_str = "monitored" if self .gpu_monitoring else "unmonitored"
121143 dimensions_str = f"b{ self .batch_size } _s{ self .sequence_length } _n{ self .num_tokens_to_generate } "
122144 attn_code = self .attn_implementation
123- compile_str = f"compiled_{ self .compile_mode } " if self .compile_mode is not None else "uncompiled"
145+ compile_str = f"compiled_{ self .compile_config . mode } " if self .compile_config is not None else "uncompiled"
124146 kernelize_str = "kernelized" if self .kernelize else "unkernelized"
125147 continuous_batching_str = "cb" if self .continuous_batching else "generate"
126148 sep = "-"
@@ -129,7 +151,7 @@ def infer_name(self, compact: bool = True) -> str:
129151 gpu_monitor_str = ("with" if self .gpu_monitoring else "no" ) + " GPU monitoring"
130152 dimensions_str = f"batch size { self .batch_size } , sequence length { self .sequence_length } , { self .num_tokens_to_generate } generated tokens"
131153 attn_code = f"{ self .attn_implementation } attention"
132- compile_str = "compiled" if self .compile_mode is not None else "not compiled"
154+ compile_str = "compiled" if self .compile_config is not None else "not compiled"
133155 kernelize_str = "kernelized" if self .kernelize else "not kernelized"
134156 continuous_batching_str = "continuous batching" if self .continuous_batching else "regular generate"
135157 sep = ", "
@@ -148,8 +170,7 @@ def to_dict(self) -> dict[str, Any]:
148170 "sequence_length" : self .sequence_length ,
149171 "num_tokens_to_generate" : self .num_tokens_to_generate ,
150172 "attn_implementation" : self .attn_implementation ,
151- "compile_mode" : self .compile_mode ,
152- "compile_options" : self .compile_options | {}, # to avoid inplace modification of the original dict
173+ "compile_kwargs" : self .compile_config .to_dict () if self .compile_config is not None else None ,
153174 "kernelize" : self .kernelize ,
154175 }
155176
@@ -164,8 +185,7 @@ def from_dict(cls, data: dict[str, Any], skip_validity_check: bool = False) -> "
164185 sequence_length = data .get ("sequence_length" , 128 ),
165186 num_tokens_to_generate = data .get ("num_tokens_to_generate" , 128 ),
166187 attn_implementation = data .get ("attn_implementation" , "eager" ),
167- compile_mode = data .get ("compile_mode" ),
168- compile_options = data .get ("compile_options" ),
188+ compile_kwargs = data .get ("compile_kwargs" ),
169189 kernelize = data .get ("kernelize" , False ),
170190 name = data .get ("name" ),
171191 skip_validity_check = skip_validity_check ,
@@ -218,27 +238,28 @@ def get_config_by_level(level: int) -> list[BenchmarkConfig]:
218238 # Usually there is not much to gain by compiling with other modes, but we allow it for level 4
219239 compile_modes = BenchmarkConfig .all_compiled_modes if level >= 4 else [None , "default" ]
220240 for cm in compile_modes :
241+ compile_kwargs = {"mode" : cm } if cm is not None else None
221242 for kernelize_on in {False , KERNELIZATION_AVAILABLE }:
222243 for cb_on in [False , True ]:
223244 configs .append (
224245 BenchmarkConfig (
225246 attn_implementation = attn_implementation ,
226- compile_mode = cm ,
247+ compile_kwargs = compile_kwargs ,
227248 kernelize = kernelize_on ,
228249 continuous_batching = cb_on ,
229250 )
230251 )
231252 return configs
232253 # Otherwise, we add the configs for the given level
233254 if level >= 0 :
234- configs .append (BenchmarkConfig (attn_implementation = "flex_attention" , compile_mode = "default" ))
255+ configs .append (BenchmarkConfig (attn_implementation = "flex_attention" , compile_kwargs = {} ))
235256 if level >= 1 :
236257 configs .append (BenchmarkConfig (attn_implementation = "flash_attention_2" ))
237- configs .append (BenchmarkConfig (attn_implementation = "eager" , compile_mode = "default" ))
258+ configs .append (BenchmarkConfig (attn_implementation = "eager" , compile_kwargs = {} ))
238259 configs .append (BenchmarkConfig (attn_implementation = "flash_attention_2" , continuous_batching = True ))
239260 if level >= 2 :
240- configs .append (BenchmarkConfig (attn_implementation = "sdpa" , compile_mode = "default" ))
241- configs .append (BenchmarkConfig (attn_implementation = "flex_attention" , compile_mode = "default" , kernelize = True ))
261+ configs .append (BenchmarkConfig (attn_implementation = "sdpa" , compile_kwargs = {} ))
262+ configs .append (BenchmarkConfig (attn_implementation = "flex_attention" , compile_kwargs = {} , kernelize = True ))
242263 configs .append (BenchmarkConfig (attn_implementation = "flash_attention_2" , kernelize = True ))
243264 configs .append (BenchmarkConfig (attn_implementation = "sdpa" , continuous_batching = True ))
244265 return configs
0 commit comments