|
21 | 21 | "weight_quant_matmul", |
22 | 22 | "fused_moe", |
23 | 23 | "linear", |
| 24 | + "dynamic_quant", |
| 25 | + "linear_w8a8", |
| 26 | + "rms_norm_w8a8", |
| 27 | + "add_rms_norm_w8a8", |
24 | 28 | ] |
25 | 29 |
|
26 | 30 |
|
@@ -58,7 +62,7 @@ def apply_rotary_pos_emb( |
58 | 62 | cos_sin_cache: Optional[Tensor], |
59 | 63 | ) -> Tuple[Tensor, Tensor]: |
60 | 64 | """ |
61 | | - Applies rotary position embeddings to the query and key tensors. |
| 65 | + Apply rotary position embeddings to the query and key tensors. |
62 | 66 |
|
63 | 67 | Rotary position embedding is a method of embedding positional information into |
64 | 68 | self-attention computations without increasing the model size. |
@@ -613,3 +617,110 @@ def linear( |
613 | 617 | Tensor: The output tensor of linear computation. |
614 | 618 | """ |
615 | 619 | return vendor_ops_registry["linear"](x, weight, bias, all_reduce) |
| 620 | + |
| 621 | + |
| 622 | +def dynamic_quant( |
| 623 | + x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = "PER_TOKEN" |
| 624 | +) -> Tuple[Tensor, float]: |
| 625 | + """ |
| 626 | + Perform dynamic quantization on a tensor. |
| 627 | +
|
| 628 | + Args: |
| 629 | + x (Tensor): The input tensor to be quantized. |
| 630 | + quant_dtype (torch.dtype): The data type to which the tensor should be quantized. |
| 631 | + quant_granularity (str, optional): The granularity of quantization. Defaults to "PER_TOKEN". |
| 632 | + Options include: |
| 633 | + - "PER_TOKEN": Quantize each element independently. |
| 634 | + - "PER_CHANNEL": Quantize each channel independently. |
| 635 | + - "PER_TENSOR": Quantize the entire tensor as a whole. |
| 636 | +
|
| 637 | + Returns: |
| 638 | + Tuple[Tensor, float]: A tuple containing: |
| 639 | + - The quantized tensor. |
| 640 | + - The scaling factor used during quantization. |
| 641 | +
|
| 642 | + """ |
| 643 | + return vendor_ops_registry["dynamic_quant"](x, quant_dtype, quant_granularity) |
| 644 | + |
| 645 | + |
| 646 | +def linear_w8a8( |
| 647 | + a: Tensor, |
| 648 | + b: Tensor, |
| 649 | + rms_scale: float, |
| 650 | + linear_scale: float, |
| 651 | + out_dtype: torch.dtype, |
| 652 | + quant_dtype: torch.dtype, |
| 653 | + bias: Tensor, |
| 654 | +) -> Tensor: |
| 655 | + """ |
| 656 | + Performs a linear transformation on two quantized input tensors. |
| 657 | +
|
| 658 | + Args: |
| 659 | + a (Tensor): The first quantized input tensor. |
| 660 | + b (Tensor): The second quantized input tensor. |
| 661 | + rms_scale (float): The scaling factor for a. |
| 662 | + linear_scale (float): The scaling factor for b. |
| 663 | + out_dtype (torch.dtype): The target data type for the output tensor. |
| 664 | + quant_dtype (torch.dtype): The data type of the quantized input tensors. |
| 665 | + bias (Tensor): The bias tensor to be added to the output. |
| 666 | +
|
| 667 | + Returns: |
| 668 | + Tensor: The output tensor after applying the linear transformation. |
| 669 | + """ |
| 670 | + return vendor_ops_registry["linear_w8a8"]( |
| 671 | + a, b, rms_scale, linear_scale, out_dtype, quant_dtype, bias |
| 672 | + ) |
| 673 | + |
| 674 | + |
| 675 | +def rms_norm_w8a8( |
| 676 | + hidden_states: Tensor, |
| 677 | + weight: Tensor, |
| 678 | + epsilon: float, |
| 679 | + quant_dtype: torch.dtype, |
| 680 | +) -> Tuple[Tensor, float]: |
| 681 | + """ |
| 682 | + Apply RMS normalization to the input tensor and quantizes the result. |
| 683 | +
|
| 684 | + Args: |
| 685 | + hidden_states (Tensor): The input tensor to be normalized and quantized. |
| 686 | + weight (Tensor): The scaling weight applied to the normalized tensor. |
| 687 | + epsilon (float): A value added to the denominator for numerical stability during normalization. |
| 688 | + quant_dtype (torch.dtype): The target data type for the quantized result. |
| 689 | +
|
| 690 | + Returns: |
| 691 | + Tuple[Tensor, float]: A tuple containing: |
| 692 | + - The RMS-normalized and quantized tensor. |
| 693 | + - The scaling factor used during quantization. |
| 694 | + """ |
| 695 | + return vendor_ops_registry["rms_norm_w8a8"]( |
| 696 | + hidden_states, weight, epsilon, quant_dtype |
| 697 | + ) |
| 698 | + |
| 699 | + |
| 700 | +def add_rms_norm_w8a8( |
| 701 | + hidden_states: Tensor, |
| 702 | + residual: Tensor, |
| 703 | + weight: Tensor, |
| 704 | + epsilon: float, |
| 705 | + quant_dtype: torch.dtype, |
| 706 | +) -> Tuple[Tensor, float, Tensor]: |
| 707 | + """ |
| 708 | + Apply RMS normalization to the input tensor, adds a residual connection, |
| 709 | + and quantizes the result. |
| 710 | +
|
| 711 | + Args: |
| 712 | + hidden_states (Tensor): The input tensor to be normalized and quantized. |
| 713 | + residual (Tensor): The residual tensor to be added to the normalized tensor. |
| 714 | + weight (Tensor): The scaling weight applied to the normalized tensor. |
| 715 | + epsilon (float): A value added to the denominator for numerical stability during normalization. |
| 716 | + quant_dtype (torch.dtype): The target data type for the quantized result. |
| 717 | +
|
| 718 | + Returns: |
| 719 | + Tuple[Tensor, float, Tensor]: A tuple containing: |
| 720 | + - The RMS-normalized, residual-added, and quantized tensor. |
| 721 | + - The scaling factor used during quantization. |
| 722 | + - The residual tensor. |
| 723 | + """ |
| 724 | + return vendor_ops_registry["add_rms_norm_w8a8"]( |
| 725 | + hidden_states, residual, weight, epsilon, quant_dtype |
| 726 | + ) |
0 commit comments