@@ -297,7 +297,8 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
297297 }
298298}
299299
300- void cutlass_scaled_sparse_mm_sm90 (torch::Tensor& c, torch::Tensor const & a,
300+ void cutlass_scaled_sparse_mm_sm90 (torch::Tensor& out,
301+ torch::Tensor const & a,
301302 torch::Tensor const & e,
302303 torch::Tensor const & b,
303304 torch::Tensor const & a_scales,
@@ -306,36 +307,35 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
306307 TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
307308 TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
308309 if (bias) {
309- TORCH_CHECK (bias->dtype () == c .dtype (),
310- " currently bias dtype must match output dtype " , c .dtype ());
310+ TORCH_CHECK (bias->dtype () == out .dtype (),
311+ " currently bias dtype must match output dtype " , out .dtype ());
311312 return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBias>(
312- c , a, e, b, a_scales, b_scales, *bias);
313+ out , a, e, b, a_scales, b_scales, *bias);
313314 } else {
314315 return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogue>(
315- c , a, e, b, a_scales, b_scales);
316+ out , a, e, b, a_scales, b_scales);
316317 }
317318}
318319
319- // void cutlass_scaled_sparse_mm_azp_sm90(torch::Tensor& out, torch::Tensor
320- // const& a,
321- // torch::Tensor const& e,
322- // torch::Tensor const& b,
323- // torch::Tensor const& a_scales,
324- // torch::Tensor const& b_scales,
325- // torch::Tensor const& azp_adj,
326- // c10::optional<torch::Tensor> const& azp,
327- // c10::optional<torch::Tensor> const& bias) {
328- // TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
329- // TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
330-
331- // if (azp) {
332- // return
333- // cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
334- // out, a, e, b, a_scales, b_scales, azp_adj, *azp, bias);
335- // } else {
336- // return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
337- // out, a, e, b, a_scales, b_scales, azp_adj, bias);
338- // }
339- // }
320+ void cutlass_scaled_sparse_mm_azp_sm90 (torch::Tensor& out,
321+ torch::Tensor const & a,
322+ torch::Tensor const & e,
323+ torch::Tensor const & b,
324+ torch::Tensor const & a_scales,
325+ torch::Tensor const & b_scales,
326+ torch::Tensor const & azp_adj,
327+ c10::optional<torch::Tensor> const & azp,
328+ c10::optional<torch::Tensor> const & bias) {
329+ TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
330+ TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
331+
332+ if (azp) {
333+ return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
334+ out, a, e, b, a_scales, b_scales, azp_adj, *azp, bias);
335+ } else {
336+ return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
337+ out, a, e, b, a_scales, b_scales, azp_adj, bias);
338+ }
339+ }
340340
341341#endif
0 commit comments