Skip to content

Commit 56fb59b

Browse files
authored
[TRITON] Change all MI300 references to gfx942 (#1509)
* Change all MI300 references to gfx942 * removing the redundant _ARCH_TO_DEVICE to directly use get_arch
1 parent 2741e6f commit 56fb59b

File tree

67 files changed

+59
-71
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+59
-71
lines changed

aiter/ops/triton/_triton_kernels/batched_gemm_a16wfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def _get_config(
322322
K: int,
323323
):
324324
if not hasattr(_get_config, "_config_dict"):
325-
dev = arch_info.get_device()
325+
dev = arch_info.get_arch()
326326
_get_config._config_dict = {}
327327
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM_PREQUANT-AFP4WFP4.json"
328328
with open(fpath, "r") as file:
@@ -331,7 +331,7 @@ def _get_config(
331331

332332
key = f"{N}_{K}"
333333
if key not in _get_config._config_dict.keys():
334-
dev = arch_info.get_device()
334+
dev = arch_info.get_arch()
335335
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM_PREQUANT-AFP4WFP4-N={N}-K={2*K}.json"
336336
if os.path.exists(fpath):
337337
with open(fpath, "r") as file:

aiter/ops/triton/_triton_kernels/batched_gemm_a8w8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _get_config(
204204
K: int,
205205
):
206206
if not hasattr(_get_config, "_config_dict"):
207-
dev = arch_info.get_device()
207+
dev = arch_info.get_arch()
208208
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-A8W8.json"
209209
print(f"fpath={fpath}")
210210
with open(fpath, "r") as file:

aiter/ops/triton/_triton_kernels/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def _get_config(
208208
K: int,
209209
):
210210
if not hasattr(_get_config, "_config_dict"):
211-
dev = arch_info.get_device()
211+
dev = arch_info.get_arch()
212212
_get_config._config_dict = {}
213213
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT.json"
214214
with open(fpath, "r") as file:
@@ -217,7 +217,7 @@ def _get_config(
217217

218218
key = f"{N}_{K}"
219219
if key not in _get_config._config_dict.keys():
220-
dev = arch_info.get_device()
220+
dev = arch_info.get_arch()
221221
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N={N}-K={K}.json"
222222
if os.path.exists(fpath):
223223
with open(fpath, "r") as file:

aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def _get_config(
331331
K: int,
332332
):
333333
if not hasattr(_get_config, "_config_dict"):
334-
dev = arch_info.get_device()
334+
dev = arch_info.get_arch()
335335
_get_config._config_dict = {}
336336
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-AFP4WFP4.json"
337337
with open(fpath, "r") as file:
@@ -340,7 +340,7 @@ def _get_config(
340340

341341
key = f"{N}_{K}"
342342
if key not in _get_config._config_dict.keys():
343-
dev = arch_info.get_device()
343+
dev = arch_info.get_arch()
344344
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-AFP4WFP4-N={N}-K={2*K}.json"
345345
if os.path.exists(fpath):
346346
with open(fpath, "r") as file:

aiter/ops/triton/_triton_kernels/batched_gemm_bf16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _get_config(
183183
K: int,
184184
):
185185
if not hasattr(_get_config, "_config_dict"):
186-
dev = arch_info.get_device()
186+
dev = arch_info.get_arch()
187187
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-A16W16.json"
188188
print(f"fpath={fpath}")
189189
with open(fpath, "r") as file:

aiter/ops/triton/_triton_kernels/extend_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def _fwd_kernel(
323323
@functools.lru_cache(maxsize=1024)
324324
def _get_config(HEAD_SIZE, dtype):
325325
if not hasattr(_get_config, "_config_dict"):
326-
dev = arch_info.get_device()
326+
dev = arch_info.get_arch()
327327
_get_config._config_dict = {}
328328
fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-EXTEND_ATTENTION.json"
329329
with open(fpath, "r") as file:

aiter/ops/triton/_triton_kernels/ff_a16w16_fused_gated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _get_config(
204204
K: int,
205205
):
206206
if not hasattr(_get_config, "_config_dict"):
207-
dev = arch_info.get_device()
207+
dev = arch_info.get_arch()
208208
_get_config._config_dict = {}
209209
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FF-A16W16-fused.json"
210210
with open(fpath, "r") as file:
@@ -213,7 +213,7 @@ def _get_config(
213213

214214
key = f"{N}_{K}"
215215
if key not in _get_config._config_dict.keys():
216-
dev = arch_info.get_device()
216+
dev = arch_info.get_arch()
217217
fpath = (
218218
f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FF-A16W16-fused-N={N}-K={K}.json"
219219
)

aiter/ops/triton/_triton_kernels/ff_a16w16_fused_ungated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _get_config(
168168
K: int,
169169
):
170170
if not hasattr(_get_config, "_config_dict"):
171-
dev = arch_info.get_device()
171+
dev = arch_info.get_arch()
172172
_get_config._config_dict = {}
173173
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FF-A16W16-fused.json"
174174
with open(fpath, "r") as file:
@@ -177,7 +177,7 @@ def _get_config(
177177

178178
key = f"{N}_{K}"
179179
if key not in _get_config._config_dict.keys():
180-
dev = arch_info.get_device()
180+
dev = arch_info.get_arch()
181181
fpath = (
182182
f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FF-A16W16-fused-N={N}-K={K}.json"
183183
)

aiter/ops/triton/_triton_kernels/fused_gemm_a8w8_blockscale_a16w16.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def _get_config(
410410
K: int,
411411
):
412412
if not hasattr(_get_config, "_config_dict"):
413-
dev = arch_info.get_device()
413+
dev = arch_info.get_arch()
414414
_get_config._config_dict = {}
415415
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json"
416416
with open(fpath, "r") as file:
@@ -419,7 +419,7 @@ def _get_config(
419419

420420
key = f"{N_fp8}_{N_bf16}_{K}"
421421
if key not in _get_config._config_dict.keys():
422-
dev = arch_info.get_device()
422+
dev = arch_info.get_arch()
423423
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8={N_fp8}-N16={N_bf16}-K={K}.json"
424424
if os.path.exists(fpath):
425425
with open(fpath, "r") as file:

aiter/ops/triton/_triton_kernels/fused_gemm_afp4wfp4_a16w16.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def _get_config(
825825
):
826826
shuffle_filename_suffix = "" if not shuffle else "_PRESHUFFLED"
827827
if not hasattr(_get_config, "_config_dict"):
828-
dev = arch_info.get_device()
828+
dev = arch_info.get_arch()
829829
_get_config._config_dict = {}
830830
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FUSED-GEMM-AFP4WFP4{shuffle_filename_suffix}-A16W16.json"
831831
with open(fpath, "r") as file:
@@ -834,7 +834,7 @@ def _get_config(
834834

835835
key = f"{N_fp4}_{N_bf16}_{K}"
836836
if key not in _get_config._config_dict.keys():
837-
dev = arch_info.get_device()
837+
dev = arch_info.get_arch()
838838
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FUSED-GEMM-AFP4WFP4{shuffle_filename_suffix}-A16W16-N4={N_fp4}-N16={N_bf16}-K={2*K}.json"
839839
if os.path.exists(fpath):
840840
with open(fpath, "r") as file:

0 commit comments

Comments
 (0)