Skip to content

Commit 72f4577

Browse files
committed
Clean up code comments
1 parent 6d574af commit 72f4577

File tree

5 files changed

+32
-42
lines changed

5 files changed

+32
-42
lines changed

benchmarks/cutlass_benchmarks/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,6 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
5252
a = torch.randn((m, k), device='cuda') * 5
5353
b = torch.randn((n, k), device='cuda').t() * 5
5454

55-
# # Initialize a to all ones
56-
# a = torch.ones((m, k), device='cuda')
57-
# # Initialize b to all ones
58-
# b = torch.ones((n, k), device='cuda')
59-
6055
b = prune_to_2_4(b.t()).t()
6156

6257
if dtype == torch.int8:

csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
297297
}
298298
}
299299

300-
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
300+
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out,
301+
torch::Tensor const& a,
301302
torch::Tensor const& e,
302303
torch::Tensor const& b,
303304
torch::Tensor const& a_scales,
@@ -306,36 +307,35 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
306307
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
307308
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
308309
if (bias) {
309-
TORCH_CHECK(bias->dtype() == c.dtype(),
310-
"currently bias dtype must match output dtype ", c.dtype());
310+
TORCH_CHECK(bias->dtype() == out.dtype(),
311+
"currently bias dtype must match output dtype ", out.dtype());
311312
return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBias>(
312-
c, a, e, b, a_scales, b_scales, *bias);
313+
out, a, e, b, a_scales, b_scales, *bias);
313314
} else {
314315
return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogue>(
315-
c, a, e, b, a_scales, b_scales);
316+
out, a, e, b, a_scales, b_scales);
316317
}
317318
}
318319

319-
// void cutlass_scaled_sparse_mm_azp_sm90(torch::Tensor& out, torch::Tensor
320-
// const& a,
321-
// torch::Tensor const& e,
322-
// torch::Tensor const& b,
323-
// torch::Tensor const& a_scales,
324-
// torch::Tensor const& b_scales,
325-
// torch::Tensor const& azp_adj,
326-
// c10::optional<torch::Tensor> const& azp,
327-
// c10::optional<torch::Tensor> const& bias) {
328-
// TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
329-
// TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
330-
331-
// if (azp) {
332-
// return
333-
// cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
334-
// out, a, e, b, a_scales, b_scales, azp_adj, *azp, bias);
335-
// } else {
336-
// return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
337-
// out, a, e, b, a_scales, b_scales, azp_adj, bias);
338-
// }
339-
// }
320+
void cutlass_scaled_sparse_mm_azp_sm90(torch::Tensor& out,
321+
torch::Tensor const& a,
322+
torch::Tensor const& e,
323+
torch::Tensor const& b,
324+
torch::Tensor const& a_scales,
325+
torch::Tensor const& b_scales,
326+
torch::Tensor const& azp_adj,
327+
c10::optional<torch::Tensor> const& azp,
328+
c10::optional<torch::Tensor> const& bias) {
329+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
330+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
331+
332+
if (azp) {
333+
return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
334+
out, a, e, b, a_scales, b_scales, azp_adj, *azp, bias);
335+
} else {
336+
return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
337+
out, a, e, b, a_scales, b_scales, azp_adj, bias);
338+
}
339+
}
340340

341341
#endif

csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,6 @@ struct cutlass_3x_gemm {
364364
using ElementAB = ElementAB_;
365365
using ElementD = ElementD_;
366366
using ElementAcc = AccType;
367-
// typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
368-
// float>::type;
369367

370368
using EpilogueDescriptor =
371369
cutlass::epilogue::collective::detail::EpilogueDescriptor<
@@ -432,9 +430,6 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
432430
int64_t ldb = b.stride(1);
433431
int64_t ldc = out.stride(1);
434432

435-
// using StrideB = Stride<int64_t, Int<1>, int64_t>;
436-
// using StrideC = typename Gemm::StrideC;
437-
438433
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
439434
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
440435
using StrideB = typename Gemm::GemmKernel::StrideB;

csrc/sparse/cutlass/sparse_scaled_mm_entry.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
5252

5353
// Check for strides and alignment
5454
TORCH_CHECK(a.stride(1) == 1); // Row-major
55-
// TORCH_CHECK(b.stride(0) == 1 && c.stride(0) == 1); // Column-major
56-
// TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
55+
TORCH_CHECK(b.stride(0) == 1 && c.stride(0) == 1); // Column-major
56+
TORCH_CHECK(c.stride(1) % 16 == 0); // 16 Byte Alignment
5757
TORCH_CHECK(b.stride(1) % 16 == 0); // 16 Byte Alignment
5858
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
5959

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,11 @@ def apply_weights(self,
163163
input_scale = layer.input_scale
164164
q_input = x
165165

166-
out = ops.cutlass_scaled_sparse_mm(a=layer.weight,
166+
out = ops.cutlass_scaled_sparse_mm(a=q_input,
167+
b=layer.weight,
167168
e=layer.meta,
168-
b=q_input.t(),
169-
scale_a=layer.weight_scale,
170-
scale_b=input_scale,
169+
scale_a=input_scale,
170+
scale_b=layer.weight_scale,
171171
out_dtype=self.output_dtype,
172172
bias=bias)
173173
assert out.is_contiguous()

0 commit comments

Comments
 (0)