lokiai / main.py
ParthSadaria's picture
Update main.py
f97c315 verified
raw
history blame
17.1 kB
import os
import json
import datetime
import asyncio
import re
from functools import lru_cache
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional
import httpx
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, Depends, Security, Query, APIRouter
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse, FileResponse, PlainTextResponse
from fastapi.security import APIKeyHeader
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.status import HTTP_403_FORBIDDEN, HTTP_503_SERVICE_UNAVAILABLE
# Use cloudscraper for specific endpoints that need it
try:
import cloudscraper
except ImportError:
cloudscraper = None
from usage_tracker import UsageTracker
# --- Initial Setup ---
load_dotenv()
# Use uvloop for better performance if available
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
# --- Configuration Management using Pydantic ---
class Settings(BaseSettings):
"""Manages all application settings and environment variables in one place."""
api_keys: List[str] = Field(..., env="API_KEYS")
# Endpoints for various model providers
secret_api_endpoint: str = Field(..., env="SECRET_API_ENDPOINT")
secret_api_endpoint_2: str = Field(..., env="SECRET_API_ENDPOINT_2")
secret_api_endpoint_3: str = Field(..., env="SECRET_API_ENDPOINT_3")
secret_api_endpoint_4: str = "https://text.pollinations.ai/openai"
secret_api_endpoint_5: str = Field(..., env="SECRET_API_ENDPOINT_5")
secret_api_endpoint_6: str = Field(..., env="SECRET_API_ENDPOINT_6")
# Specific provider keys and APIs
mistral_api: str = "https://api.mistral.ai"
mistral_key: str = Field(..., env="MISTRAL_KEY")
gemini_key: str = Field(..., env="GEMINI_KEY")
new_img_api: str = Field(..., env="NEW_IMG")
endpoint_origin: Optional[str] = Field(None, env="ENDPOINT_ORIGIN")
header_url: Optional[str] = Field(None, env="HEADER_URL")
class Config:
env_file = '.env'
env_file_encoding = 'utf-8'
@lru_cache()
def get_settings():
return Settings()
# --- Pydantic Models for Payloads ---
class ChatPayload(BaseModel):
model: str
messages: List[Dict[str, Any]]
stream: bool = False
class ImageGenerationPayload(BaseModel):
model: str
prompt: str
size: int
number: int
# --- Global Objects & State ---
app = FastAPI(
title="LokiAI API",
version="2.5.0",
description="A robust and scalable API proxy for various AI models, now fully rewritten.",
)
usage_tracker = UsageTracker()
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
server_status = {"online": True}
# --- Model & API Configuration ---
MODEL_SETS = {
"mistral": {"mistral-large-latest", "codestral-latest", "mistral-small-latest"},
"pollinations": {"openai", "gemini", "phi", "llama"},
"alternate": {"o1", "grok-3", "sonar-pro"},
"claude": {"claude-3-7-sonnet", "claude 3.5 sonnet", "o3-mini-medium"},
"gemini": {"gemini-1.5-pro", "gemini-1.5-flash", "gemini-2.0-flash"},
"image": {"Flux Pro Ultra", "dall-e-3", "stable-diffusion-3-large-turbo"},
}
def get_api_details(model_name: str, settings: Settings) -> Tuple[str, Dict, str]:
"""Returns the endpoint, headers, and path for a given model."""
if model_name in MODEL_SETS["mistral"]:
return settings.mistral_api, {"Authorization": f"Bearer {settings.mistral_key}"}, "/v1/chat/completions"
if model_name in MODEL_SETS["gemini"]:
return settings.secret_api_endpoint_6, {"Authorization": f"Bearer {settings.gemini_key}"}, "/chat/completions"
if model_name in MODEL_SETS["pollinations"]:
return settings.secret_api_endpoint_4, {}, "/v1/chat/completions"
if model_name in MODEL_SETS["claude"]:
return settings.secret_api_endpoint_5, {}, "/v1/chat/completions"
if model_name in MODEL_SETS["alternate"]:
return settings.secret_api_endpoint_2, {}, "/v1/chat/completions"
if model_name in MODEL_SETS["image"]:
return settings.new_img_api, {}, ""
# Default case
headers = {
"Origin": settings.header_url, "Referer": settings.header_url
} if settings.header_url else {}
return settings.secret_api_endpoint, headers, "/v1/chat/completions"
# --- Dependencies & Security ---
async def get_api_key(request: Request, api_key: str = Security(api_key_header)):
"""Validates the API key, allowing specific referers to bypass."""
referer = request.headers.get("referer", "")
if referer and "parthsadaria-lokiai.hf.space" in referer:
return "hf_space_bypass"
settings = get_settings()
if not api_key or not api_key.startswith("Bearer "):
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Invalid authorization format.")
key = api_key.split(" ")[1]
if key not in settings.api_keys:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Invalid API key.")
return key
@lru_cache()
def get_http_client() -> httpx.AsyncClient:
return httpx.AsyncClient(timeout=60.0, limits=httpx.Limits(max_connections=200))
# --- API Routers ---
chat_router = APIRouter(tags=["AI Models"])
image_router = APIRouter(tags=["AI Models"])
usage_router = APIRouter(tags=["Server Administration"])
utility_router = APIRouter(tags=["Utilities & Pages"])
# --- Chat Completions Router ---
@chat_router.post("/chat/completions")
async def chat_completions(
payload: ChatPayload,
request: Request,
api_key: str = Depends(get_api_key),
client: httpx.AsyncClient = Depends(get_http_client)
):
if not server_status["online"]:
raise HTTPException(status_code=HTTP_503_SERVICE_UNAVAILABLE, detail="Server under maintenance.")
settings = get_settings()
usage_tracker.record_request(request, payload.model, "/chat/completions")
endpoint, headers, path = get_api_details(payload.model, settings)
async def stream_generator():
try:
async with client.stream("POST", f"{endpoint}{path}", json=payload.dict(), headers=headers) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
except httpx.HTTPStatusError as e:
print(f"Upstream error: {e.response.status_code} - {e.response.text}")
yield json.dumps({"error": {"code": 502, "message": "Bad Gateway: Upstream service error."}}).encode()
except Exception as e:
print(f"Streaming error: {e}")
yield json.dumps({"error": {"code": 500, "message": "An internal error occurred."}}).encode()
return StreamingResponse(stream_generator(), media_type="text/event-stream")
# --- Image Generation Router ---
@image_router.post("/images/generations")
async def images_generations(
payload: ImageGenerationPayload,
request: Request,
api_key: str = Depends(get_api_key),
client: httpx.AsyncClient = Depends(get_http_client)
):
if not server_status["online"]:
raise HTTPException(status_code=HTTP_503_SERVICE_UNAVAILABLE, detail="Server under maintenance.")
if payload.model not in MODEL_SETS["image"]:
raise HTTPException(status_code=400, detail=f"Image model '{payload.model}' not supported.")
settings = get_settings()
usage_tracker.record_request(request, payload.model, "/images/generations")
endpoint, headers, _ = get_api_details(payload.model, settings)
try:
response = await client.post(endpoint, json=payload.dict(), headers=headers)
response.raise_for_status()
return JSONResponse(content=response.json())
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=e.response.json().get("detail", "Upstream error"))
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Failed to connect to image service: {e}")
# --- Usage & Health Router ---
@usage_router.get("/usage", response_class=HTMLResponse)
async def get_usage_dashboard(days: int = Query(7, ge=1, le=30)):
summary = usage_tracker.get_usage_summary(days=days)
# The generate_usage_html function from the previous version can be used here directly
# It has been moved to a separate file or helper for cleanliness in a real app
# For this example, it's defined below for completeness.
from usage_dashboard_generator import generate_usage_html
return HTMLResponse(content=generate_usage_html(summary))
@usage_router.get("/health")
async def health_check():
return {"status": "healthy" if server_status["online"] else "unhealthy", "version": app.version}
@usage_router.get("/models")
async def get_models():
try:
with open(Path(__file__).parent / 'models.json', 'r') as f:
return json.load(f)
except Exception:
raise HTTPException(status_code=500, detail="models.json not found or is invalid.")
# --- Utility & Pages Router ---
@lru_cache(maxsize=10)
def read_static_file(file_path):
try:
with open(Path(__file__).parent / file_path, "r", encoding="utf-8") as file:
return file.read()
except FileNotFoundError:
return None
@utility_router.get("/", response_class=HTMLResponse)
async def root_page():
return HTMLResponse(content=read_static_file("index.html") or "<h1>Not Found</h1>")
@utility_router.get("/playground", response_class=HTMLResponse)
async def playground_page():
return HTMLResponse(content=read_static_file("playground.html") or "<h1>Not Found</h1>")
@utility_router.get("/image-playground", response_class=HTMLResponse)
async def image_playground_page():
return HTMLResponse(content=read_static_file("image-playground.html") or "<h1>Not Found</h1>")
@utility_router.get("/scraper", response_class=PlainTextResponse)
async def scrape_url(url: str = Query(..., description="URL to scrape")):
if not cloudscraper:
raise HTTPException(status_code=501, detail="Scraper library not installed.")
try:
scraper = cloudscraper.create_scraper()
response = scraper.get(url)
response.raise_for_status()
return PlainTextResponse(content=response.text)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to scrape URL: {e}")
# --- Main Application Setup ---
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include all the organized routers
app.include_router(chat_router, prefix="/api/v1")
app.include_router(chat_router) # For legacy /chat/completions
app.include_router(image_router, prefix="/api/v1")
app.include_router(image_router) # For legacy /images/generations
app.include_router(usage_router)
app.include_router(utility_router)
@app.on_event("startup")
async def startup_event():
# Pre-load settings and client to catch config errors early
try:
get_settings()
except Exception as e:
print(f"FATAL: Could not load settings from environment variables. Error: {e}")
# In a real app, you might want to exit here
get_http_client()
print("--- LokiAI Server Started ---")
print(f"Version: {app.version}")
print("Usage tracking is active and will save data periodically.")
@app.on_event("shutdown")
async def shutdown_event():
client = get_http_client()
await client.aclose()
usage_tracker.save_data()
print("--- LokiAI Server Shutdown Complete ---")
# Helper for usage dashboard - in a real project, this would be in its own file
# I'm creating it here to make the example self-contained
if not (Path(__file__).parent / "usage_dashboard_generator.py").exists():
with open(Path(__file__).parent / "usage_dashboard_generator.py", "w") as f:
f.write('''
import json
import datetime
def generate_usage_html(usage_data: dict) -> str:
model_labels = json.dumps(list(usage_data['model_usage'].keys()))
model_values = json.dumps(list(usage_data['model_usage'].values()))
daily_labels = json.dumps(list(usage_data['daily_usage'].keys()))
daily_values = json.dumps([v['requests'] for v in usage_data['daily_usage'].values()])
recent_requests_rows = "".join([
f"""<tr>
<td>{datetime.datetime.fromisoformat(req['timestamp']).strftime('%Y-%m-%d %H:%M:%S')}</td>
<td>{req['model']}</td>
<td>{req['endpoint']}</td>
<td>{req['ip_address']}</td>
</tr>""" for req in usage_data['recent_requests']
])
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>LokiAI - Usage Statistics</title>
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap" rel="stylesheet">
<style>
body {{ font-family: 'Inter', sans-serif; background-color: #0B0F19; color: #E0E0E0; margin: 0; padding: 20px; }}
.container {{ max-width: 1400px; margin: auto; }}
h1, h2 {{ color: #FFFFFF; }}
.header {{ text-align: center; margin-bottom: 40px; }}
.header h1 {{ font-size: 3em; font-weight: 700; }}
.stats-grid {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); gap: 20px; margin-bottom: 40px; }}
.chart-grid {{ display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin-bottom: 40px; }}
.stat-card, .chart-container, .table-container {{ background: #1A2035; padding: 25px; border-radius: 12px; border: 1px solid #2A3045; }}
.stat-card h3 {{ margin-top: 0; color: #8E95A9; font-size: 1em; font-weight: 600; text-transform: uppercase; }}
.stat-card .value {{ font-size: 2.5em; font-weight: 700; color: #FFFFFF; }}
table {{ width: 100%; border-collapse: collapse; }}
th, td {{ padding: 14px; text-align: left; border-bottom: 1px solid #2A3045; }}
th {{ background-color: #2A3045; font-weight: 600; }}
@media (max-width: 768px) {{ .chart-grid {{ grid-template-columns: 1fr; }} }}
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>LokiAI Usage Dashboard</h1></div>
<div class="stats-grid">
<div class="stat-card"><h3>Total Requests</h3><p class="value">{usage_data['total_requests']}</p></div>
<div class="stat-card"><h3>Unique IPs (All Time)</h3><p class="value">{usage_data['unique_ip_count']}</p></div>
<div class="stat-card"><h3>Models Used (Last 7 Days)</h3><p class="value">{len(usage_data['model_usage'])}</p></div>
</div>
<div class="chart-grid">
<div class="chart-container"><canvas id="dailyUsageChart"></canvas></div>
<div class="chart-container"><canvas id="modelUsageChart"></canvas></div>
</div>
<div class="table-container">
<h2>Recent Requests</h2>
<table>
<thead><tr><th>Timestamp (UTC)</th><th>Model</th><th>Endpoint</th><th>IP Address</th></tr></thead>
<tbody>{recent_requests_rows}</tbody>
</table>
</div>
</div>
<script>
const chartOptions = (ticksColor, gridColor) => ({{
plugins: {{ legend: {{ labels: {{ color: ticksColor }} }} }},
scales: {{
y: {{ ticks: {{ color: ticksColor }}, grid: {{ color: gridColor }} }},
x: {{ ticks: {{ color: ticksColor }}, grid: {{ color: 'transparent' }} }}
}}
}});
new Chart(document.getElementById('dailyUsageChart'), {{
type: 'line',
data: {{ labels: {daily_labels}, datasets: [{{ label: 'Requests per Day', data: {daily_values}, borderColor: '#3a6ee0', tension: 0.1, backgroundColor: 'rgba(58, 110, 224, 0.2)', fill: true }}] }},
options: chartOptions('#E0E0E0', '#2A3045')
}});
new Chart(document.getElementById('modelUsageChart'), {{
type: 'doughnut',
data: {{ labels: {model_labels}, datasets: [{{ label: 'Model Usage', data: {model_values}, backgroundColor: ['#3A6EE0', '#E94F37', '#44AF69', '#F4D35E', '#A06CD5'] }}] }},
options: {{ plugins: {{ legend: {{ position: 'right', labels: {{ color: '#E0E0E0' }} }} }} }}
}});
</script>
</body>
</html>
"""
''')
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)