Skip to content

Commit 4d494f1

Browse files
authored
Upgrade TPU image to Python 3.11.Tpupy311 (#1493)
Python 3.10 is entering its [last year of support](https://devguide.python.org/versions/) before end-of-life and many packages, including [NumPy](https://devguide.python.org/versions/), have dropped support for it altogether. Included in this change: * Upgrade the TPU docker image to derive from `python:3.11`. * Upgrade `tensorflow` to 2.20.0. * Upgrade `jax` to >=0.5.2. For a compatible dep closure, this installs `jax` 0.7.2 re: `tensorflow-tpu` dep on`libtpu`. * Upgrade `torch` (and ecosystem) to 2.8.0. Of note, there is no wheel with a `+libtpu` label. * Remove unneeded environment variable. Tested: Locally, by invoking `./tpu/build`: > <img width="1116" height="297" alt="9jyLVT6hAPjKaCh" src="https://github.com/user-attachments/assets/e1c8e37e-5e65-4f43-807f-4ffca8d6b6ac" /> Also invoked other back-end testing.
1 parent 0eb38ee commit 4d494f1

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

tpu/Dockerfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tens
5555
RUN sed -i '/_install_default_resolvers()/a \ \ registry.resolver.add_implementation(kaggle_module_resolver.KaggleFileResolver())' /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/config.py
5656

5757
# Set these env vars so that they don't produce errs calling the metadata server to load them:
58-
ENV TPU_ACCELERATOR_TYPE=v3-8
5958
ENV TPU_PROCESS_ADDRESSES=local
6059

6160
# Metadata

tpu/config.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
BASE_IMAGE=python:3.10
2-
PYTHON_WHEEL_VERSION=cp310
3-
PYTHON_VERSION_PATH=python3.10
4-
TENSORFLOW_VERSION=2.18.0
1+
BASE_IMAGE=python:3.11
2+
PYTHON_WHEEL_VERSION=cp311
3+
PYTHON_VERSION_PATH=python3.11
4+
TENSORFLOW_VERSION=2.20.0
55
# gsutil ls gs://pytorch-xla-releases/wheels/tpuvm/* | grep libtpu | grep torch_xla | grep -v -E ".*rc[0-9].*" | sed 's/.*torch_xla-\(.*\)+libtpu.*/\1/' | sort -rV
66
# Supports nightly
7-
TORCH_VERSION=2.6.0
7+
TORCH_VERSION=2.8.0
88
# https://github.com/pytorch/audio supports nightly
9-
TORCHAUDIO_VERSION=2.6.0
9+
TORCHAUDIO_VERSION=2.8.0
1010
# https://github.com/pytorch/vision supports nightly
11-
TORCHVISION_VERSION=0.21.0
11+
TORCHVISION_VERSION=0.23.0
1212
TORCH_LINUX_WHEEL_VERSION=manylinux_2_28_x86_64

tpu/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ tensorflow-io
88
tensorflow-probability
99
# Torch packages
1010
torch==${TORCH_VERSION}
11-
https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}+libtpu-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl
11+
https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl
1212
torchaudio==${TORCHAUDIO_VERSION}
1313
torchvision==${TORCHVISION_VERSION}
1414
# Jax packages
15-
jax[tpu]>=0.4.34
15+
jax[tpu]>=0.5.2
1616
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
1717
distrax
1818
flax

0 commit comments

Comments
 (0)