@@ -509,14 +509,13 @@ def _fbgemm_grouped_gemm_ws(
509509 num_tiles = num_m_tiles * NUM_N_TILES
510510
511511 if USE_TMA_STORE :
512- with tl .async_task ([0 ]):
513- c_desc_ptr = tl .make_tensor_descriptor (
514- c_ptr + M_start_offset * N ,
515- shape = [m_size , N ],
516- # pyre-ignore
517- strides = [N , 1 ],
518- block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_N ],
519- )
512+ c_desc_ptr = tl .make_tensor_descriptor (
513+ c_ptr + M_start_offset * N ,
514+ shape = [m_size , N ],
515+ # pyre-ignore
516+ strides = [N , 1 ],
517+ block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_N ],
518+ )
520519
521520 # Move across tiles
522521 next_iterated_tiles = iterated_tiles + num_tiles
@@ -534,72 +533,59 @@ def _fbgemm_grouped_gemm_ws(
534533 m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M ).to (tl .int32 )
535534 n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N ).to (tl .int32 )
536535 for k_offset in range (0 , K , BLOCK_SIZE_K ):
537- with tl .async_task ([0 ]):
538- a = tl ._experimental_descriptor_load (
539- a_desc_ptr ,
540- [m_offset , k_offset ],
541- [BLOCK_SIZE_M , BLOCK_SIZE_K ],
542- dtype ,
543- )
544- b = tl ._experimental_descriptor_load (
545- b_desc_ptr ,
546- [n_offset , k_offset ],
547- [BLOCK_SIZE_N , BLOCK_SIZE_K ],
548- dtype ,
549- )
550- with tl .async_task ([1 , NUM_CONSUMER_GROUPS ]):
551- if USE_FAST_ACCUM :
552- accumulator = tl .dot (a , b .T , accumulator )
553- else :
554- accumulator += tl .dot (a , b .T )
536+ a = tl ._experimental_descriptor_load (
537+ a_desc_ptr ,
538+ [m_offset , k_offset ],
539+ [BLOCK_SIZE_M , BLOCK_SIZE_K ],
540+ dtype ,
541+ )
542+ b = tl ._experimental_descriptor_load (
543+ b_desc_ptr ,
544+ [n_offset , k_offset ],
545+ [BLOCK_SIZE_N , BLOCK_SIZE_K ],
546+ dtype ,
547+ )
548+ if USE_FAST_ACCUM :
549+ accumulator = tl .dot (a , b .T , accumulator )
550+ else :
551+ accumulator += tl .dot (a , b .T )
555552
556553 if USE_TMA_STORE :
557- with tl .async_task ([1 , NUM_CONSUMER_GROUPS ]):
558- m_offset = (tile_m_idx * BLOCK_SIZE_M ).to (tl .int32 )
559- n_offset = (tile_n_idx * BLOCK_SIZE_N ).to (tl .int32 )
560- # pyre-ignore
561- c_desc_ptr .store (
562- [m_offset , n_offset ],
563- accumulator .to (c_ptr .dtype .element_ty ),
564- )
554+ m_offset = (tile_m_idx * BLOCK_SIZE_M ).to (tl .int32 )
555+ n_offset = (tile_n_idx * BLOCK_SIZE_N ).to (tl .int32 )
556+ # pyre-ignore
557+ c_desc_ptr .store (
558+ [m_offset , n_offset ],
559+ accumulator .to (c_ptr .dtype .element_ty ),
560+ )
565561 elif FUSE_SCATTER_ADD :
566- with tl .async_task ([1 , NUM_CONSUMER_GROUPS ]):
567- offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (
568- 0 , BLOCK_SIZE_M
569- )
570- mask = offs_am < m_size
571- m_offsets = tl .load (
572- scatter_add_indices + M_start_offset + offs_am ,
573- mask = mask ,
574- cache_modifier = ".ca" ,
575- )
576- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl .arange (
577- 0 , BLOCK_SIZE_N
578- )
579- c = accumulator .to (c_ptr .dtype .element_ty )
580- tl .atomic_add (
581- c_ptr + m_offsets [:, None ] * N + offs_bn [None , :],
582- c ,
583- mask = mask [:, None ],
584- sem = "relaxed" ,
585- )
562+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
563+ mask = offs_am < m_size
564+ m_offsets = tl .load (
565+ scatter_add_indices + M_start_offset + offs_am ,
566+ mask = mask ,
567+ cache_modifier = ".ca" ,
568+ )
569+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
570+ c = accumulator .to (c_ptr .dtype .element_ty )
571+ tl .atomic_add (
572+ c_ptr + m_offsets [:, None ] * N + offs_bn [None , :],
573+ c ,
574+ mask = mask [:, None ],
575+ sem = "relaxed" ,
576+ )
586577 else :
587- with tl .async_task ([1 , NUM_CONSUMER_GROUPS ]):
588- offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (
589- 0 , BLOCK_SIZE_M
590- )
591- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl .arange (
592- 0 , BLOCK_SIZE_N
593- )
594- c = accumulator .to (c_ptr .dtype .element_ty )
595- tl .store (
596- c_ptr
597- + (M_start_offset + offs_am [:, None ]) * N
598- + offs_bn [None , :],
599- c ,
600- mask = offs_am [:, None ] < m_size ,
601- cache_modifier = ".cs" ,
602- )
578+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
579+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
580+ c = accumulator .to (c_ptr .dtype .element_ty )
581+ tl .store (
582+ c_ptr
583+ + (M_start_offset + offs_am [:, None ]) * N
584+ + offs_bn [None , :],
585+ c ,
586+ mask = offs_am [:, None ] < m_size ,
587+ cache_modifier = ".cs" ,
588+ )
603589 tidx += NUM_SMS
604590
605591 iterated_tiles += num_tiles
@@ -841,14 +827,13 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
841827 num_tiles = num_m_tiles * NUM_N_TILES
842828
843829 if USE_TMA_STORE :
844- with tl .async_task ([0 ]):
845- c_desc_ptr = tl .make_tensor_descriptor (
846- c_ptr + M_start_offset * N ,
847- shape = [m_size , N ],
848- # pyre-ignore
849- strides = [N , 1 ],
850- block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_N ],
851- )
830+ c_desc_ptr = tl .make_tensor_descriptor (
831+ c_ptr + M_start_offset * N ,
832+ shape = [m_size , N ],
833+ # pyre-ignore
834+ strides = [N , 1 ],
835+ block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_N ],
836+ )
852837
853838 # Move across tiles
854839 next_iterated_tiles = iterated_tiles + num_tiles
@@ -867,107 +852,85 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
867852 m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M ).to (tl .int32 )
868853 n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N ).to (tl .int32 )
869854 for k_offset in range (0 , K , BLOCK_SIZE_K ):
870- with tl .async_task ([0 ]):
871- a = tl ._experimental_descriptor_load (
872- a_desc_ptr ,
873- [m_offset , k_offset ],
874- [BLOCK_SIZE_M , BLOCK_SIZE_K ],
875- dtype ,
876- )
877- b = tl ._experimental_descriptor_load (
878- b_desc_ptr ,
879- [n_offset , k_offset ],
880- [BLOCK_SIZE_N , BLOCK_SIZE_K ],
881- dtype ,
882- )
883- with tl .async_task ([1 , NUM_CONSUMER_GROUPS ]):
884- if USE_FAST_ACCUM :
885- accumulator = tl .dot (a , b .T , accumulator )
886- else :
887- accumulator += tl .dot (a , b .T )
855+ a = tl ._experimental_descriptor_load (
856+ a_desc_ptr ,
857+ [m_offset , k_offset ],
858+ [BLOCK_SIZE_M , BLOCK_SIZE_K ],
859+ dtype ,
860+ )
861+ b = tl ._experimental_descriptor_load (
862+ b_desc_ptr ,
863+ [n_offset , k_offset ],
864+ [BLOCK_SIZE_N , BLOCK_SIZE_K ],
865+ dtype ,
866+ )
867+ if USE_FAST_ACCUM :
868+ accumulator = tl .dot (a , b .T , accumulator )
869+ else :
870+ accumulator += tl .dot (a , b .T )
888871
889872 if USE_TMA_LOAD_ON_SCALES :
890- with tl .async_task ([0 ]):
891- b_scale = tl ._experimental_descriptor_load (
892- b_scale_desc_ptr ,
893- [n_offset ],
894- [BLOCK_SIZE_N ],
895- tl .float32 ,
896- )
897-
898- with tl .async_task ([1 , NUM_CONSUMER_GROUPS ]):
899- offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (
900- 0 , BLOCK_SIZE_M
901- )
902- a_scale = tl .load (
903- a_scale_ptr + M_start_offset + offs_am [:, None ],
904- mask = offs_am [:, None ] < m_size ,
905- cache_modifier = ".ca" ,
906- )
907- c = accumulator .to (tl .float32 ) * a_scale * b_scale [None , :]
873+ b_scale = tl ._experimental_descriptor_load (
874+ b_scale_desc_ptr ,
875+ [n_offset ],
876+ [BLOCK_SIZE_N ],
877+ tl .float32 ,
878+ )
879+
880+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
881+ a_scale = tl .load (
882+ a_scale_ptr + M_start_offset + offs_am [:, None ],
883+ mask = offs_am [:, None ] < m_size ,
884+ cache_modifier = ".ca" ,
885+ )
886+ c = accumulator .to (tl .float32 ) * a_scale * b_scale [None , :]
908887 else :
909- with tl .async_task ([1 , NUM_CONSUMER_GROUPS ]):
910- offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (
911- 0 , BLOCK_SIZE_M
912- )
913- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl .arange (
914- 0 , BLOCK_SIZE_N
915- )
916- a_scale = tl .load (
917- a_scale_ptr + M_start_offset + offs_am [:, None ],
918- mask = offs_am [:, None ] < m_size ,
919- cache_modifier = ".ca" ,
920- )
921- b_scale = tl .load (
922- b_scale_ptr + N_start_offset + offs_bn [None , :],
923- cache_modifier = ".ca" ,
924- )
925- c = accumulator .to (tl .float32 ) * a_scale * b_scale
888+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
889+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
890+ a_scale = tl .load (
891+ a_scale_ptr + M_start_offset + offs_am [:, None ],
892+ mask = offs_am [:, None ] < m_size ,
893+ cache_modifier = ".ca" ,
894+ )
895+ b_scale = tl .load (
896+ b_scale_ptr + N_start_offset + offs_bn [None , :],
897+ cache_modifier = ".ca" ,
898+ )
899+ c = accumulator .to (tl .float32 ) * a_scale * b_scale
926900
927901 if USE_TMA_STORE :
928- with tl .async_task ([1 , NUM_CONSUMER_GROUPS ]):
929- m_offset = (tile_m_idx * BLOCK_SIZE_M ).to (tl .int32 )
930- n_offset = (tile_n_idx * BLOCK_SIZE_N ).to (tl .int32 )
931- # pyre-ignore
932- c_desc_ptr .store (
933- [m_offset , n_offset ], c .to (c_ptr .dtype .element_ty )
934- )
902+ m_offset = (tile_m_idx * BLOCK_SIZE_M ).to (tl .int32 )
903+ n_offset = (tile_n_idx * BLOCK_SIZE_N ).to (tl .int32 )
904+ # pyre-ignore
905+ c_desc_ptr .store (
906+ [m_offset , n_offset ], c .to (c_ptr .dtype .element_ty )
907+ )
935908 elif FUSE_SCATTER_ADD :
936- with tl .async_task ([1 , NUM_CONSUMER_GROUPS ]):
937- offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (
938- 0 , BLOCK_SIZE_M
939- )
940- mask = offs_am < m_size
941- m_offsets = tl .load (
942- scatter_add_indices + M_start_offset + offs_am ,
943- mask = mask ,
944- cache_modifier = ".ca" ,
945- )
946- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl .arange (
947- 0 , BLOCK_SIZE_N
948- )
949- tl .atomic_add (
950- c_ptr + m_offsets [:, None ] * N + offs_bn [None , :],
951- c ,
952- mask = mask [:, None ],
953- sem = "relaxed" ,
954- )
909+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
910+ mask = offs_am < m_size
911+ m_offsets = tl .load (
912+ scatter_add_indices + M_start_offset + offs_am ,
913+ mask = mask ,
914+ cache_modifier = ".ca" ,
915+ )
916+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
917+ tl .atomic_add (
918+ c_ptr + m_offsets [:, None ] * N + offs_bn [None , :],
919+ c ,
920+ mask = mask [:, None ],
921+ sem = "relaxed" ,
922+ )
955923 else :
956- with tl .async_task ([1 , NUM_CONSUMER_GROUPS ]):
957- offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (
958- 0 , BLOCK_SIZE_M
959- )
960- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl .arange (
961- 0 , BLOCK_SIZE_N
962- )
963- tl .store (
964- c_ptr
965- + (M_start_offset + offs_am [:, None ]) * N
966- + offs_bn [None , :],
967- c ,
968- mask = offs_am [:, None ] < m_size ,
969- cache_modifier = ".cs" ,
970- )
924+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
925+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
926+ tl .store (
927+ c_ptr
928+ + (M_start_offset + offs_am [:, None ]) * N
929+ + offs_bn [None , :],
930+ c ,
931+ mask = offs_am [:, None ] < m_size ,
932+ cache_modifier = ".cs" ,
933+ )
971934 tidx += NUM_SMS
972935
973936 iterated_tiles += num_tiles
0 commit comments