File tree Expand file tree Collapse file tree 1 file changed +1
-10
lines changed
Expand file tree Collapse file tree 1 file changed +1
-10
lines changed Original file line number Diff line number Diff line change @@ -80,16 +80,7 @@ RUN pip uninstall -y lightgbm && \
8080 /tmp/clean-layer.sh
8181
8282# Install JAX
83- # b/154150582#comment9: JAX 0.1.63 with jaxlib 0.1.43 is causing the GPU tests to hang.
84- ENV JAX_VERSION=0.1.62
85- ENV JAXLIB_VERSION=0.1.41
86- ENV JAX_PYTHON_VERSION=cp37
87- ENV JAX_CUDA_VERSION=cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION
88- ENV JAX_PLATFORM=linux_x86_64
89- ENV JAX_BASE_URL="https://storage.googleapis.com/jax-releases"
90-
91- RUN pip install $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-$JAXLIB_VERSION-$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \
92- pip install jax==$JAX_VERSION && \
83+ RUN pip install jax==0.2.6 jaxlib==0.1.57+cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
9384 /tmp/clean-layer.sh
9485
9586# Reinstall packages with a separate version for GPU support.
You can’t perform that action at this time.
0 commit comments