@@ -115,20 +115,27 @@ def _padded_cutlass(
115115 dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple )
116116 )
117117
118- padded_shape = [padded , * qx .shape [1 :]]
119- padded_qx = torch .zeros (padded_shape , device = qx .device , dtype = qx .dtype )
120- padded_qx [0 : qx .shape [0 ], ...].copy_ (qx )
121-
122- padded_x_scale_shape = [* x_scale .shape [1 :], padded ]
123- padded_x_scale = torch .ones (
124- padded_x_scale_shape , device = x_scale .device , dtype = x_scale .dtype
125- ).permute (- 1 , - 2 )
126- padded_x_scale [0 : x_scale .shape [0 ], ...].copy_ (x_scale )
127-
128- output = cutlass_scaled_mm (
129- padded_qx , weight , padded_x_scale , weight_scale , block_size , output_dtype
130- )
131- return output [0 : qx .shape [0 ], ...]
118+ has_pad = padded > dim
119+
120+ if has_pad :
121+ padded_shape = [padded , * qx .shape [1 :]]
122+ padded_qx = torch .zeros (padded_shape , device = qx .device , dtype = qx .dtype )
123+ padded_qx [0 : qx .shape [0 ], ...].copy_ (qx )
124+
125+ padded_x_scale_shape = [* x_scale .shape [1 :], padded ]
126+ padded_x_scale = torch .ones (
127+ padded_x_scale_shape , device = x_scale .device , dtype = x_scale .dtype
128+ ).permute (- 1 , - 2 )
129+ padded_x_scale [0 : x_scale .shape [0 ], ...].copy_ (x_scale )
130+
131+ output = cutlass_scaled_mm (
132+ padded_qx , weight , padded_x_scale , weight_scale , block_size , output_dtype
133+ )
134+ return output [0 : qx .shape [0 ], ...]
135+ else :
136+ return cutlass_scaled_mm (
137+ qx , weight , x_scale , weight_scale , block_size , output_dtype
138+ )
132139
133140
134141def _padded_cutlass_fake (
0 commit comments