lokiai / main.py
ParthSadaria's picture
Update main.py
925b0de verified
raw
history blame
59.9 kB
import os
import re
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, Depends, Security, Query
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse, PlainTextResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
import httpx
from functools import lru_cache
from pathlib import Path
import json
import datetime
import time
import threading
from typing import Optional, Dict, List, Any, Generator
import asyncio
from starlette.status import HTTP_403_FORBIDDEN
import cloudscraper
from concurrent.futures import ThreadPoolExecutor
import uvloop
from fastapi.middleware.gzip import GZipMiddleware
from starlette.middleware.cors import CORSMiddleware
import contextlib
import requests
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
executor = ThreadPoolExecutor(max_workers=16)
load_dotenv()
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
from usage_tracker import UsageTracker
usage_tracker = UsageTracker()
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@lru_cache(maxsize=1)
def get_env_vars():
"""
Loads and caches environment variables. This function is memoized
to avoid re-reading .env file on every call, improving performance.
"""
return {
'api_keys': os.getenv('API_KEYS', '').split(','),
'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'),
'secret_api_endpoint_4': os.getenv('SECRET_API_ENDPOINT_4', "https://text.pollinations.ai/openai"),
'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'),
'secret_api_endpoint_6': os.getenv('SECRET_API_ENDPOINT_6'), # New endpoint for Gemini
'mistral_api': os.getenv('MISTRAL_API', "https://api.mistral.ai"),
'mistral_key': os.getenv('MISTRAL_KEY'),
'gemini_key': os.getenv('GEMINI_KEY'), # Gemini API Key
'endpoint_origin': os.getenv('ENDPOINT_ORIGIN'),
'new_img': os.getenv('NEW_IMG') # For image generation API
}
# Define sets of models for different API endpoints for easier routing
mistral_models = {
"mistral-large-latest", "pixtral-large-latest", "mistral-moderation-latest",
"ministral-3b-latest", "ministral-8b-latest", "open-mistral-nemo",
"mistral-small-latest", "mistral-saba-latest", "codestral-latest"
}
pollinations_models = {
"openai", "openai-large", "openai-fast", "openai-xlarge", "openai-reasoning",
"qwen-coder", "llama", "mistral", "searchgpt", "deepseek", "claude-hybridspace",
"deepseek-r1", "deepseek-reasoner", "llamalight", "gemini", "gemini-thinking",
"hormoz", "phi", "phi-mini", "openai-audio", "llama-scaleway"
}
alternate_models = {
"o1", "llama-4-scout", "o4-mini", "sonar", "sonar-pro", "sonar-reasoning",
"sonar-reasoning-pro", "grok-3", "grok-3-fast", "r1-1776", "o3"
}
claude_3_models = {
"claude-3-7-sonnet", "claude-3-7-sonnet-thinking", "claude 3.5 haiku",
"claude 3.5 sonnet", "claude 3.5 haiku", "o3-mini-medium", "o3-mini-high",
"grok-3", "grok-3-thinking", "grok 2"
}
gemini_models = {
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-2.0-flash-lite-preview",
"gemini-2.0-flash", "gemini-2.0-flash-thinking", # aka Reasoning
"gemini-2.0-flash-preview-image-generation", "gemini-2.5-flash",
"gemini-2.5-pro-exp", "gemini-exp-1206"
}
supported_image_models = {
"Flux Pro Ultra", "grok-2-aurora", "Flux Pro", "Flux Pro Ultra Raw",
"Flux Dev", "Flux Schnell", "stable-diffusion-3-large-turbo",
"Flux Realism", "stable-diffusion-ultra", "dall-e-3", "sdxl-lightning-4step"
}
class Payload(BaseModel):
"""Pydantic model for chat completion requests."""
model: str
messages: list
stream: bool = False
class ImageGenerationPayload(BaseModel):
"""Pydantic model for image generation requests."""
model: str
prompt: str
size: str = "1024x1024" # Default size, assuming models support it
number: int = 1
server_status = True # Global flag for server maintenance status
available_model_ids: List[str] = [] # List of all available model IDs
@lru_cache(maxsize=1)
def get_async_client():
"""Returns a memoized httpx.AsyncClient instance for making async HTTP requests."""
return httpx.AsyncClient(
timeout=60.0,
limits=httpx.Limits(max_keepalive_connections=50, max_connections=200)
)
scraper_pool = []
MAX_SCRAPERS = 20
def get_scraper():
"""Retrieves a cloudscraper instance from a pool for web scraping."""
if not scraper_pool:
# Initialize the pool if it's empty (should be done at startup)
for _ in range(MAX_SCRAPERS):
scraper_pool.append(cloudscraper.create_scraper())
# Simple round-robin selection from the pool
return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS]
async def verify_api_key(
request: Request,
api_key: str = Security(api_key_header)
) -> bool:
"""
Verifies the API key provided in the Authorization header.
Allows access without API key if the request comes from specific Hugging Face spaces.
"""
referer = request.headers.get("referer", "")
if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground",
"https://parthsadaria-lokiai.hf.space/image-playground")):
return True
if not api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="No API key provided"
)
if api_key.startswith('Bearer '):
api_key = api_key[7:]
valid_api_keys = get_env_vars().get('api_keys', [])
if not valid_api_keys or valid_api_keys == ['']:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="API keys not configured on server"
)
if api_key not in set(valid_api_keys):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid API key"
)
return True
@lru_cache(maxsize=1)
def load_models_data():
"""Loads model data from 'models.json' and caches it."""
try:
file_path = Path(__file__).parent / 'models.json'
with open(file_path, 'r') as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading models.json: {str(e)}")
return []
@app.get("/api/v1/models")
@app.get("/models")
async def get_models():
"""Returns the list of available models."""
models_data = load_models_data()
if not models_data:
raise HTTPException(status_code=500, detail="Error loading available models")
return models_data
async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
"""
Asynchronously generates a response using a search-based model.
Streams results if `stream` is True.
"""
queue = asyncio.Queue()
async def _fetch_search_data():
"""Internal helper to fetch data from the search API and put into queue."""
try:
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
system_message = systemprompt or "Be Helpful and Friendly"
prompt = [{"role": "user", "content": query}]
prompt.insert(0, {"content": system_message, "role": "system"})
payload = {
"is_vscode_extension": True,
"message_history": prompt,
"requested_model": "searchgpt",
"user_input": prompt[-1]["content"],
}
secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
if not secret_api_endpoint_3:
await queue.put({"error": "Search API endpoint not configured"})
return
async with httpx.AsyncClient(timeout=30.0) as client:
async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response:
if response.status_code != 200:
error_detail = await response.text()
await queue.put({"error": f"Search API returned status code {response.status_code}: {error_detail}"})
return
buffer = ""
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
json_data = json.loads(line[6:])
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content.strip():
cleaned_response = {
"created": json_data.get("created"),
"id": json_data.get("id"),
"model": "searchgpt",
"object": "chat.completion",
"choices": [
{
"message": {
"content": content
}
}
]
}
await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content})
except json.JSONDecodeError:
# If line is not valid JSON, treat it as raw text and pass through if it's the end of stream
if line.strip() == "[DONE]":
continue # This is usually handled by the aiter_lines loop finishing
print(f"Warning: Could not decode JSON from search API stream: {line}")
await queue.put({"error": f"Invalid JSON from search API: {line}"})
break # Stop processing on bad JSON
await queue.put(None) # Signal end of stream
except Exception as e:
print(f"Error in _fetch_search_data: {e}")
await queue.put({"error": str(e)})
await queue.put(None)
asyncio.create_task(_fetch_search_data())
return queue
@lru_cache(maxsize=10)
def read_html_file(file_path):
"""Reads content of an HTML file and caches it."""
try:
with open(file_path, "r") as file:
return file.read()
except FileNotFoundError:
return None
# Static file routes for basic web assets
@app.get("/favicon.ico")
async def favicon():
favicon_path = Path(__file__).parent / "favicon.ico"
return FileResponse(favicon_path, media_type="image/x-icon")
@app.get("/banner.jpg")
async def banner():
banner_path = Path(__file__).parent / "banner.jpg"
return FileResponse(banner_path, media_type="image/jpeg")
@app.get("/ping")
async def ping():
"""Simple health check endpoint."""
return {"message": "pong", "response_time": "0.000000 seconds"}
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serves the main index.html file."""
html_content = read_html_file("index.html")
if html_content is None:
raise HTTPException(status_code=404, detail="index.html not found")
return HTMLResponse(content=html_content)
@app.get("/script.js", response_class=HTMLResponse)
async def script():
"""Serves script.js."""
html_content = read_html_file("script.js")
if html_content is None:
raise HTTPException(status_code=404, detail="script.js not found")
return HTMLResponse(content=html_content)
@app.get("/style.css", response_class=HTMLResponse)
async def style():
"""Serves style.css."""
html_content = read_html_file("style.css")
if html_content is None:
raise HTTPException(status_code=404, detail="style.css not found")
return HTMLResponse(content=html_content)
@app.get("/dynamo", response_class=HTMLResponse)
async def dynamic_ai_page(request: Request):
"""
Generates a dynamic HTML page using an AI model based on user-agent and IP.
Note: The hardcoded API endpoint and bearer token should ideally be managed
more securely, perhaps via environment variables and proper authentication.
"""
user_agent = request.headers.get('user-agent', 'Unknown User')
client_ip = request.client.host if request.client else "Unknown IP"
location = f"IP: {client_ip}"
prompt = f"""
Generate a dynamic HTML page for a user with the following details: with name "LOKI.AI"
- User-Agent: {user_agent}
- Location: {location}
- Style: Cyberpunk, minimalist, or retro
Make sure the HTML is clean and includes a heading, also have cool animations a motivational message, and a cool background.
Wrap the generated HTML in triple backticks (```).
"""
payload = {
"model": "mistral-small-latest",
"messages": [{"role": "user", "content": prompt}]
}
# Using the local /chat/completions endpoint for internal model call
# This assumes the current server can proxy to Mistral.
# For production, consider direct calls if not proxying is needed.
headers = {
"Authorization": "Bearer playground" # Use a dedicated internal token if available
}
try:
# Use httpx.AsyncClient for making an async request
async with httpx.AsyncClient() as client:
response = await client.post(
f"http://localhost:7860/chat/completions", # Call self or internal API
json=payload,
headers=headers,
timeout=30.0
)
response.raise_for_status() # Raise an exception for bad status codes
data = response.json()
html_content = None
if data and 'choices' in data and len(data['choices']) > 0:
message_content = data['choices'][0].get('message', {}).get('content', '')
# Extract content within triple backticks
match = re.search(r"```(?:html)?(.*?)```", message_content, re.DOTALL)
if match:
html_content = match.group(1).strip()
else:
# Fallback: if no backticks, assume the whole content is HTML
html_content = message_content.strip()
if not html_content:
raise HTTPException(status_code=500, detail="Failed to generate HTML content from AI.")
return HTMLResponse(content=html_content)
except httpx.RequestError as e:
print(f"HTTPX Request Error in /dynamo: {e}")
raise HTTPException(status_code=500, detail=f"Failed to connect to internal AI service: {e}")
except httpx.HTTPStatusError as e:
print(f"HTTPX Status Error in /dynamo: {e.response.status_code} - {e.response.text}")
raise HTTPException(status_code=e.response.status_code, detail=f"Internal AI service responded with error: {e.response.text}")
except Exception as e:
print(f"An unexpected error occurred in /dynamo: {e}")
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
@app.get("/scraper", response_class=PlainTextResponse)
async def scrape_site(url: str = Query(..., description="URL to scrape")):
"""
Scrapes the content of a given URL using cloudscraper.
Uses await in front of get_scraper().get() for async execution.
"""
try:
# get_scraper() returns a synchronous scraper object, but we are running
# it in an async endpoint. For CPU-bound tasks like this, it's better
# to offload to a thread pool to not block the event loop.
# However, cloudscraper's get method is typically synchronous.
# If cloudscraper were truly async, we'd use await.
# For now, running in executor to prevent blocking.
loop = asyncio.get_running_loop()
response_text = await loop.run_in_executor(
executor,
lambda: get_scraper().get(url).text
)
if response_text and len(response_text.strip()) > 0:
return PlainTextResponse(response_text)
else:
raise HTTPException(status_code=500, detail="Scraping returned empty content.")
except Exception as e:
print(f"Cloudscraper failed: {e}")
raise HTTPException(status_code=500, detail=f"Cloudscraper failed: {e}")
@app.get("/playground", response_class=HTMLResponse)
async def playground():
"""Serves the playground.html file."""
html_content = read_html_file("playground.html")
if html_content is None:
raise HTTPException(status_code=404, detail="playground.html not found")
return HTMLResponse(content=html_content)
@app.get("/image-playground", response_class=HTMLResponse)
async def image_playground():
"""Serves the image-playground.html file."""
html_content = read_html_file("image-playground.html")
if html_content is None:
raise HTTPException(status_code=404, detail="image-playground.html not found")
return HTMLResponse(content=html_content)
GITHUB_BASE = "[https://raw.githubusercontent.com/Parthsadaria/Vetra/main](https://raw.githubusercontent.com/Parthsadaria/Vetra/main)"
FILES = {
"html": "index.html",
"css": "style.css",
"js": "script.js"
}
async def get_github_file(filename: str) -> Optional[str]:
"""Fetches a file from a specified GitHub raw URL."""
url = f"{GITHUB_BASE}/{filename}"
async with httpx.AsyncClient() as client:
try:
res = await client.get(url, follow_redirects=True)
res.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
return res.text
except httpx.HTTPStatusError as e:
print(f"Error fetching {filename} from GitHub: {e.response.status_code} - {e.response.text}")
return None
except httpx.RequestError as e:
print(f"Request error fetching {filename} from GitHub: {e}")
return None
@app.get("/vetra", response_class=HTMLResponse)
async def serve_vetra():
"""
Serves a dynamic HTML page by fetching HTML, CSS, and JS from GitHub
and embedding them into a single HTML response.
"""
html = await get_github_file(FILES["html"])
css = await get_github_file(FILES["css"])
js = await get_github_file(FILES["js"])
if not html:
raise HTTPException(status_code=404, detail="index.html not found on GitHub")
final_html = html.replace(
"</head>",
f"<style>{css or '/* CSS not found */'}</style></head>"
).replace(
"</body>",
f"<script>{js or '// JS not found'}</script></body>"
)
return HTMLResponse(content=final_html)
@app.get("/searchgpt")
async def search_gpt(q: str, request: Request, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
"""
Endpoint for search-based AI completion.
Records usage and streams results.
"""
if not q:
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
# Record usage for searchgpt endpoint
usage_tracker.record_request(request=request, model="searchgpt", endpoint="/searchgpt")
queue = await generate_search_async(q, systemprompt=systemprompt, stream=True)
if stream:
async def stream_generator():
"""Generator for streaming search results."""
collected_text = ""
while True:
item = await queue.get()
if item is None:
break
if "error" in item:
# Yield error as a data event so client can handle it gracefully
yield f"data: {json.dumps({'error': item['error']})}\n\n"
break
if "data" in item:
yield item["data"]
collected_text += item.get("text", "")
return StreamingResponse(
stream_generator(),
media_type="text/event-stream"
)
else:
# Non-streaming response: collect all chunks and return as JSON
collected_text = ""
while True:
item = await queue.get()
if item is None:
break
if "error" in item:
raise HTTPException(status_code=500, detail=item["error"])
collected_text += item.get("text", "")
return JSONResponse(content={"response": collected_text})
header_url = os.getenv('HEADER_URL') # This variable should be configured in .env
@app.post("/chat/completions")
@app.post("/api/v1/chat/completions")
async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
"""
Proxies chat completion requests to various AI model endpoints based on the model specified in the payload.
Records usage and handles streaming responses.
"""
if not server_status:
raise HTTPException(
status_code=503,
detail="Server is under maintenance. Please try again later."
)
model_to_use = payload.model or "gpt-4o-mini" # Default model
# Validate if the requested model is available
if available_model_ids and model_to_use not in set(available_model_ids):
raise HTTPException(
status_code=400,
detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
)
# Record usage before making the external API call
usage_tracker.record_request(request=request, model=model_to_use, endpoint="/chat/completions")
payload_dict = payload.dict()
payload_dict["model"] = model_to_use # Ensure the payload has the resolved model name
stream_enabled = payload_dict.get("stream", True) # Default to streaming if not specified
env_vars = get_env_vars()
endpoint = None
custom_headers = {}
target_url_path = "/v1/chat/completions" # Default path for OpenAI-like APIs
# Determine the correct endpoint and headers based on the model
if model_to_use in mistral_models:
endpoint = env_vars['mistral_api']
custom_headers = {
"Authorization": f"Bearer {env_vars['mistral_key']}"
}
elif model_to_use in pollinations_models:
endpoint = env_vars['secret_api_endpoint_4']
custom_headers = {} # Pollinations.ai might not require auth
elif model_to_use in alternate_models:
endpoint = env_vars['secret_api_endpoint_2']
custom_headers = {}
elif model_to_use in claude_3_models:
endpoint = env_vars['secret_api_endpoint_5']
custom_headers = {} # Assuming no specific auth needed for this proxy
elif model_to_use in gemini_models:
endpoint = env_vars['secret_api_endpoint_6']
if not endpoint:
raise HTTPException(status_code=500, detail="Gemini API endpoint (SECRET_API_ENDPOINT_6) not configured.")
if not env_vars['gemini_key']:
raise HTTPException(status_code=500, detail="GEMINI_KEY not configured for Gemini models.")
custom_headers = {
"Authorization": f"Bearer {env_vars['gemini_key']}"
}
target_url_path = "/chat/completions" # Gemini's specific path
else:
# Default fallback for other models (e.g., OpenAI compatible APIs)
endpoint = env_vars['secret_api_endpoint']
custom_headers = {
"Origin": header_url,
"Priority": "u=1, i",
"Referer": header_url
}
if not endpoint:
raise HTTPException(status_code=500, detail=f"No API endpoint configured for model: {model_to_use}")
print(f"Proxying request for model '{model_to_use}' to endpoint: {endpoint}{target_url_path}")
async def real_time_stream_generator():
"""Generator to stream responses from the upstream API."""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
# Stream the request to the upstream API
async with client.stream("POST", f"{endpoint}{target_url_path}", json=payload_dict, headers=custom_headers) as response:
# Handle non-2xx responses from the upstream API
if response.status_code >= 400:
error_messages = {
400: "Bad request. Verify input data.",
401: "Unauthorized. Invalid API key for upstream service.",
403: "Forbidden. You do not have access to this resource on upstream.",
404: "The requested resource was not found on upstream.",
422: "Unprocessable entity. Check your payload for upstream API.",
500: "Internal server error from upstream API."
}
detail_message = error_messages.get(response.status_code, f"Upstream error code: {response.status_code}")
# Attempt to read upstream error response body for more detail
try:
error_body = await response.aread()
error_json = json.loads(error_body.decode('utf-8'))
if 'error' in error_json and 'message' in error_json['error']:
detail_message += f" - Upstream detail: {error_json['error']['message']}"
elif 'detail' in error_json:
detail_message += f" - Upstream detail: {error_json['detail']}"
else:
detail_message += f" - Upstream raw: {error_body.decode('utf-8')[:200]}..." # Limit for logging
except (json.JSONDecodeError, UnicodeDecodeError):
detail_message += f" - Upstream raw: {error_body.decode('utf-8', errors='ignore')[:200]}..."
raise HTTPException(status_code=response.status_code, detail=detail_message)
# Yield each line from the upstream stream
async for line in response.aiter_lines():
if line:
yield line + "\n"
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Request to upstream AI service timed out.")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Failed to connect to upstream AI service: {str(e)}")
except Exception as e:
# Re-raise HTTPException if it's already one, otherwise wrap in a 500
if isinstance(e, HTTPException):
raise e
print(f"An unexpected error occurred during chat completion proxy: {e}")
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
if stream_enabled:
return StreamingResponse(
real_time_stream_generator(),
media_type="text/event-stream",
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Disable buffering for SSE
}
)
else:
# For non-streaming requests, collect all parts and return a single JSON response
response_content_lines = []
async for line in real_time_stream_generator():
response_content_lines.append(line)
full_response_text = "".join(response_content_lines)
# Parse the concatenated stream data. This often involves stripping "data: " prefix
# and combining JSON objects from each line.
parsed_data = []
for line in full_response_text.splitlines():
if line.startswith("data: "):
try:
parsed_data.append(json.loads(line[6:]))
except json.JSONDecodeError:
print(f"Warning: Could not decode JSON line in non-streaming response: {line}")
# Attempt to reconstruct a single coherent JSON response
# This logic might need refinement based on actual API response format for non-streaming
final_json_response = {}
if parsed_data:
# Example: For OpenAI-like API, you might want the last 'choices' part
# This is a simplification and might need adjustment for other APIs
if 'choices' in parsed_data[-1]:
final_json_response = parsed_data[-1]
else:
# Fallback: just return the list of parsed objects
final_json_response = {"response_parts": parsed_data}
if not final_json_response:
# If nothing was parsed, indicate an issue
raise HTTPException(status_code=500, detail="No valid JSON response received from upstream API for non-streaming request.")
return JSONResponse(content=final_json_response)
@app.post("/images/generations")
async def create_image(payload: ImageGenerationPayload, request: Request, authenticated: bool = Depends(verify_api_key)):
"""
Proxies image generation requests to a dedicated image generation API.
Records usage.
"""
if not server_status:
raise HTTPException(
status_code=503,
detail="Server is under maintenance. Please try again later."
)
if payload.model not in supported_image_models:
raise HTTPException(
status_code=400,
detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {', '.join(supported_image_models)}"
)
# Record usage for image generation endpoint
usage_tracker.record_request(request=request, model=payload.model, endpoint="/images/generations")
api_payload = {
"model": payload.model,
"prompt": payload.prompt,
"size": payload.size,
"n": payload.number # Often 'n' for number of images in APIs
}
target_api_url = get_env_vars().get('new_img') # Get the image API URL from env vars
if not target_api_url:
raise HTTPException(status_code=500, detail="Image generation API endpoint (NEW_IMG) not configured.")
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(target_api_url, json=api_payload)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
return JSONResponse(content=response.json())
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Image generation request timed out.")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}")
except httpx.HTTPStatusError as e:
error_detail = e.response.json().get("detail", f"Image generation failed with status code: {e.response.status_code}")
raise HTTPException(status_code=e.response.status_code, detail=error_detail)
except Exception as e:
print(f"An unexpected error occurred during image generation: {e}")
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
@app.get("/usage")
async def get_usage_json(days: int = 7):
"""
Returns the raw usage data as JSON.
Can specify the number of days for the summary.
"""
return usage_tracker.get_usage_summary(days)
def generate_usage_html(usage_data: Dict[str, Any]):
"""
Generates an HTML page to display usage statistics.
Includes tables for model, API endpoint usage, daily usage, and recent requests.
Also includes placeholders for Chart.js to render graphs.
"""
# Prepare data for Chart.js
# Model Usage Chart Data
model_labels = list(usage_data['model_usage_period'].keys())
model_counts = list(usage_data['model_usage_period'].values())
# Endpoint Usage Chart Data
endpoint_labels = list(usage_data['endpoint_usage_period'].keys())
endpoint_counts = list(usage_data['endpoint_usage_period'].values())
# Daily Usage Chart Data
daily_dates = list(usage_data['daily_usage_period'].keys())
daily_requests = [data['requests'] for data in usage_data['daily_usage_period'].values()]
daily_unique_ips = [data['unique_ips_count'] for data in usage_data['daily_usage_period'].values()]
# Format table rows for HTML
model_usage_all_time_rows = "\n".join([
f"""
<tr>
<td>{model}</td>
<td>{stats['total_requests']}</td>
<td>{datetime.datetime.fromisoformat(stats['first_used']).strftime("%Y-%m-%d %H:%M")}</td>
<td>{datetime.datetime.fromisoformat(stats['last_used']).strftime("%Y-%m-%d %H:%M")}</td>
</tr>
""" for model, stats in usage_data['all_time_model_usage'].items()
])
api_usage_all_time_rows = "\n".join([
f"""
<tr>
<td>{endpoint}</td>
<td>{stats['total_requests']}</td>
<td>{datetime.datetime.fromisoformat(stats['first_used']).strftime("%Y-%m-%d %H:%M")}</td>
<td>{datetime.datetime.fromisoformat(stats['last_used']).strftime("%Y-%m-%d %H:%M")}</td>
</tr>
""" for endpoint, stats in usage_data['all_time_endpoint_usage'].items()
])
daily_usage_table_rows = "\n".join([
f"""
<tr>
<td>{date}</td>
<td>{data['requests']}</td>
<td>{data['unique_ips_count']}</td>
</tr>
""" for date, data in usage_data['daily_usage_period'].items()
])
recent_requests_rows = "\n".join([
f"""
<tr>
<td>{datetime.datetime.fromisoformat(req['timestamp']).strftime("%Y-%m-%d %H:%M:%S")}</td>
<td>{req['model']}</td>
<td>{req['endpoint']}</td>
<td>{req['ip_address']}</td>
<td>{req['user_agent']}</td>
</tr>
""" for req in usage_data['recent_requests']
])
html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Lokiai AI - Usage Statistics</title>
<link href="[https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;700&display=swap](https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;700&display=swap)" rel="stylesheet">
<script src="[https://cdn.jsdelivr.net/npm/chart.js](https://cdn.jsdelivr.net/npm/chart.js)"></script>
<style>
:root {{
--bg-dark: #0f1011;
--bg-darker: #070708;
--text-primary: #e6e6e6;
--text-secondary: #8c8c8c;
--border-color: #2c2c2c;
--accent-color: #3a6ee0;
--accent-hover: #4a7ef0;
--chart-bg-light: rgba(58, 110, 224, 0.2);
--chart-border-light: #3a6ee0;
}}
body {{
font-family: 'Inter', sans-serif;
background-color: var(--bg-dark);
color: var(--text-primary);
max-width: 1200px;
margin: 0 auto;
padding: 40px 20px;
line-height: 1.6;
}}
.logo {{
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 30px;
}}
.logo h1 {{
font-weight: 700;
font-size: 2.8em;
color: var(--text-primary);
margin-left: 15px;
}}
.logo img {{
width: 70px;
height: 70px;
border-radius: 12px;
box-shadow: 0 5px 15px rgba(0,0,0,0.2);
}}
.container {{
background-color: var(--bg-darker);
border-radius: 16px;
padding: 30px;
box-shadow: 0 20px 50px rgba(0,0,0,0.4);
border: 1px solid var(--border-color);
}}
h2, h3 {{
color: var(--text-primary);
border-bottom: 2px solid var(--border-color);
padding-bottom: 12px;
margin-top: 40px;
margin-bottom: 25px;
font-weight: 600;
font-size: 1.8em;
}}
.summary-grid {{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 20px;
margin-bottom: 30px;
}}
.summary-card {{
background-color: var(--bg-dark);
border-radius: 10px;
padding: 20px;
text-align: center;
border: 1px solid var(--border-color);
box-shadow: 0 8px 20px rgba(0,0,0,0.2);
transition: transform 0.2s ease-in-out;
}}
.summary-card:hover {{
transform: translateY(-5px);
}}
.summary-card h3 {{
margin-top: 0;
font-size: 1.1em;
color: var(--text-secondary);
border-bottom: none;
padding-bottom: 0;
margin-bottom: 10px;
}}
.summary-card p {{
font-size: 2.2em;
font-weight: 700;
color: var(--accent-color);
margin: 0;
}}
table {{
width: 100%;
border-collapse: separate;
border-spacing: 0;
margin-bottom: 40px;
background-color: var(--bg-dark);
border-radius: 10px;
overflow: hidden;
box-shadow: 0 8px 20px rgba(0,0,0,0.2);
}}
th, td {{
border: 1px solid var(--border-color);
padding: 15px;
text-align: left;
transition: background-color 0.3s ease;
}}
th {{
background-color: #1a1a1a;
color: var(--text-primary);
font-weight: 600;
text-transform: uppercase;
font-size: 0.95em;
}}
tr:nth-child(even) {{
background-color: rgba(255,255,255,0.03);
}}
tr:hover {{
background-color: rgba(62,100,255,0.1);
}}
.chart-container {{
background-color: var(--bg-dark);
border-radius: 10px;
padding: 20px;
margin-bottom: 40px;
border: 1px solid var(--border-color);
box-shadow: 0 8px 20px rgba(0,0,0,0.2);
max-height: 400px; /* Limit chart height */
position: relative; /* For responsive canvas */
}}
canvas {{
max-width: 100% !important;
height: auto !important;
}}
@media (max-width: 768px) {{
body {{
padding: 20px 10px;
}}
.container {{
padding: 20px;
}}
.logo h1 {{
font-size: 2em;
}}
.summary-card p {{
font-size: 1.8em;
}}
h2, h3 {{
font-size: 1.5em;
}}
table {{
font-size: 0.85em;
}}
th, td {{
padding: 10px;
}}
}}
</style>
</head>
<body>
<div class="container">
<div class="logo">
<img src="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjAwIiBoZWlnaHQ9IjIwMCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMC9zdmciPjxwYXRoIGQ9Ik0xMDAgMzVMNTAgOTBoMTAwWiIgZmlsbD0iIzNhNmVlMCIvPjxjaXJjbGUgY3g9IjEwMCIgY3k9IjE0MCIgcj0iMzAiIGZpbGw9IiMzYTZlZTAiLz48L3N2Zz4=" alt="Lokiai AI Logo">
<h1>Lokiai AI Usage</h1>
</div>
<div class="summary-grid">
<div class="summary-card">
<h3>Total Requests (All Time)</h3>
<p>{usage_data['total_requests']}</p>
</div>
<div class="summary-card">
<h3>Unique IPs (All Time)</h3>
<p>{usage_data['unique_ips_total_count']}</p>
</div>
<div class="summary-card">
<h3>Models Used (Last {days} Days)</h3>
<p>{len(usage_data['model_usage_period'])}</p>
</div>
<div class="summary-card">
<h3>Endpoints Used (Last {days} Days)</h3>
<p>{len(usage_data['endpoint_usage_period'])}</p>
</div>
</div>
<h2>Daily Usage (Last {days} Days)</h2>
<div class="chart-container">
<canvas id="dailyRequestsChart"></canvas>
</div>
<table>
<thead>
<tr>
<th>Date</th>
<th>Requests</th>
<th>Unique IPs</th>
</tr>
</thead>
<tbody>
{daily_usage_table_rows}
</tbody>
</table>
<h2>Model Usage (Last {days} Days)</h2>
<div class="chart-container">
<canvas id="modelUsageChart"></canvas>
</div>
<h3>Model Usage (All Time Details)</h3>
<table>
<thead>
<tr>
<th>Model</th>
<th>Total Requests</th>
<th>First Used</th>
<th>Last Used</th>
</tr>
</thead>
<tbody>
{model_usage_all_time_rows}
</tbody>
</table>
<h2>API Endpoint Usage (Last {days} Days)</h2>
<div class="chart-container">
<canvas id="endpointUsageChart"></canvas>
</div>
<h3>API Endpoint Usage (All Time Details)</h3>
<table>
<thead>
<tr>
<th>Endpoint</th>
<th>Total Requests</th>
<th>First Used</th>
<th>Last Used</th>
</tr>
</thead>
<tbody>
{api_usage_all_time_rows}
</tbody>
</table>
<h2>Recent Requests (Last 20)</h2>
<table>
<thead>
<tr>
<th>Timestamp</th>
<th>Model</th>
<th>Endpoint</th>
<th>IP Address</th>
<th>User Agent</th>
</tr>
</thead>
<tbody>
{recent_requests_rows}
</tbody>
</table>
</div>
<script>
// Chart.js data and rendering logic
const modelLabels = {json.dumps(model_labels)};
const modelCounts = {json.dumps(model_counts)};
const endpointLabels = {json.dumps(endpoint_labels)};
const endpointCounts = {json.dumps(endpoint_counts)};
const dailyDates = {json.dumps(daily_dates)};
const dailyRequests = {json.dumps(daily_requests)};
const dailyUniqueIps = {json.dumps(daily_unique_ips)};
// Model Usage Chart (Bar Chart)
new Chart(document.getElementById('modelUsageChart'), {{
type: 'bar',
data: {{
labels: modelLabels,
datasets: [{{
label: 'Requests',
data: modelCounts,
backgroundColor: 'var(--chart-bg-light)',
borderColor: 'var(--chart-border-light)',
borderWidth: 1,
borderRadius: 5,
}}]
}},
options: {{
responsive: true,
maintainAspectRatio: false,
plugins: {{
legend: {{
labels: {{
color: 'var(--text-primary)'
}}
}},
title: {{
display: true,
text: 'Model Usage',
color: 'var(--text-primary)'
}}
}},
scales: {{
x: {{
ticks: {{
color: 'var(--text-secondary)'
}},
grid: {{
color: 'var(--border-color)'
}}
}},
y: {{
beginAtZero: true,
ticks: {{
color: 'var(--text-secondary)'
}},
grid: {{
color: 'var(--border-color)'
}}
}}
}}
}}
}});
// Endpoint Usage Chart (Doughnut Chart)
new Chart(document.getElementById('endpointUsageChart'), {{
type: 'doughnut',
data: {{
labels: endpointLabels,
datasets: [{{
label: 'Requests',
data: endpointCounts,
backgroundColor: [
'#3a6ee0', '#5b8bff', '#8dc4ff', '#b3d8ff', '#d0e8ff',
'#FF6384', '#36A2EB', '#FFCE56', '#4BC0C0', '#9966FF'
],
hoverOffset: 4
}}]
}},
options: {{
responsive: true,
maintainAspectRatio: false,
plugins: {{
legend: {{
position: 'right',
labels: {{
color: 'var(--text-primary)'
}}
}},
title: {{
display: true,
text: 'API Endpoint Usage',
color: 'var(--text-primary)'
}}
}}
}}
}});
// Daily Requests Chart (Line Chart)
new Chart(document.getElementById('dailyRequestsChart'), {{
type: 'line',
data: {{
labels: dailyDates,
datasets: [
{{
label: 'Total Requests',
data: dailyRequests,
borderColor: 'var(--accent-color)',
backgroundColor: 'rgba(58, 110, 224, 0.1)',
fill: true,
tension: 0.3
}},
{{
label: 'Unique IPs',
data: dailyUniqueIps,
borderColor: '#FFCE56', // A distinct color for unique IPs
backgroundColor: 'rgba(255, 206, 86, 0.1)',
fill: true,
tension: 0.3
}}
]
}},
options: {{
responsive: true,
maintainAspectRatio: false,
plugins: {{
legend: {{
labels: {{
color: 'var(--text-primary)'
}}
}},
title: {{
display: true,
text: 'Daily Requests and Unique IPs',
color: 'var(--text-primary)'
}}
}},
scales: {{
x: {{
ticks: {{
color: 'var(--text-secondary)'
}},
grid: {{
color: 'var(--border-color)'
}}
}},
y: {{
beginAtZero: true,
ticks: {{
color: 'var(--text-secondary)'
}},
grid: {{
color: 'var(--border-color)'
}}
}}
}}
}}
}});
</script>
</body>
</html>
"""
return html_content
@app.get("/usage/page", response_class=HTMLResponse)
async def usage_page(days: int = 7):
"""
Serves a detailed HTML page with usage statistics and charts.
The 'days' query parameter can be used to specify the reporting period for charts.
"""
usage_data = usage_tracker.get_usage_summary(days=days)
html_content = generate_usage_html(usage_data)
return HTMLResponse(content=html_content)
@app.get("/meme")
async def get_meme():
"""
Fetches a random meme from meme-api.com and streams the image content.
Handles potential errors during fetching.
"""
try:
client = get_async_client()
response = await client.get("[https://meme-api.com/gimme](https://meme-api.com/gimme)")
response.raise_for_status() # Raise an exception for bad status codes
response_data = response.json()
meme_url = response_data.get("url")
if not meme_url:
raise HTTPException(status_code=404, detail="No meme URL found in response.")
# Stream the image content back to the client
image_response = await client.get(meme_url, follow_redirects=True)
image_response.raise_for_status()
async def stream_with_larger_chunks():
"""Streams binary data in larger chunks for efficiency."""
chunks = []
size = 0
# Define a larger chunk size for better streaming performance
chunk_size = 65536 # 64 KB
async for chunk in image_response.aiter_bytes(chunk_size=chunk_size):
chunks.append(chunk)
size += len(chunk)
if size >= chunk_size * 2: # Send chunks when accumulated size is significant
yield b''.join(chunks)
chunks = []
size = 0
if chunks: # Yield any remaining chunks
yield b''.join(chunks)
return StreamingResponse(
stream_with_larger_chunks(),
media_type=image_response.headers.get("content-type", "image/png"), # Fallback to png
headers={'Cache-Control': 'max-age=3600'} # Cache memes for 1 hour
)
except httpx.HTTPStatusError as e:
print(f"Error fetching meme from upstream: {e.response.status_code} - {e.response.text}")
raise HTTPException(status_code=e.response.status_code, detail=f"Failed to fetch meme: {e.response.text}")
except httpx.RequestError as e:
print(f"Request error fetching meme: {e}")
raise HTTPException(status_code=502, detail=f"Could not connect to meme service: {e}")
except Exception as e:
print(f"An unexpected error occurred while getting meme: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve meme due to an unexpected error.")
def load_model_ids(json_file_path: str) -> List[str]:
"""
Loads model IDs from a JSON file.
This helps in dynamically determining available models.
"""
try:
with open(json_file_path, 'r') as f:
models_data = json.load(f)
return [model['id'] for model in models_data if 'id' in model]
except Exception as e:
print(f"Error loading model IDs from {json_file_path}: {str(e)}")
return []
@app.on_event("startup")
async def startup_event():
"""
Actions to perform on application startup:
- Load available model IDs.
- Initialize scraper pool.
- Check for missing environment variables and issue warnings.
"""
global available_model_ids
# Load models from a local models.json file first
available_model_ids = load_model_ids("models.json")
print(f"Loaded {len(available_model_ids)} model IDs from models.json")
# Extend with hardcoded model lists for various providers
available_model_ids.extend(list(pollinations_models))
available_model_ids.extend(list(alternate_models))
available_model_ids.extend(list(mistral_models))
available_model_ids.extend(list(claude_3_models))
available_model_ids.extend(list(gemini_models)) # Add Gemini models explicitly
# Remove duplicates and store as a set for faster lookups
available_model_ids = list(set(available_model_ids))
print(f"Total unique available models after merging: {len(available_model_ids)}")
# Initialize scraper pool
for _ in range(MAX_SCRAPERS):
scraper_pool.append(cloudscraper.create_scraper())
print(f"Initialized Cloudscraper pool with {MAX_SCRAPERS} instances.")
# Environment variable check for critical services
env_vars = get_env_vars()
missing_vars = []
if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
missing_vars.append('API_KEYS')
if not env_vars['secret_api_endpoint']:
missing_vars.append('SECRET_API_ENDPOINT')
if not env_vars['secret_api_endpoint_2']:
missing_vars.append('SECRET_API_ENDPOINT_2')
if not env_vars['secret_api_endpoint_3']:
missing_vars.append('SECRET_API_ENDPOINT_3')
if not env_vars['secret_api_endpoint_4'] and any(model in pollinations_models for model in available_model_ids):
missing_vars.append('SECRET_API_ENDPOINT_4 (Pollinations.ai)')
if not env_vars['secret_api_endpoint_5'] and any(model in claude_3_models for model in available_model_ids):
missing_vars.append('SECRET_API_ENDPOINT_5 (Claude 3.x)')
if not env_vars['secret_api_endpoint_6'] and any(model in gemini_models for model in available_model_ids):
missing_vars.append('SECRET_API_ENDPOINT_6 (Gemini)')
if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids):
missing_vars.append('MISTRAL_API')
if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
missing_vars.append('MISTRAL_KEY')
if not env_vars['gemini_key'] and any(model in gemini_models for model in available_model_ids):
missing_vars.append('GEMINI_KEY')
if not env_vars['new_img'] and len(supported_image_models) > 0:
missing_vars.append('NEW_IMG (Image Generation)')
if missing_vars:
print(f"WARNING: The following critical environment variables are missing or empty: {', '.join(missing_vars)}")
print("Some server functionality (e.g., specific AI models, image generation) may be limited or unavailable.")
else:
print("All critical environment variables appear to be configured.")
print("Server started successfully!")
@app.on_event("shutdown")
async def shutdown_event():
"""
Actions to perform on application shutdown:
- Close HTTPX client.
- Clear scraper pool.
- Save usage data to disk.
"""
client = get_async_client()
await client.aclose() # Ensure the httpx client connection pool is closed
scraper_pool.clear() # Clear the scraper pool
usage_tracker.save_data() # Persist usage data on shutdown
print("Server shutdown complete!")
@app.get("/health")
async def health_check():
"""
Provides a health check endpoint, reporting server status and missing critical environment variables.
"""
env_vars = get_env_vars()
missing_critical_vars = []
# Re-check critical environment variables for health status
if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
missing_critical_vars.append('API_KEYS')
if not env_vars['secret_api_endpoint']:
missing_critical_vars.append('SECRET_API_ENDPOINT')
if not env_vars['secret_api_endpoint_2']:
missing_critical_vars.append('SECRET_API_ENDPOINT_2')
if not env_vars['secret_api_endpoint_3']:
missing_critical_vars.append('SECRET_API_ENDPOINT_3')
# Check for specific service endpoints only if corresponding models are configured/supported
if not env_vars['secret_api_endpoint_4'] and any(model in pollinations_models for model in available_model_ids):
missing_critical_vars.append('SECRET_API_ENDPOINT_4 (Pollinations.ai)')
if not env_vars['secret_api_endpoint_5'] and any(model in claude_3_models for model in available_model_ids):
missing_critical_vars.append('SECRET_API_ENDPOINT_5 (Claude 3.x)')
if not env_vars['secret_api_endpoint_6'] and any(model in gemini_models for model in available_model_ids):
missing_critical_vars.append('SECRET_API_ENDPOINT_6 (Gemini)')
if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids):
missing_critical_vars.append('MISTRAL_API')
if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
missing_critical_vars.append('MISTRAL_KEY')
if not env_vars['gemini_key'] and any(model in gemini_models for model in available_model_ids):
missing_critical_vars.append('GEMINI_KEY')
if not env_vars['new_img'] and len(supported_image_models) > 0:
missing_critical_vars.append('NEW_IMG (Image Generation)')
health_status = {
"status": "healthy" if not missing_critical_vars else "unhealthy",
"missing_env_vars": missing_critical_vars,
"server_status": server_status, # Reports global server status flag
"message": "Everything's lit! πŸš€" if not missing_critical_vars else "Uh oh, some env vars are missing. 😬"
}
return JSONResponse(content=health_status)
if __name__ == "__main__":
import uvicorn
# When running directly, ensure startup_event is called to load models and check env vars
# uvicorn handles startup/shutdown events automatically when run with `uvicorn.run()`
uvicorn.run(app, host="0.0.0.0", port=7860)