Skip to content

Commit d8d2482

Browse files
committed
[CIR][AMDGPU] Add CIR lowering for amdgcn wave reduce intrinsics
1 parent 51870fc commit d8d2482

File tree

3 files changed

+396
-1
lines changed

3 files changed

+396
-1
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinAMDGPU.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,40 @@ using namespace clang;
2020
using namespace clang::CIRGen;
2121
using namespace cir;
2222

23+
static llvm::StringRef getIntrinsicNameforWaveReduction(unsigned BuiltinID) {
24+
switch (BuiltinID) {
25+
default:
26+
llvm_unreachable("Unknown BuiltinID for wave reduction");
27+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_add_u32:
28+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_add_u64:
29+
return "amdgcn.wave.reduce.add";
30+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_sub_u32:
31+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_sub_u64:
32+
return "amdgcn.wave.reduce.sub";
33+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_min_i32:
34+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_min_i64:
35+
return "amdgcn.wave.reduce.min";
36+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_min_u32:
37+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_min_u64:
38+
return "amdgcn.wave.reduce.umin";
39+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_max_i32:
40+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_max_i64:
41+
return "amdgcn.wave.reduce.max";
42+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_max_u32:
43+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_max_u64:
44+
return "amdgcn.wave.reduce.umax";
45+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_and_b32:
46+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_and_b64:
47+
return "amdgcn.wave.reduce.and";
48+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_or_b32:
49+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_or_b64:
50+
return "amdgcn.wave.reduce.or";
51+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_xor_b32:
52+
case clang::AMDGPU::BI__builtin_amdgcn_wave_reduce_xor_b64:
53+
return "amdgcn.wave.reduce.xor";
54+
}
55+
}
56+
2357
mlir::Value CIRGenFunction::emitAMDGPUBuiltinExpr(unsigned builtinId,
2458
const CallExpr *expr) {
2559
switch (builtinId) {
@@ -41,7 +75,13 @@ mlir::Value CIRGenFunction::emitAMDGPUBuiltinExpr(unsigned builtinId,
4175
case AMDGPU::BI__builtin_amdgcn_wave_reduce_and_b64:
4276
case AMDGPU::BI__builtin_amdgcn_wave_reduce_or_b64:
4377
case AMDGPU::BI__builtin_amdgcn_wave_reduce_xor_b64: {
44-
llvm_unreachable("wave_reduce_* NYI");
78+
llvm::StringRef intrinsicName = getIntrinsicNameforWaveReduction(builtinId);
79+
mlir::Value Value = emitScalarExpr(expr->getArg(0));
80+
mlir::Value Strategy = emitScalarExpr(expr->getArg(1));
81+
return LLVMIntrinsicCallOp::create(builder, getLoc(expr->getExprLoc()),
82+
builder.getStringAttr(intrinsicName),
83+
Value.getType(), {Value, Strategy})
84+
.getResult();
4585
}
4686
case AMDGPU::BI__builtin_amdgcn_div_scale:
4787
case AMDGPU::BI__builtin_amdgcn_div_scalef: {
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#include "../Inputs/cuda.h"
2+
3+
// REQUIRES: amdgpu-registered-target
4+
// RUN: %clang_cc1 -triple amdgcn-amd-amdhsa -x hip -std=c++11 -fclangir \
5+
// RUN: -fcuda-is-device -emit-cir %s -o %t.cir
6+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
7+
8+
// RUN: %clang_cc1 -triple amdgcn-amd-amdhsa -x hip -std=c++11 -fclangir \
9+
// RUN: -fcuda-is-device -emit-llvm %s -o %t.ll
10+
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s
11+
12+
//===----------------------------------------------------------------------===//
13+
// Test AMDGPU built-in functions
14+
//===----------------------------------------------------------------------===//
15+
16+
// CIR-LABEL: @_Z28test_wave_reduce_add_u32_i32Pi
17+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.add" {{.*}} : (!u32i, !s32i) -> !u32i
18+
// LLVM: define{{.*}} void @_Z28test_wave_reduce_add_u32_i32Pii(
19+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.add.i32(i32 %{{.*}}, i32 0)
20+
__device__ void test_wave_reduce_add_u32_i32(int* out, int in) {
21+
*out = __builtin_amdgcn_wave_reduce_add_u32(in, 0);
22+
}
23+
24+
// CIR-LABEL: @_Z28test_wave_reduce_add_u64_i64Pl
25+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.add" {{.*}} : (!u64i, !s32i) -> !u64i
26+
// LLVM: define{{.*}} void @_Z28test_wave_reduce_add_u64_i64Pll(
27+
// LLVM: call i64 @llvm.amdgcn.wave.reduce.add.i64(i64 %{{.*}}, i32 0)
28+
__device__ void test_wave_reduce_add_u64_i64(long* out, long in) {
29+
*out = __builtin_amdgcn_wave_reduce_add_u64(in, 0);
30+
}
31+
32+
// CIR-LABEL: @_Z28test_wave_reduce_sub_u32_i32Pi
33+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.sub" {{.*}} : (!u32i, !s32i) -> !u32i
34+
// LLVM: define{{.*}} void @_Z28test_wave_reduce_sub_u32_i32Pii(
35+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.sub.i32(i32 %{{.*}}, i32 0)
36+
__device__ void test_wave_reduce_sub_u32_i32(int* out, int in) {
37+
*out = __builtin_amdgcn_wave_reduce_sub_u32(in, 0);
38+
}
39+
40+
// CIR-LABEL: @_Z28test_wave_reduce_sub_u64_i64Pl
41+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.sub" {{.*}} : (!u64i, !s32i) -> !u64i
42+
// LLVM: define{{.*}} void @_Z28test_wave_reduce_sub_u64_i64Pll(
43+
// LLVM: call i64 @llvm.amdgcn.wave.reduce.sub.i64(i64 %{{.*}}, i32 0)
44+
__device__ void test_wave_reduce_sub_u64_i64(long* out, long in) {
45+
*out = __builtin_amdgcn_wave_reduce_sub_u64(in, 0);
46+
}
47+
48+
// CIR-LABEL: @_Z29test_wave_reduce_min_i32_signPii
49+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.min" {{.*}} : (!s32i, !s32i) -> !s32i
50+
// LLVM: define{{.*}} void @_Z29test_wave_reduce_min_i32_signPii(
51+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.min.i32(i32 %{{.*}}, i32 0)
52+
__device__ void test_wave_reduce_min_i32_sign(int* out, int in) {
53+
*out = __builtin_amdgcn_wave_reduce_min_i32(in, 0);
54+
}
55+
56+
// CIR-LABEL: @_Z31test_wave_reduce_min_u32_unsignPjj
57+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.umin" {{.*}} : (!u32i, !s32i) -> !u32i
58+
// LLVM: define{{.*}} void @_Z31test_wave_reduce_min_u32_unsignPjj(
59+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.umin.i32(i32 %{{.*}}, i32 0)
60+
__device__ void test_wave_reduce_min_u32_unsign(unsigned int* out, unsigned int in) {
61+
*out = __builtin_amdgcn_wave_reduce_min_u32(in, 0);
62+
}
63+
64+
// CIR-LABEL: @_Z29test_wave_reduce_min_i64_signPll
65+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.min" {{.*}} : (!s64i, !s32i) -> !s64i
66+
// LLVM: define{{.*}} void @_Z29test_wave_reduce_min_i64_signPll(
67+
// LLVM: call i64 @llvm.amdgcn.wave.reduce.min.i64(i64 %{{.*}}, i32 0)
68+
__device__ void test_wave_reduce_min_i64_sign(long* out, long in) {
69+
*out = __builtin_amdgcn_wave_reduce_min_i64(in, 0);
70+
}
71+
72+
// CIR-LABEL: @_Z31test_wave_reduce_min_u64_unsignPmm
73+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.umin" {{.*}} : (!u64i, !s32i) -> !u64i
74+
// LLVM: define{{.*}} void @_Z31test_wave_reduce_min_u64_unsignPmm(
75+
// LLVM: call i64 @llvm.amdgcn.wave.reduce.umin.i64(i64 %{{.*}}, i32 0)
76+
__device__ void test_wave_reduce_min_u64_unsign(unsigned long* out, unsigned long in) {
77+
*out = __builtin_amdgcn_wave_reduce_min_u64(in, 0);
78+
}
79+
80+
// CIR-LABEL: @_Z29test_wave_reduce_max_i32_signPii
81+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.max" {{.*}} : (!s32i, !s32i) -> !s32i
82+
// LLVM: define{{.*}} void @_Z29test_wave_reduce_max_i32_signPii(
83+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.max.i32(i32 %{{.*}}, i32 0)
84+
__device__ void test_wave_reduce_max_i32_sign(int* out, int in) {
85+
*out = __builtin_amdgcn_wave_reduce_max_i32(in, 0);
86+
}
87+
88+
// CIR-LABEL: @_Z31test_wave_reduce_max_u32_unsignPjj
89+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.umax" {{.*}} : (!u32i, !s32i) -> !u32i
90+
// LLVM: define{{.*}} void @_Z31test_wave_reduce_max_u32_unsignPjj(
91+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.umax.i32(i32 %{{.*}}, i32 0)
92+
__device__ void test_wave_reduce_max_u32_unsign(unsigned int* out, unsigned int in) {
93+
*out = __builtin_amdgcn_wave_reduce_max_u32(in, 0);
94+
}
95+
96+
// CIR-LABEL: @_Z29test_wave_reduce_max_i64_signPll
97+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.max" {{.*}} : (!s64i, !s32i) -> !s64i
98+
// LLVM: define{{.*}} void @_Z29test_wave_reduce_max_i64_signPll(
99+
// LLVM: call i64 @llvm.amdgcn.wave.reduce.max.i64(i64 %{{.*}}, i32 0)
100+
__device__ void test_wave_reduce_max_i64_sign(long* out, long in) {
101+
*out = __builtin_amdgcn_wave_reduce_max_i64(in, 0);
102+
}
103+
104+
// CIR-LABEL: @_Z31test_wave_reduce_max_u64_unsignPmm
105+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.umax" {{.*}} : (!u64i, !s32i) -> !u64i
106+
// LLVM: define{{.*}} void @_Z31test_wave_reduce_max_u64_unsignPmm(
107+
// LLVM: call i64 @llvm.amdgcn.wave.reduce.umax.i64(i64 %{{.*}}, i32 0)
108+
__device__ void test_wave_reduce_max_u64_unsign(unsigned long* out, unsigned long in) {
109+
*out = __builtin_amdgcn_wave_reduce_max_u64(in, 0);
110+
}
111+
112+
// CIR-LABEL: @_Z28test_wave_reduce_and_b32_i32Pii
113+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.and" {{.*}} : (!s32i, !s32i) -> !s32i
114+
// LLVM: define{{.*}} void @_Z28test_wave_reduce_and_b32_i32Pii(
115+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.and.i32(i32 %{{.*}}, i32 0)
116+
__device__ void test_wave_reduce_and_b32_i32(int* out, int in) {
117+
*out = __builtin_amdgcn_wave_reduce_and_b32(in, 0);
118+
}
119+
120+
// CIR-LABEL: @_Z28test_wave_reduce_and_b64_i64Pll
121+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.and" {{.*}} : (!s64i, !s32i) -> !s64i
122+
// LLVM: define{{.*}} void @_Z28test_wave_reduce_and_b64_i64Pll(
123+
// LLVM: call i64 @llvm.amdgcn.wave.reduce.and.i64(i64 %{{.*}}, i32 0)
124+
__device__ void test_wave_reduce_and_b64_i64(long* out, long in) {
125+
*out = __builtin_amdgcn_wave_reduce_and_b64(in, 0);
126+
}
127+
128+
// CIR-LABEL: @_Z27test_wave_reduce_or_b32_i32Pii
129+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.or" {{.*}} : (!s32i, !s32i) -> !s32i
130+
// LLVM: define{{.*}} void @_Z27test_wave_reduce_or_b32_i32Pii(
131+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.or.i32(i32 %{{.*}}, i32 0)
132+
__device__ void test_wave_reduce_or_b32_i32(int* out, int in) {
133+
*out = __builtin_amdgcn_wave_reduce_or_b32(in, 0);
134+
}
135+
136+
// CIR-LABEL: @_Z27test_wave_reduce_or_b64_i64Pll
137+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.or" {{.*}} : (!s64i, !s32i) -> !s64i
138+
// LLVM: define{{.*}} void @_Z27test_wave_reduce_or_b64_i64Pll(
139+
// LLVM: call i64 @llvm.amdgcn.wave.reduce.or.i64(i64 %{{.*}}, i32 0)
140+
__device__ void test_wave_reduce_or_b64_i64(long* out, long in) {
141+
*out = __builtin_amdgcn_wave_reduce_or_b64(in, 0);
142+
}
143+
144+
// CIR-LABEL: @_Z28test_wave_reduce_xor_b32_i32Pii
145+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.xor" {{.*}} : (!s32i, !s32i) -> !s32i
146+
// LLVM: define{{.*}} void @_Z28test_wave_reduce_xor_b32_i32Pii(
147+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.xor.i32(i32 %{{.*}}, i32 0)
148+
__device__ void test_wave_reduce_xor_b32_i32(int* out, int in) {
149+
*out = __builtin_amdgcn_wave_reduce_xor_b32(in, 0);
150+
}
151+
152+
// CIR-LABEL: @_Z28test_wave_reduce_xor_b64_i64Pll
153+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.xor" {{.*}} : (!s64i, !s32i) -> !s64i
154+
// LLVM: define{{.*}} void @_Z28test_wave_reduce_xor_b64_i64Pll(
155+
// LLVM: call i64 @llvm.amdgcn.wave.reduce.xor.i64(i64 %{{.*}}, i32 0)
156+
__device__ void test_wave_reduce_xor_b64_i64(long* out, long in) {
157+
*out = __builtin_amdgcn_wave_reduce_xor_b64(in, 0);
158+
}
159+
160+
// CIR-LABEL: @_Z38test_wave_reduce_add_u32_iterative_i32Pii
161+
// CIR: cir.const #cir.int<1> : !s32i
162+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.add" {{.*}} : (!u32i, !s32i) -> !u32i
163+
// LLVM: define{{.*}} void @_Z38test_wave_reduce_add_u32_iterative_i32Pii(
164+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.add.i32(i32 %{{.*}}, i32 1)
165+
__device__ void test_wave_reduce_add_u32_iterative_i32(int* out, int in) {
166+
*out = __builtin_amdgcn_wave_reduce_add_u32(in, 1);
167+
}
168+
169+
// CIR-LABEL: @_Z32test_wave_reduce_add_u32_dpp_i32Pii
170+
// CIR: cir.const #cir.int<2> : !s32i
171+
// CIR: cir.llvm.intrinsic "amdgcn.wave.reduce.add" {{.*}} : (!u32i, !s32i) -> !u32i
172+
// LLVM: define{{.*}} void @_Z32test_wave_reduce_add_u32_dpp_i32Pii(
173+
// LLVM: call i32 @llvm.amdgcn.wave.reduce.add.i32(i32 %{{.*}}, i32 2)
174+
__device__ void test_wave_reduce_add_u32_dpp_i32(int* out, int in) {
175+
*out = __builtin_amdgcn_wave_reduce_add_u32(in, 2);
176+
}

0 commit comments

Comments
 (0)