import os from transformers import AutoTokenizer from vllm import LLM, SamplingParams from db_utils import get_schema, execute_sql # Initialize model at startup model = None tokenizer = None try: tokenizer = AutoTokenizer.from_pretrained( "Snowflake/Arctic-Text2SQL-R1-7B", cache_dir="/tmp/cache/huggingface", trust_remote_code=True ) model = LLM( model="Snowflake/Arctic-Text2SQL-R1-7B", dtype="float16", gpu_memory_utilization=0.75, max_model_len=1024, max_num_seqs=1, enforce_eager=True, trust_remote_code=True ) except Exception as e: print(f"Error loading model at startup: {e}") raise def text_to_sql(nl_query): try: schema = get_schema() prompt = f"""### Task Generate a SQL query to answer the following natural language question: {nl_query} ### Database Schema {schema} ### Response Format Output only the SQL query. """ sampling_params = SamplingParams( temperature=0, max_tokens=128, stop=["\n\n"] ) outputs = model.generate([prompt], sampling_params) sql = outputs[0].outputs[0].text.strip() results = execute_sql(sql) return sql, results except Exception as e: print(f"Error in text_to_sql: {e}") raise