Skip to content

Commit 0898ca4

Browse files
committed
Switch to newer libtpu/tunix and cpu-based torch for tpu.
http://b/436838265
1 parent 3e031ba commit 0898ca4

File tree

2 files changed

+3
-13
lines changed

2 files changed

+3
-13
lines changed

tpu/Dockerfile

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,6 @@ RUN export PATH="${HOME}/.local/bin:${PATH}" && \
5858
/tmp/clean-layer.sh
5959
ENV PATH="~/.local/bin:${PATH}"
6060

61-
# We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0.
62-
# Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23. However, that
63-
# libtpu causes pjrt api errors for torch 2.8.0. screenshot/5heUtdyaJ4MmR3D
64-
# https://github.com/pytorch/xla/blob/d517649bdef6ab0519c30c704bde8779c8216502/setup.py#L111
65-
# https://github.com/jax-ml/jax/blob/3489529b38d1f11d1e5caf4540775aadd5f2cdda/setup.py#L26
66-
RUN export PATH="${HOME}/.local/bin:${PATH}" && \
67-
uv pip install --system --force-reinstall libtpu==0.0.17 && \
68-
uv cache clean && \
69-
/tmp/clean-layer.sh
70-
7161
# Kaggle Model Hub patches:
7262
ADD patches/kaggle_module_resolver.py /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/kaggle_module_resolver.py
7363
RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/config.py

tpu/requirements.in

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ tensorflow-io
88
tensorflow-probability
99
tensorflow_datasets
1010
# Torch packages
11-
torch==${TORCH_VERSION}
11+
https://download.pytorch.org/whl/cpu/torch-${TORCH_VERSION}%2Bcpu-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl
12+
https://download.pytorch.org/whl/cpu/torchaudio-${TORCHAUDIO_VERSION}%2Bcpu-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl
13+
https://download.pytorch.org/whl/cpu/torchvision-${TORCHVISION_VERSION}%2Bcpu-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl
1214
https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl
13-
torchaudio==${TORCHAUDIO_VERSION}
14-
torchvision==${TORCHVISION_VERSION}
1515
# Jax packages
1616
jax[tpu]
1717
distrax

0 commit comments

Comments
 (0)