Skip to content

Commit f091f0c

Browse files
authored
[feature] add dlinfer w8a8 support. (#151)
1 parent 75984c0 commit f091f0c

File tree

2 files changed

+178
-1
lines changed

2 files changed

+178
-1
lines changed

dlinfer/ops/llm.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
"weight_quant_matmul",
2222
"fused_moe",
2323
"linear",
24+
"dynamic_quant",
25+
"linear_w8a8",
26+
"rms_norm_w8a8",
27+
"add_rms_norm_w8a8",
2428
]
2529

2630

@@ -58,7 +62,7 @@ def apply_rotary_pos_emb(
5862
cos_sin_cache: Optional[Tensor],
5963
) -> Tuple[Tensor, Tensor]:
6064
"""
61-
Applies rotary position embeddings to the query and key tensors.
65+
Apply rotary position embeddings to the query and key tensors.
6266
6367
Rotary position embedding is a method of embedding positional information into
6468
self-attention computations without increasing the model size.
@@ -613,3 +617,110 @@ def linear(
613617
Tensor: The output tensor of linear computation.
614618
"""
615619
return vendor_ops_registry["linear"](x, weight, bias, all_reduce)
620+
621+
622+
def dynamic_quant(
623+
x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = "PER_TOKEN"
624+
) -> Tuple[Tensor, float]:
625+
"""
626+
Perform dynamic quantization on a tensor.
627+
628+
Args:
629+
x (Tensor): The input tensor to be quantized.
630+
quant_dtype (torch.dtype): The data type to which the tensor should be quantized.
631+
quant_granularity (str, optional): The granularity of quantization. Defaults to "PER_TOKEN".
632+
Options include:
633+
- "PER_TOKEN": Quantize each element independently.
634+
- "PER_CHANNEL": Quantize each channel independently.
635+
- "PER_TENSOR": Quantize the entire tensor as a whole.
636+
637+
Returns:
638+
Tuple[Tensor, float]: A tuple containing:
639+
- The quantized tensor.
640+
- The scaling factor used during quantization.
641+
642+
"""
643+
return vendor_ops_registry["dynamic_quant"](x, quant_dtype, quant_granularity)
644+
645+
646+
def linear_w8a8(
647+
a: Tensor,
648+
b: Tensor,
649+
rms_scale: float,
650+
linear_scale: float,
651+
out_dtype: torch.dtype,
652+
quant_dtype: torch.dtype,
653+
bias: Tensor,
654+
) -> Tensor:
655+
"""
656+
Performs a linear transformation on two quantized input tensors.
657+
658+
Args:
659+
a (Tensor): The first quantized input tensor.
660+
b (Tensor): The second quantized input tensor.
661+
rms_scale (float): The scaling factor for a.
662+
linear_scale (float): The scaling factor for b.
663+
out_dtype (torch.dtype): The target data type for the output tensor.
664+
quant_dtype (torch.dtype): The data type of the quantized input tensors.
665+
bias (Tensor): The bias tensor to be added to the output.
666+
667+
Returns:
668+
Tensor: The output tensor after applying the linear transformation.
669+
"""
670+
return vendor_ops_registry["linear_w8a8"](
671+
a, b, rms_scale, linear_scale, out_dtype, quant_dtype, bias
672+
)
673+
674+
675+
def rms_norm_w8a8(
676+
hidden_states: Tensor,
677+
weight: Tensor,
678+
epsilon: float,
679+
quant_dtype: torch.dtype,
680+
) -> Tuple[Tensor, float]:
681+
"""
682+
Apply RMS normalization to the input tensor and quantizes the result.
683+
684+
Args:
685+
hidden_states (Tensor): The input tensor to be normalized and quantized.
686+
weight (Tensor): The scaling weight applied to the normalized tensor.
687+
epsilon (float): A value added to the denominator for numerical stability during normalization.
688+
quant_dtype (torch.dtype): The target data type for the quantized result.
689+
690+
Returns:
691+
Tuple[Tensor, float]: A tuple containing:
692+
- The RMS-normalized and quantized tensor.
693+
- The scaling factor used during quantization.
694+
"""
695+
return vendor_ops_registry["rms_norm_w8a8"](
696+
hidden_states, weight, epsilon, quant_dtype
697+
)
698+
699+
700+
def add_rms_norm_w8a8(
701+
hidden_states: Tensor,
702+
residual: Tensor,
703+
weight: Tensor,
704+
epsilon: float,
705+
quant_dtype: torch.dtype,
706+
) -> Tuple[Tensor, float, Tensor]:
707+
"""
708+
Apply RMS normalization to the input tensor, adds a residual connection,
709+
and quantizes the result.
710+
711+
Args:
712+
hidden_states (Tensor): The input tensor to be normalized and quantized.
713+
residual (Tensor): The residual tensor to be added to the normalized tensor.
714+
weight (Tensor): The scaling weight applied to the normalized tensor.
715+
epsilon (float): A value added to the denominator for numerical stability during normalization.
716+
quant_dtype (torch.dtype): The target data type for the quantized result.
717+
718+
Returns:
719+
Tuple[Tensor, float, Tensor]: A tuple containing:
720+
- The RMS-normalized, residual-added, and quantized tensor.
721+
- The scaling factor used during quantization.
722+
- The residual tensor.
723+
"""
724+
return vendor_ops_registry["add_rms_norm_w8a8"](
725+
hidden_states, residual, weight, epsilon, quant_dtype
726+
)

dlinfer/vendor/maca/maca_ops.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
"moe_gating_topk_softmax",
2828
"linear",
2929
"weight_quant_matmul",
30+
"dynamic_quant",
31+
"linear_w8a8",
32+
"rms_norm_w8a8",
33+
"add_rms_norm_w8a8",
3034
]
3135

3236

@@ -403,3 +407,65 @@ def weight_quant_matmul(
403407
if bias is not None:
404408
output += bias
405409
return output
410+
411+
412+
@register_ops(vendor_ops_registry)
413+
def dynamic_quant(
414+
x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = "PER_TOKEN"
415+
):
416+
assert quant_dtype == torch.int8
417+
assert quant_granularity == "PER_TOKEN"
418+
x, input_scale, _ = vllm._custom_ops.scaled_int8_quant(x, None)
419+
return x, input_scale
420+
421+
422+
@register_ops(vendor_ops_registry)
423+
def linear_w8a8(
424+
a: Tensor,
425+
b: Tensor,
426+
rms_scale: float,
427+
linear_scale: float,
428+
out_dtype: torch.dtype,
429+
quant_dtype: torch.dtype = torch.int8,
430+
bias: Tensor = None,
431+
):
432+
assert quant_dtype == torch.int8
433+
bs, seq_len, head_size = a.size()
434+
out = vllm._custom_ops.cutlass_scaled_mm(
435+
a.view(-1, head_size),
436+
b,
437+
scale_a=rms_scale,
438+
scale_b=linear_scale,
439+
out_dtype=out_dtype,
440+
bias=bias,
441+
)
442+
out = out.view(bs, seq_len, -1)
443+
return out
444+
445+
446+
@register_ops(vendor_ops_registry)
447+
def rms_norm_w8a8(
448+
hidden_states: Tensor,
449+
weight: Tensor,
450+
epsilon: float,
451+
quant_dtype: torch.dtype = torch.int8,
452+
):
453+
assert quant_dtype == torch.int8
454+
x = torch.empty_like(hidden_states)
455+
vllm._custom_ops.rms_norm(x, hidden_states, weight, epsilon)
456+
x, input_scale, _ = vllm._custom_ops.scaled_int8_quant(x, None)
457+
return x, input_scale
458+
459+
460+
@register_ops(vendor_ops_registry)
461+
def add_rms_norm_w8a8(
462+
hidden_states: Tensor,
463+
residual: Tensor,
464+
weight: Tensor,
465+
epsilon: float,
466+
quant_dtype: torch.dtype = torch.int8,
467+
):
468+
assert quant_dtype == torch.int8
469+
vllm._custom_ops.fused_add_rms_norm(hidden_states, residual, weight, epsilon)
470+
x, input_scale, _ = vllm._custom_ops.scaled_int8_quant(hidden_states, None)
471+
return x, input_scale, residual

0 commit comments

Comments
 (0)