Spaces:
Running
Running
Dockerfile pipeline.py requirements.txt optized parallelism
Browse files- Dockerfile +1 -0
- pipeline.py +11 -4
- requirements.txt +2 -1
Dockerfile
CHANGED
@@ -22,6 +22,7 @@ COPY app.py pipeline.py db_utils.py ./
|
|
22 |
|
23 |
ENV HF_HOME=/cache/huggingface
|
24 |
ENV PORT=8501
|
|
|
25 |
|
26 |
EXPOSE 8501
|
27 |
|
|
|
22 |
|
23 |
ENV HF_HOME=/cache/huggingface
|
24 |
ENV PORT=8501
|
25 |
+
ENV OMP_NUM_THREADS=8 # Set to match 8vCPUs
|
26 |
|
27 |
EXPOSE 8501
|
28 |
|
pipeline.py
CHANGED
@@ -15,12 +15,15 @@ def init_model():
|
|
15 |
model = LLM(
|
16 |
model="Snowflake/Arctic-Text2SQL-R1-7B",
|
17 |
dtype="float16",
|
18 |
-
gpu_memory_utilization=0.
|
19 |
-
max_model_len=
|
|
|
|
|
|
|
20 |
)
|
21 |
except Exception as e:
|
22 |
print(f"Error loading model: {e}")
|
23 |
-
|
24 |
|
25 |
def text_to_sql(nl_query):
|
26 |
try:
|
@@ -35,7 +38,11 @@ Generate a SQL query to answer the following natural language question: {nl_quer
|
|
35 |
### Response Format
|
36 |
Output only the SQL query.
|
37 |
"""
|
38 |
-
sampling_params = SamplingParams(
|
|
|
|
|
|
|
|
|
39 |
outputs = model.generate([prompt], sampling_params)
|
40 |
sql = outputs[0].outputs[0].text.strip()
|
41 |
results = execute_sql(sql)
|
|
|
15 |
model = LLM(
|
16 |
model="Snowflake/Arctic-Text2SQL-R1-7B",
|
17 |
dtype="float16",
|
18 |
+
gpu_memory_utilization=0.75, # Balanced for 30GB VRAM
|
19 |
+
max_model_len=1024, # Reduced for speed
|
20 |
+
max_num_seqs=1, # Single query
|
21 |
+
enforce_eager=True, # Avoid graph compilation
|
22 |
+
trust_remote_code=True # Model compatibility
|
23 |
)
|
24 |
except Exception as e:
|
25 |
print(f"Error loading model: {e}")
|
26 |
+
raise
|
27 |
|
28 |
def text_to_sql(nl_query):
|
29 |
try:
|
|
|
38 |
### Response Format
|
39 |
Output only the SQL query.
|
40 |
"""
|
41 |
+
sampling_params = SamplingParams(
|
42 |
+
temperature=0, # Deterministic
|
43 |
+
max_tokens=128, # Short queries
|
44 |
+
stop=["\n\n"] # Stop at query end
|
45 |
+
)
|
46 |
outputs = model.generate([prompt], sampling_params)
|
47 |
sql = outputs[0].outputs[0].text.strip()
|
48 |
results = execute_sql(sql)
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ psycopg2-binary==2.9.10
|
|
4 |
sqlalchemy==2.0.43
|
5 |
python-dotenv==1.1.1
|
6 |
vllm==0.10.1
|
7 |
-
streamlit==1.39.0
|
|
|
|
4 |
sqlalchemy==2.0.43
|
5 |
python-dotenv==1.1.1
|
6 |
vllm==0.10.1
|
7 |
+
streamlit==1.39.0
|
8 |
+
torch==2.8.0
|