acadiaway commited on
Commit
f6799fb
·
1 Parent(s): d94839d

Dockerfile pipeline.py requirements.txt optized parallelism

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -0
  2. pipeline.py +11 -4
  3. requirements.txt +2 -1
Dockerfile CHANGED
@@ -22,6 +22,7 @@ COPY app.py pipeline.py db_utils.py ./
22
 
23
  ENV HF_HOME=/cache/huggingface
24
  ENV PORT=8501
 
25
 
26
  EXPOSE 8501
27
 
 
22
 
23
  ENV HF_HOME=/cache/huggingface
24
  ENV PORT=8501
25
+ ENV OMP_NUM_THREADS=8 # Set to match 8vCPUs
26
 
27
  EXPOSE 8501
28
 
pipeline.py CHANGED
@@ -15,12 +15,15 @@ def init_model():
15
  model = LLM(
16
  model="Snowflake/Arctic-Text2SQL-R1-7B",
17
  dtype="float16",
18
- gpu_memory_utilization=0.9,
19
- max_model_len=4096
 
 
 
20
  )
21
  except Exception as e:
22
  print(f"Error loading model: {e}")
23
- exit(1)
24
 
25
  def text_to_sql(nl_query):
26
  try:
@@ -35,7 +38,11 @@ Generate a SQL query to answer the following natural language question: {nl_quer
35
  ### Response Format
36
  Output only the SQL query.
37
  """
38
- sampling_params = SamplingParams(temperature=0, max_tokens=512)
 
 
 
 
39
  outputs = model.generate([prompt], sampling_params)
40
  sql = outputs[0].outputs[0].text.strip()
41
  results = execute_sql(sql)
 
15
  model = LLM(
16
  model="Snowflake/Arctic-Text2SQL-R1-7B",
17
  dtype="float16",
18
+ gpu_memory_utilization=0.75, # Balanced for 30GB VRAM
19
+ max_model_len=1024, # Reduced for speed
20
+ max_num_seqs=1, # Single query
21
+ enforce_eager=True, # Avoid graph compilation
22
+ trust_remote_code=True # Model compatibility
23
  )
24
  except Exception as e:
25
  print(f"Error loading model: {e}")
26
+ raise
27
 
28
  def text_to_sql(nl_query):
29
  try:
 
38
  ### Response Format
39
  Output only the SQL query.
40
  """
41
+ sampling_params = SamplingParams(
42
+ temperature=0, # Deterministic
43
+ max_tokens=128, # Short queries
44
+ stop=["\n\n"] # Stop at query end
45
+ )
46
  outputs = model.generate([prompt], sampling_params)
47
  sql = outputs[0].outputs[0].text.strip()
48
  results = execute_sql(sql)
requirements.txt CHANGED
@@ -4,4 +4,5 @@ psycopg2-binary==2.9.10
4
  sqlalchemy==2.0.43
5
  python-dotenv==1.1.1
6
  vllm==0.10.1
7
- streamlit==1.39.0
 
 
4
  sqlalchemy==2.0.43
5
  python-dotenv==1.1.1
6
  vllm==0.10.1
7
+ streamlit==1.39.0
8
+ torch==2.8.0