2525#include " jaxlib/gpu/vendor.h"
2626#include " xla/service/custom_call_status.h"
2727#include " xla/stream_executor/gpu/asm_compiler.h"
28+ #include " tsl/platform/env.h"
2829
29- #define CUDA_RETURN_IF_ERROR (expr ) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
30+ #define GPU_RETURN_IF_ERROR (expr ) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
3031
3132
3233namespace jax ::JAX_GPU_NAMESPACE {
3334namespace {
3435
35- constexpr uint32_t kNumThreadsPerWarp = 32 ;
3636constexpr float kBenchmarkTimeMillis = 10 .;
3737
38- struct CuModuleDeleter {
39- void operator ()(CUmodule module ) { cuModuleUnload (module ); }
38+ struct gpuModuleDeleter {
39+ void operator ()(gpuModule_t module ) { gpuModuleUnload (module ); }
4040};
4141
42- using OwnedCUmodule =
43- std::unique_ptr<std::remove_pointer_t <CUmodule >, CuModuleDeleter >;
42+ using OwnedGPUmodule =
43+ std::unique_ptr<std::remove_pointer_t <gpuModule_t >, gpuModuleDeleter >;
4444
4545absl::StatusOr<ModuleImage*> GetModuleImage (std::string kernel_name,
4646 uint32_t shared_mem_bytes,
@@ -58,13 +58,21 @@ absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
5858 auto it = module_images.find (key);
5959 if (it != module_images.end ()) return it->second .get ();
6060
61+ #ifdef JAX_GPU_HIP // For HIP/ROCM just read the hsaco file
62+ std::string result_blob;
63+ std::string fname{ptx};
64+ TF_RETURN_IF_ERROR (
65+ tsl::ReadFileToString (tsl::Env::Default (), fname, &result_blob));
66+ std::vector<uint8_t > module_image (result_blob.begin (), result_blob.end ());
67+ #else
6168 // TODO(cjfj): Support `TRITON_PTXAS_PATH` environment variable?
6269 int cc_major = compute_capability / 10 ;
6370 int cc_minor = compute_capability % 10 ;
6471 JAX_ASSIGN_OR_RETURN (
6572 std::vector<uint8_t > module_image,
6673 stream_executor::CompileGpuAsm (cc_major, cc_minor, ptx.data (),
6774 stream_executor::GpuAsmOpts{}));
75+ #endif
6876
6977 auto [it2, success] = module_images.insert (
7078 {std::move (key),
@@ -74,27 +82,27 @@ absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
7482 return it2->second .get ();
7583}
7684
77- absl::StatusOr<float > Benchmark (CUstream stream, KernelCall& kernel_call,
85+ absl::StatusOr<float > Benchmark (gpuStream_t stream, KernelCall& kernel_call,
7886 void ** buffers, int num_iterations) {
79- CUevent start, stop;
80- CUDA_RETURN_IF_ERROR ( cuEventCreate (&start, /* Flags=*/ CU_EVENT_DEFAULT ));
81- CUDA_RETURN_IF_ERROR ( cuEventCreate (&stop, /* Flags=*/ CU_EVENT_DEFAULT ));
87+ gpuEvent_t start, stop;
88+ GPU_RETURN_IF_ERROR ( gpuEventCreate (&start, /* Flags=*/ GPU_EVENT_DEFAULT ));
89+ GPU_RETURN_IF_ERROR ( gpuEventCreate (&stop, /* Flags=*/ GPU_EVENT_DEFAULT ));
8290 JAX_RETURN_IF_ERROR (kernel_call.Launch (stream, buffers)); // Warm-up.
83- CUDA_RETURN_IF_ERROR ( cuEventRecord (start, stream));
91+ GPU_RETURN_IF_ERROR ( gpuEventRecord (start, stream));
8492 for (int i = 0 ; i < num_iterations; ++i) {
8593 JAX_RETURN_IF_ERROR (kernel_call.Launch (stream, buffers));
8694 }
87- CUDA_RETURN_IF_ERROR ( cuEventRecord (stop, stream));
88- CUDA_RETURN_IF_ERROR ( cuEventSynchronize (stop));
95+ GPU_RETURN_IF_ERROR ( gpuEventRecord (stop, stream));
96+ GPU_RETURN_IF_ERROR ( gpuEventSynchronize (stop));
8997 float elapsed_ms;
90- CUDA_RETURN_IF_ERROR ( cuEventElapsedTime (&elapsed_ms, start, stop));
91- CUDA_RETURN_IF_ERROR ( cuEventDestroy (start));
92- CUDA_RETURN_IF_ERROR ( cuEventDestroy (stop));
98+ GPU_RETURN_IF_ERROR ( gpuEventElapsedTime (&elapsed_ms, start, stop));
99+ GPU_RETURN_IF_ERROR ( gpuEventDestroy (start));
100+ GPU_RETURN_IF_ERROR ( gpuEventDestroy (stop));
93101 return elapsed_ms;
94102}
95103
96104absl::StatusOr<KernelCall*> GetKernelCall (absl::string_view opaque,
97- CUstream stream, void ** buffers) {
105+ gpuStream_t stream, void ** buffers) {
98106 static absl::Mutex mutex;
99107 static auto & kernel_calls =
100108 *new absl::flat_hash_map<std::string, std::unique_ptr<KernelCall>>
@@ -147,23 +155,23 @@ class ModuleImage {
147155 module_image_ (std::move(module_image)),
148156 shared_mem_bytes_(shared_mem_bytes) {}
149157
150- absl::StatusOr<CUfunction > GetFunctionForContext (CUcontext context) {
158+ absl::StatusOr<gpuFunction_t > GetFunctionForContext (gpuContext_t context) {
151159 absl::MutexLock lock (&mutex_);
152160 auto it = functions_.find (context);
153161 if (ABSL_PREDICT_TRUE (it != functions_.end ())) {
154162 return it->second ;
155163 }
156164
157- CUDA_RETURN_IF_ERROR ( cuCtxPushCurrent (context));
158- absl::Cleanup ctx_restorer = [] { cuCtxPopCurrent (nullptr ); };
165+ GPU_RETURN_IF_ERROR ( gpuCtxPushCurrent (context));
166+ absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent (nullptr ); };
159167
160- CUmodule module ;
161- CUDA_RETURN_IF_ERROR ( cuModuleLoadData (&module , module_image_.data ()));
162- modules_.push_back (OwnedCUmodule (module , CuModuleDeleter ()));
168+ gpuModule_t module ;
169+ GPU_RETURN_IF_ERROR ( gpuModuleLoadData (&module , module_image_.data ()));
170+ modules_.push_back (OwnedGPUmodule (module , gpuModuleDeleter ()));
163171
164- CUfunction function;
165- CUDA_RETURN_IF_ERROR (
166- cuModuleGetFunction (&function, module , kernel_name_.c_str ()));
172+ gpuFunction_t function;
173+ GPU_RETURN_IF_ERROR (
174+ gpuModuleGetFunction (&function, module , kernel_name_.c_str ()));
167175 auto [_, success] = functions_.insert ({context, function});
168176 CHECK (success);
169177
@@ -175,12 +183,12 @@ class ModuleImage {
175183 }
176184
177185 // Set up dynamic shared memory.
178- CUdevice device;
179- CUDA_RETURN_IF_ERROR ( cuCtxGetDevice (&device));
186+ gpuDevice_t device;
187+ GPU_RETURN_IF_ERROR ( gpuCtxGetDevice (&device));
180188
181189 int shared_optin;
182- CUDA_RETURN_IF_ERROR ( cuDeviceGetAttribute (
183- &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN ,
190+ GPU_RETURN_IF_ERROR ( gpuDeviceGetAttribute (
191+ &shared_optin, GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN ,
184192 device));
185193
186194 if (shared_mem_bytes_ > shared_optin) {
@@ -190,18 +198,22 @@ class ModuleImage {
190198 }
191199
192200 if (shared_optin > kMaxStaticSharedMemBytes ) {
193- CUDA_RETURN_IF_ERROR (
194- cuFuncSetCacheConfig (function, CU_FUNC_CACHE_PREFER_SHARED));
201+ #ifdef JAX_GPU_CUDA
202+ GPU_RETURN_IF_ERROR (
203+ gpuFuncSetCacheConfig (function, CU_FUNC_CACHE_PREFER_SHARED));
204+ #endif
195205 int shared_total;
196- CUDA_RETURN_IF_ERROR ( cuDeviceGetAttribute (
206+ GPU_RETURN_IF_ERROR ( gpuDeviceGetAttribute (
197207 &shared_total,
198- CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR , device));
208+ GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR , device));
199209 int shared_static;
200- CUDA_RETURN_IF_ERROR (cuFuncGetAttribute (
201- &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, function));
202- CUDA_RETURN_IF_ERROR (cuFuncSetAttribute (
203- function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
210+ GPU_RETURN_IF_ERROR (gpuFuncGetAttribute (
211+ &shared_static, GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, function));
212+ #ifdef JAX_GPU_CUDA
213+ GPU_RETURN_IF_ERROR (cuFuncSetAttribute (
214+ function, GPU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
204215 shared_optin - shared_static));
216+ #endif
205217 }
206218 return function;
207219 }
@@ -212,8 +224,8 @@ class ModuleImage {
212224 uint32_t shared_mem_bytes_;
213225
214226 absl::Mutex mutex_;
215- std::vector<OwnedCUmodule > modules_ ABSL_GUARDED_BY (mutex_);
216- absl::flat_hash_map<CUcontext, CUfunction > functions_ ABSL_GUARDED_BY (mutex_);
227+ std::vector<OwnedGPUmodule > modules_ ABSL_GUARDED_BY (mutex_);
228+ absl::flat_hash_map<gpuContext_t, gpuFunction_t > functions_ ABSL_GUARDED_BY (mutex_);
217229};
218230
219231Kernel::Kernel (std::string kernel_name, uint32_t num_warps,
@@ -226,18 +238,25 @@ Kernel::Kernel(std::string kernel_name, uint32_t num_warps,
226238 ttir_(std::move(ttir)),
227239 compute_capability_(compute_capability) {}
228240
229- absl::Status Kernel::Launch (CUstream stream, uint32_t grid[3 ], void ** params) {
241+ absl::Status Kernel::Launch (gpuStream_t stream, uint32_t grid[3 ], void ** params) {
230242 if (ABSL_PREDICT_FALSE (module_image_ == nullptr )) {
231243 JAX_ASSIGN_OR_RETURN (module_image_,
232244 GetModuleImage (kernel_name_, shared_mem_bytes_, ptx_,
233245 compute_capability_));
234246 }
235247
236- CUcontext context;
237- CUDA_RETURN_IF_ERROR (cuStreamGetCtx (stream, &context));
238- JAX_ASSIGN_OR_RETURN (CUfunction kernel,
248+ gpuContext_t context;
249+ #ifdef JAX_GPU_HIP
250+ int device_id = gpuGetStreamDeviceId (stream);
251+ gpuDevice_t device;
252+ GPU_RETURN_IF_ERROR (gpuDeviceGet (&device, device_id));
253+ GPU_RETURN_IF_ERROR (gpuDevicePrimaryCtxRetain (&context, device));
254+ #else // JAX_GPU_CUDA
255+ GPU_RETURN_IF_ERROR (gpuStreamGetCtx (stream, &context));
256+ #endif
257+ JAX_ASSIGN_OR_RETURN (gpuFunction_t kernel,
239258 module_image_->GetFunctionForContext (context));
240- return JAX_AS_STATUS (cuLaunchKernel (
259+ return JAX_AS_STATUS (gpuLaunchKernel (
241260 kernel, grid[0 ], grid[1 ], grid[2 ], block_dim_x_,
242261 /* blockDimY=*/ 1 , /* blockDimZ=*/ 1 , shared_mem_bytes_, stream, params,
243262 /* extra=*/ nullptr ));
@@ -329,26 +348,26 @@ KernelCall::KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1,
329348 grid_{grid_0, grid_1, grid_2},
330349 parameters_ (std::move(parameters)) {}
331350
332- absl::Status KernelCall::Launch (CUstream stream, void ** buffers) {
351+ absl::Status KernelCall::Launch (gpuStream_t stream, void ** buffers) {
333352 std::vector<void *> params;
334353 params.reserve (parameters_.size ());
335354 for (size_t i = 0 ; i < parameters_.size (); ++i) {
336355 const Parameter& param = parameters_[i];
337356 if (std::holds_alternative<Parameter::Array>(param.value )) {
338357 const auto & array = std::get<Parameter::Array>(param.value );
339358 void *& ptr = *(buffers++);
340- auto cu_ptr = reinterpret_cast <CUdeviceptr >(ptr);
359+ auto cu_ptr = reinterpret_cast <gpuDevicePtr_t >(ptr);
341360
342361 if (ABSL_PREDICT_FALSE ((array.ptr_divisibility != 0 ) &&
343- (cu_ptr % array.ptr_divisibility != 0 ))) {
362+ (( size_t ) cu_ptr % array.ptr_divisibility != 0 ))) {
344363 return absl::InvalidArgumentError (
345- absl::StrFormat (" Parameter %zu (%p ) is not divisible by %d." , i,
346- ptr, array.ptr_divisibility ));
364+ absl::StrFormat (" Parameter %zu (%zu ) is not divisible by %d." , i,
365+ ( size_t ) ptr, array.ptr_divisibility ));
347366 }
348367
349368 if (array.bytes_to_zero > 0 ) {
350- CUDA_RETURN_IF_ERROR (
351- cuMemsetD8Async (cu_ptr, 0 , array.bytes_to_zero , stream));
369+ GPU_RETURN_IF_ERROR (
370+ gpuMemsetD8Async (cu_ptr, 0 , array.bytes_to_zero , stream));
352371 }
353372 params.push_back (&ptr);
354373 } else {
@@ -433,12 +452,12 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
433452}
434453
435454/* static*/ absl::StatusOr<KernelCall> AutotunedKernelCall::Autotune (
436- AutotunedKernelCall kernel_call, CUstream stream, void ** buffers) {
455+ AutotunedKernelCall kernel_call, gpuStream_t stream, void ** buffers) {
437456 // Ensure a valid context for driver calls that don't take the stream.
438- CUcontext context;
439- CUDA_RETURN_IF_ERROR ( cuStreamGetCtx (stream, &context));
440- CUDA_RETURN_IF_ERROR ( cuCtxPushCurrent (context));
441- absl::Cleanup ctx_restorer = [] { cuCtxPopCurrent (nullptr ); };
457+ // gpuContext_t context;
458+ // GPU_RETURN_IF_ERROR(gpuStreamGetCtx (stream, &context));
459+ // GPU_RETURN_IF_ERROR(gpuCtxPushCurrent (context));
460+ // absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent (nullptr); };
442461
443462 // If an input aliases with an output, it will get overwritten during the
444463 // kernel execution. If the kernel is called repeatedly, as we do during
@@ -448,8 +467,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
448467 for (auto [input_idx, output_idx, size] : kernel_call.input_output_aliases_ ) {
449468 if (buffers[input_idx] == buffers[output_idx]) {
450469 std::vector<uint8_t > input_copy (size);
451- CUDA_RETURN_IF_ERROR ( cuMemcpyDtoHAsync (
452- input_copy.data (), reinterpret_cast <CUdeviceptr >(buffers[input_idx]),
470+ GPU_RETURN_IF_ERROR ( gpuMemcpyDtoHAsync (
471+ input_copy.data (), reinterpret_cast <gpuDevicePtr_t >(buffers[input_idx]),
453472 size, stream));
454473 input_copies[input_idx] = std::move (input_copy);
455474 }
@@ -495,17 +514,17 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
495514
496515 // Restore aliased inputs to their original values.
497516 for (auto [input_idx, _, size] : kernel_call.input_output_aliases_ ) {
498- CUDA_RETURN_IF_ERROR (
499- cuMemcpyHtoDAsync (reinterpret_cast <CUdeviceptr >(buffers[input_idx]),
517+ GPU_RETURN_IF_ERROR (
518+ gpuMemcpyHtoDAsync (reinterpret_cast <gpuDevicePtr_t >(buffers[input_idx]),
500519 input_copies[input_idx].data (), size, stream));
501520 }
502521 // Synchronize stream to ensure copies are complete before the host copy
503522 // is deleted.
504- CUDA_RETURN_IF_ERROR ( cuStreamSynchronize (stream));
523+ GPU_RETURN_IF_ERROR ( gpuStreamSynchronize (stream));
505524 return std::move (kernel_call.configs_ [0 ].kernel_call );
506525}
507526
508- void TritonKernelCall (CUstream stream, void ** buffers, const char * opaque,
527+ void TritonKernelCall (gpuStream_t stream, void ** buffers, const char * opaque,
509528 size_t opaque_len, XlaCustomCallStatus* status) {
510529 absl::Status result = [=] {
511530 JAX_ASSIGN_OR_RETURN (
0 commit comments