Spaces:
Running
Running
import os | |
import json | |
import datetime | |
import asyncio | |
import re | |
from functools import lru_cache | |
from pathlib import Path | |
from typing import List, Dict, Any, Tuple, Optional | |
import httpx | |
import uvicorn | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, HTTPException, Request, Depends, Security, Query, APIRouter | |
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse, FileResponse, PlainTextResponse | |
from fastapi.security import APIKeyHeader | |
from pydantic_settings import BaseSettings | |
from pydantic import BaseModel, Field | |
from starlette.middleware.cors import CORSMiddleware | |
from starlette.middleware.gzip import GZipMiddleware | |
from starlette.status import HTTP_403_FORBIDDEN, HTTP_503_SERVICE_UNAVAILABLE | |
# Use cloudscraper for specific endpoints that need it | |
try: | |
import cloudscraper | |
except ImportError: | |
cloudscraper = None | |
from usage_tracker import UsageTracker | |
# --- Initial Setup --- | |
load_dotenv() | |
# Use uvloop for better performance if available | |
try: | |
import uvloop | |
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | |
except ImportError: | |
pass | |
# --- Configuration Management using Pydantic --- | |
class Settings(BaseSettings): | |
"""Manages all application settings and environment variables in one place.""" | |
api_keys: List[str] = Field(..., env="API_KEYS") | |
# Endpoints for various model providers | |
secret_api_endpoint: str = Field(..., env="SECRET_API_ENDPOINT") | |
secret_api_endpoint_2: str = Field(..., env="SECRET_API_ENDPOINT_2") | |
secret_api_endpoint_3: str = Field(..., env="SECRET_API_ENDPOINT_3") | |
secret_api_endpoint_4: str = "https://text.pollinations.ai/openai" | |
secret_api_endpoint_5: str = Field(..., env="SECRET_API_ENDPOINT_5") | |
secret_api_endpoint_6: str = Field(..., env="SECRET_API_ENDPOINT_6") | |
# Specific provider keys and APIs | |
mistral_api: str = "https://api.mistral.ai" | |
mistral_key: str = Field(..., env="MISTRAL_KEY") | |
gemini_key: str = Field(..., env="GEMINI_KEY") | |
new_img_api: str = Field(..., env="NEW_IMG") | |
endpoint_origin: Optional[str] = Field(None, env="ENDPOINT_ORIGIN") | |
header_url: Optional[str] = Field(None, env="HEADER_URL") | |
class Config: | |
env_file = '.env' | |
env_file_encoding = 'utf-8' | |
def get_settings(): | |
return Settings() | |
# --- Pydantic Models for Payloads --- | |
class ChatPayload(BaseModel): | |
model: str | |
messages: List[Dict[str, Any]] | |
stream: bool = False | |
class ImageGenerationPayload(BaseModel): | |
model: str | |
prompt: str | |
size: int | |
number: int | |
# --- Global Objects & State --- | |
app = FastAPI( | |
title="LokiAI API", | |
version="2.5.0", | |
description="A robust and scalable API proxy for various AI models, now fully rewritten.", | |
) | |
usage_tracker = UsageTracker() | |
api_key_header = APIKeyHeader(name="Authorization", auto_error=False) | |
server_status = {"online": True} | |
# --- Model & API Configuration --- | |
MODEL_SETS = { | |
"mistral": {"mistral-large-latest", "codestral-latest", "mistral-small-latest"}, | |
"pollinations": {"openai", "gemini", "phi", "llama"}, | |
"alternate": {"o1", "grok-3", "sonar-pro"}, | |
"claude": {"claude-3-7-sonnet", "claude 3.5 sonnet", "o3-mini-medium"}, | |
"gemini": {"gemini-1.5-pro", "gemini-1.5-flash", "gemini-2.0-flash"}, | |
"image": {"Flux Pro Ultra", "dall-e-3", "stable-diffusion-3-large-turbo"}, | |
} | |
def get_api_details(model_name: str, settings: Settings) -> Tuple[str, Dict, str]: | |
"""Returns the endpoint, headers, and path for a given model.""" | |
if model_name in MODEL_SETS["mistral"]: | |
return settings.mistral_api, {"Authorization": f"Bearer {settings.mistral_key}"}, "/v1/chat/completions" | |
if model_name in MODEL_SETS["gemini"]: | |
return settings.secret_api_endpoint_6, {"Authorization": f"Bearer {settings.gemini_key}"}, "/chat/completions" | |
if model_name in MODEL_SETS["pollinations"]: | |
return settings.secret_api_endpoint_4, {}, "/v1/chat/completions" | |
if model_name in MODEL_SETS["claude"]: | |
return settings.secret_api_endpoint_5, {}, "/v1/chat/completions" | |
if model_name in MODEL_SETS["alternate"]: | |
return settings.secret_api_endpoint_2, {}, "/v1/chat/completions" | |
if model_name in MODEL_SETS["image"]: | |
return settings.new_img_api, {}, "" | |
# Default case | |
headers = { | |
"Origin": settings.header_url, "Referer": settings.header_url | |
} if settings.header_url else {} | |
return settings.secret_api_endpoint, headers, "/v1/chat/completions" | |
# --- Dependencies & Security --- | |
async def get_api_key(request: Request, api_key: str = Security(api_key_header)): | |
"""Validates the API key, allowing specific referers to bypass.""" | |
referer = request.headers.get("referer", "") | |
if referer and "parthsadaria-lokiai.hf.space" in referer: | |
return "hf_space_bypass" | |
settings = get_settings() | |
if not api_key or not api_key.startswith("Bearer "): | |
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Invalid authorization format.") | |
key = api_key.split(" ")[1] | |
if key not in settings.api_keys: | |
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Invalid API key.") | |
return key | |
def get_http_client() -> httpx.AsyncClient: | |
return httpx.AsyncClient(timeout=60.0, limits=httpx.Limits(max_connections=200)) | |
# --- API Routers --- | |
chat_router = APIRouter(tags=["AI Models"]) | |
image_router = APIRouter(tags=["AI Models"]) | |
usage_router = APIRouter(tags=["Server Administration"]) | |
utility_router = APIRouter(tags=["Utilities & Pages"]) | |
# --- Chat Completions Router --- | |
async def chat_completions( | |
payload: ChatPayload, | |
request: Request, | |
api_key: str = Depends(get_api_key), | |
client: httpx.AsyncClient = Depends(get_http_client) | |
): | |
if not server_status["online"]: | |
raise HTTPException(status_code=HTTP_503_SERVICE_UNAVAILABLE, detail="Server under maintenance.") | |
settings = get_settings() | |
usage_tracker.record_request(request, payload.model, "/chat/completions") | |
endpoint, headers, path = get_api_details(payload.model, settings) | |
async def stream_generator(): | |
try: | |
async with client.stream("POST", f"{endpoint}{path}", json=payload.dict(), headers=headers) as response: | |
response.raise_for_status() | |
async for chunk in response.aiter_bytes(): | |
yield chunk | |
except httpx.HTTPStatusError as e: | |
print(f"Upstream error: {e.response.status_code} - {e.response.text}") | |
yield json.dumps({"error": {"code": 502, "message": "Bad Gateway: Upstream service error."}}).encode() | |
except Exception as e: | |
print(f"Streaming error: {e}") | |
yield json.dumps({"error": {"code": 500, "message": "An internal error occurred."}}).encode() | |
return StreamingResponse(stream_generator(), media_type="text/event-stream") | |
# --- Image Generation Router --- | |
async def images_generations( | |
payload: ImageGenerationPayload, | |
request: Request, | |
api_key: str = Depends(get_api_key), | |
client: httpx.AsyncClient = Depends(get_http_client) | |
): | |
if not server_status["online"]: | |
raise HTTPException(status_code=HTTP_503_SERVICE_UNAVAILABLE, detail="Server under maintenance.") | |
if payload.model not in MODEL_SETS["image"]: | |
raise HTTPException(status_code=400, detail=f"Image model '{payload.model}' not supported.") | |
settings = get_settings() | |
usage_tracker.record_request(request, payload.model, "/images/generations") | |
endpoint, headers, _ = get_api_details(payload.model, settings) | |
try: | |
response = await client.post(endpoint, json=payload.dict(), headers=headers) | |
response.raise_for_status() | |
return JSONResponse(content=response.json()) | |
except httpx.HTTPStatusError as e: | |
raise HTTPException(status_code=e.response.status_code, detail=e.response.json().get("detail", "Upstream error")) | |
except httpx.RequestError as e: | |
raise HTTPException(status_code=502, detail=f"Failed to connect to image service: {e}") | |
# --- Usage & Health Router --- | |
async def get_usage_dashboard(days: int = Query(7, ge=1, le=30)): | |
summary = usage_tracker.get_usage_summary(days=days) | |
# The generate_usage_html function from the previous version can be used here directly | |
# It has been moved to a separate file or helper for cleanliness in a real app | |
# For this example, it's defined below for completeness. | |
from usage_dashboard_generator import generate_usage_html | |
return HTMLResponse(content=generate_usage_html(summary)) | |
async def health_check(): | |
return {"status": "healthy" if server_status["online"] else "unhealthy", "version": app.version} | |
async def get_models(): | |
try: | |
with open(Path(__file__).parent / 'models.json', 'r') as f: | |
return json.load(f) | |
except Exception: | |
raise HTTPException(status_code=500, detail="models.json not found or is invalid.") | |
# --- Utility & Pages Router --- | |
def read_static_file(file_path): | |
try: | |
with open(Path(__file__).parent / file_path, "r", encoding="utf-8") as file: | |
return file.read() | |
except FileNotFoundError: | |
return None | |
async def root_page(): | |
return HTMLResponse(content=read_static_file("index.html") or "<h1>Not Found</h1>") | |
async def playground_page(): | |
return HTMLResponse(content=read_static_file("playground.html") or "<h1>Not Found</h1>") | |
async def image_playground_page(): | |
return HTMLResponse(content=read_static_file("image-playground.html") or "<h1>Not Found</h1>") | |
async def scrape_url(url: str = Query(..., description="URL to scrape")): | |
if not cloudscraper: | |
raise HTTPException(status_code=501, detail="Scraper library not installed.") | |
try: | |
scraper = cloudscraper.create_scraper() | |
response = scraper.get(url) | |
response.raise_for_status() | |
return PlainTextResponse(content=response.text) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to scrape URL: {e}") | |
# --- Main Application Setup --- | |
app.add_middleware(GZipMiddleware, minimum_size=1000) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Include all the organized routers | |
app.include_router(chat_router, prefix="/api/v1") | |
app.include_router(chat_router) # For legacy /chat/completions | |
app.include_router(image_router, prefix="/api/v1") | |
app.include_router(image_router) # For legacy /images/generations | |
app.include_router(usage_router) | |
app.include_router(utility_router) | |
async def startup_event(): | |
# Pre-load settings and client to catch config errors early | |
try: | |
get_settings() | |
except Exception as e: | |
print(f"FATAL: Could not load settings from environment variables. Error: {e}") | |
# In a real app, you might want to exit here | |
get_http_client() | |
print("--- LokiAI Server Started ---") | |
print(f"Version: {app.version}") | |
print("Usage tracking is active and will save data periodically.") | |
async def shutdown_event(): | |
client = get_http_client() | |
await client.aclose() | |
usage_tracker.save_data() | |
print("--- LokiAI Server Shutdown Complete ---") | |
# Helper for usage dashboard - in a real project, this would be in its own file | |
# I'm creating it here to make the example self-contained | |
if not (Path(__file__).parent / "usage_dashboard_generator.py").exists(): | |
with open(Path(__file__).parent / "usage_dashboard_generator.py", "w") as f: | |
f.write(''' | |
import json | |
import datetime | |
def generate_usage_html(usage_data: dict) -> str: | |
model_labels = json.dumps(list(usage_data['model_usage'].keys())) | |
model_values = json.dumps(list(usage_data['model_usage'].values())) | |
daily_labels = json.dumps(list(usage_data['daily_usage'].keys())) | |
daily_values = json.dumps([v['requests'] for v in usage_data['daily_usage'].values()]) | |
recent_requests_rows = "".join([ | |
f"""<tr> | |
<td>{datetime.datetime.fromisoformat(req['timestamp']).strftime('%Y-%m-%d %H:%M:%S')}</td> | |
<td>{req['model']}</td> | |
<td>{req['endpoint']}</td> | |
<td>{req['ip_address']}</td> | |
</tr>""" for req in usage_data['recent_requests'] | |
]) | |
return f""" | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>LokiAI - Usage Statistics</title> | |
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script> | |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap" rel="stylesheet"> | |
<style> | |
body {{ font-family: 'Inter', sans-serif; background-color: #0B0F19; color: #E0E0E0; margin: 0; padding: 20px; }} | |
.container {{ max-width: 1400px; margin: auto; }} | |
h1, h2 {{ color: #FFFFFF; }} | |
.header {{ text-align: center; margin-bottom: 40px; }} | |
.header h1 {{ font-size: 3em; font-weight: 700; }} | |
.stats-grid {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); gap: 20px; margin-bottom: 40px; }} | |
.chart-grid {{ display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin-bottom: 40px; }} | |
.stat-card, .chart-container, .table-container {{ background: #1A2035; padding: 25px; border-radius: 12px; border: 1px solid #2A3045; }} | |
.stat-card h3 {{ margin-top: 0; color: #8E95A9; font-size: 1em; font-weight: 600; text-transform: uppercase; }} | |
.stat-card .value {{ font-size: 2.5em; font-weight: 700; color: #FFFFFF; }} | |
table {{ width: 100%; border-collapse: collapse; }} | |
th, td {{ padding: 14px; text-align: left; border-bottom: 1px solid #2A3045; }} | |
th {{ background-color: #2A3045; font-weight: 600; }} | |
@media (max-width: 768px) {{ .chart-grid {{ grid-template-columns: 1fr; }} }} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="header"><h1>LokiAI Usage Dashboard</h1></div> | |
<div class="stats-grid"> | |
<div class="stat-card"><h3>Total Requests</h3><p class="value">{usage_data['total_requests']}</p></div> | |
<div class="stat-card"><h3>Unique IPs (All Time)</h3><p class="value">{usage_data['unique_ip_count']}</p></div> | |
<div class="stat-card"><h3>Models Used (Last 7 Days)</h3><p class="value">{len(usage_data['model_usage'])}</p></div> | |
</div> | |
<div class="chart-grid"> | |
<div class="chart-container"><canvas id="dailyUsageChart"></canvas></div> | |
<div class="chart-container"><canvas id="modelUsageChart"></canvas></div> | |
</div> | |
<div class="table-container"> | |
<h2>Recent Requests</h2> | |
<table> | |
<thead><tr><th>Timestamp (UTC)</th><th>Model</th><th>Endpoint</th><th>IP Address</th></tr></thead> | |
<tbody>{recent_requests_rows}</tbody> | |
</table> | |
</div> | |
</div> | |
<script> | |
const chartOptions = (ticksColor, gridColor) => ({{ | |
plugins: {{ legend: {{ labels: {{ color: ticksColor }} }} }}, | |
scales: {{ | |
y: {{ ticks: {{ color: ticksColor }}, grid: {{ color: gridColor }} }}, | |
x: {{ ticks: {{ color: ticksColor }}, grid: {{ color: 'transparent' }} }} | |
}} | |
}}); | |
new Chart(document.getElementById('dailyUsageChart'), {{ | |
type: 'line', | |
data: {{ labels: {daily_labels}, datasets: [{{ label: 'Requests per Day', data: {daily_values}, borderColor: '#3a6ee0', tension: 0.1, backgroundColor: 'rgba(58, 110, 224, 0.2)', fill: true }}] }}, | |
options: chartOptions('#E0E0E0', '#2A3045') | |
}}); | |
new Chart(document.getElementById('modelUsageChart'), {{ | |
type: 'doughnut', | |
data: {{ labels: {model_labels}, datasets: [{{ label: 'Model Usage', data: {model_values}, backgroundColor: ['#3A6EE0', '#E94F37', '#44AF69', '#F4D35E', '#A06CD5'] }}] }}, | |
options: {{ plugins: {{ legend: {{ position: 'right', labels: {{ color: '#E0E0E0' }} }} }} }} | |
}}); | |
</script> | |
</body> | |
</html> | |
""" | |
''') | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |