File tree Expand file tree Collapse file tree 2 files changed +3
-13
lines changed
Expand file tree Collapse file tree 2 files changed +3
-13
lines changed Original file line number Diff line number Diff line change @@ -58,16 +58,6 @@ RUN export PATH="${HOME}/.local/bin:${PATH}" && \
5858 /tmp/clean-layer.sh
5959ENV PATH="~/.local/bin:${PATH}"
6060
61- # We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0.
62- # Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23. However, that
63- # libtpu causes pjrt api errors for torch 2.8.0. screenshot/5heUtdyaJ4MmR3D
64- # https://github.com/pytorch/xla/blob/d517649bdef6ab0519c30c704bde8779c8216502/setup.py#L111
65- # https://github.com/jax-ml/jax/blob/3489529b38d1f11d1e5caf4540775aadd5f2cdda/setup.py#L26
66- RUN export PATH="${HOME}/.local/bin:${PATH}" && \
67- uv pip install --system --force-reinstall libtpu==0.0.17 && \
68- uv cache clean && \
69- /tmp/clean-layer.sh
70-
7161# Kaggle Model Hub patches:
7262ADD patches/kaggle_module_resolver.py /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/kaggle_module_resolver.py
7363RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/config.py
Original file line number Diff line number Diff line change @@ -8,10 +8,10 @@ tensorflow-io
88tensorflow-probability
99tensorflow_datasets
1010# Torch packages
11- torch==${TORCH_VERSION}
11+ https://download.pytorch.org/whl/cpu/torch-${TORCH_VERSION}%2Bcpu-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl
12+ https://download.pytorch.org/whl/cpu/torchaudio-${TORCHAUDIO_VERSION}%2Bcpu-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl
13+ https://download.pytorch.org/whl/cpu/torchvision-${TORCHVISION_VERSION}%2Bcpu-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl
1214https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl
13- torchaudio==${TORCHAUDIO_VERSION}
14- torchvision==${TORCHVISION_VERSION}
1515# Jax packages
1616jax[tpu]
1717distrax
You can’t perform that action at this time.
0 commit comments