3131
3232SPARSITY_CONFIG_NAME : Literal ["sparsity_config" ] = "sparsity_config"
3333QUANTIZATION_SCHEME_MAP_TYPE = Dict [str , Optional [Dict [str , QuantizationArgs ]]]
34+
35+
3436class CompressedTensorsConfig (QuantizationConfig ):
3537
36- def __init__ (self ,
37- target_scheme_map : Dict [str , Any ],
38- ignore : List [str ],
39- quant_format : str ,
40- kv_cache_scheme : Optional [Dict [str , Any ]] = None ,
41- sparsity_scheme_map : Optional [Dict [str , SparsityCompressionConfig ]] = None ,
42- config : Optional [Dict [str , Any ]] = None ,
43- ):
38+ def __init__ (
39+ self ,
40+ target_scheme_map : Dict [str , Any ],
41+ ignore : List [str ],
42+ quant_format : str ,
43+ kv_cache_scheme : Optional [Dict [str , Any ]] = None ,
44+ sparsity_scheme_map : Optional [Dict [str ,
45+ SparsityCompressionConfig ]] = None ,
46+ config : Optional [Dict [str , Any ]] = None ,
47+ ):
4448
4549 self .ignore = ignore
4650 self .quant_format = quant_format
@@ -92,8 +96,10 @@ def get_quant_method(
9296 def from_config (cls , config : Dict [str , Any ]) -> "CompressedTensorsConfig" :
9397 ignore : List [str ] = cast (List [str ], config .get ("ignore" , []))
9498 quant_format = cast (str , config .get ("format" ))
95- target_scheme_map = cls ._quantization_scheme_map_from_config (config = config )
96- sparsity_scheme_map = cls ._sparsity_scheme_map_from_config (config = config )
99+ target_scheme_map = cls ._quantization_scheme_map_from_config (
100+ config = config )
101+ sparsity_scheme_map = cls ._sparsity_scheme_map_from_config (
102+ config = config )
97103
98104 return cls (
99105 target_scheme_map = target_scheme_map ,
@@ -102,26 +108,30 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
102108 sparsity_scheme_map = sparsity_scheme_map ,
103109 config = config ,
104110 )
105-
111+
106112 @classmethod
107- def _sparsity_scheme_map_from_config (cls , config : Dict [str , Any ]) -> Dict [str , SparsityCompressionConfig ]:
113+ def _sparsity_scheme_map_from_config (
114+ cls , config : Dict [str ,
115+ Any ]) -> Dict [str , SparsityCompressionConfig ]:
108116 """
109117 :param config: The `quantization_config` dictionary from config.json
110118 :return: A dictionary mapping target layer names to their corresponding
111119 sparsity compression configurations
112120 """
113- if (sparsity_config := config .get (SPARSITY_CONFIG_NAME )) is None :
121+ if (sparsity_config := config .get (SPARSITY_CONFIG_NAME )) is None :
114122 return dict ()
115-
116- sparsity_config = SparsityCompressionConfig .model_validate (sparsity_config )
123+
124+ sparsity_config = SparsityCompressionConfig .model_validate (
125+ sparsity_config )
117126 sparse_scheme_map : Dict [str , SparsityCompressionConfig ] = {
118127 target : sparsity_config
119128 for target in sparsity_config .targets or list ()
120129 }
121130 return sparse_scheme_map
122131
123132 @classmethod
124- def _quantization_scheme_map_from_config (cls , config : Dict [str , Any ]) -> QUANTIZATION_SCHEME_MAP_TYPE :
133+ def _quantization_scheme_map_from_config (
134+ cls , config : Dict [str , Any ]) -> QUANTIZATION_SCHEME_MAP_TYPE :
125135 """
126136 :param config: The `quantization_config` dictionary from config.json
127137 :return: A dictionary mapping target layer names to their corresponding
@@ -144,7 +154,8 @@ def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZ
144154 targets = quant_config .get ("targets" )
145155 for target in targets :
146156 target_scheme_map [target ] = {}
147- target_scheme_map [target ]["weights" ] = QuantizationArgs .model_validate (
157+ target_scheme_map [target ][
158+ "weights" ] = QuantizationArgs .model_validate (
148159 quant_config .get ("weights" ))
149160
150161 target_scheme_map [target ]["input_activations" ] = None
@@ -158,7 +169,8 @@ def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZ
158169 assert target_scheme_map [target ][
159170 "weights" ].type == QuantizationType .FLOAT
160171 else :
161- target_scheme_map [target ]["input_activations" ] = QuantizationArgs .model_validate (
172+ target_scheme_map [target ][
173+ "input_activations" ] = QuantizationArgs .model_validate (
162174 quant_config .get ("input_activations" ))
163175 return target_scheme_map
164176
@@ -359,7 +371,7 @@ def get_scheme(
359371 # TODO (@robertgshaw): add compressed-tensors as dep
360372 # so we do not have to re-write these functions
361373 # need to make accelerate optional in ct to do this
362-
374+
363375 matched_target = find_matched_target (
364376 layer_name = layer_name ,
365377 module = layer ,
@@ -369,42 +381,37 @@ def get_scheme(
369381 weight_quant = scheme_dict .get ("weights" )
370382 input_quant = scheme_dict .get ("input_activations" )
371383
372- sparsity_scheme : Optional [SparsityCompressionConfig ] = self .sparsity_scheme_map .get (matched_target )
373-
374- if self .supports_cutlass_24 (
375- weight_quant = weight_quant ,
376- input_quant = input_quant ,
377- sparsity_scheme = sparsity_scheme
378- ):
379- # Have a valid sparsity scheme and the layer is supported by the Cutlass 2:4 Kernel
380- needs_decompression = sparsity_scheme .format != CompressionFormat .dense .value
381- is_quantized = weight_quant is not None or input_quant is not None
382- scheme = CompressedTensors24 (
383- layer_name = layer_name ,
384- quantized = is_quantized ,
385- do_decompress = needs_decompression ,
386- weight_quant = weight_quant ,
387- input_quant = input_quant ,
388- config = self .config ,
389- )
384+ sparsity_scheme : Optional [
385+ SparsityCompressionConfig ] = self .sparsity_scheme_map .get (
386+ matched_target )
387+
388+ if self .supports_cutlass_24 (weight_quant = weight_quant ,
389+ input_quant = input_quant ,
390+ sparsity_scheme = sparsity_scheme ):
391+ # Have a valid sparsity scheme
392+ # Validate layer is supported by Cutlass 2:4 Kernel
393+ scheme = CompressedTensors24 (quantized = weight_quant is not None
394+ or input_quant is not None ,
395+ weight_quant = weight_quant ,
396+ input_quant = input_quant )
390397 else :
391- # Find the quant_scheme
398+ # Find the quant_scheme
392399 scheme = self ._get_scheme_from_parts (
393400 weight_quant = weight_quant ,
394401 input_quant = input_quant ,
395- )
402+ )
396403
397404 # Raise error if device does not support the scheme
398405 # (e.g. fp8 needs ada lovelace)
399406 self ._check_scheme_supported (scheme .get_min_capability ())
400407 return scheme
401-
408+
402409 @staticmethod
403410 def supports_cutlass_24 (
404- weight_quant : Optional [QuantizationArgs ],
405- input_quant : Optional [QuantizationArgs ],
406- sparsity_scheme : Optional [SparsityCompressionConfig ]= None
407- ) -> bool :
411+ weight_quant : Optional [QuantizationArgs ],
412+ input_quant : Optional [QuantizationArgs ],
413+ sparsity_scheme : Optional [SparsityCompressionConfig ] = None
414+ ) -> bool :
408415 """
409416 Check if the layer is supported by the Cutlass 2:4 Kernel
410417 Conditions:
@@ -418,39 +425,37 @@ def supports_cutlass_24(
418425 :return: True if the layer is supported by the Cutlass 2:4 Kernel
419426 False otherwise
420427 """
421-
422- if (
423- sparsity_scheme is None or
424- sparsity_scheme .sparsity_structure != SparsityStructure . TWO_FOUR . value
425- ) :
428+ is_valid_sparsity = ( sparsity_scheme is not None
429+ and sparsity_scheme . sparsity_structure
430+ == SparsityStructure . TWO_FOUR . value
431+ and sparsity_scheme .format == "dense" )
432+ if not is_valid_sparsity :
426433 return False
427-
434+
428435 # Unquantized cases are supported
429436 if weight_quant is None and input_quant is None :
430437 return True
431-
438+
432439 # Weight only quantization is not-supported
433440 if weight_quant is not None and input_quant is None :
434441 return False
435-
442+
436443 supported_weight_quant_strategies = [
437444 QuantizationStrategy .TENSOR .value ,
438445 QuantizationStrategy .CHANNEL .value
439446 ]
440447
441448 if weight_quant .strategy not in supported_weight_quant_strategies :
442449 return False
443-
450+
444451 supported_input_quant_strategies = [
445- QuantizationStrategy .TENSOR .value ,
446- QuantizationStrategy .TOKEN .value
452+ QuantizationStrategy .TENSOR .value , QuantizationStrategy .TOKEN .value
447453 ]
448-
454+
449455 if input_quant .strategy not in supported_input_quant_strategies :
450456 return False
451-
452- return weight_quant .num_bits == input_quant .num_bits == 8
453457
458+ return weight_quant .num_bits == input_quant .num_bits == 8
454459
455460
456461class CompressedTensorsLinearMethod (LinearMethodBase ):
0 commit comments