Skip to content

Commit 221ec49

Browse files
committed
Re-add the libtpu pin to make torch and jax work together again...
http://b/436838265
1 parent 0898ca4 commit 221ec49

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

tpu/Dockerfile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ RUN export PATH="${HOME}/.local/bin:${PATH}" && \
5858
/tmp/clean-layer.sh
5959
ENV PATH="~/.local/bin:${PATH}"
6060

61+
# We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0.
62+
# Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23. However, that
63+
# libtpu causes pjrt api errors for torch 2.8.0. screenshot/5heUtdyaJ4MmR3D
64+
# https://github.com/pytorch/xla/blob/d517649bdef6ab0519c30c704bde8779c8216502/setup.py#L111
65+
# https://github.com/jax-ml/jax/blob/3489529b38d1f11d1e5caf4540775aadd5f2cdda/setup.py#L26
66+
RUN export PATH="${HOME}/.local/bin:${PATH}" && \
67+
uv pip install --system --force-reinstall libtpu==0.0.17 && \
68+
uv cache clean && \
69+
/tmp/clean-layer.sh
70+
6171
# Kaggle Model Hub patches:
6272
ADD patches/kaggle_module_resolver.py /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/kaggle_module_resolver.py
6373
RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/config.py

0 commit comments

Comments
 (0)