X

Transformers
code
Not-For-All-Audiences
legal
anything
medical
biology
X / handler.py
CLASSIFIED-HEX's picture
Update handler.py
ffb427f verified
# handler.py
import requests
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Optional
# Setup logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
router = APIRouter()
# Your Hugging Face model URL – must be public
MODEL_URL = "https://api-inference.huggingface.co/models/CLASSIFIED-HEX/X"
# Input model
class PromptInput(BaseModel):
prompt: str
max_tokens: Optional[int] = 250
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.95
top_k: Optional[int] = 50
repetition_penalty: Optional[float] = 1.2
trim_output: Optional[bool] = False # New feature to remove prompt from result
# Root health check
@router.get("/")
async def root():
return {"message": "AI text generation backend is running πŸš€"}
# Ping model check
@router.get("/ping-model")
async def ping_model():
try:
response = requests.post(MODEL_URL, json={"inputs": "ping test"})
if response.status_code == 200:
return {"status": "Model is online βœ…"}
else:
return {"status": "Model responded with error ❌", "details": response.json()}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Could not reach model: {str(e)}")
# Main generation route
@router.post("/generate")
async def generate_text(input_data: PromptInput):
payload = {
"inputs": input_data.prompt,
"parameters": {
"max_new_tokens": input_data.max_tokens,
"temperature": input_data.temperature,
"top_p": input_data.top_p,
"top_k": input_data.top_k,
"repetition_penalty": input_data.repetition_penalty
}
}
try:
logger.info(f"Sending prompt to model: {input_data.prompt}")
response = requests.post(MODEL_URL, json=payload)
if response.status_code != 200:
logger.error(f"Model error: {response.status_code} - {response.text}")
raise HTTPException(status_code=response.status_code, detail=response.json())
result = response.json()
raw_output = result[0].get("generated_text") if isinstance(result, list) else result.get("generated_text", "")
# Optionally trim prompt from beginning
if input_data.trim_output and raw_output.startswith(input_data.prompt):
raw_output = raw_output[len(input_data.prompt):].lstrip()
return {
"status": "success",
"output": raw_output
}
except Exception as e:
logger.exception("Text generation failed")
raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}")