File size: 4,374 Bytes
f2139e9
62ad6d3
f2139e9
 
 
 
 
 
 
62ad6d3
f2139e9
 
 
 
 
 
d6b2e8d
f2139e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5e886a
f2139e9
 
62ad6d3
 
 
 
 
 
 
 
 
 
 
 
 
d5e886a
62ad6d3
 
 
 
 
 
 
 
 
 
 
 
f2139e9
 
 
62ad6d3
f2139e9
 
62ad6d3
 
 
 
 
f2139e9
 
d6b2e8d
f2139e9
d6b2e8d
f2139e9
 
d6b2e8d
62ad6d3
 
 
 
 
 
 
f2139e9
 
 
d5e886a
f2139e9
d6b2e8d
f2139e9
d6b2e8d
6452cea
d6b2e8d
f2139e9
 
 
 
62ad6d3
d6b2e8d
f2139e9
d6b2e8d
62ad6d3
 
d5e886a
f2139e9
 
d6b2e8d
f2139e9
 
d6b2e8d
f2139e9
 
d6b2e8d
f2139e9
 
d6b2e8d
f2139e9
 
 
 
62ad6d3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import httpx
import os
import logging
from typing import Optional, Dict, Any

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="Llama3-Papalia Inference API & UI",
    description="API para interactuar con el modelo Llama3-Papalia",
    version="1.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

templates = Jinja2Templates(directory="templates")

class QueryRequest(BaseModel):
    prompt: str
    temperature: Optional[float] = 0.7
    max_tokens: Optional[int] = 500

class QueryResponse(BaseModel):
    response: str
    model: str = "llama3.2:1b-papalia"

OLLAMA_API_URL = "http://localhost:11434/api/generate"
OLLAMA_BASE_URL = "http://localhost:11434"

async def check_ollama_status() -> Dict[str, Any]:
    try:
        async with httpx.AsyncClient(timeout=5.0) as client:
            response = await client.get(OLLAMA_BASE_URL)
            if response.status_code != 200:
                return {"status": "error", "message": "Ollama no responde", "code": response.status_code}

            # Verificar que el modelo esté disponible
            model_response = await client.post(
                OLLAMA_BASE_URL + "/api/generate",
                json={
                    "model": "llama3.2:1b-papalia",
                    "prompt": "test",
                    "max_tokens": 1
                },
                timeout=5.0
            )
            
            if model_response.status_code != 200:
                return {"status": "error", "message": "Modelo no disponible", "code": model_response.status_code}
                
            return {"status": "ok", "message": "Servicio funcionando correctamente"}
    except Exception as e:
        return {"status": "error", "message": str(e)}

@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
    status = await check_ollama_status()
    return templates.TemplateResponse(
        "index.html",
        {
            "request": request,
            "title": "Llama3-Papalia Inference",
            "status": status
        }
    )

@app.post("/generate")
async def generate_response(query: QueryRequest):
    logger.info(f"Recibida solicitud: {query.prompt[:50]}...")
    
    try:
        async with httpx.AsyncClient(timeout=60.0) as client:
            status = await check_ollama_status()
            if status["status"] != "ok":
                raise HTTPException(
                    status_code=503,
                    detail=status["message"]
                )

            response = await client.post(
                OLLAMA_API_URL,
                json={
                    "model": "llama3.2:1b-papalia",
                    "prompt": query.prompt,
                    "stream": False,
                    "temperature": query.temperature,
                    "max_tokens": query.max_tokens
                },
                timeout=60.0
            )
            
            if response.status_code != 200:
                raise HTTPException(
                    status_code=response.status_code,
                    detail=f"Error del modelo: {response.text}"
                )

            result = response.json()
            logger.info("Respuesta generada exitosamente")
            return {"response": result.get("response", ""), "model": "llama3.2:1b-papalia"}
            
    except httpx.TimeoutException:
        logger.error("Timeout en la solicitud a Ollama")
        raise HTTPException(
            status_code=504,
            detail="Timeout en la solicitud al modelo"
        )
    except Exception as e:
        logger.error(f"Error: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail=str(e)
        )

@app.get("/health")
async def health_check():
    status = await check_ollama_status()
    if status["status"] == "ok":
        return {"status": "healthy", "message": status["message"]}
    return {"status": "unhealthy", "error": status["message"]}