Skip to content

Commit 3761cdd

Browse files
authored
Merge branch 'main' into fix-fsdp2-default-version
2 parents 2252af4 + ef780bf commit 3761cdd

File tree

9 files changed

+317
-192
lines changed

9 files changed

+317
-192
lines changed

benchmark_v2/framework/benchmark_config.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from functools import lru_cache
66
from typing import Any
77

8+
from transformers.generation.configuration_utils import CompileConfig
89
from transformers.utils.import_utils import is_flash_attn_2_available, is_kernels_available
910

1011

@@ -61,8 +62,7 @@ def __init__(
6162
sequence_length: int = 128,
6263
num_tokens_to_generate: int = 128,
6364
attn_implementation: str = "eager",
64-
compile_mode: str | None = None,
65-
compile_options: dict[str, Any] | None = None,
65+
compile_kwargs: dict[str, Any] | None = None,
6666
kernelize: bool = False,
6767
name: str | None = None,
6868
skip_validity_check: bool = False,
@@ -79,8 +79,11 @@ def __init__(
7979
# Generation parameters
8080
self.attn_implementation = attn_implementation
8181
# Optimization parameters
82-
self.compile_mode = compile_mode
83-
self.compile_options = compile_options if compile_options is not None else {}
82+
if compile_kwargs is None:
83+
self.compile_config = None
84+
else:
85+
compile_kwargs["fullgraph"] = compile_kwargs.get("fullgraph", True)
86+
self.compile_config = CompileConfig(**compile_kwargs)
8487
self.kernelize = kernelize
8588
# Constant parameters
8689
self.dtype = "torch.bfloat16"
@@ -92,22 +95,41 @@ def __init__(
9295
def check_validity(self, skip_validity_check: bool = False) -> None:
9396
if skip_validity_check:
9497
return
95-
# Check FA is installed
96-
is_fa = self.attn_implementation == "flash_attention_2"
97-
if is_fa and not is_fa2_or_kernel_available():
98-
logger.warning("Flash attention is not available. Defaulting to SDPA.")
98+
99+
# If flash_attention_2 is selected but not available, default to SDPA
100+
if self.attn_implementation == "flash_attention_2" and not is_fa2_or_kernel_available():
101+
logger.error("Flash attention is not available. Defaulting to SDPA.")
99102
self.attn_implementation = "sdpa"
100-
# Flash attention does not support compile mode, so we turn it off # FIXME: it would be better to support it
101-
if is_fa and self.compile_mode is not None:
102-
logger.warning("Flash attention does not support compile mode. Turning off compile mode.")
103-
self.compile_mode = None
104-
# Handle continuous batching cases
105-
if self.continuous_batching:
106-
if self.attn_implementation == "flex_attention":
107-
logger.error(
108-
"Disabling continuous batching because of invalid configuration: flex attention is not supported."
109-
)
110-
self.continuous_batching = False
103+
104+
# The combination of flash_attention_2, compile and generate is not supported # FIXME: support it
105+
if (
106+
not self.continuous_batching
107+
and self.attn_implementation == "flash_attention_2"
108+
and self.compile_config is not None
109+
):
110+
logger.error(
111+
"The combination of flash_attention_2, compile and generate is not supported. Turning off compile."
112+
)
113+
self.compile_config = None
114+
115+
# Continuous batching does not support flex attention as an attention implementation # FIXME: support it
116+
if self.attn_implementation == "flex_attention" and self.continuous_batching:
117+
logger.error(
118+
"Disabling continuous batching because of invalid configuration: flex attention is not supported."
119+
)
120+
self.continuous_batching = False
121+
122+
# Continuous batching supports compile mode "default" or "max-autotune-no-cudagraphs"
123+
if (
124+
self.continuous_batching
125+
and self.compile_config is not None
126+
and self.compile_config.mode not in ["default", "max-autotune-no-cudagraphs"]
127+
):
128+
logger.error(
129+
f"You have continuous batching and compile enabled, but {self.compile_config.mode = } is not supported."
130+
" Supported modes are: default, max-autotune-no-cudagraphs. Changing to default."
131+
)
132+
self.compile_config.mode = "default"
111133

112134
@property
113135
def hash(self) -> str:
@@ -120,7 +142,7 @@ def infer_name(self, compact: bool = True) -> str:
120142
gpu_monitor_str = "monitored" if self.gpu_monitoring else "unmonitored"
121143
dimensions_str = f"b{self.batch_size}_s{self.sequence_length}_n{self.num_tokens_to_generate}"
122144
attn_code = self.attn_implementation
123-
compile_str = f"compiled_{self.compile_mode}" if self.compile_mode is not None else "uncompiled"
145+
compile_str = f"compiled_{self.compile_config.mode}" if self.compile_config is not None else "uncompiled"
124146
kernelize_str = "kernelized" if self.kernelize else "unkernelized"
125147
continuous_batching_str = "cb" if self.continuous_batching else "generate"
126148
sep = "-"
@@ -129,7 +151,7 @@ def infer_name(self, compact: bool = True) -> str:
129151
gpu_monitor_str = ("with" if self.gpu_monitoring else "no") + " GPU monitoring"
130152
dimensions_str = f"batch size {self.batch_size}, sequence length {self.sequence_length}, {self.num_tokens_to_generate} generated tokens"
131153
attn_code = f"{self.attn_implementation} attention"
132-
compile_str = "compiled" if self.compile_mode is not None else "not compiled"
154+
compile_str = "compiled" if self.compile_config is not None else "not compiled"
133155
kernelize_str = "kernelized" if self.kernelize else "not kernelized"
134156
continuous_batching_str = "continuous batching" if self.continuous_batching else "regular generate"
135157
sep = ", "
@@ -148,8 +170,7 @@ def to_dict(self) -> dict[str, Any]:
148170
"sequence_length": self.sequence_length,
149171
"num_tokens_to_generate": self.num_tokens_to_generate,
150172
"attn_implementation": self.attn_implementation,
151-
"compile_mode": self.compile_mode,
152-
"compile_options": self.compile_options | {}, # to avoid inplace modification of the original dict
173+
"compile_kwargs": self.compile_config.to_dict() if self.compile_config is not None else None,
153174
"kernelize": self.kernelize,
154175
}
155176

@@ -164,8 +185,7 @@ def from_dict(cls, data: dict[str, Any], skip_validity_check: bool = False) -> "
164185
sequence_length=data.get("sequence_length", 128),
165186
num_tokens_to_generate=data.get("num_tokens_to_generate", 128),
166187
attn_implementation=data.get("attn_implementation", "eager"),
167-
compile_mode=data.get("compile_mode"),
168-
compile_options=data.get("compile_options"),
188+
compile_kwargs=data.get("compile_kwargs"),
169189
kernelize=data.get("kernelize", False),
170190
name=data.get("name"),
171191
skip_validity_check=skip_validity_check,
@@ -218,27 +238,28 @@ def get_config_by_level(level: int) -> list[BenchmarkConfig]:
218238
# Usually there is not much to gain by compiling with other modes, but we allow it for level 4
219239
compile_modes = BenchmarkConfig.all_compiled_modes if level >= 4 else [None, "default"]
220240
for cm in compile_modes:
241+
compile_kwargs = {"mode": cm} if cm is not None else None
221242
for kernelize_on in {False, KERNELIZATION_AVAILABLE}:
222243
for cb_on in [False, True]:
223244
configs.append(
224245
BenchmarkConfig(
225246
attn_implementation=attn_implementation,
226-
compile_mode=cm,
247+
compile_kwargs=compile_kwargs,
227248
kernelize=kernelize_on,
228249
continuous_batching=cb_on,
229250
)
230251
)
231252
return configs
232253
# Otherwise, we add the configs for the given level
233254
if level >= 0:
234-
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default"))
255+
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_kwargs={}))
235256
if level >= 1:
236257
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2"))
237-
configs.append(BenchmarkConfig(attn_implementation="eager", compile_mode="default"))
258+
configs.append(BenchmarkConfig(attn_implementation="eager", compile_kwargs={}))
238259
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", continuous_batching=True))
239260
if level >= 2:
240-
configs.append(BenchmarkConfig(attn_implementation="sdpa", compile_mode="default"))
241-
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", kernelize=True))
261+
configs.append(BenchmarkConfig(attn_implementation="sdpa", compile_kwargs={}))
262+
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_kwargs={}, kernelize=True))
242263
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", kernelize=True))
243264
configs.append(BenchmarkConfig(attn_implementation="sdpa", continuous_batching=True))
244265
return configs

0 commit comments

Comments
 (0)