@@ -686,7 +686,7 @@ def convert(
686686 missing_keys = None ,
687687 ** kwargs ,
688688 ) -> dict [str , torch .Tensor ]:
689- if len (input_dict ) != 2 :
689+ if len (input_dict ) < 2 :
690690 # in case of no scales, the weights are not quantized, so we return the weights as is
691691 return {
692692 full_layer_name : input_dict ["weight$" ][0 ]
@@ -702,15 +702,18 @@ def convert(
702702
703703 rows , cols = quantized .shape [- 2 :]
704704 block_size = self .hf_quantizer .quantization_config .weight_block_size
705+ if block_size is None :
706+ block_size = (quantized .shape [- 2 ], quantized .shape [- 1 ])
705707
706708 block_m , block_n = block_size
709+
707710 if rows % block_m != 0 or cols % block_n != 0 :
708711 raise ValueError (
709712 f"Matrix dimensions ({ rows } , { cols } ) must be divisible by block sizes ({ block_m } , { block_n } )."
710713 )
711-
714+ quantized = quantized . to ( scales . dtype )
712715 reshaped = quantized .reshape (- 1 , rows // block_m , block_m , cols // block_n , block_n )
713- expanded_scales = scales .to ( torch . float32 ). reshape (- 1 , rows // block_m , cols // block_n )
716+ expanded_scales = scales .reshape (- 1 , rows // block_m , cols // block_n )
714717 expanded_scales = expanded_scales .unsqueeze (- 1 ).unsqueeze (2 )
715718 dequantized = reshaped * expanded_scales
716719
0 commit comments