Spaces:
Running
Running
import os | |
from transformers import AutoTokenizer | |
from vllm import LLM, SamplingParams | |
from db_utils import get_schema, execute_sql | |
# Initialize model lazily | |
model = None | |
tokenizer = None | |
def init_model(): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("Snowflake/Arctic-Text2SQL-R1-7B") | |
model = LLM( | |
model="Snowflake/Arctic-Text2SQL-R1-7B", | |
dtype="float16", | |
gpu_memory_utilization=0.75, # Balanced for 30GB VRAM | |
max_model_len=1024, # Reduced for speed | |
max_num_seqs=1, # Single query | |
enforce_eager=True, # Avoid graph compilation | |
trust_remote_code=True # Model compatibility | |
) | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise | |
def text_to_sql(nl_query): | |
try: | |
init_model() | |
schema = get_schema() | |
prompt = f"""### Task | |
Generate a SQL query to answer the following natural language question: {nl_query} | |
### Database Schema | |
{schema} | |
### Response Format | |
Output only the SQL query. | |
""" | |
sampling_params = SamplingParams( | |
temperature=0, # Deterministic | |
max_tokens=128, # Short queries | |
stop=["\n\n"] # Stop at query end | |
) | |
outputs = model.generate([prompt], sampling_params) | |
sql = outputs[0].outputs[0].text.strip() | |
results = execute_sql(sql) | |
return sql, results | |
except Exception as e: | |
print(f"Error in text_to_sql: {e}") | |
raise | |