File size: 2,233 Bytes
04676a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os

# Fix for Hugging Face Spaces cache permission issue
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers'
os.environ['HF_HOME'] = '/tmp/huggingface'


from fastapi import FastAPI, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import Optional, Dict, Annotated
import torch
import logging

os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers'  # Safe path on HF Spaces

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load lightweight T5 model
MODEL_NAME = "mrm8488/t5-small-finetuned-wikiSQL"
logger.info(f"Loading model: {MODEL_NAME}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

app = FastAPI(title="Lightweight Text-to-SQL API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

class RequestModel(BaseModel):
    entity_urn: str
    prompt: str

class ResponseModel(BaseModel):
    message: str
    result: str
    action_type: str
    entity_urn: str
    metadata: Optional[Dict] = None

@app.get("/")
async def root():
    return {"status": "Text-to-SQL API running", "docs": "/docs"}

@app.post("/generate", response_model=ResponseModel)
async def generate_sql(request: RequestModel, x_api_key: Annotated[str, Header()]):
    try:
        logger.info(f"Prompt: {request.prompt}")
        input_ids = tokenizer.encode(request.prompt, return_tensors="pt", truncation=True)
        outputs = model.generate(input_ids, max_length=128)
        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)

        return ResponseModel(
            message="success",
            result=decoded,
            action_type="text_to_sql",
            entity_urn=request.entity_urn,
            metadata={"input_tokens": len(input_ids[0])}
        )

    except Exception as e:
        return ResponseModel(
            message="failure",
            result=str(e),
            action_type="text_to_sql",
            entity_urn=request.entity_urn
        )