@@ -10,7 +10,7 @@ ARG TORCHVISION_VERSION
1010
1111{{ if eq .Accelerator "gpu" }}
1212FROM gcr.io/kaggle-images/python-lightgbm-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${LIGHTGBM_VERSION} AS lightgbm_whl
13- # FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl
13+ FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl
1414FROM ${BASE_IMAGE_REPO}/${GPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
1515{{ else }}
1616FROM ${BASE_IMAGE_REPO}/${CPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
@@ -111,15 +111,13 @@ RUN mamba install implicit && \
111111{{ if eq .Accelerator "gpu" }}
112112#COPY --from=torch_whl /tmp/whl/*.whl /tmp/torch/
113113RUN mamba install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION} && \
114- # pip install /tmp/torch/*.whl && \
114+ pip install /tmp/torch/*.whl && \
115115 # b/255757999 openmp (libomp.so) is an dependency of libtorchtext and libtorchaudio but
116- # the built from source versions don't seem to properly link it in. This forces the dep
117- # which makes sure that libomp is loaded when these libraries are loaded.
118- mamba install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 torchtext cudatoolkit=11.3 -c pytorch && \
116+ mamba install -y openmp && \
119117 #pip install patchelf && \
120118 #patchelf --add-needed libomp.so /opt/conda/lib/python3.7/site-packages/torchtext/lib/libtorchtext.so && \
121119 #patchelf --add-needed libomp.so /opt/conda/lib/python3.7/site-packages/torchaudio/lib/libtorchaudio.so && \
122- # rm -rf /tmp/torch && \
120+ rm -rf /tmp/torch && \
123121 /tmp/clean-layer.sh
124122{{ else }}
125123RUN pip install \
0 commit comments