Skip to content

Commit dddfa05

Browse files
committed
Fix missing defaults for GPU target arch'es (#293)
Move defaults for these up to ci_build python We also add some error catching logic in the python script to clean up weirdness coming from the build scripts. Also clean up some formatting and typos (cherry picked from commit ef1d561) (cherry picked from commit ba2e112)
1 parent 0074fed commit dddfa05

File tree

4 files changed

+144
-6
lines changed

4 files changed

+144
-6
lines changed

build/rocm/ci_build

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ import argparse
2424
import os
2525
import subprocess
2626
import sys
27+
from typing import List
28+
29+
30+
DEFAULT_GPU_DEVICE_TARGETS = "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"
2731

2832

2933
def image_by_name(name):
@@ -40,7 +44,11 @@ def dist_wheels(
4044
rocm_build_job="",
4145
rocm_build_num="",
4246
compiler="gcc",
47+
gpu_device_targets : List[str] = None,
4348
):
49+
if not gpu_device_targets:
50+
gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",")
51+
4452
if xla_path:
4553
xla_path = os.path.abspath(xla_path)
4654

@@ -63,6 +71,7 @@ def dist_wheels(
6371
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
6472
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
6573
"--tag=%s" % image,
74+
"--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets),
6675
".",
6776
]
6877

@@ -85,6 +94,8 @@ def dist_wheels(
8594
pyver_string,
8695
"--compiler",
8796
compiler,
97+
"--gpu-device-targets",
98+
",".join(gpu_device_targets),
8899
]
89100

90101
if xla_path:
@@ -121,6 +132,7 @@ def dist_wheels(
121132
]
122133
)
123134

135+
LOG.info("Running: %s", cmd)
124136
_ = subprocess.run(cmd, check=True)
125137

126138

@@ -158,10 +170,14 @@ def dist_docker(
158170
tag="rocm/jax-dev",
159171
dockerfile=None,
160172
keep_image=True,
173+
gpu_device_targets : List[str] = None,
161174
):
162175
if not dockerfile:
163176
dockerfile = "build/rocm/Dockerfile.ms"
164177

178+
if not gpu_device_targets:
179+
gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",")
180+
165181
python_version = python_versions[0]
166182

167183
md = _fetch_jax_metadata(xla_path)
@@ -174,6 +190,7 @@ def dist_docker(
174190
"--target",
175191
"rt_build",
176192
"--build-arg=ROCM_VERSION=%s" % rocm_version,
193+
"--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets),
177194
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
178195
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
179196
"--build-arg=BASE_DOCKER=%s" % base_docker,
@@ -238,6 +255,37 @@ def test(image_name):
238255
subprocess.check_call(cmd)
239256

240257

258+
def parse_gpu_targets(targets_string):
259+
# catch case where targets_string was empty.
260+
# None should already be caught by argparse, but
261+
# it doesn't hurt to check twice
262+
if not targets_string:
263+
targets_string = DEFAULT_GPU_DEVICE_TARGETS
264+
265+
if "," in targets_string:
266+
targets = targets_string.split(",")
267+
elif " " in targets_string:
268+
targets = targets_string.split(" ")
269+
else:
270+
targets = targets_string
271+
272+
res = []
273+
# cleanup and validation
274+
for t in targets:
275+
if not t:
276+
continue
277+
278+
if not t.startswith("gfx"):
279+
raise ValueError("Invalid GPU architecture target: %r" % t)
280+
281+
res.append(t.strip())
282+
283+
if not res:
284+
raise ValueError("GPU_DEVICE_TARGETS cannot be empty")
285+
286+
return res
287+
288+
241289
def parse_args():
242290
p = argparse.ArgumentParser()
243291
p.add_argument(
@@ -249,7 +297,7 @@ def parse_args():
249297
p.add_argument(
250298
"--python-versions",
251299
type=lambda x: x.split(","),
252-
default="3.12",
300+
default=["3.12"],
253301
help="Comma separated list of CPython versions to build wheels for",
254302
)
255303

@@ -281,6 +329,11 @@ def parse_args():
281329
choices=["gcc", "clang"],
282330
help="Compiler backend to use when compiling jax/jaxlib",
283331
)
332+
p.add_argument(
333+
"--gpu-device-targets",
334+
default=DEFAULT_GPU_DEVICE_TARGETS,
335+
help="List of AMDGPU device targets passed from job",
336+
)
284337

285338
subp = p.add_subparsers(dest="action", required=True)
286339

@@ -299,6 +352,7 @@ def parse_args():
299352

300353
def main():
301354
args = parse_args()
355+
gpu_device_targets = parse_gpu_targets(args.gpu_device_targets)
302356

303357
if args.action == "dist_wheels":
304358
dist_wheels(
@@ -308,6 +362,7 @@ def main():
308362
args.rocm_build_job,
309363
args.rocm_build_num,
310364
compiler=args.compiler,
365+
gpu_device_targets=gpu_device_targets,
311366
)
312367

313368
elif args.action == "test":
@@ -321,6 +376,7 @@ def main():
321376
args.rocm_build_job,
322377
args.rocm_build_num,
323378
compiler=args.compiler,
379+
gpu_device_targets=gpu_device_targets,
324380
)
325381
dist_docker(
326382
args.rocm_version,
@@ -332,6 +388,7 @@ def main():
332388
tag=args.image_tag,
333389
dockerfile=args.dockerfile,
334390
keep_image=args.keep_image,
391+
gpu_device_targets=gpu_device_targets,
335392
)
336393

337394

build/rocm/ci_build.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ ROCM_BUILD_NUM=""
5151
BASE_DOCKER="ubuntu:22.04"
5252
CUSTOM_INSTALL=""
5353
JAX_USE_CLANG=""
54+
GPU_DEVICE_TARGETS=""
5455
POSITIONAL_ARGS=()
5556

5657
RUNTIME_FLAG=0
@@ -98,6 +99,18 @@ while [[ $# -gt 0 ]]; do
9899
JAX_USE_CLANG="$2"
99100
shift 2
100101
;;
102+
--gpu_device_targets)
103+
if [[ "$2" == "--custom_install" ]]; then
104+
GPU_DEVICE_TARGETS=""
105+
shift 2
106+
elif [[ -n "$2" ]]; then
107+
GPU_DEVICE_TARGETS="$2"
108+
shift 2
109+
else
110+
GPU_DEVICE_TARGETS=""
111+
shift 1
112+
fi
113+
;;
101114
*)
102115
POSITIONAL_ARGS+=("$1")
103116
shift
@@ -164,6 +177,7 @@ fi
164177
--rocm-build-job=$ROCM_BUILD_JOB \
165178
--rocm-build-num=$ROCM_BUILD_NUM \
166179
--compiler=$JAX_COMPILER \
180+
--gpu-device-targets="${GPU_DEVICE_TARGETS}" \
167181
dist_docker \
168182
--dockerfile $DOCKERFILE_PATH \
169183
--image-tag $DOCKER_IMG_NAME

build/rocm/test_ci_build.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2024 The JAX Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import unittest
18+
19+
import importlib.util
20+
import importlib.machinery
21+
22+
23+
def load_ci_build():
24+
spec = importlib.util.spec_from_loader(
25+
"ci_build", importlib.machinery.SourceFileLoader("ci_build", "./ci_build")
26+
)
27+
mod = importlib.util.module_from_spec(spec)
28+
spec.loader.exec_module(mod)
29+
return mod
30+
31+
32+
ci_build = load_ci_build()
33+
34+
35+
class CIBuildTestCase(unittest.TestCase):
36+
def test_parse_gpu_targets(self):
37+
targets = ["gfx908", "gfx940", "gfx1201"]
38+
39+
r = ci_build.parse_gpu_targets(" ".join(targets))
40+
self.assertEqual(r, targets)
41+
42+
r = ci_build.parse_gpu_targets(",".join(targets))
43+
self.assertEqual(r, targets)
44+
45+
def test_parse_gpu_targets_empty_string(self):
46+
expected = ci_build.DEFAULT_GPU_DEVICE_TARGETS.split(",")
47+
r = ci_build.parse_gpu_targets("")
48+
self.assertEqual(r, expected)
49+
50+
self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ")
51+
52+
def test_parse_gpu_targets_invalid_arch(self):
53+
targets = ["gfx908", "gfx940", "--oops", "/jax"]
54+
self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ".join(targets))
55+
56+
57+
if __name__ == "__main__":
58+
unittest.main()

build/rocm/tools/build_wheels.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
import subprocess
3131
import shutil
3232
import sys
33+
from typing import List
3334

3435

3536
LOG = logging.getLogger(__name__)
3637

3738

38-
GPU_DEVICE_TARGETS = "gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
39+
DEFAULT_GPU_DEVICE_TARGETS = (
40+
"gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"
41+
)
3942

4043

4144
def build_rocm_path(rocm_version_str):
@@ -46,11 +49,11 @@ def build_rocm_path(rocm_version_str):
4649
return os.path.realpath("/opt/rocm")
4750

4851

49-
def update_rocm_targets(rocm_path, targets):
52+
def update_rocm_targets(rocm_path: str, targets: List[str]):
5053
target_fp = os.path.join(rocm_path, "bin/target.lst")
5154
version_fp = os.path.join(rocm_path, ".info/version")
5255
with open(target_fp, "w") as fd:
53-
fd.write("%s\n" % targets)
56+
fd.write("%s\n" % " ".join(targets))
5457

5558
# mimic touch
5659
open(version_fp, "a").close()
@@ -250,7 +253,7 @@ def parse_args():
250253
)
251254
p.add_argument(
252255
"--python-versions",
253-
default=["3.10.19,3.12"],
256+
default="3.10.19,3.12",
254257
help="Comma separated CPython versions that wheels will be built and output for",
255258
)
256259
p.add_argument(
@@ -265,6 +268,11 @@ def parse_args():
265268
default="gcc",
266269
help="Compiler backend to use when compiling jax/jaxlib",
267270
)
271+
p.add_argument(
272+
"--gpu-device-targets",
273+
default=DEFAULT_GPU_DEVICE_TARGETS,
274+
help="Comma separated list of GPU device targets passed from job",
275+
)
268276

269277
p.add_argument("jax_path", help="Directory where JAX source directory is located")
270278

@@ -285,6 +293,7 @@ def find_wheels(path):
285293
def main():
286294
args = parse_args()
287295
python_versions = args.python_versions.split(",")
296+
gpu_device_targets = args.gpu_device_targets.split(",")
288297

289298
print("ROCM_VERSION=%s" % args.rocm_version)
290299
print("PYTHON_VERSIONS=%r" % python_versions)
@@ -294,7 +303,7 @@ def main():
294303

295304
rocm_path = build_rocm_path(args.rocm_version)
296305

297-
update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS)
306+
update_rocm_targets(rocm_path, gpu_device_targets)
298307

299308
for py in python_versions:
300309
build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler)

0 commit comments

Comments
 (0)