Spaces:
Running
Running
import os | |
from transformers import AutoTokenizer | |
from vllm import LLM, SamplingParams | |
from db_utils import get_schema, execute_sql | |
# Initialize model at startup | |
model = None | |
tokenizer = None | |
try: | |
tokenizer = AutoTokenizer.from_pretrained( | |
"Snowflake/Arctic-Text2SQL-R1-7B", | |
cache_dir="/tmp/cache/huggingface", | |
trust_remote_code=True | |
) | |
model = LLM( | |
model="Snowflake/Arctic-Text2SQL-R1-7B", | |
dtype="float16", | |
gpu_memory_utilization=0.75, | |
max_model_len=1024, | |
max_num_seqs=1, | |
enforce_eager=True, | |
trust_remote_code=True | |
) | |
except Exception as e: | |
print(f"Error loading model at startup: {e}") | |
raise | |
def text_to_sql(nl_query): | |
try: | |
schema = get_schema() | |
prompt = f"""### Task | |
Generate a SQL query to answer the following natural language question: {nl_query} | |
### Database Schema | |
{schema} | |
### Response Format | |
Output only the SQL query. | |
""" | |
sampling_params = SamplingParams( | |
temperature=0, | |
max_tokens=128, | |
stop=["\n\n"] | |
) | |
outputs = model.generate([prompt], sampling_params) | |
sql = outputs[0].outputs[0].text.strip() | |
results = execute_sql(sql) | |
return sql, results | |
except Exception as e: | |
print(f"Error in text_to_sql: {e}") | |
raise | |