Skip to content

Commit a238514

Browse files
zhiqchen-amdzhaoan12-prc
authored andcommitted
add new version #1328
1 parent bfe4aca commit a238514

File tree

8 files changed

+52
-33
lines changed

8 files changed

+52
-33
lines changed

3rdparty/aiter/aiter.patch

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@
3636

3737
--- aiter/jit/core.py
3838
+++ aiter/jit/core.py
39-
@@ -66,37 +66,18 @@
40-
AITER_ROOT_DIR = os.path.abspath(f"{this_dir}/../../")
41-
AITER_LOG_MORE = int(os.getenv("AITER_LOG_MORE", 0))
39+
@@ -168,36 +168,18 @@
40+
# config_env end here
4241

43-
-find_aiter = importlib.util.find_spec("aiter")
42+
find_aiter = importlib.util.find_spec("aiter")
4443
-if find_aiter is not None:
4544
- if find_aiter.submodule_search_locations:
4645
- package_path = find_aiter.submodule_search_locations[0]
@@ -82,4 +81,27 @@
8281
)
8382

8483

84+
--- aiter/__init__.py
85+
+++ aiter/__init__.py
86+
@@ -77,3 +77,4 @@
87+
from .ops.trans_ragged_layout import *
88+
from .ops.sample import *
89+
from . import mla
90+
+from . import paged_attn
91+
92+
--- aiter/ops/gemm_op_a8w8.py
93+
+++ aiter/ops/gemm_op_a8w8.py
94+
@@ -425,9 +425,11 @@
95+
WQ: Tensor,
96+
x_scale: Tensor,
97+
w_scale: Tensor,
98+
- dtype: torch.dtype = dtypes.bf16,
99+
+ dtype: torch.dtype = None,
100+
isBpreshuffled: bool = False,
101+
) -> torch.Tensor:
102+
+ if dtype is None:
103+
+ dtype = torch.bfloat16
104+
assert dtype in [
105+
dtypes.bf16,
106+
dtypes.fp16,
85107

open_source/bazel/arch_select.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def subscribe_deps():
6565
def whl_deps():
6666
return select({
6767
"@//:using_cuda12": ["torch==2.6.0+cu126"],
68-
"@//:using_rocm": ["pyrsmi", "amdsmi@https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis%2FAMD%2Famd_smi%2Fali%2Famd_smi.tar", "aiter@https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis/AMD/RTP/aiter-0.1.5%2Bgit.007fe7aa.date.202510272053-py3-none-any.whl"],
68+
"@//:using_rocm": ["pyrsmi", "amdsmi@https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis%2FAMD%2Famd_smi%2Fali%2Famd_smi.tar", "aiter@https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis/AMD/aiter/0.1.6/aiter-0.1.6%2Bgit.329d07ba.date.202511061625-py3-none-any.whl"],
6969
"//conditions:default": ["torch==2.1.2"],
7070
})
7171

open_source/deps/git.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def git_deps():
1010
git_repository(
1111
name = "aiter_src",
1212
remote = "https://github.com/ROCm/aiter.git",
13-
commit = "007fe7aa070d827bbdad398a578f403057a34e87", # add several ds shapes to fp4 tuned config (#1131)
13+
commit = "329d07ba5d77f7d6b2a0557174288c5707f95e5f", # [Triton] DS a16w8 GEMM and fused reduce_rms_fp8_group_quant (#1328)
1414
recursive_init_submodules = True,
1515
patches = ["//3rdparty/aiter:aiter.patch", "//3rdparty/aiter:gemm_a8w8.patch"],
1616
patch_cmds = [

open_source/deps/http.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def http_deps():
6060

6161
http_archive(
6262
name = "aiter",
63-
sha256 = "236197b1e55e546ab80a19a2f42cfb69075ff058c8b33341b70bc513e488febd",
63+
sha256 = "cf1ac18a72e08f38133cf8891a1484d694b482925f1196dda398fd10c19586f2",
6464
urls = [
65-
"https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis/AMD/RTP/aiter-0.1.5%2Bgit.007fe7aa.date.202510272053-py3-none-any.whl",
65+
"https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis/AMD/aiter/0.1.6/aiter-0.1.6%2Bgit.329d07ba.date.202511061625-py3-none-any.whl",
6666
],
6767
type = "zip",
6868
build_file = clean_dep("//:BUILD.aiter"),

open_source/deps/requirements_lock_rocm.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ aiosignal==1.3.1 \
114114
--hash=sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc \
115115
--hash=sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17
116116
# via aiohttp
117-
aiter @ https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis/AMD/RTP/aiter-0.1.5%2Bgit.007fe7aa.date.202510272053-py3-none-any.whl \
118-
--hash=sha256:236197b1e55e546ab80a19a2f42cfb69075ff058c8b33341b70bc513e488febd
117+
aiter @ https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis/AMD/aiter/0.1.6/aiter-0.1.6%2Bgit.329d07ba.date.202511061625-py3-none-any.whl \
118+
--hash=sha256:cf1ac18a72e08f38133cf8891a1484d694b482925f1196dda398fd10c19586f2
119119
# via -r open_source/deps/requirements_rocm.txt
120120
aliyun-python-sdk-core==2.15.2 \
121121
--hash=sha256:54f66a53e193c61c5e16ea4505a0cab43543f8ad2ef22833f69c4d5e5151c17d

open_source/deps/requirements_rocm.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/torch-2.4.1%2Brocm6.4.1.gi
44
https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/torchvision-0.19.0%2Brocm6.4.1.git4d41ad71-cp310-cp310-linux_x86_64.whl
55
pyrsmi
66
pyyaml
7-
https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis/AMD/RTP/aiter-0.1.5%2Bgit.007fe7aa.date.202510272053-py3-none-any.whl
7+
https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis/AMD/aiter/0.1.6/aiter-0.1.6%2Bgit.329d07ba.date.202511061625-py3-none-any.whl
88
https://sinian-metrics-platform.oss-cn-hangzhou.aliyuncs.com/kis%2FAMD%2Famd_smi%2Fali%2Famd_smi.tar

rtp_llm/cpp/rocm/custom_ar/custom_ar_comm.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,16 @@ bool CustomAllReduceComm::checkAllReduceAvailable(size_t elts_total_num, DataTyp
4343

4444
void CustomAllReduceComm::allReduce(torch::Tensor& input_tensor, torch::Tensor& output_tensor) {
4545
if (at::hip::currentStreamCaptureStatusMayInitCtx() != at::hip::CaptureStatus::None) {
46-
aiter::all_reduce_reg(fa_, input_tensor, output_tensor, false);
46+
aiter::all_reduce(fa_, input_tensor, output_tensor, false, std::nullopt);
4747
} else {
48-
aiter::all_reduce_unreg(fa_, input_tensor, buffer_, output_tensor);
48+
aiter::all_reduce(fa_, input_tensor, output_tensor, false, buffer_);
4949
}
5050
}
5151

5252
void CustomAllReduceComm::registerGraphBuffers() {
5353
auto handle_and_offset = aiter::get_graph_buffer_ipc_meta(fa_); // tuple<tensor, vector<int64_t>> -> vector<tensor> size=2
54-
auto handle = handle_and_offset[0];
55-
auto offset = handle_and_offset[1];
54+
auto handle = std::get<0>(handle_and_offset);
55+
auto offset = std::get<1>(handle_and_offset);
5656

5757
auto _handles = all_gather(handle.data_ptr(), handle.element_size() * handle.numel(), at::hip::getCurrentHIPStream().stream());
5858
auto _offsets = all_gather(offset.data_ptr(), offset.element_size() * offset.numel(), at::hip::getCurrentHIPStream().stream());

rtp_llm/cpp/rocm/rocmFmhaWrapper.cc

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,12 @@ uint32_t rocmFmhaWrapper::runCKFmha(void* q,
221221
// nullptr, // o_acc_buf.GetDeviceBuffer(),
222222
softmax_lse_,
223223
output,
224+
seqstart_q, //seqstart_q_ptr
225+
seqstart_k, //seqstart_k_ptr
226+
nullptr, //seqlen_q_ptr
227+
nullptr, //seqlen_k_ptr
224228
nullptr, //cu_seqlen_q_ptr
225-
nullptr, //cu_seqlen_kv_ptr
226-
seqstart_q,
227-
seqstart_k,
228-
nullptr, // seqlen_kpads
229-
nullptr, //seqstart_padded_q_ptr
230-
nullptr, //seqstart_padded_k_ptr
229+
nullptr, // cu_seqlen_k_ptr
231230
shape_seqlen_q,
232231
shape_seqlen_k,
233232
batch,
@@ -489,13 +488,12 @@ uint32_t rocmFmhaWrapper::runCKFmhaV2(void* q,
489488
// nullptr, // o_acc_buf.GetDeviceBuffer(),
490489
softmax_lse_,
491490
output,
491+
seqstart_q, //seqstart_q_ptr
492+
seqstart_k, //seqstart_k_ptr
493+
nullptr, //seqlen_q_ptr
494+
nullptr, //seqlen_k_ptr
492495
nullptr, //cu_seqlen_q_ptr
493-
nullptr, //cu_seqlen_kv_ptr
494-
seqstart_q,
495-
seqstart_k,
496-
nullptr, // seqlen_kpads
497-
nullptr, //seqstart_padded_q_ptr
498-
nullptr, //seqstart_padded_k_ptr
496+
nullptr, // cu_seqlen_k_ptr
499497
shape_seqlen_q,
500498
shape_seqlen_k,
501499
batch,
@@ -759,13 +757,12 @@ uint32_t rocmFmhaWrapper::runCKFmhaMLA(void* q,
759757
// nullptr, // o_acc_buf.GetDeviceBuffer(),
760758
softmax_lse_,
761759
output,
760+
seqstart_q, //seqstart_q_ptr
761+
seqstart_k, //seqstart_k_ptr
762+
nullptr, //seqlen_q_ptr
763+
nullptr, //seqlen_k_ptr
762764
nullptr, //cu_seqlen_q_ptr
763-
nullptr, //cu_seqlen_kv_ptr
764-
seqstart_q,
765-
seqstart_k,
766-
nullptr, // seqlen_kpads
767-
nullptr, //seqstart_padded_q_ptr
768-
nullptr, //seqstart_padded_k_ptr
765+
nullptr, // cu_seqlen_k_ptr
769766
shape_seqlen_q,
770767
shape_seqlen_k,
771768
batch,

0 commit comments

Comments
 (0)