gemini_nl2sql / pipeline.py
acadiaway's picture
Dockerfile pipeline.py requirements.txt optized parallelism
f6799fb
raw
history blame
1.68 kB
import os
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from db_utils import get_schema, execute_sql
# Initialize model lazily
model = None
tokenizer = None
def init_model():
global model, tokenizer
if model is None or tokenizer is None:
try:
tokenizer = AutoTokenizer.from_pretrained("Snowflake/Arctic-Text2SQL-R1-7B")
model = LLM(
model="Snowflake/Arctic-Text2SQL-R1-7B",
dtype="float16",
gpu_memory_utilization=0.75, # Balanced for 30GB VRAM
max_model_len=1024, # Reduced for speed
max_num_seqs=1, # Single query
enforce_eager=True, # Avoid graph compilation
trust_remote_code=True # Model compatibility
)
except Exception as e:
print(f"Error loading model: {e}")
raise
def text_to_sql(nl_query):
try:
init_model()
schema = get_schema()
prompt = f"""### Task
Generate a SQL query to answer the following natural language question: {nl_query}
### Database Schema
{schema}
### Response Format
Output only the SQL query.
"""
sampling_params = SamplingParams(
temperature=0, # Deterministic
max_tokens=128, # Short queries
stop=["\n\n"] # Stop at query end
)
outputs = model.generate([prompt], sampling_params)
sql = outputs[0].outputs[0].text.strip()
results = execute_sql(sql)
return sql, results
except Exception as e:
print(f"Error in text_to_sql: {e}")
raise