Skip to content

Commit 598ee70

Browse files
committed
[Performance][Hopper] Avoid M dim padding to 4x for most cases (due to cuda graphs padding)
Signed-off-by: Alexander Matveev <[email protected]>
1 parent df4d3a4 commit 598ee70

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,27 @@ def _padded_cutlass(
115115
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
116116
)
117117

118-
padded_shape = [padded, *qx.shape[1:]]
119-
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
120-
padded_qx[0 : qx.shape[0], ...].copy_(qx)
121-
122-
padded_x_scale_shape = [*x_scale.shape[1:], padded]
123-
padded_x_scale = torch.ones(
124-
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
125-
).permute(-1, -2)
126-
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
127-
128-
output = cutlass_scaled_mm(
129-
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
130-
)
131-
return output[0 : qx.shape[0], ...]
118+
has_pad = padded > dim
119+
120+
if has_pad:
121+
padded_shape = [padded, *qx.shape[1:]]
122+
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
123+
padded_qx[0 : qx.shape[0], ...].copy_(qx)
124+
125+
padded_x_scale_shape = [*x_scale.shape[1:], padded]
126+
padded_x_scale = torch.ones(
127+
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
128+
).permute(-1, -2)
129+
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
130+
131+
output = cutlass_scaled_mm(
132+
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
133+
)
134+
return output[0 : qx.shape[0], ...]
135+
else:
136+
return cutlass_scaled_mm(
137+
qx, weight, x_scale, weight_scale, block_size, output_dtype
138+
)
132139

133140

134141
def _padded_cutlass_fake(

0 commit comments

Comments
 (0)