Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import httpx | |
import os | |
import logging | |
from typing import Optional, Dict, Any | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="Llama3-Papalia Inference API & UI", | |
description="API para interactuar con el modelo Llama3-Papalia", | |
version="1.0.0" | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
templates = Jinja2Templates(directory="templates") | |
class QueryRequest(BaseModel): | |
prompt: str | |
temperature: Optional[float] = 0.7 | |
max_tokens: Optional[int] = 500 | |
class QueryResponse(BaseModel): | |
response: str | |
model: str = "llama3.2:1b-papalia" | |
OLLAMA_API_URL = "http://localhost:11434/api/generate" | |
OLLAMA_BASE_URL = "http://localhost:11434" | |
async def check_ollama_status() -> Dict[str, Any]: | |
try: | |
async with httpx.AsyncClient(timeout=5.0) as client: | |
response = await client.get(OLLAMA_BASE_URL) | |
if response.status_code != 200: | |
return {"status": "error", "message": "Ollama no responde", "code": response.status_code} | |
# Verificar que el modelo esté disponible | |
model_response = await client.post( | |
OLLAMA_BASE_URL + "/api/generate", | |
json={ | |
"model": "llama3.2:1b-papalia", | |
"prompt": "test", | |
"max_tokens": 1 | |
}, | |
timeout=5.0 | |
) | |
if model_response.status_code != 200: | |
return {"status": "error", "message": "Modelo no disponible", "code": model_response.status_code} | |
return {"status": "ok", "message": "Servicio funcionando correctamente"} | |
except Exception as e: | |
return {"status": "error", "message": str(e)} | |
async def read_root(request: Request): | |
status = await check_ollama_status() | |
return templates.TemplateResponse( | |
"index.html", | |
{ | |
"request": request, | |
"title": "Llama3-Papalia Inference", | |
"status": status | |
} | |
) | |
async def generate_response(query: QueryRequest): | |
logger.info(f"Recibida solicitud: {query.prompt[:50]}...") | |
try: | |
async with httpx.AsyncClient(timeout=60.0) as client: | |
status = await check_ollama_status() | |
if status["status"] != "ok": | |
raise HTTPException( | |
status_code=503, | |
detail=status["message"] | |
) | |
response = await client.post( | |
OLLAMA_API_URL, | |
json={ | |
"model": "llama3.2:1b-papalia", | |
"prompt": query.prompt, | |
"stream": False, | |
"temperature": query.temperature, | |
"max_tokens": query.max_tokens | |
}, | |
timeout=60.0 | |
) | |
if response.status_code != 200: | |
raise HTTPException( | |
status_code=response.status_code, | |
detail=f"Error del modelo: {response.text}" | |
) | |
result = response.json() | |
logger.info("Respuesta generada exitosamente") | |
return {"response": result.get("response", ""), "model": "llama3.2:1b-papalia"} | |
except httpx.TimeoutException: | |
logger.error("Timeout en la solicitud a Ollama") | |
raise HTTPException( | |
status_code=504, | |
detail="Timeout en la solicitud al modelo" | |
) | |
except Exception as e: | |
logger.error(f"Error: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=str(e) | |
) | |
async def health_check(): | |
status = await check_ollama_status() | |
if status["status"] == "ok": | |
return {"status": "healthy", "message": status["message"]} | |
return {"status": "unhealthy", "error": status["message"]} | |