File tree Expand file tree Collapse file tree 5 files changed +20
-5
lines changed
docker/transformers-quantization-latest-gpu
tests/quantization/fp_quant_integration Expand file tree Collapse file tree 5 files changed +20
-5
lines changed Original file line number Diff line number Diff line change @@ -81,7 +81,7 @@ RUN python3 -m pip uninstall -y flash-attn
8181RUN cd transformers && python3 setup.py develop
8282
8383# Add fp-quant for quantization testing
84- RUN python3 -m pip install --no-cache-dir "fp-quant>=0.2.0 "
84+ RUN python3 -m pip install --no-cache-dir "fp-quant>=0.3.2 "
8585
8686# Low usage or incompatible lib, will enable later on
8787
Original file line number Diff line number Diff line change @@ -35,6 +35,10 @@ def adapt_fp_quant_config(config: FPQuantConfig):
3535
3636 if config .backward_dtype == "bf16" :
3737 backward_dtype = FPQuantDtype .BF16
38+ elif config .backward_dtype == "mxfp8" :
39+ backward_dtype = FPQuantDtype .MXFP8
40+ elif config .backward_dtype == "mxfp4" :
41+ backward_dtype = FPQuantDtype .MXFP4
3842 else :
3943 raise ValueError (f"Unsupported backward dtype: { config .backward_dtype } " )
4044
Original file line number Diff line number Diff line change @@ -973,13 +973,13 @@ def is_quark_available() -> bool:
973973@lru_cache
974974def is_fp_quant_available ():
975975 is_available , fp_quant_version = _is_package_available ("fp_quant" , return_version = True )
976- return is_available and version .parse (fp_quant_version ) >= version .parse ("0.2.0 " )
976+ return is_available and version .parse (fp_quant_version ) >= version .parse ("0.3.2 " )
977977
978978
979979@lru_cache
980980def is_qutlass_available ():
981981 is_available , qutlass_version = _is_package_available ("qutlass" , return_version = True )
982- return is_available and version .parse (qutlass_version ) >= version .parse ("0.1 .0" )
982+ return is_available and version .parse (qutlass_version ) >= version .parse ("0.2 .0" )
983983
984984
985985@lru_cache
Original file line number Diff line number Diff line change @@ -1601,8 +1601,12 @@ def post_init(self):
16011601 else :
16021602 raise ValueError ("Only 'mxfp4' and 'nvfp4' are supported for forward_dtype for now." )
16031603
1604- if self .backward_dtype != "bf16" :
1605- raise ValueError ("Only 'bf16' is supported for backward_dtype for now." )
1604+ if self .backward_dtype not in ["bf16" , "mxfp8" , "mxfp4" ]:
1605+ raise ValueError ("Only 'bf16', 'mxfp8' and 'mxfp4' are supported for backward_dtype for now." )
1606+
1607+ if self .backward_dtype != "bf16" and self .forward_dtype != "mxfp4" :
1608+ raise ValueError ("Only 'mxfp4' forward is compatible with non-bf16 backwards for now." )
1609+
16061610 if self .transform_init not in ["hadamard" , "identity" , "gsr" ]:
16071611 raise ValueError ("Only 'hadamard', 'identity' and 'gsr' are supported for transform_init." )
16081612
Original file line number Diff line number Diff line change @@ -163,6 +163,13 @@ def getQuantizationConfig(cls):
163163 return FPQuantConfig (forward_dtype = "mxfp4" , pseudoquantization = False )
164164
165165
166+ @require_qutlass
167+ class FPQuantNVFP4Test (FPQuantBaseTest ):
168+ @classmethod
169+ def getQuantizationConfig (cls ):
170+ return FPQuantConfig (forward_dtype = "nvfp4" , pseudoquantization = False )
171+
172+
166173@require_qutlass
167174class FPQuantMXFP4GS128Test (FPQuantBaseTest ):
168175 @classmethod
You can’t perform that action at this time.
0 commit comments