texttosql / app.py
vijkid001's picture
Upload app.py
04676a3 verified
import os
# Fix for Hugging Face Spaces cache permission issue
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers'
os.environ['HF_HOME'] = '/tmp/huggingface'
from fastapi import FastAPI, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import Optional, Dict, Annotated
import torch
import logging
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers' # Safe path on HF Spaces
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load lightweight T5 model
MODEL_NAME = "mrm8488/t5-small-finetuned-wikiSQL"
logger.info(f"Loading model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
app = FastAPI(title="Lightweight Text-to-SQL API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class RequestModel(BaseModel):
entity_urn: str
prompt: str
class ResponseModel(BaseModel):
message: str
result: str
action_type: str
entity_urn: str
metadata: Optional[Dict] = None
@app.get("/")
async def root():
return {"status": "Text-to-SQL API running", "docs": "/docs"}
@app.post("/generate", response_model=ResponseModel)
async def generate_sql(request: RequestModel, x_api_key: Annotated[str, Header()]):
try:
logger.info(f"Prompt: {request.prompt}")
input_ids = tokenizer.encode(request.prompt, return_tensors="pt", truncation=True)
outputs = model.generate(input_ids, max_length=128)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
return ResponseModel(
message="success",
result=decoded,
action_type="text_to_sql",
entity_urn=request.entity_urn,
metadata={"input_tokens": len(input_ids[0])}
)
except Exception as e:
return ResponseModel(
message="failure",
result=str(e),
action_type="text_to_sql",
entity_urn=request.entity_urn
)