Skip to content

Commit 020e713

Browse files
authored
[FPQuant] MXFP8 and MXFP4 backwards support (#41897)
* FP-Quant backwards * fp-quant v0.3.0 docker * availability version bump * fp_quant==0.3.1 * fp_quant v0.3.2
1 parent 371ef0f commit 020e713

File tree

5 files changed

+20
-5
lines changed

5 files changed

+20
-5
lines changed

docker/transformers-quantization-latest-gpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ RUN python3 -m pip uninstall -y flash-attn
8181
RUN cd transformers && python3 setup.py develop
8282

8383
# Add fp-quant for quantization testing
84-
RUN python3 -m pip install --no-cache-dir "fp-quant>=0.2.0"
84+
RUN python3 -m pip install --no-cache-dir "fp-quant>=0.3.2"
8585

8686
# Low usage or incompatible lib, will enable later on
8787

src/transformers/integrations/fp_quant.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def adapt_fp_quant_config(config: FPQuantConfig):
3535

3636
if config.backward_dtype == "bf16":
3737
backward_dtype = FPQuantDtype.BF16
38+
elif config.backward_dtype == "mxfp8":
39+
backward_dtype = FPQuantDtype.MXFP8
40+
elif config.backward_dtype == "mxfp4":
41+
backward_dtype = FPQuantDtype.MXFP4
3842
else:
3943
raise ValueError(f"Unsupported backward dtype: {config.backward_dtype}")
4044

src/transformers/utils/import_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -973,13 +973,13 @@ def is_quark_available() -> bool:
973973
@lru_cache
974974
def is_fp_quant_available():
975975
is_available, fp_quant_version = _is_package_available("fp_quant", return_version=True)
976-
return is_available and version.parse(fp_quant_version) >= version.parse("0.2.0")
976+
return is_available and version.parse(fp_quant_version) >= version.parse("0.3.2")
977977

978978

979979
@lru_cache
980980
def is_qutlass_available():
981981
is_available, qutlass_version = _is_package_available("qutlass", return_version=True)
982-
return is_available and version.parse(qutlass_version) >= version.parse("0.1.0")
982+
return is_available and version.parse(qutlass_version) >= version.parse("0.2.0")
983983

984984

985985
@lru_cache

src/transformers/utils/quantization_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,8 +1601,12 @@ def post_init(self):
16011601
else:
16021602
raise ValueError("Only 'mxfp4' and 'nvfp4' are supported for forward_dtype for now.")
16031603

1604-
if self.backward_dtype != "bf16":
1605-
raise ValueError("Only 'bf16' is supported for backward_dtype for now.")
1604+
if self.backward_dtype not in ["bf16", "mxfp8", "mxfp4"]:
1605+
raise ValueError("Only 'bf16', 'mxfp8' and 'mxfp4' are supported for backward_dtype for now.")
1606+
1607+
if self.backward_dtype != "bf16" and self.forward_dtype != "mxfp4":
1608+
raise ValueError("Only 'mxfp4' forward is compatible with non-bf16 backwards for now.")
1609+
16061610
if self.transform_init not in ["hadamard", "identity", "gsr"]:
16071611
raise ValueError("Only 'hadamard', 'identity' and 'gsr' are supported for transform_init.")
16081612

tests/quantization/fp_quant_integration/test_fp_quant.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ def getQuantizationConfig(cls):
163163
return FPQuantConfig(forward_dtype="mxfp4", pseudoquantization=False)
164164

165165

166+
@require_qutlass
167+
class FPQuantNVFP4Test(FPQuantBaseTest):
168+
@classmethod
169+
def getQuantizationConfig(cls):
170+
return FPQuantConfig(forward_dtype="nvfp4", pseudoquantization=False)
171+
172+
166173
@require_qutlass
167174
class FPQuantMXFP4GS128Test(FPQuantBaseTest):
168175
@classmethod

0 commit comments

Comments
 (0)