Spaces:
Running
Running
File size: 1,348 Bytes
6b5f4d7 b4bcb5e 6b5f4d7 2741cd0 b4bcb5e 2741cd0 b4bcb5e 2741cd0 6b5f4d7 3d893a8 f6799fb 2741cd0 f6799fb 3d893a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import os
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from db_utils import get_schema, execute_sql
# Initialize model at startup
model = None
tokenizer = None
try:
tokenizer = AutoTokenizer.from_pretrained(
"Snowflake/Arctic-Text2SQL-R1-7B",
cache_dir="/tmp/cache/huggingface",
trust_remote_code=True
)
model = LLM(
model="Snowflake/Arctic-Text2SQL-R1-7B",
dtype="float16",
gpu_memory_utilization=0.75,
max_model_len=1024,
max_num_seqs=1,
enforce_eager=True,
trust_remote_code=True
)
except Exception as e:
print(f"Error loading model at startup: {e}")
raise
def text_to_sql(nl_query):
try:
schema = get_schema()
prompt = f"""### Task
Generate a SQL query to answer the following natural language question: {nl_query}
### Database Schema
{schema}
### Response Format
Output only the SQL query.
"""
sampling_params = SamplingParams(
temperature=0,
max_tokens=128,
stop=["\n\n"]
)
outputs = model.generate([prompt], sampling_params)
sql = outputs[0].outputs[0].text.strip()
results = execute_sql(sql)
return sql, results
except Exception as e:
print(f"Error in text_to_sql: {e}")
raise
|