|
import os
|
|
|
|
|
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers'
|
|
os.environ['HF_HOME'] = '/tmp/huggingface'
|
|
|
|
|
|
from fastapi import FastAPI, Header
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
from typing import Optional, Dict, Annotated
|
|
import torch
|
|
import logging
|
|
|
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers'
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
MODEL_NAME = "mrm8488/t5-small-finetuned-wikiSQL"
|
|
logger.info(f"Loading model: {MODEL_NAME}")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
|
|
|
app = FastAPI(title="Lightweight Text-to-SQL API")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
class RequestModel(BaseModel):
|
|
entity_urn: str
|
|
prompt: str
|
|
|
|
class ResponseModel(BaseModel):
|
|
message: str
|
|
result: str
|
|
action_type: str
|
|
entity_urn: str
|
|
metadata: Optional[Dict] = None
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"status": "Text-to-SQL API running", "docs": "/docs"}
|
|
|
|
@app.post("/generate", response_model=ResponseModel)
|
|
async def generate_sql(request: RequestModel, x_api_key: Annotated[str, Header()]):
|
|
try:
|
|
logger.info(f"Prompt: {request.prompt}")
|
|
input_ids = tokenizer.encode(request.prompt, return_tensors="pt", truncation=True)
|
|
outputs = model.generate(input_ids, max_length=128)
|
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
return ResponseModel(
|
|
message="success",
|
|
result=decoded,
|
|
action_type="text_to_sql",
|
|
entity_urn=request.entity_urn,
|
|
metadata={"input_tokens": len(input_ids[0])}
|
|
)
|
|
|
|
except Exception as e:
|
|
return ResponseModel(
|
|
message="failure",
|
|
result=str(e),
|
|
action_type="text_to_sql",
|
|
entity_urn=request.entity_urn
|
|
)
|
|
|