-
Notifications
You must be signed in to change notification settings - Fork 13.7k
convert: add dequant function for compressed_tensor (kimi-k2-thinking) #17064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
1bd57a3
ab0b550
ed7b7c7
489a7b8
caf0e42
f46686b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -333,6 +333,38 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) | |
|
|
||
| return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T | ||
|
|
||
| # ref: https://github.com/vllm-project/compressed-tensors/blob/52792be02ec09e59f3517104e755a02d0e003fbb/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py | ||
| def dequant_compressed_tensor(weight: Tensor, scale: Tensor) -> Tensor: | ||
| weights_config = quant_config["config_groups"]["group_0"]["weights"] | ||
| group_size = weights_config["group_size"] | ||
| num_bits = weights_config["num_bits"] | ||
| # only tested with https://huggingface.co/moonshotai/Kimi-K2-Thinking/blob/main/config.json | ||
| # TODO: extend this if other configurations are needed | ||
| assert(group_size == 32) | ||
| assert(num_bits == 4) | ||
| assert(quant_config["format"] == "pack-quantized") | ||
|
|
||
| pack_factor = group_size // num_bits | ||
| mask = (1 << num_bits) - 1 | ||
| unpacked = torch.zeros( | ||
| (weight.shape[0], weight.shape[1] * pack_factor), | ||
| dtype=torch.int32, | ||
| ) | ||
| if self.lazy: | ||
| unpacked = LazyTorchTensor.from_eager(unpacked) | ||
| else: | ||
| unpacked = unpacked.to(weight.device) # is this needed? | ||
| for i in range(pack_factor): | ||
| unpacked[:, i::pack_factor] = (weight >> (num_bits * i)) & mask | ||
| # TODO: may need to unpad | ||
| unpacked = unpacked - (mask + 1) // 2 # convert uint4 to int4 (shift scale) | ||
| scale = scale.to(torch.float32) | ||
| scale = scale.unsqueeze(2) | ||
| unpacked = unpacked.to(torch.float32) | ||
| unpacked = unpacked.reshape(-1, unpacked.shape[1] // group_size, group_size) | ||
| dequantized = (unpacked * scale).reshape(-1, unpacked.shape[1] * group_size) | ||
| return dequantized | ||
|
|
||
| if quant_method == "bitnet": | ||
| for name in self.model_tensors.keys(): | ||
| if name.endswith(".weight_scale"): | ||
|
|
@@ -371,6 +403,22 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) | |
| ".scales", | ||
| ) | ||
| ] | ||
| elif quant_method == "compressed-tensors": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might want to check for |
||
| for name in self.model_tensors.keys(): | ||
| if name.endswith("_packed"): | ||
| base_name = name.removesuffix("_packed") | ||
| packed = self.model_tensors[base_name + "_packed"] | ||
| scale = self.model_tensors[base_name + "_scale"] | ||
| # TODO: use _shape for unpadding if necessary | ||
| new_tensors[base_name] = lambda p=packed, s=scale: dequant_compressed_tensor(p(), s()) | ||
| tensors_to_remove += [ | ||
| base_name + n | ||
| for n in ( | ||
| "_packed", | ||
| "_scale", | ||
| "_shape", | ||
| ) | ||
| ] | ||
| else: | ||
| raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}") | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lazy tensors don't handle
__setitem__correctly, I think (or it causes eager evaluation). That's because the function returnsNoneand so the change tree can't really be updated with how it's currently implemented.Prefer explicit concatenation instead if possible (like with
torch.cat,torch.stack, etc.). (this should help with memory usage)Alternatively, there are other ways to unpack without concatenation, like the broadcasting shifts done in
gguf-py/gguf/quants.py.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yeah I need to go offline in next few minutes. Feel free to push directly to this branch if you have any suggestions!