Skip to content

Commit a1652a6

Browse files
dshi7facebook-github-bot
authored andcommitted
Deprecate tl.async_task from fbgemm
Summary: see D86119952 Differential Revision: D86319606
1 parent 17a7653 commit a1652a6

File tree

1 file changed

+135
-172
lines changed

1 file changed

+135
-172
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py

Lines changed: 135 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -509,14 +509,13 @@ def _fbgemm_grouped_gemm_ws(
509509
num_tiles = num_m_tiles * NUM_N_TILES
510510

511511
if USE_TMA_STORE:
512-
with tl.async_task([0]):
513-
c_desc_ptr = tl.make_tensor_descriptor(
514-
c_ptr + M_start_offset * N,
515-
shape=[m_size, N],
516-
# pyre-ignore
517-
strides=[N, 1],
518-
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
519-
)
512+
c_desc_ptr = tl.make_tensor_descriptor(
513+
c_ptr + M_start_offset * N,
514+
shape=[m_size, N],
515+
# pyre-ignore
516+
strides=[N, 1],
517+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
518+
)
520519

521520
# Move across tiles
522521
next_iterated_tiles = iterated_tiles + num_tiles
@@ -534,72 +533,59 @@ def _fbgemm_grouped_gemm_ws(
534533
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
535534
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
536535
for k_offset in range(0, K, BLOCK_SIZE_K):
537-
with tl.async_task([0]):
538-
a = tl._experimental_descriptor_load(
539-
a_desc_ptr,
540-
[m_offset, k_offset],
541-
[BLOCK_SIZE_M, BLOCK_SIZE_K],
542-
dtype,
543-
)
544-
b = tl._experimental_descriptor_load(
545-
b_desc_ptr,
546-
[n_offset, k_offset],
547-
[BLOCK_SIZE_N, BLOCK_SIZE_K],
548-
dtype,
549-
)
550-
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
551-
if USE_FAST_ACCUM:
552-
accumulator = tl.dot(a, b.T, accumulator)
553-
else:
554-
accumulator += tl.dot(a, b.T)
536+
a = tl._experimental_descriptor_load(
537+
a_desc_ptr,
538+
[m_offset, k_offset],
539+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
540+
dtype,
541+
)
542+
b = tl._experimental_descriptor_load(
543+
b_desc_ptr,
544+
[n_offset, k_offset],
545+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
546+
dtype,
547+
)
548+
if USE_FAST_ACCUM:
549+
accumulator = tl.dot(a, b.T, accumulator)
550+
else:
551+
accumulator += tl.dot(a, b.T)
555552

556553
if USE_TMA_STORE:
557-
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
558-
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
559-
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
560-
# pyre-ignore
561-
c_desc_ptr.store(
562-
[m_offset, n_offset],
563-
accumulator.to(c_ptr.dtype.element_ty),
564-
)
554+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
555+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
556+
# pyre-ignore
557+
c_desc_ptr.store(
558+
[m_offset, n_offset],
559+
accumulator.to(c_ptr.dtype.element_ty),
560+
)
565561
elif FUSE_SCATTER_ADD:
566-
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
567-
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
568-
0, BLOCK_SIZE_M
569-
)
570-
mask = offs_am < m_size
571-
m_offsets = tl.load(
572-
scatter_add_indices + M_start_offset + offs_am,
573-
mask=mask,
574-
cache_modifier=".ca",
575-
)
576-
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
577-
0, BLOCK_SIZE_N
578-
)
579-
c = accumulator.to(c_ptr.dtype.element_ty)
580-
tl.atomic_add(
581-
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
582-
c,
583-
mask=mask[:, None],
584-
sem="relaxed",
585-
)
562+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
563+
mask = offs_am < m_size
564+
m_offsets = tl.load(
565+
scatter_add_indices + M_start_offset + offs_am,
566+
mask=mask,
567+
cache_modifier=".ca",
568+
)
569+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
570+
c = accumulator.to(c_ptr.dtype.element_ty)
571+
tl.atomic_add(
572+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
573+
c,
574+
mask=mask[:, None],
575+
sem="relaxed",
576+
)
586577
else:
587-
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
588-
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
589-
0, BLOCK_SIZE_M
590-
)
591-
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
592-
0, BLOCK_SIZE_N
593-
)
594-
c = accumulator.to(c_ptr.dtype.element_ty)
595-
tl.store(
596-
c_ptr
597-
+ (M_start_offset + offs_am[:, None]) * N
598-
+ offs_bn[None, :],
599-
c,
600-
mask=offs_am[:, None] < m_size,
601-
cache_modifier=".cs",
602-
)
578+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
579+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
580+
c = accumulator.to(c_ptr.dtype.element_ty)
581+
tl.store(
582+
c_ptr
583+
+ (M_start_offset + offs_am[:, None]) * N
584+
+ offs_bn[None, :],
585+
c,
586+
mask=offs_am[:, None] < m_size,
587+
cache_modifier=".cs",
588+
)
603589
tidx += NUM_SMS
604590

605591
iterated_tiles += num_tiles
@@ -841,14 +827,13 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
841827
num_tiles = num_m_tiles * NUM_N_TILES
842828

843829
if USE_TMA_STORE:
844-
with tl.async_task([0]):
845-
c_desc_ptr = tl.make_tensor_descriptor(
846-
c_ptr + M_start_offset * N,
847-
shape=[m_size, N],
848-
# pyre-ignore
849-
strides=[N, 1],
850-
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
851-
)
830+
c_desc_ptr = tl.make_tensor_descriptor(
831+
c_ptr + M_start_offset * N,
832+
shape=[m_size, N],
833+
# pyre-ignore
834+
strides=[N, 1],
835+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
836+
)
852837

853838
# Move across tiles
854839
next_iterated_tiles = iterated_tiles + num_tiles
@@ -867,107 +852,85 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
867852
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
868853
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
869854
for k_offset in range(0, K, BLOCK_SIZE_K):
870-
with tl.async_task([0]):
871-
a = tl._experimental_descriptor_load(
872-
a_desc_ptr,
873-
[m_offset, k_offset],
874-
[BLOCK_SIZE_M, BLOCK_SIZE_K],
875-
dtype,
876-
)
877-
b = tl._experimental_descriptor_load(
878-
b_desc_ptr,
879-
[n_offset, k_offset],
880-
[BLOCK_SIZE_N, BLOCK_SIZE_K],
881-
dtype,
882-
)
883-
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
884-
if USE_FAST_ACCUM:
885-
accumulator = tl.dot(a, b.T, accumulator)
886-
else:
887-
accumulator += tl.dot(a, b.T)
855+
a = tl._experimental_descriptor_load(
856+
a_desc_ptr,
857+
[m_offset, k_offset],
858+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
859+
dtype,
860+
)
861+
b = tl._experimental_descriptor_load(
862+
b_desc_ptr,
863+
[n_offset, k_offset],
864+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
865+
dtype,
866+
)
867+
if USE_FAST_ACCUM:
868+
accumulator = tl.dot(a, b.T, accumulator)
869+
else:
870+
accumulator += tl.dot(a, b.T)
888871

889872
if USE_TMA_LOAD_ON_SCALES:
890-
with tl.async_task([0]):
891-
b_scale = tl._experimental_descriptor_load(
892-
b_scale_desc_ptr,
893-
[n_offset],
894-
[BLOCK_SIZE_N],
895-
tl.float32,
896-
)
897-
898-
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
899-
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
900-
0, BLOCK_SIZE_M
901-
)
902-
a_scale = tl.load(
903-
a_scale_ptr + M_start_offset + offs_am[:, None],
904-
mask=offs_am[:, None] < m_size,
905-
cache_modifier=".ca",
906-
)
907-
c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
873+
b_scale = tl._experimental_descriptor_load(
874+
b_scale_desc_ptr,
875+
[n_offset],
876+
[BLOCK_SIZE_N],
877+
tl.float32,
878+
)
879+
880+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
881+
a_scale = tl.load(
882+
a_scale_ptr + M_start_offset + offs_am[:, None],
883+
mask=offs_am[:, None] < m_size,
884+
cache_modifier=".ca",
885+
)
886+
c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
908887
else:
909-
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
910-
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
911-
0, BLOCK_SIZE_M
912-
)
913-
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
914-
0, BLOCK_SIZE_N
915-
)
916-
a_scale = tl.load(
917-
a_scale_ptr + M_start_offset + offs_am[:, None],
918-
mask=offs_am[:, None] < m_size,
919-
cache_modifier=".ca",
920-
)
921-
b_scale = tl.load(
922-
b_scale_ptr + N_start_offset + offs_bn[None, :],
923-
cache_modifier=".ca",
924-
)
925-
c = accumulator.to(tl.float32) * a_scale * b_scale
888+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
889+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
890+
a_scale = tl.load(
891+
a_scale_ptr + M_start_offset + offs_am[:, None],
892+
mask=offs_am[:, None] < m_size,
893+
cache_modifier=".ca",
894+
)
895+
b_scale = tl.load(
896+
b_scale_ptr + N_start_offset + offs_bn[None, :],
897+
cache_modifier=".ca",
898+
)
899+
c = accumulator.to(tl.float32) * a_scale * b_scale
926900

927901
if USE_TMA_STORE:
928-
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
929-
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
930-
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
931-
# pyre-ignore
932-
c_desc_ptr.store(
933-
[m_offset, n_offset], c.to(c_ptr.dtype.element_ty)
934-
)
902+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
903+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
904+
# pyre-ignore
905+
c_desc_ptr.store(
906+
[m_offset, n_offset], c.to(c_ptr.dtype.element_ty)
907+
)
935908
elif FUSE_SCATTER_ADD:
936-
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
937-
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
938-
0, BLOCK_SIZE_M
939-
)
940-
mask = offs_am < m_size
941-
m_offsets = tl.load(
942-
scatter_add_indices + M_start_offset + offs_am,
943-
mask=mask,
944-
cache_modifier=".ca",
945-
)
946-
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
947-
0, BLOCK_SIZE_N
948-
)
949-
tl.atomic_add(
950-
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
951-
c,
952-
mask=mask[:, None],
953-
sem="relaxed",
954-
)
909+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
910+
mask = offs_am < m_size
911+
m_offsets = tl.load(
912+
scatter_add_indices + M_start_offset + offs_am,
913+
mask=mask,
914+
cache_modifier=".ca",
915+
)
916+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
917+
tl.atomic_add(
918+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
919+
c,
920+
mask=mask[:, None],
921+
sem="relaxed",
922+
)
955923
else:
956-
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
957-
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
958-
0, BLOCK_SIZE_M
959-
)
960-
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
961-
0, BLOCK_SIZE_N
962-
)
963-
tl.store(
964-
c_ptr
965-
+ (M_start_offset + offs_am[:, None]) * N
966-
+ offs_bn[None, :],
967-
c,
968-
mask=offs_am[:, None] < m_size,
969-
cache_modifier=".cs",
970-
)
924+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
925+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
926+
tl.store(
927+
c_ptr
928+
+ (M_start_offset + offs_am[:, None]) * N
929+
+ offs_bn[None, :],
930+
c,
931+
mask=offs_am[:, None] < m_size,
932+
cache_modifier=".cs",
933+
)
971934
tidx += NUM_SMS
972935

973936
iterated_tiles += num_tiles

0 commit comments

Comments
 (0)