Skip to content

Commit 88fe0da

Browse files
author
jax authors
committed
Merge pull request #18078 from ROCmSoftwarePlatform:rocm-jax-triton
PiperOrigin-RevId: 574546618
2 parents 9435a0a + b4b97cd commit 88fe0da

File tree

9 files changed

+277
-86
lines changed

9 files changed

+277
-86
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,7 +1496,11 @@ def pallas_call_lowering(
14961496
**compiler_params
14971497
)
14981498
num_warps = compiler_params.get("num_warps", 4)
1499-
num_stages = compiler_params.get("num_stages", 3)
1499+
if ctx.module_context.platform == 'rocm':
1500+
num_stages = compiler_params.get("num_stages", 1)
1501+
else:
1502+
num_stages = compiler_params.get("num_stages", 3)
1503+
15001504
if debug:
15011505
print(jaxpr)
15021506
print(grid_mapping)
@@ -1509,6 +1513,9 @@ def pallas_call_lowering(
15091513
num_stages,
15101514
debug=debug,
15111515
)
1516+
#Triton returns a tuple for ROCm. We just want file path to be passed
1517+
if ctx.module_context.platform == 'rocm':
1518+
compilation_result.ptx = compilation_result.ptx[1]
15121519

15131520
if debug:
15141521
compilation_result.lowering_result.module.dump()
@@ -1562,4 +1569,4 @@ def pallas_call_lowering(
15621569
).results
15631570

15641571

1565-
mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="cuda")
1572+
mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="gpu")

jaxlib/cuda/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ cc_library(
401401
"@xla//xla/service:custom_call_status",
402402
"@xla//xla/stream_executor/gpu:asm_compiler",
403403
"@tsl//tsl/cuda:cudart",
404+
"@tsl//tsl/platform:env",
404405
"@com_google_absl//absl/base:core_headers",
405406
"@com_google_absl//absl/cleanup",
406407
"@com_google_absl//absl/container:flat_hash_map",

jaxlib/gpu/triton.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "jaxlib/gpu/vendor.h"
2121
#include "jaxlib/kernel_nanobind_helpers.h"
2222

23-
#define CUDA_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
23+
#define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
2424

2525
namespace nb = nanobind;
2626

@@ -124,11 +124,11 @@ NB_MODULE(_triton, m) {
124124
m.def("get_compute_capability",
125125
ValueOrThrowWrapper([](int device) -> absl::StatusOr<int> {
126126
int major, minor;
127-
CUDA_RETURN_IF_ERROR(cuInit(device));
128-
CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute(
129-
&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
130-
CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute(
131-
&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
127+
GPU_RETURN_IF_ERROR(gpuInit(device));
128+
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
129+
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
130+
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
131+
&minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
132132
return major * 10 + minor;
133133
}));
134134

jaxlib/gpu/triton_kernels.cc

Lines changed: 82 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,22 @@
2525
#include "jaxlib/gpu/vendor.h"
2626
#include "xla/service/custom_call_status.h"
2727
#include "xla/stream_executor/gpu/asm_compiler.h"
28+
#include "tsl/platform/env.h"
2829

29-
#define CUDA_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
30+
#define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
3031

3132

3233
namespace jax::JAX_GPU_NAMESPACE {
3334
namespace {
3435

35-
constexpr uint32_t kNumThreadsPerWarp = 32;
3636
constexpr float kBenchmarkTimeMillis = 10.;
3737

38-
struct CuModuleDeleter {
39-
void operator()(CUmodule module) { cuModuleUnload(module); }
38+
struct gpuModuleDeleter {
39+
void operator()(gpuModule_t module) { gpuModuleUnload(module); }
4040
};
4141

42-
using OwnedCUmodule =
43-
std::unique_ptr<std::remove_pointer_t<CUmodule>, CuModuleDeleter>;
42+
using OwnedGPUmodule =
43+
std::unique_ptr<std::remove_pointer_t<gpuModule_t>, gpuModuleDeleter>;
4444

4545
absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
4646
uint32_t shared_mem_bytes,
@@ -58,13 +58,21 @@ absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
5858
auto it = module_images.find(key);
5959
if (it != module_images.end()) return it->second.get();
6060

61+
#ifdef JAX_GPU_HIP //For HIP/ROCM just read the hsaco file
62+
std::string result_blob;
63+
std::string fname{ptx};
64+
TF_RETURN_IF_ERROR(
65+
tsl::ReadFileToString(tsl::Env::Default(), fname, &result_blob));
66+
std::vector<uint8_t> module_image(result_blob.begin(), result_blob.end());
67+
#else
6168
// TODO(cjfj): Support `TRITON_PTXAS_PATH` environment variable?
6269
int cc_major = compute_capability / 10;
6370
int cc_minor = compute_capability % 10;
6471
JAX_ASSIGN_OR_RETURN(
6572
std::vector<uint8_t> module_image,
6673
stream_executor::CompileGpuAsm(cc_major, cc_minor, ptx.data(),
6774
stream_executor::GpuAsmOpts{}));
75+
#endif
6876

6977
auto [it2, success] = module_images.insert(
7078
{std::move(key),
@@ -74,27 +82,27 @@ absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
7482
return it2->second.get();
7583
}
7684

77-
absl::StatusOr<float> Benchmark(CUstream stream, KernelCall& kernel_call,
85+
absl::StatusOr<float> Benchmark(gpuStream_t stream, KernelCall& kernel_call,
7886
void** buffers, int num_iterations) {
79-
CUevent start, stop;
80-
CUDA_RETURN_IF_ERROR(cuEventCreate(&start, /*Flags=*/CU_EVENT_DEFAULT));
81-
CUDA_RETURN_IF_ERROR(cuEventCreate(&stop, /*Flags=*/CU_EVENT_DEFAULT));
87+
gpuEvent_t start, stop;
88+
GPU_RETURN_IF_ERROR(gpuEventCreate(&start, /*Flags=*/GPU_EVENT_DEFAULT));
89+
GPU_RETURN_IF_ERROR(gpuEventCreate(&stop, /*Flags=*/GPU_EVENT_DEFAULT));
8290
JAX_RETURN_IF_ERROR(kernel_call.Launch(stream, buffers)); // Warm-up.
83-
CUDA_RETURN_IF_ERROR(cuEventRecord(start, stream));
91+
GPU_RETURN_IF_ERROR(gpuEventRecord(start, stream));
8492
for (int i = 0; i < num_iterations; ++i) {
8593
JAX_RETURN_IF_ERROR(kernel_call.Launch(stream, buffers));
8694
}
87-
CUDA_RETURN_IF_ERROR(cuEventRecord(stop, stream));
88-
CUDA_RETURN_IF_ERROR(cuEventSynchronize(stop));
95+
GPU_RETURN_IF_ERROR(gpuEventRecord(stop, stream));
96+
GPU_RETURN_IF_ERROR(gpuEventSynchronize(stop));
8997
float elapsed_ms;
90-
CUDA_RETURN_IF_ERROR(cuEventElapsedTime(&elapsed_ms, start, stop));
91-
CUDA_RETURN_IF_ERROR(cuEventDestroy(start));
92-
CUDA_RETURN_IF_ERROR(cuEventDestroy(stop));
98+
GPU_RETURN_IF_ERROR(gpuEventElapsedTime(&elapsed_ms, start, stop));
99+
GPU_RETURN_IF_ERROR(gpuEventDestroy(start));
100+
GPU_RETURN_IF_ERROR(gpuEventDestroy(stop));
93101
return elapsed_ms;
94102
}
95103

96104
absl::StatusOr<KernelCall*> GetKernelCall(absl::string_view opaque,
97-
CUstream stream, void** buffers) {
105+
gpuStream_t stream, void** buffers) {
98106
static absl::Mutex mutex;
99107
static auto& kernel_calls =
100108
*new absl::flat_hash_map<std::string, std::unique_ptr<KernelCall>>
@@ -147,23 +155,23 @@ class ModuleImage {
147155
module_image_(std::move(module_image)),
148156
shared_mem_bytes_(shared_mem_bytes) {}
149157

150-
absl::StatusOr<CUfunction> GetFunctionForContext(CUcontext context) {
158+
absl::StatusOr<gpuFunction_t> GetFunctionForContext(gpuContext_t context) {
151159
absl::MutexLock lock(&mutex_);
152160
auto it = functions_.find(context);
153161
if (ABSL_PREDICT_TRUE(it != functions_.end())) {
154162
return it->second;
155163
}
156164

157-
CUDA_RETURN_IF_ERROR(cuCtxPushCurrent(context));
158-
absl::Cleanup ctx_restorer = [] { cuCtxPopCurrent(nullptr); };
165+
GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
166+
absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
159167

160-
CUmodule module;
161-
CUDA_RETURN_IF_ERROR(cuModuleLoadData(&module, module_image_.data()));
162-
modules_.push_back(OwnedCUmodule(module, CuModuleDeleter()));
168+
gpuModule_t module;
169+
GPU_RETURN_IF_ERROR(gpuModuleLoadData(&module, module_image_.data()));
170+
modules_.push_back(OwnedGPUmodule(module, gpuModuleDeleter()));
163171

164-
CUfunction function;
165-
CUDA_RETURN_IF_ERROR(
166-
cuModuleGetFunction(&function, module, kernel_name_.c_str()));
172+
gpuFunction_t function;
173+
GPU_RETURN_IF_ERROR(
174+
gpuModuleGetFunction(&function, module, kernel_name_.c_str()));
167175
auto [_, success] = functions_.insert({context, function});
168176
CHECK(success);
169177

@@ -175,12 +183,12 @@ class ModuleImage {
175183
}
176184

177185
// Set up dynamic shared memory.
178-
CUdevice device;
179-
CUDA_RETURN_IF_ERROR(cuCtxGetDevice(&device));
186+
gpuDevice_t device;
187+
GPU_RETURN_IF_ERROR(gpuCtxGetDevice(&device));
180188

181189
int shared_optin;
182-
CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute(
183-
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
190+
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
191+
&shared_optin, GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
184192
device));
185193

186194
if (shared_mem_bytes_ > shared_optin) {
@@ -190,18 +198,22 @@ class ModuleImage {
190198
}
191199

192200
if (shared_optin > kMaxStaticSharedMemBytes) {
193-
CUDA_RETURN_IF_ERROR(
194-
cuFuncSetCacheConfig(function, CU_FUNC_CACHE_PREFER_SHARED));
201+
#ifdef JAX_GPU_CUDA
202+
GPU_RETURN_IF_ERROR(
203+
gpuFuncSetCacheConfig(function, CU_FUNC_CACHE_PREFER_SHARED));
204+
#endif
195205
int shared_total;
196-
CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute(
206+
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
197207
&shared_total,
198-
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
208+
GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
199209
int shared_static;
200-
CUDA_RETURN_IF_ERROR(cuFuncGetAttribute(
201-
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, function));
202-
CUDA_RETURN_IF_ERROR(cuFuncSetAttribute(
203-
function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
210+
GPU_RETURN_IF_ERROR(gpuFuncGetAttribute(
211+
&shared_static, GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, function));
212+
#ifdef JAX_GPU_CUDA
213+
GPU_RETURN_IF_ERROR(cuFuncSetAttribute(
214+
function, GPU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
204215
shared_optin - shared_static));
216+
#endif
205217
}
206218
return function;
207219
}
@@ -212,8 +224,8 @@ class ModuleImage {
212224
uint32_t shared_mem_bytes_;
213225

214226
absl::Mutex mutex_;
215-
std::vector<OwnedCUmodule> modules_ ABSL_GUARDED_BY(mutex_);
216-
absl::flat_hash_map<CUcontext, CUfunction> functions_ ABSL_GUARDED_BY(mutex_);
227+
std::vector<OwnedGPUmodule> modules_ ABSL_GUARDED_BY(mutex_);
228+
absl::flat_hash_map<gpuContext_t, gpuFunction_t> functions_ ABSL_GUARDED_BY(mutex_);
217229
};
218230

219231
Kernel::Kernel(std::string kernel_name, uint32_t num_warps,
@@ -226,18 +238,25 @@ Kernel::Kernel(std::string kernel_name, uint32_t num_warps,
226238
ttir_(std::move(ttir)),
227239
compute_capability_(compute_capability) {}
228240

229-
absl::Status Kernel::Launch(CUstream stream, uint32_t grid[3], void** params) {
241+
absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], void** params) {
230242
if (ABSL_PREDICT_FALSE(module_image_ == nullptr)) {
231243
JAX_ASSIGN_OR_RETURN(module_image_,
232244
GetModuleImage(kernel_name_, shared_mem_bytes_, ptx_,
233245
compute_capability_));
234246
}
235247

236-
CUcontext context;
237-
CUDA_RETURN_IF_ERROR(cuStreamGetCtx(stream, &context));
238-
JAX_ASSIGN_OR_RETURN(CUfunction kernel,
248+
gpuContext_t context;
249+
#ifdef JAX_GPU_HIP
250+
int device_id = gpuGetStreamDeviceId(stream);
251+
gpuDevice_t device;
252+
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
253+
GPU_RETURN_IF_ERROR(gpuDevicePrimaryCtxRetain(&context, device));
254+
#else //JAX_GPU_CUDA
255+
GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
256+
#endif
257+
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
239258
module_image_->GetFunctionForContext(context));
240-
return JAX_AS_STATUS(cuLaunchKernel(
259+
return JAX_AS_STATUS(gpuLaunchKernel(
241260
kernel, grid[0], grid[1], grid[2], block_dim_x_,
242261
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
243262
/*extra=*/nullptr));
@@ -329,26 +348,26 @@ KernelCall::KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1,
329348
grid_{grid_0, grid_1, grid_2},
330349
parameters_(std::move(parameters)) {}
331350

332-
absl::Status KernelCall::Launch(CUstream stream, void** buffers) {
351+
absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) {
333352
std::vector<void*> params;
334353
params.reserve(parameters_.size());
335354
for (size_t i = 0; i < parameters_.size(); ++i) {
336355
const Parameter& param = parameters_[i];
337356
if (std::holds_alternative<Parameter::Array>(param.value)) {
338357
const auto& array = std::get<Parameter::Array>(param.value);
339358
void*& ptr = *(buffers++);
340-
auto cu_ptr = reinterpret_cast<CUdeviceptr>(ptr);
359+
auto cu_ptr = reinterpret_cast<gpuDevicePtr_t>(ptr);
341360

342361
if (ABSL_PREDICT_FALSE((array.ptr_divisibility != 0) &&
343-
(cu_ptr % array.ptr_divisibility != 0))) {
362+
((size_t)cu_ptr % array.ptr_divisibility != 0))) {
344363
return absl::InvalidArgumentError(
345-
absl::StrFormat("Parameter %zu (%p) is not divisible by %d.", i,
346-
ptr, array.ptr_divisibility));
364+
absl::StrFormat("Parameter %zu (%zu) is not divisible by %d.", i,
365+
(size_t)ptr, array.ptr_divisibility));
347366
}
348367

349368
if (array.bytes_to_zero > 0) {
350-
CUDA_RETURN_IF_ERROR(
351-
cuMemsetD8Async(cu_ptr, 0, array.bytes_to_zero, stream));
369+
GPU_RETURN_IF_ERROR(
370+
gpuMemsetD8Async(cu_ptr, 0, array.bytes_to_zero, stream));
352371
}
353372
params.push_back(&ptr);
354373
} else {
@@ -433,12 +452,12 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
433452
}
434453

435454
/*static*/ absl::StatusOr<KernelCall> AutotunedKernelCall::Autotune(
436-
AutotunedKernelCall kernel_call, CUstream stream, void** buffers) {
455+
AutotunedKernelCall kernel_call, gpuStream_t stream, void** buffers) {
437456
// Ensure a valid context for driver calls that don't take the stream.
438-
CUcontext context;
439-
CUDA_RETURN_IF_ERROR(cuStreamGetCtx(stream, &context));
440-
CUDA_RETURN_IF_ERROR(cuCtxPushCurrent(context));
441-
absl::Cleanup ctx_restorer = [] { cuCtxPopCurrent(nullptr); };
457+
//gpuContext_t context;
458+
//GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
459+
//GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
460+
//absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
442461

443462
// If an input aliases with an output, it will get overwritten during the
444463
// kernel execution. If the kernel is called repeatedly, as we do during
@@ -448,8 +467,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
448467
for (auto [input_idx, output_idx, size] : kernel_call.input_output_aliases_) {
449468
if (buffers[input_idx] == buffers[output_idx]) {
450469
std::vector<uint8_t> input_copy(size);
451-
CUDA_RETURN_IF_ERROR(cuMemcpyDtoHAsync(
452-
input_copy.data(), reinterpret_cast<CUdeviceptr>(buffers[input_idx]),
470+
GPU_RETURN_IF_ERROR(gpuMemcpyDtoHAsync(
471+
input_copy.data(), reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
453472
size, stream));
454473
input_copies[input_idx] = std::move(input_copy);
455474
}
@@ -495,17 +514,17 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
495514

496515
// Restore aliased inputs to their original values.
497516
for (auto [input_idx, _, size] : kernel_call.input_output_aliases_) {
498-
CUDA_RETURN_IF_ERROR(
499-
cuMemcpyHtoDAsync(reinterpret_cast<CUdeviceptr>(buffers[input_idx]),
517+
GPU_RETURN_IF_ERROR(
518+
gpuMemcpyHtoDAsync(reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
500519
input_copies[input_idx].data(), size, stream));
501520
}
502521
// Synchronize stream to ensure copies are complete before the host copy
503522
// is deleted.
504-
CUDA_RETURN_IF_ERROR(cuStreamSynchronize(stream));
523+
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(stream));
505524
return std::move(kernel_call.configs_[0].kernel_call);
506525
}
507526

508-
void TritonKernelCall(CUstream stream, void** buffers, const char* opaque,
527+
void TritonKernelCall(gpuStream_t stream, void** buffers, const char* opaque,
509528
size_t opaque_len, XlaCustomCallStatus* status) {
510529
absl::Status result = [=] {
511530
JAX_ASSIGN_OR_RETURN(

jaxlib/gpu/triton_kernels.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
namespace jax::JAX_GPU_NAMESPACE {
1818

19-
void TritonKernelCall(CUstream stream, void** buffers, const char* opaque,
19+
void TritonKernelCall(gpuStream_t stream, void** buffers, const char* opaque,
2020
size_t opaque_len, XlaCustomCallStatus* status);
2121

2222
class ModuleImage;
@@ -26,7 +26,7 @@ class Kernel {
2626
Kernel(std::string kernel_name, uint32_t num_warps, uint32_t shared_mem_bytes,
2727
std::string ptx, std::string ttir, int compute_capability);
2828

29-
absl::Status Launch(CUstream stream, uint32_t grid[3], void** params);
29+
absl::Status Launch(gpuStream_t stream, uint32_t grid[3], void** params);
3030

3131
static Kernel FromProto(const jax_triton::TritonKernel& proto);
3232
jax_triton::TritonKernel ToProto() const;
@@ -62,7 +62,7 @@ class KernelCall {
6262
KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1, uint32_t grid_2,
6363
std::vector<Parameter> parameters);
6464

65-
absl::Status Launch(CUstream stream, void** buffers);
65+
absl::Status Launch(gpuStream_t stream, void** buffers);
6666

6767
static absl::StatusOr<KernelCall> FromProto(
6868
const jax_triton::TritonKernelCall& proto);
@@ -86,7 +86,7 @@ class AutotunedKernelCall {
8686
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases);
8787

8888
static absl::StatusOr<KernelCall> Autotune(AutotunedKernelCall kernel_call,
89-
CUstream stream, void** buffers);
89+
gpuStream_t stream, void** buffers);
9090

9191
static absl::StatusOr<AutotunedKernelCall> FromProto(
9292
const jax_triton::TritonAutotunedKernelCall& proto);

0 commit comments

Comments
 (0)