gemini_nl2sql / pipeline.py
acadiaway's picture
Simplify Dockerfile, use /tmp/cache/huggingface, preload model in pipeline.py
b4bcb5e
raw
history blame
1.35 kB
import os
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from db_utils import get_schema, execute_sql
# Initialize model at startup
model = None
tokenizer = None
try:
tokenizer = AutoTokenizer.from_pretrained(
"Snowflake/Arctic-Text2SQL-R1-7B",
cache_dir="/tmp/cache/huggingface",
trust_remote_code=True
)
model = LLM(
model="Snowflake/Arctic-Text2SQL-R1-7B",
dtype="float16",
gpu_memory_utilization=0.75,
max_model_len=1024,
max_num_seqs=1,
enforce_eager=True,
trust_remote_code=True
)
except Exception as e:
print(f"Error loading model at startup: {e}")
raise
def text_to_sql(nl_query):
try:
schema = get_schema()
prompt = f"""### Task
Generate a SQL query to answer the following natural language question: {nl_query}
### Database Schema
{schema}
### Response Format
Output only the SQL query.
"""
sampling_params = SamplingParams(
temperature=0,
max_tokens=128,
stop=["\n\n"]
)
outputs = model.generate([prompt], sampling_params)
sql = outputs[0].outputs[0].text.strip()
results = execute_sql(sql)
return sql, results
except Exception as e:
print(f"Error in text_to_sql: {e}")
raise