Skip to content

Commit 7d829bd

Browse files
authored
Merge pull request #901 from Kaggle/upgrade-jax
Upgrade JAX to 0.2.6
2 parents 84a5c03 + 79dffe5 commit 7d829bd

File tree

1 file changed

+1
-10
lines changed

1 file changed

+1
-10
lines changed

gpu.Dockerfile

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,7 @@ RUN pip uninstall -y lightgbm && \
8080
/tmp/clean-layer.sh
8181

8282
# Install JAX
83-
# b/154150582#comment9: JAX 0.1.63 with jaxlib 0.1.43 is causing the GPU tests to hang.
84-
ENV JAX_VERSION=0.1.62
85-
ENV JAXLIB_VERSION=0.1.41
86-
ENV JAX_PYTHON_VERSION=cp37
87-
ENV JAX_CUDA_VERSION=cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION
88-
ENV JAX_PLATFORM=linux_x86_64
89-
ENV JAX_BASE_URL="https://storage.googleapis.com/jax-releases"
90-
91-
RUN pip install $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-$JAXLIB_VERSION-$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
92-
pip install jax==$JAX_VERSION && \
83+
RUN pip install jax==0.2.6 jaxlib==0.1.57+cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
9384
/tmp/clean-layer.sh
9485

9586
# Reinstall packages with a separate version for GPU support.

0 commit comments

Comments
 (0)