|
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 |
|
|
|
|
|
|
|
RUN apt-get update && apt-get install -y --no-install-recommends \ |
|
python3.11 python3.11-venv python3.11-distutils python3-pip \ |
|
libsndfile1 ffmpeg git ca-certificates curl \ |
|
&& rm -rf /var/lib/apt/lists/* |
|
|
|
|
|
RUN curl -LsSf https://astral.sh/uv/install.sh | sh -s -- -y |
|
ENV PATH="/root/.local/bin:${PATH}" |
|
|
|
|
|
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} |
|
|
|
|
|
ENV TF_FORCE_GPU_ALLOW_GROWTH=true \ |
|
XLA_PYTHON_CLIENT_PREALLOCATE=false \ |
|
JAX_PLATFORMS=cuda,cpu |
|
|
|
|
|
WORKDIR /opt/app |
|
COPY pyproject.toml ./ |
|
|
|
|
|
RUN uv lock |
|
|
|
|
|
RUN uv sync --frozen --python=/usr/bin/python3.11 --no-dev |
|
|
|
|
|
RUN /opt/venv/bin/python - <<'PY' |
|
import jax, jaxlib |
|
print("JAX:", jax.__version__) |
|
print("JAXLIB:", jaxlib.__version__) |
|
try: |
|
import importlib |
|
print("CUDA plugin:", importlib.metadata.version("jax-cuda12-plugin")) |
|
except Exception as e: |
|
print("CUDA plugin:", "not found?", e) |
|
PY |
|
|
|
|
|
COPY app.py utils.py jam_worker.py ./ |
|
|
|
EXPOSE 7860 |
|
CMD ["/opt/venv/bin/uvicorn","app:app","--host","0.0.0.0","--port","7860"] |
|
|