acadiaway commited on
Commit
2741cd0
·
1 Parent(s): 3893c8b

Preload model in Dockerfile and pipeline.py to fix PermissionError

Browse files
Files changed (2) hide show
  1. Dockerfile +6 -5
  2. pipeline.py +22 -23
Dockerfile CHANGED
@@ -20,12 +20,13 @@ RUN pip install --no-cache-dir -r requirements.txt
20
 
21
  COPY app.py pipeline.py db_utils.py ./
22
 
23
- # Set up cache directory permissions and clear stale locks
24
- RUN mkdir -p /cache/huggingface && \
25
- chmod -R 777 /cache/huggingface && \
26
- rm -f /cache/huggingface/*.lock /cache/huggingface/*/*.lock
 
27
 
28
- ENV HF_HOME=/cache/huggingface
29
  ENV PORT=8501
30
  ENV OMP_NUM_THREADS=8
31
 
 
20
 
21
  COPY app.py pipeline.py db_utils.py ./
22
 
23
+ # Set up cache directory and preload model
24
+ RUN mkdir -p /app/cache/huggingface && \
25
+ chmod -R 777 /app/cache/huggingface && \
26
+ python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('Snowflake/Arctic-Text2SQL-R1-7B', cache_dir='/app/cache/huggingface')" && \
27
+ python -c "from vllm import LLM; LLM(model='Snowflake/Arctic-Text2SQL-R1-7B', dtype='float16', gpu_memory_utilization=0.75, max_model_len=1024, max_num_seqs=1, enforce_eager=True, trust_remote_code=True, cache_dir='/app/cache/huggingface')"
28
 
29
+ ENV HF_HOME=/app/cache/huggingface
30
  ENV PORT=8501
31
  ENV OMP_NUM_THREADS=8
32
 
pipeline.py CHANGED
@@ -3,31 +3,30 @@ from transformers import AutoTokenizer
3
  from vllm import LLM, SamplingParams
4
  from db_utils import get_schema, execute_sql
5
 
6
- # Initialize model lazily
7
  model = None
8
  tokenizer = None
9
-
10
- def init_model():
11
- global model, tokenizer
12
- if model is None or tokenizer is None:
13
- try:
14
- tokenizer = AutoTokenizer.from_pretrained("Snowflake/Arctic-Text2SQL-R1-7B")
15
- model = LLM(
16
- model="Snowflake/Arctic-Text2SQL-R1-7B",
17
- dtype="float16",
18
- gpu_memory_utilization=0.75, # Balanced for 30GB VRAM
19
- max_model_len=1024, # Reduced for speed
20
- max_num_seqs=1, # Single query
21
- enforce_eager=True, # Avoid graph compilation
22
- trust_remote_code=True # Model compatibility
23
- )
24
- except Exception as e:
25
- print(f"Error loading model: {e}")
26
- raise
27
 
28
  def text_to_sql(nl_query):
29
  try:
30
- init_model()
31
  schema = get_schema()
32
  prompt = f"""### Task
33
  Generate a SQL query to answer the following natural language question: {nl_query}
@@ -39,9 +38,9 @@ Generate a SQL query to answer the following natural language question: {nl_quer
39
  Output only the SQL query.
40
  """
41
  sampling_params = SamplingParams(
42
- temperature=0, # Deterministic
43
- max_tokens=128, # Short queries
44
- stop=["\n\n"] # Stop at query end
45
  )
46
  outputs = model.generate([prompt], sampling_params)
47
  sql = outputs[0].outputs[0].text.strip()
 
3
  from vllm import LLM, SamplingParams
4
  from db_utils import get_schema, execute_sql
5
 
6
+ # Initialize model at startup to avoid lazy loading
7
  model = None
8
  tokenizer = None
9
+ try:
10
+ tokenizer = AutoTokenizer.from_pretrained(
11
+ "Snowflake/Arctic-Text2SQL-R1-7B",
12
+ cache_dir="/app/cache/huggingface"
13
+ )
14
+ model = LLM(
15
+ model="Snowflake/Arctic-Text2SQL-R1-7B",
16
+ dtype="float16",
17
+ gpu_memory_utilization=0.75,
18
+ max_model_len=1024,
19
+ max_num_seqs=1,
20
+ enforce_eager=True,
21
+ trust_remote_code=True,
22
+ cache_dir="/app/cache/huggingface"
23
+ )
24
+ except Exception as e:
25
+ print(f"Error loading model at startup: {e}")
26
+ raise
27
 
28
  def text_to_sql(nl_query):
29
  try:
 
30
  schema = get_schema()
31
  prompt = f"""### Task
32
  Generate a SQL query to answer the following natural language question: {nl_query}
 
38
  Output only the SQL query.
39
  """
40
  sampling_params = SamplingParams(
41
+ temperature=0,
42
+ max_tokens=128,
43
+ stop=["\n\n"]
44
  )
45
  outputs = model.generate([prompt], sampling_params)
46
  sql = outputs[0].outputs[0].text.strip()