|
|
|
|
|
import requests |
|
import logging |
|
from fastapi import APIRouter, HTTPException |
|
from pydantic import BaseModel |
|
from typing import Optional |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
router = APIRouter() |
|
|
|
|
|
MODEL_URL = "https://api-inference.huggingface.co/models/CLASSIFIED-HEX/X" |
|
|
|
|
|
class PromptInput(BaseModel): |
|
prompt: str |
|
max_tokens: Optional[int] = 250 |
|
temperature: Optional[float] = 0.7 |
|
top_p: Optional[float] = 0.95 |
|
top_k: Optional[int] = 50 |
|
repetition_penalty: Optional[float] = 1.2 |
|
trim_output: Optional[bool] = False |
|
|
|
|
|
@router.get("/") |
|
async def root(): |
|
return {"message": "AI text generation backend is running π"} |
|
|
|
|
|
@router.get("/ping-model") |
|
async def ping_model(): |
|
try: |
|
response = requests.post(MODEL_URL, json={"inputs": "ping test"}) |
|
if response.status_code == 200: |
|
return {"status": "Model is online β
"} |
|
else: |
|
return {"status": "Model responded with error β", "details": response.json()} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Could not reach model: {str(e)}") |
|
|
|
|
|
@router.post("/generate") |
|
async def generate_text(input_data: PromptInput): |
|
payload = { |
|
"inputs": input_data.prompt, |
|
"parameters": { |
|
"max_new_tokens": input_data.max_tokens, |
|
"temperature": input_data.temperature, |
|
"top_p": input_data.top_p, |
|
"top_k": input_data.top_k, |
|
"repetition_penalty": input_data.repetition_penalty |
|
} |
|
} |
|
|
|
try: |
|
logger.info(f"Sending prompt to model: {input_data.prompt}") |
|
response = requests.post(MODEL_URL, json=payload) |
|
|
|
if response.status_code != 200: |
|
logger.error(f"Model error: {response.status_code} - {response.text}") |
|
raise HTTPException(status_code=response.status_code, detail=response.json()) |
|
|
|
result = response.json() |
|
raw_output = result[0].get("generated_text") if isinstance(result, list) else result.get("generated_text", "") |
|
|
|
|
|
if input_data.trim_output and raw_output.startswith(input_data.prompt): |
|
raw_output = raw_output[len(input_data.prompt):].lstrip() |
|
|
|
return { |
|
"status": "success", |
|
"output": raw_output |
|
} |
|
|
|
except Exception as e: |
|
logger.exception("Text generation failed") |
|
raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}") |
|
|