Spaces:
Running
Running
import os | |
import re | |
import json | |
import datetime | |
import time | |
import asyncio | |
import logging | |
from pathlib import Path | |
from functools import lru_cache | |
from typing import Optional, Dict, List, Any, Generator, Set | |
from concurrent.futures import ThreadPoolExecutor | |
# Third-party libraries (ensure these are in requirements.txt) | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, HTTPException, Request, Depends, Security, Response | |
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse | |
from fastapi.security import APIKeyHeader | |
from pydantic import BaseModel | |
import httpx | |
import uvloop # Use uvloop for performance | |
from fastapi.middleware.gzip import GZipMiddleware | |
from starlette.middleware.cors import CORSMiddleware | |
import cloudscraper # For bypassing Cloudflare, potentially unreliable | |
import requests # For synchronous requests like in /dynamo | |
# HF Space Note: Ensure usage_tracker.py is in your repository | |
try: | |
from usage_tracker import UsageTracker | |
usage_tracker = UsageTracker() | |
except ImportError: | |
print("Warning: usage_tracker.py not found. Usage tracking will be disabled.") | |
# Create a dummy tracker if the file is missing | |
class DummyUsageTracker: | |
def record_request(self, *args, **kwargs): pass | |
def get_usage_summary(self, *args, **kwargs): return {} | |
def save_data(self, *args, **kwargs): pass | |
usage_tracker = DummyUsageTracker() | |
# --- Configuration & Setup --- | |
# HF Space Note: uvloop can improve performance in I/O bound tasks common in web apps. | |
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | |
# HF Space Note: Adjust max_workers based on your HF Space resources (CPU). | |
# Higher tiers allow more workers. Start lower (e.g., 4) for free tier. | |
executor = ThreadPoolExecutor(max_workers=8) | |
# HF Space Note: load_dotenv() is useful for local dev but HF Spaces use Secrets. | |
# os.getenv will automatically pick up secrets set in the HF Space settings. | |
load_dotenv() | |
# Logging setup | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# API key security | |
api_key_header = APIKeyHeader(name="Authorization", auto_error=False) | |
# --- FastAPI App Initialization --- | |
app = FastAPI( | |
title="LokiAI API", | |
description="API Proxy for various AI models with usage tracking and streaming.", | |
version="1.0.0" | |
) | |
# Middleware | |
app.add_middleware(GZipMiddleware, minimum_size=1000) # Compress large responses | |
app.add_middleware( | |
CORSMiddleware, # Allow cross-origin requests (useful for web playgrounds) | |
allow_origins=["*"], # Or restrict to specific origins | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --- Environment Variables & Model Config --- | |
# Cache environment variables | |
def get_env_vars() -> Dict[str, Any]: | |
"""Loads and returns essential environment variables.""" | |
# HF Space Note: Set these as Secrets in your Hugging Face Space settings. | |
return { | |
'api_keys': set(filter(None, os.getenv('API_KEYS', '').split(','))), # Use set for faster lookup | |
'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'), | |
'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'), | |
'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'), # Search endpoint | |
'secret_api_endpoint_4': os.getenv('SECRET_API_ENDPOINT_4', "https://text.pollinations.ai/openai"), # Pollinations | |
'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'), # Claude 3 endpoint | |
'mistral_api': os.getenv('MISTRAL_API', "https://api.mistral.ai"), | |
'mistral_key': os.getenv('MISTRAL_KEY'), | |
'new_img_endpoint': os.getenv('NEW_IMG'), # Image generation endpoint | |
'hf_space_url': os.getenv('HF_SPACE_URL', 'https://your-space-name.hf.space') # HF Space Note: Set this! Used for Referer/Origin checks. | |
} | |
# Model sets for fast lookups | |
# HF Space Note: Consider moving these large sets to a separate config file (e.g., config.py or models_config.json) | |
# for better organization if they grow larger. | |
mistral_models: Set[str] = { | |
"mistral-large-latest", "pixtral-large-latest", "mistral-moderation-latest", | |
"ministral-3b-latest", "ministral-8b-latest", "open-mistral-nemo", | |
"mistral-small-latest", "mistral-saba-latest", "codestral-latest" | |
} | |
pollinations_models: Set[str] = { | |
"openai", "openai-large", "openai-xlarge", "openai-reasoning", "qwen-coder", | |
"llama", "mistral", "searchgpt", "deepseek", "claude-hybridspace", | |
"deepseek-r1", "deepseek-reasoner", "llamalight", "gemini", "gemini-thinking", | |
"hormoz", "phi", "phi-mini", "openai-audio", "llama-scaleway" | |
} | |
alternate_models: Set[str] = { | |
"gpt-4o", "deepseek-v3", "llama-3.1-8b-instruct", "llama-3.1-sonar-small-128k-online", | |
"deepseek-r1-uncensored", "tinyswallow1.5b", "andy-3.5", "o3-mini-low", | |
"hermes-3-llama-3.2-3b", "creitin-r1", "fluffy.1-chat", "plutotext-1-text", | |
"command-a", "claude-3-7-sonnet-20250219", "plutogpt-3.5-turbo" | |
} | |
claude_3_models: Set[str] = { | |
"claude-3-7-sonnet", "claude-3-7-sonnet-thinking", "claude 3.5 haiku", | |
"claude 3.5 sonnet", "claude 3.5 haiku", "o3-mini-medium", "o3-mini-high", | |
"grok-3", "grok-3-thinking", "grok 2" | |
} | |
supported_image_models: Set[str] = { | |
"Flux Pro Ultra", "grok-2-aurora", "Flux Pro", "Flux Pro Ultra Raw", "Flux Dev", | |
"Flux Schnell", "stable-diffusion-3-large-turbo", "Flux Realism", | |
"stable-diffusion-ultra", "dall-e-3", "sdxl-lightning-4step" | |
} | |
# --- Pydantic Models --- | |
class Message(BaseModel): | |
role: str | |
content: Any # Allow content to be string or potentially list for multimodal models | |
class Payload(BaseModel): | |
model: str | |
messages: List[Message] | |
stream: bool = False | |
# Add other potential OpenAI compatible parameters with defaults | |
max_tokens: Optional[int] = None | |
temperature: Optional[float] = None | |
top_p: Optional[float] = None | |
# ... add others as needed | |
class ImageGenerationPayload(BaseModel): | |
model: str | |
prompt: str | |
size: Optional[str] = "1024x1024" # Default size, make optional if API allows | |
n: Optional[int] = 1 # Number of images, OpenAI uses 'n' | |
# HF Space Note: Ensure these parameter names match the target NEW_IMG endpoint API | |
# Renaming from 'number' to 'n' and 'size' type hint correction. | |
# --- Global State & Clients --- | |
server_status: bool = True # For maintenance mode | |
available_model_ids: List[str] = [] # Loaded at startup | |
# HF Space Note: Reusable HTTP client with connection pooling is crucial for performance. | |
# Adjust limits based on expected load and HF Space resources. | |
def get_async_client() -> httpx.AsyncClient: | |
"""Returns a cached instance of httpx.AsyncClient.""" | |
# HF Space Note: Timeouts are important to prevent hanging requests. | |
# Keepalive connections reduce handshake overhead. | |
timeout = httpx.Timeout(30.0, connect=10.0) # 30s total, 10s connect | |
limits = httpx.Limits(max_keepalive_connections=20, max_connections=100) | |
return httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True) | |
# HF Space Note: cloudscraper pool. Be mindful of potential rate limits or blocks. | |
# Consider alternatives if this becomes unreliable. | |
scraper_pool: List[cloudscraper.CloudScraper] = [] | |
MAX_SCRAPERS = 10 # Reduced pool size for potentially lower resource usage | |
def get_scraper() -> cloudscraper.CloudScraper: | |
"""Gets a cloudscraper instance from the pool.""" | |
if not scraper_pool: | |
logger.info(f"Initializing {MAX_SCRAPERS} cloudscraper instances...") | |
for _ in range(MAX_SCRAPERS): | |
# HF Space Note: Scraper creation can be slow, doing it upfront is good. | |
scraper_pool.append(cloudscraper.create_scraper()) | |
logger.info("Cloudscraper pool initialized.") | |
# Simple round-robin selection | |
return scraper_pool[int(time.monotonic() * 1000) % MAX_SCRAPERS] | |
# --- Security & Authentication --- | |
async def verify_api_key( | |
request: Request, | |
api_key: Optional[str] = Security(api_key_header) | |
) -> bool: | |
"""Verifies the provided API key against environment variables.""" | |
env_vars = get_env_vars() | |
valid_api_keys = env_vars.get('api_keys', set()) | |
hf_space_url = env_vars.get('hf_space_url', '') | |
# Allow bypass if the referer is from the known HF Space playground URLs | |
# HF Space Note: Make HF_SPACE_URL a secret for flexibility. | |
referer = request.headers.get("referer", "") | |
if hf_space_url and referer.startswith((f"{hf_space_url}/playground", f"{hf_space_url}/image-playground")): | |
logger.debug(f"API Key check bypassed for referer: {referer}") | |
return True | |
if not api_key: | |
logger.warning("API Key missing.") | |
raise HTTPException(status_code=403, detail="Not authenticated: No API key provided") | |
# Clean 'Bearer ' prefix if present | |
if api_key.startswith('Bearer '): | |
api_key = api_key[7:] | |
if not valid_api_keys: | |
logger.error("API keys are not configured on the server (API_KEYS secret missing?).") | |
raise HTTPException(status_code=500, detail="Server configuration error: API keys not set") | |
if api_key not in valid_api_keys: | |
logger.warning(f"Invalid API key received: {api_key[:4]}...") # Log prefix only | |
raise HTTPException(status_code=403, detail="Not authenticated: Invalid API key") | |
logger.debug("API Key verified successfully.") | |
return True | |
# --- Model & File Loading --- | |
def load_models_data() -> List[Dict]: | |
"""Loads model data from models.json.""" | |
# HF Space Note: Ensure models.json is in the root of your HF Space repo. | |
models_file = Path(__file__).parent / 'models.json' | |
if not models_file.is_file(): | |
logger.error("models.json not found!") | |
return [] | |
try: | |
with open(models_file, 'r') as f: | |
return json.load(f) | |
except (FileNotFoundError, json.JSONDecodeError) as e: | |
logger.error(f"Error loading models.json: {e}") | |
return [] | |
async def get_models() -> List[Dict]: | |
"""Async wrapper to get models data.""" | |
models_data = load_models_data() | |
if not models_data: | |
raise HTTPException(status_code=500, detail="Error loading available models") | |
return models_data | |
# --- Static File Serving --- | |
# HF Space Note: Cache frequently accessed static files in memory. | |
def read_static_file(file_path: str) -> Optional[str]: | |
"""Reads a static file, caching the result.""" | |
full_path = Path(__file__).parent / file_path | |
if not full_path.is_file(): | |
logger.warning(f"Static file not found: {file_path}") | |
return None | |
try: | |
with open(full_path, "r", encoding="utf-8") as file: | |
return file.read() | |
except Exception as e: | |
logger.error(f"Error reading static file {file_path}: {e}") | |
return None | |
async def serve_static_html(file_path: str) -> HTMLResponse: | |
"""Serves a static HTML file.""" | |
content = read_static_file(file_path) | |
if content is None: | |
return HTMLResponse(content=f"<h1>Error: {file_path} not found</h1>", status_code=404) | |
return HTMLResponse(content=content) | |
# --- API Endpoints --- | |
# Basic Routes & Static Files | |
async def favicon(): | |
favicon_path = Path(__file__).parent / "favicon.ico" | |
if favicon_path.is_file(): | |
return FileResponse(favicon_path, media_type="image/vnd.microsoft.icon") | |
raise HTTPException(status_code=404, detail="favicon.ico not found") | |
async def banner(): | |
banner_path = Path(__file__).parent / "banner.jpg" | |
if banner_path.is_file(): | |
return FileResponse(banner_path, media_type="image/jpeg") # Assuming JPEG | |
raise HTTPException(status_code=404, detail="banner.jpg not found") | |
async def ping(): | |
"""Simple health check endpoint.""" | |
return {"message": "pong"} | |
async def root(): | |
"""Serves the main index HTML page.""" | |
return await serve_static_html("index.html") | |
async def script_js(): | |
content = read_static_file("script.js") | |
if content is None: | |
return Response(content="/* script.js not found */", status_code=404, media_type="application/javascript") | |
return Response(content=content, media_type="application/javascript") | |
async def style_css(): | |
content = read_static_file("style.css") | |
if content is None: | |
return Response(content="/* style.css not found */", status_code=404, media_type="text/css") | |
return Response(content=content, media_type="text/css") | |
async def playground(): | |
"""Serves the chat playground HTML page.""" | |
return await serve_static_html("playground.html") | |
async def image_playground(): | |
"""Serves the image playground HTML page.""" | |
return await serve_static_html("image-playground.html") | |
# Dynamic Page Example | |
async def dynamic_ai_page(request: Request): | |
"""Generates a dynamic HTML page using an AI model (example).""" | |
# HF Space Note: This uses a hardcoded URL to *itself* if running in the space. | |
# Ensure the HF_SPACE_URL secret is set correctly. | |
env_vars = get_env_vars() | |
hf_space_url = env_vars.get('hf_space_url', '') | |
if not hf_space_url: | |
raise HTTPException(status_code=500, detail="HF_SPACE_URL environment variable not set.") | |
user_agent = request.headers.get('user-agent', 'Unknown') | |
client_ip = request.client.host if request.client else "Unknown" | |
location = f"IP: {client_ip}" # Basic IP, location requires GeoIP lookup (extra dependency) | |
prompt = f""" | |
Generate a cool, dynamic HTML page for a user with the following details: | |
- App Name: "LokiAI" | |
- User-Agent: {user_agent} | |
- Location Info: {location} | |
- Style: Cyberpunk aesthetic, minimalist layout, maybe some retro touches. | |
- Content: Include a heading, a short motivational or witty message, and perhaps a subtle animation. Use inline CSS for styling within a <style> tag. | |
- Output: Provide ONLY the raw HTML code, starting with <!DOCTYPE html>. Do not wrap it in backticks or add explanations. | |
""" | |
payload = { | |
"model": "mistral-small-latest", # Or another capable model | |
"messages": [{"role": "user", "content": prompt}], | |
"max_tokens": 1000, | |
"temperature": 0.7 | |
} | |
headers = { | |
# HF Space Note: Use the space's own URL and a valid API key if required by your setup. | |
# Here, we assume the playground key bypass works or use a dedicated internal key. | |
"Authorization": f"Bearer {list(env_vars['api_keys'])[0] if env_vars['api_keys'] else 'dummy-key'}" # Use first key or dummy | |
} | |
try: | |
# HF Space Note: Use the async client for internal requests too. | |
client = get_async_client() | |
api_url = f"{hf_space_url}/chat/completions" # Call own endpoint | |
response = await client.post(api_url, json=payload, headers=headers) | |
response.raise_for_status() # Raise exception for bad status codes | |
data = response.json() | |
html_content = data.get('choices', [{}])[0].get('message', {}).get('content', '') | |
# Basic cleanup (remove potential markdown backticks if model adds them) | |
html_content = re.sub(r"^```html\s*", "", html_content, flags=re.IGNORECASE) | |
html_content = re.sub(r"\s*```$", "", html_content) | |
if not html_content.strip().lower().startswith("<!doctype html"): | |
logger.warning("Dynamo page generation might be incomplete or malformed.") | |
# Optionally return a fallback static page here | |
return HTMLResponse(content=html_content) | |
except httpx.HTTPStatusError as e: | |
logger.error(f"Error calling self API for /dynamo: {e.response.status_code} - {e.response.text}") | |
raise HTTPException(status_code=502, detail=f"Failed to generate dynamic content: Upstream API error {e.response.status_code}") | |
except Exception as e: | |
logger.error(f"Unexpected error in /dynamo: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail="Failed to generate dynamic content due to an internal error.") | |
# Vetra Example (Fetching from GitHub) | |
# HF Space Note: Ensure outbound requests to raw.githubusercontent.com are allowed. | |
GITHUB_BASE = "https://raw.githubusercontent.com/Parthsadaria/Vetra/main" | |
VETRA_FILES = {"html": "index.html", "css": "style.css", "js": "script.js"} | |
async def get_github_file(filename: str) -> Optional[str]: | |
"""Fetches a file from the Vetra GitHub repo.""" | |
url = f"{GITHUB_BASE}/{filename}" | |
try: | |
client = get_async_client() | |
res = await client.get(url) | |
res.raise_for_status() | |
return res.text | |
except httpx.RequestError as e: | |
logger.error(f"Error fetching GitHub file {url}: {e}") | |
return None | |
except httpx.HTTPStatusError as e: | |
logger.error(f"GitHub file {url} returned status {e.response.status_code}") | |
return None | |
async def serve_vetra(): | |
"""Serves the Vetra application by fetching components from GitHub.""" | |
logger.info("Fetching Vetra files from GitHub...") | |
# Fetch files concurrently | |
html_task = asyncio.create_task(get_github_file(VETRA_FILES["html"])) | |
css_task = asyncio.create_task(get_github_file(VETRA_FILES["css"])) | |
js_task = asyncio.create_task(get_github_file(VETRA_FILES["js"])) | |
html, css, js = await asyncio.gather(html_task, css_task, js_task) | |
if not html: | |
logger.error("Failed to fetch Vetra index.html") | |
return HTMLResponse(content="<h1>Error: Could not load Vetra application (HTML missing)</h1>", status_code=502) | |
# Inject CSS and JS into HTML | |
css_content = f"<style>{css or '/* CSS failed to load */'}</style>" | |
js_content = f"<script>{js or '// JS failed to load'}</script>" | |
# Inject carefully before closing tags | |
final_html = html.replace("</head>", f"{css_content}\n</head>", 1) | |
final_html = final_html.replace("</body>", f"{js_content}\n</body>", 1) | |
logger.info("Successfully served Vetra application.") | |
return HTMLResponse(content=final_html) | |
# Model Info Endpoint | |
async def return_models(): | |
"""Returns the list of available models loaded from models.json.""" | |
# HF Space Note: This endpoint now relies on models.json being present. | |
# It no longer dynamically adds models defined only in the script's sets. | |
# Ensure models.json is comprehensive or adjust startup logic if needed. | |
return await get_models() | |
# Search Endpoint (using cloudscraper) | |
# HF Space Note: This uses cloudscraper which might be blocked or require updates. | |
# Consider replacing with a more stable search API if possible. | |
async def generate_search_async(query: str, systemprompt: Optional[str] = None) -> asyncio.Queue: | |
"""Performs search using the configured backend and streams results.""" | |
queue = asyncio.Queue() | |
env_vars = get_env_vars() | |
search_endpoint = env_vars.get('secret_api_endpoint_3') | |
async def _fetch_search_data(): | |
if not search_endpoint: | |
await queue.put({"error": "Search API endpoint (SECRET_API_ENDPOINT_3) not configured"}) | |
await queue.put(None) # Signal end | |
return | |
try: | |
scraper = get_scraper() # Get a scraper instance from the pool | |
loop = asyncio.get_running_loop() | |
system_message = systemprompt or "You are a helpful search assistant." | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": query}, | |
] | |
payload = { | |
"model": "searchgpt", # Assuming the endpoint expects this model name | |
"messages": messages, | |
"stream": True # Explicitly request streaming from backend | |
} | |
headers = {"User-Agent": "Mozilla/5.0"} # Standard user agent | |
# HF Space Note: Run synchronous scraper call in executor thread | |
response = await loop.run_in_executor( | |
executor, | |
scraper.post, | |
search_endpoint, | |
json=payload, | |
headers=headers, | |
stream=True # Request streaming from requests library perspective | |
) | |
response.raise_for_status() | |
# Process SSE stream | |
# HF Space Note: Iterating lines on the response directly can be blocking if not handled carefully. | |
# Using iter_lines with decode_unicode=True is generally safe. | |
for line in response.iter_lines(decode_unicode=True): | |
if line.startswith("data: "): | |
try: | |
data_str = line[6:] | |
if data_str.strip() == "[DONE]": # Check for OpenAI style completion | |
break | |
json_data = json.loads(data_str) | |
# Assuming OpenAI compatible streaming format | |
delta = json_data.get("choices", [{}])[0].get("delta", {}) | |
content = delta.get("content") | |
if content: | |
# Reconstruct OpenAI-like SSE chunk | |
chunk = { | |
"id": json_data.get("id"), | |
"object": "chat.completion.chunk", | |
"created": json_data.get("created", int(time.time())), | |
"model": "searchgpt", | |
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}] | |
} | |
await queue.put({"data": f"data: {json.dumps(chunk)}\n\n", "text": content}) | |
# Check for finish reason | |
finish_reason = json_data.get("choices", [{}])[0].get("finish_reason") | |
if finish_reason: | |
chunk = { | |
"id": json_data.get("id"), | |
"object": "chat.completion.chunk", | |
"created": json_data.get("created", int(time.time())), | |
"model": "searchgpt", | |
"choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}] | |
} | |
await queue.put({"data": f"data: {json.dumps(chunk)}\n\n", "text": ""}) | |
break # Stop processing after finish reason | |
except json.JSONDecodeError: | |
logger.warning(f"Failed to decode JSON from search stream: {line}") | |
continue | |
except Exception as e: | |
logger.error(f"Error processing search stream chunk: {e}", exc_info=True) | |
await queue.put({"error": f"Error processing stream: {e}"}) | |
break # Stop on processing error | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Search request failed: {e}") | |
await queue.put({"error": f"Search request failed: {e}"}) | |
except Exception as e: | |
logger.error(f"Unexpected error during search: {e}", exc_info=True) | |
await queue.put({"error": f"An unexpected error occurred during search: {e}"}) | |
finally: | |
await queue.put(None) # Signal completion | |
asyncio.create_task(_fetch_search_data()) | |
return queue | |
async def search_gpt(q: str, stream: bool = True, systemprompt: Optional[str] = None): | |
""" | |
Performs a search using the backend search model and streams results. | |
Pass `stream=false` to get the full response at once. | |
""" | |
if not q: | |
raise HTTPException(status_code=400, detail="Query parameter 'q' is required") | |
# HF Space Note: Ensure usage_tracker is thread-safe if used across async/sync boundaries. | |
# The dummy tracker used when the module isn't found is safe. | |
usage_tracker.record_request(endpoint="/searchgpt") | |
queue = await generate_search_async(q, systemprompt=systemprompt) | |
if stream: | |
async def stream_generator(): | |
full_response_text = "" # Keep track for non-streaming case if needed | |
while True: | |
item = await queue.get() | |
if item is None: # End of stream signal | |
break | |
if "error" in item: | |
# HF Space Note: Log errors server-side, return generic error to client for security. | |
logger.error(f"Search stream error: {item['error']}") | |
# Send an error event in the stream | |
error_event = {"error": {"message": "Search failed.", "code": 500}} | |
yield f"data: {json.dumps(error_event)}\n\n" | |
break | |
if "data" in item: | |
yield item["data"] | |
full_response_text += item.get("text", "") | |
# Optionally yield a [DONE] message if backend doesn't guarantee it | |
# yield "data: [DONE]\n\n" | |
return StreamingResponse( | |
stream_generator(), | |
media_type="text/event-stream", | |
headers={ | |
"Content-Type": "text/event-stream", | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"X-Accel-Buffering": "no" # Crucial for Nginx/proxies in HF Spaces | |
} | |
) | |
else: | |
# Collect full response for non-streaming request | |
full_response_text = "" | |
while True: | |
item = await queue.get() | |
if item is None: | |
break | |
if "error" in item: | |
logger.error(f"Search non-stream error: {item['error']}") | |
raise HTTPException(status_code=502, detail=f"Search failed: {item['error']}") | |
full_response_text += item.get("text", "") | |
# Mimic OpenAI non-streaming response structure | |
return JSONResponse(content={ | |
"id": f"search-{int(time.time())}", | |
"object": "chat.completion", | |
"created": int(time.time()), | |
"model": "searchgpt", | |
"choices": [{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": full_response_text, | |
}, | |
"finish_reason": "stop", | |
}], | |
"usage": { # Note: Token usage is unknown here | |
"prompt_tokens": None, | |
"completion_tokens": None, | |
"total_tokens": None, | |
} | |
}) | |
# Main Chat Completions Proxy | |
async def get_completion( | |
payload: Payload, | |
request: Request, | |
authenticated: bool = Depends(verify_api_key) # Apply authentication | |
): | |
""" | |
Proxies chat completion requests to the appropriate backend API based on the model. | |
Supports streaming (SSE). | |
""" | |
if not server_status: | |
raise HTTPException(status_code=503, detail="Server is under maintenance.") | |
model_to_use = payload.model or "gpt-4o-mini" # Default model | |
# HF Space Note: Check against models loaded at startup. | |
if available_model_ids and model_to_use not in available_model_ids: | |
logger.warning(f"Requested model '{model_to_use}' not in available list.") | |
# Check if it's a known category even if not explicitly in models.json | |
known_categories = mistral_models | pollinations_models | alternate_models | claude_3_models | |
if model_to_use not in known_categories: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Model '{model_to_use}' is not available or recognized. Check /models." | |
) | |
else: | |
logger.info(f"Allowing known category model '{model_to_use}' despite not being in models.json.") | |
# Log request asynchronously | |
asyncio.create_task(log_request(request, model_to_use)) | |
usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions") | |
# Prepare payload for the target API | |
payload_dict = payload.dict(exclude_none=True) # Exclude None values | |
payload_dict["model"] = model_to_use # Ensure model is set | |
env_vars = get_env_vars() | |
hf_space_url = env_vars.get('hf_space_url', '') # Needed for Referer/Origin | |
# Determine target endpoint and headers | |
endpoint = None | |
custom_headers = {} | |
if model_to_use in mistral_models: | |
endpoint = env_vars.get('mistral_api') | |
api_key = env_vars.get('mistral_key') | |
if not endpoint or not api_key: | |
raise HTTPException(status_code=500, detail="Mistral API endpoint or key not configured.") | |
custom_headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "application/json"} | |
# Mistral specific adjustments if needed | |
# payload_dict.pop('system', None) # Example: if Mistral doesn't use 'system' role | |
elif model_to_use in pollinations_models: | |
endpoint = env_vars.get('secret_api_endpoint_4') | |
if not endpoint: | |
raise HTTPException(status_code=500, detail="Pollinations API endpoint (SECRET_API_ENDPOINT_4) not configured.") | |
# Pollinations might need specific headers? Add them here. | |
custom_headers = {"Content-Type": "application/json"} | |
elif model_to_use in alternate_models: | |
endpoint = env_vars.get('secret_api_endpoint_2') | |
if not endpoint: | |
raise HTTPException(status_code=500, detail="Alternate API endpoint (SECRET_API_ENDPOINT_2) not configured.") | |
custom_headers = {"Content-Type": "application/json"} | |
elif model_to_use in claude_3_models: | |
endpoint = env_vars.get('secret_api_endpoint_5') | |
if not endpoint: | |
raise HTTPException(status_code=500, detail="Claude 3 API endpoint (SECRET_API_ENDPOINT_5) not configured.") | |
custom_headers = {"Content-Type": "application/json"} | |
# Claude specific headers (like anthropic-version) might be needed | |
# custom_headers["anthropic-version"] = "2023-06-01" | |
else: # Default endpoint | |
endpoint = env_vars.get('secret_api_endpoint') | |
if not endpoint: | |
raise HTTPException(status_code=500, detail="Default API endpoint (SECRET_API_ENDPOINT) not configured.") | |
# Default endpoint might need Origin/Referer | |
if hf_space_url: | |
custom_headers = { | |
"Origin": hf_space_url, | |
"Referer": hf_space_url, | |
"Content-Type": "application/json" | |
} | |
else: | |
custom_headers = {"Content-Type": "application/json"} | |
target_url = f"{endpoint.rstrip('/')}/v1/chat/completions" # Assume OpenAI compatible path | |
logger.info(f"Proxying request for model '{model_to_use}' to endpoint: {endpoint}") | |
client = get_async_client() | |
async def stream_generator(): | |
"""Generator for streaming the response.""" | |
nonlocal target_url # Allow modification if needed | |
try: | |
async with client.stream("POST", target_url, json=payload_dict, headers=custom_headers) as response: | |
# Check for initial errors before streaming | |
if response.status_code >= 400: | |
error_body = await response.aread() | |
logger.error(f"Upstream API error: {response.status_code} - {error_body.decode()}") | |
# Try to parse error detail from upstream | |
detail = f"Upstream API error: {response.status_code}" | |
try: | |
error_json = json.loads(error_body) | |
detail = error_json.get('error', {}).get('message', detail) | |
except json.JSONDecodeError: | |
pass | |
# Send error as SSE event | |
error_event = {"error": {"message": detail, "code": response.status_code}} | |
yield f"data: {json.dumps(error_event)}\n\n" | |
return # Stop generation | |
# Stream the response line by line | |
async for line in response.aiter_lines(): | |
if line: | |
# Pass through the data directly | |
yield line + "\n" | |
# Ensure stream is properly closed, yield [DONE] if backend doesn't | |
# Some backends might not send [DONE], uncomment if needed | |
# yield "data: [DONE]\n\n" | |
except httpx.TimeoutException: | |
logger.error(f"Request to {target_url} timed out.") | |
error_event = {"error": {"message": "Request timed out", "code": 504}} | |
yield f"data: {json.dumps(error_event)}\n\n" | |
except httpx.RequestError as e: | |
logger.error(f"Failed to connect to upstream API {target_url}: {e}") | |
error_event = {"error": {"message": f"Upstream connection error: {e}", "code": 502}} | |
yield f"data: {json.dumps(error_event)}\n\n" | |
except Exception as e: | |
logger.error(f"An unexpected error occurred during streaming proxy: {e}", exc_info=True) | |
error_event = {"error": {"message": f"Internal server error: {e}", "code": 500}} | |
yield f"data: {json.dumps(error_event)}\n\n" | |
if payload.stream: | |
return StreamingResponse( | |
stream_generator(), | |
media_type="text/event-stream", | |
headers={ | |
"Content-Type": "text/event-stream", | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"X-Accel-Buffering": "no" # Essential for HF Spaces proxying SSE | |
} | |
) | |
else: | |
# Handle non-streaming request by collecting the streamed chunks | |
full_response_content = "" | |
final_json_response = None | |
async for line in stream_generator(): | |
if line.startswith("data: "): | |
data_str = line[6:].strip() | |
if data_str == "[DONE]": | |
break | |
try: | |
chunk = json.loads(data_str) | |
# Check for error chunk | |
if "error" in chunk: | |
logger.error(f"Received error during non-stream collection: {chunk['error']}") | |
raise HTTPException(status_code=chunk['error'].get('code', 502), detail=chunk['error'].get('message', 'Upstream API error')) | |
# Accumulate content from delta | |
delta = chunk.get("choices", [{}])[0].get("delta", {}) | |
content = delta.get("content") | |
if content: | |
full_response_content += content | |
# Store the last chunk structure to reconstruct the final response | |
# We assume the last chunk contains necessary info like id, model, etc. | |
# but we overwrite the choices/message part. | |
final_json_response = chunk # Keep the structure | |
# Check for finish reason | |
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason") | |
if finish_reason: | |
break # Stop collecting | |
except json.JSONDecodeError: | |
logger.warning(f"Could not decode JSON chunk in non-stream mode: {data_str}") | |
except Exception as e: | |
logger.error(f"Error processing chunk in non-stream mode: {e}") | |
raise HTTPException(status_code=500, detail="Error processing response stream.") | |
if final_json_response is None: | |
# Handle cases where no valid data chunks were received | |
logger.error("No valid response chunks received for non-streaming request.") | |
raise HTTPException(status_code=502, detail="Failed to get valid response from upstream API.") | |
# Reconstruct OpenAI-like non-streaming response | |
final_response_obj = { | |
"id": final_json_response.get("id", f"chatcmpl-{int(time.time())}"), | |
"object": "chat.completion", | |
"created": final_json_response.get("created", int(time.time())), | |
"model": model_to_use, # Use the requested model | |
"choices": [{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": full_response_content, | |
}, | |
"finish_reason": final_json_response.get("choices", [{}])[0].get("finish_reason", "stop"), # Get finish reason from last chunk | |
}], | |
"usage": { # Token usage might be in the last chunk for some APIs, otherwise unknown | |
"prompt_tokens": None, | |
"completion_tokens": None, | |
"total_tokens": None, | |
} | |
} | |
# Attempt to extract usage if present in the (potentially non-standard) final chunk | |
usage_data = final_json_response.get("usage") | |
if isinstance(usage_data, dict): | |
final_response_obj["usage"].update(usage_data) | |
return JSONResponse(content=final_response_obj) | |
# Image Generation Endpoint | |
async def create_image( | |
payload: ImageGenerationPayload, | |
authenticated: bool = Depends(verify_api_key) | |
): | |
""" | |
Generates images based on a text prompt using the configured backend. | |
""" | |
if not server_status: | |
raise HTTPException(status_code=503, detail="Server is under maintenance.") | |
if payload.model not in supported_image_models: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Model '{payload.model}' is not supported for image generation. Supported: {', '.join(supported_image_models)}" | |
) | |
usage_tracker.record_request(model=payload.model, endpoint="/images/generations") | |
env_vars = get_env_vars() | |
target_api_url = env_vars.get('new_img_endpoint') | |
if not target_api_url: | |
raise HTTPException(status_code=500, detail="Image generation endpoint (NEW_IMG) not configured.") | |
# Prepare payload for the target API (adjust keys if needed) | |
# HF Space Note: Ensure the keys match the actual API expected by NEW_IMG endpoint. | |
# Assuming it's OpenAI compatible here. | |
api_payload = { | |
"model": payload.model, | |
"prompt": payload.prompt, | |
"n": payload.n, | |
"size": payload.size | |
} | |
# Remove None values the target API might not like | |
api_payload = {k: v for k, v in api_payload.items() if v is not None} | |
logger.info(f"Requesting image generation for model '{payload.model}' from {target_api_url}") | |
client = get_async_client() | |
try: | |
# HF Space Note: Image generation can take time, use a longer timeout if needed. | |
# Consider making this truly async if the backend supports webhooks or polling. | |
response = await client.post(target_api_url, json=api_payload, timeout=120.0) # 2 min timeout | |
response.raise_for_status() # Raise HTTP errors | |
# Return the exact response from the backend | |
return JSONResponse(content=response.json()) | |
except httpx.TimeoutException: | |
logger.error(f"Image generation request to {target_api_url} timed out.") | |
raise HTTPException(status_code=504, detail="Image generation request timed out.") | |
except httpx.HTTPStatusError as e: | |
logger.error(f"Image generation API error: {e.response.status_code} - {e.response.text}") | |
detail = f"Image generation failed: Upstream API error {e.response.status_code}" | |
try: | |
err_json = e.response.json() | |
detail = err_json.get('error', {}).get('message', detail) | |
except json.JSONDecodeError: | |
pass | |
raise HTTPException(status_code=e.response.status_code, detail=detail) | |
except httpx.RequestError as e: | |
logger.error(f"Error connecting to image generation service {target_api_url}: {e}") | |
raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}") | |
except Exception as e: | |
logger.error(f"Unexpected error during image generation: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}") | |
# --- Utility & Admin Endpoints --- | |
async def log_request(request: Request, model: Optional[str] = None): | |
"""Logs basic request information asynchronously.""" | |
# HF Space Note: Avoid logging sensitive info like full IP or headers unless necessary. | |
# Hashing IP provides some privacy. | |
client_host = request.client.host if request.client else "unknown" | |
ip_hash = hash(client_host) % 10000 | |
timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") | |
log_message = f"Timestamp: {timestamp}, IP Hash: {ip_hash}, Method: {request.method}, Path: {request.url.path}" | |
if model: | |
log_message += f", Model: {model}" | |
logger.info(log_message) | |
async def get_usage(days: int = 7): | |
"""Retrieves aggregated usage statistics.""" | |
# HF Space Note: Ensure usage_tracker methods are efficient, especially get_usage_summary. | |
# Caching might be needed if it becomes slow. | |
if days <= 0: | |
raise HTTPException(status_code=400, detail="Number of days must be positive.") | |
try: | |
# Run potentially CPU-bound summary generation in executor | |
loop = asyncio.get_running_loop() | |
summary = await loop.run_in_executor(executor, usage_tracker.get_usage_summary, days) | |
return summary | |
except Exception as e: | |
logger.error(f"Error retrieving usage statistics: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail="Failed to retrieve usage statistics.") | |
# HF Space Note: Generating HTML dynamically can be resource-intensive. | |
# Consider caching the generated HTML or serving a static page updated periodically. | |
def generate_usage_html(usage_data: Dict) -> str: | |
"""Generates an HTML report from usage data.""" | |
# (Keep the HTML generation logic as provided in the original file) | |
# ... (rest of the HTML generation code from the original file) ... | |
# Ensure this function handles potentially missing keys gracefully | |
models_usage = usage_data.get('models', {}) | |
endpoints_usage = usage_data.get('api_endpoints', {}) | |
daily_usage = usage_data.get('recent_daily_usage', {}) | |
total_requests = usage_data.get('total_requests', 0) | |
model_usage_rows = "\n".join([ | |
f""" | |
<tr> | |
<td>{model}</td> | |
<td>{model_data.get('total_requests', 'N/A')}</td> | |
<td>{model_data.get('first_used', 'N/A')}</td> | |
<td>{model_data.get('last_used', 'N/A')}</td> | |
</tr> | |
""" for model, model_data in models_usage.items() | |
]) if models_usage else "<tr><td colspan='4'>No model usage data</td></tr>" | |
api_usage_rows = "\n".join([ | |
f""" | |
<tr> | |
<td>{endpoint}</td> | |
<td>{endpoint_data.get('total_requests', 'N/A')}</td> | |
<td>{endpoint_data.get('first_used', 'N/A')}</td> | |
<td>{endpoint_data.get('last_used', 'N/A')}</td> | |
</tr> | |
""" for endpoint, endpoint_data in endpoints_usage.items() | |
]) if endpoints_usage else "<tr><td colspan='4'>No API endpoint usage data</td></tr>" | |
daily_usage_rows = "\n".join([ | |
f""" | |
<tr> | |
<td>{date}</td> | |
<td>{entity}</td> | |
<td>{requests}</td> | |
</tr> | |
""" | |
for date, date_data in daily_usage.items() | |
for entity, requests in date_data.items() | |
]) if daily_usage else "<tr><td colspan='3'>No daily usage data</td></tr>" | |
# HF Space Note: Using f-string for large HTML is okay, but consider template engines (Jinja2) | |
# for more complex pages. Ensure CSS/JS are either inline or served via separate endpoints. | |
html_content = f""" | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Lokiai AI - Usage Statistics</title> | |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet"> | |
<style> | |
/* (Keep the CSS styles as provided in the original file) */ | |
:root {{ | |
--bg-dark: #0f1011; --bg-darker: #070708; --text-primary: #e6e6e6; | |
--text-secondary: #8c8c8c; --border-color: #2c2c2c; --accent-color: #3a6ee0; | |
--accent-hover: #4a7ef0; | |
}} | |
body {{ font-family: 'Inter', sans-serif; background-color: var(--bg-dark); color: var(--text-primary); max-width: 1200px; margin: 0 auto; padding: 40px 20px; line-height: 1.6; }} | |
.logo {{ display: flex; align-items: center; justify-content: center; margin-bottom: 30px; }} | |
.logo h1 {{ font-weight: 600; font-size: 2.5em; color: var(--text-primary); margin-left: 15px; }} | |
.logo img {{ width: 60px; height: 60px; border-radius: 10px; }} | |
.container {{ background-color: var(--bg-darker); border-radius: 12px; padding: 30px; box-shadow: 0 15px 40px rgba(0,0,0,0.3); border: 1px solid var(--border-color); }} | |
h2, h3 {{ color: var(--text-primary); border-bottom: 2px solid var(--border-color); padding-bottom: 10px; font-weight: 500; }} | |
.total-requests {{ background-color: var(--accent-color); color: white; text-align: center; padding: 15px; border-radius: 8px; margin-bottom: 30px; font-weight: 600; letter-spacing: -0.5px; }} | |
table {{ width: 100%; border-collapse: separate; border-spacing: 0; margin-bottom: 30px; background-color: var(--bg-dark); border-radius: 8px; overflow: hidden; }} | |
th, td {{ border: 1px solid var(--border-color); padding: 12px; text-align: left; transition: background-color 0.3s ease; }} | |
th {{ background-color: #1e1e1e; color: var(--text-primary); font-weight: 600; text-transform: uppercase; font-size: 0.9em; }} | |
tr:nth-child(even) {{ background-color: rgba(255,255,255,0.05); }} | |
tr:hover {{ background-color: rgba(62,100,255,0.1); }} | |
@media (max-width: 768px) {{ .container {{ padding: 15px; }} table {{ font-size: 0.9em; }} }} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="logo"> | |
<img src="" alt="Lokai AI Logo"> | |
<h1>Lokiai AI Usage</h1> | |
</div> | |
<div class="total-requests"> | |
Total API Requests Recorded: {total_requests} | |
</div> | |
<h2>Model Usage</h2> | |
<table> | |
<thead><tr><th>Model</th><th>Total Requests</th><th>First Used</th><th>Last Used</th></tr></thead> | |
<tbody>{model_usage_rows}</tbody> | |
</table> | |
<h2>API Endpoint Usage</h2> | |
<table> | |
<thead><tr><th>Endpoint</th><th>Total Requests</th><th>First Used</th><th>Last Used</th></tr></thead> | |
<tbody>{api_usage_rows}</tbody> | |
</table> | |
<h2>Daily Usage (Last {usage_data.get('days_analyzed', 7)} Days)</h2> | |
<table> | |
<thead><tr><th>Date</th><th>Entity (Model/Endpoint)</th><th>Requests</th></tr></thead> | |
<tbody>{daily_usage_rows}</tbody> | |
</table> | |
</div> | |
</body> | |
</html> | |
""" | |
return html_content | |
# HF Space Note: Caching the generated HTML page can save resources. | |
# Invalidate cache periodically or when usage data changes significantly. | |
usage_html_cache = {"content": None, "timestamp": 0} | |
CACHE_DURATION = 300 # Cache usage page for 5 minutes | |
async def usage_page(): | |
"""Serves an HTML page showing usage statistics.""" | |
now = time.monotonic() | |
if usage_html_cache["content"] and (now - usage_html_cache["timestamp"] < CACHE_DURATION): | |
logger.info("Serving cached usage page.") | |
return HTMLResponse(content=usage_html_cache["content"]) | |
logger.info("Generating fresh usage page.") | |
try: | |
# Run potentially slow parts in executor | |
loop = asyncio.get_running_loop() | |
usage_data = await loop.run_in_executor(executor, usage_tracker.get_usage_summary, 7) # Get data for 7 days | |
html_content = await loop.run_in_executor(executor, generate_usage_html, usage_data) | |
# Update cache | |
usage_html_cache["content"] = html_content | |
usage_html_cache["timestamp"] = now | |
return HTMLResponse(content=html_content) | |
except Exception as e: | |
logger.error(f"Failed to generate usage page: {e}", exc_info=True) | |
# Serve stale cache if available, otherwise error | |
if usage_html_cache["content"]: | |
logger.warning("Serving stale usage page due to generation error.") | |
return HTMLResponse(content=usage_html_cache["content"]) | |
else: | |
raise HTTPException(status_code=500, detail="Failed to generate usage statistics page.") | |
# Meme Endpoint | |
async def get_meme(): | |
"""Fetches a random meme and streams the image.""" | |
# HF Space Note: Ensure meme-api.com is accessible from the HF Space network. | |
client = get_async_client() | |
meme_api_url = "https://meme-api.com/gimme" | |
try: | |
logger.info("Fetching meme info...") | |
response = await client.get(meme_api_url) | |
response.raise_for_status() | |
response_data = response.json() | |
meme_url = response_data.get("url") | |
if not meme_url or not isinstance(meme_url, str): | |
logger.error(f"Invalid meme URL received from API: {meme_url}") | |
raise HTTPException(status_code=502, detail="Failed to get valid meme URL from API.") | |
logger.info(f"Fetching meme image: {meme_url}") | |
# Use streaming request for the image itself | |
async with client.stream("GET", meme_url) as image_response: | |
image_response.raise_for_status() # Check if image URL is valid | |
# Get content type, default to image/png | |
media_type = image_response.headers.get("content-type", "image/png") | |
if not media_type.startswith("image/"): | |
logger.warning(f"Unexpected content type '{media_type}' for meme URL: {meme_url}") | |
# You might want to reject non-image types | |
# raise HTTPException(status_code=502, detail="Meme URL did not return an image.") | |
# Stream the image content directly | |
return StreamingResponse( | |
image_response.aiter_bytes(), | |
media_type=media_type, | |
headers={'Cache-Control': 'no-cache'} # Don't cache the meme itself heavily | |
) | |
except httpx.HTTPStatusError as e: | |
logger.error(f"HTTP error fetching meme ({e.request.url}): {e.response.status_code}") | |
raise HTTPException(status_code=502, detail=f"Failed to fetch meme (HTTP {e.response.status_code})") | |
except httpx.RequestError as e: | |
logger.error(f"Network error fetching meme ({e.request.url}): {e}") | |
raise HTTPException(status_code=502, detail="Failed to fetch meme (Network Error)") | |
except Exception as e: | |
logger.error(f"Unexpected error fetching meme: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail="Failed to retrieve meme due to an internal error.") | |
# Health Check Endpoint | |
async def health_check(): | |
"""Provides a health check status, including missing critical configurations.""" | |
env_vars = get_env_vars() | |
missing_critical_vars = [] | |
# Define critical vars needed for core functionality | |
critical_vars = [ | |
'api_keys', 'secret_api_endpoint', 'secret_api_endpoint_2', | |
'secret_api_endpoint_3', 'secret_api_endpoint_4', 'secret_api_endpoint_5', | |
'new_img_endpoint', 'hf_space_url' | |
] | |
# Conditionally critical vars | |
if any(model in mistral_models for model in available_model_ids): | |
critical_vars.extend(['mistral_api', 'mistral_key']) | |
for var_name in critical_vars: | |
value = env_vars.get(var_name) | |
# Check for None or empty strings/lists/sets | |
if value is None or (isinstance(value, (str, list, set)) and not value): | |
missing_critical_vars.append(var_name) | |
is_healthy = not missing_critical_vars and server_status | |
status_code = 200 if is_healthy else 503 # Service Unavailable if unhealthy | |
health_status = { | |
"status": "healthy" if is_healthy else "unhealthy", | |
"server_mode": "online" if server_status else "maintenance", | |
"missing_critical_env_vars": missing_critical_vars, | |
"details": "All critical configurations seem okay. Ready to roll! π" if is_healthy else "Service issues detected. Check missing env vars or server status. π οΈ" | |
} | |
return JSONResponse(content=health_status, status_code=status_code) | |
# --- Startup and Shutdown Events --- | |
async def startup_event(): | |
"""Tasks to run when the application starts.""" | |
global available_model_ids | |
logger.info("Application startup sequence initiated...") | |
# Load models from JSON | |
models_from_file = load_models_data() | |
model_ids_from_file = {model['id'] for model in models_from_file if 'id' in model} | |
# Combine models from file and predefined sets | |
predefined_model_sets = mistral_models | pollinations_models | alternate_models | claude_3_models | |
all_model_ids = model_ids_from_file.union(predefined_model_sets) | |
available_model_ids = sorted(list(all_model_ids)) # Keep as sorted list | |
logger.info(f"Loaded {len(model_ids_from_file)} models from models.json.") | |
logger.info(f"Total {len(available_model_ids)} unique models available.") | |
# Initialize scraper pool (can take time) | |
# HF Space Note: Run potentially blocking I/O in executor during startup | |
loop = asyncio.get_running_loop() | |
await loop.run_in_executor(executor, get_scraper) # This initializes the pool | |
# Validate critical environment variables and log warnings | |
env_vars = get_env_vars() | |
logger.info("Checking critical environment variables (Secrets)...") | |
await health_check() # Run health check logic to log warnings | |
# Pre-connect async client? Optional, httpx handles connections on demand. | |
# client = get_async_client() | |
# await client.get("https://www.google.com") # Example warm-up call | |
logger.info("Startup complete. Server is ready to accept requests.") | |
async def shutdown_event(): | |
"""Tasks to run when the application shuts down.""" | |
logger.info("Application shutdown sequence initiated...") | |
# Close the httpx client gracefully | |
client = get_async_client() | |
await client.aclose() | |
logger.info("HTTP client closed.") | |
# Shutdown the thread pool executor | |
executor.shutdown(wait=True) | |
logger.info("Thread pool executor shut down.") | |
# Clear scraper pool (optional, resources will be reclaimed anyway) | |
scraper_pool.clear() | |
logger.info("Scraper pool cleared.") | |
# Persist usage data | |
# HF Space Note: Ensure file system is writable if saving locally. | |
# Consider using HF Datasets or external DB for persistent storage. | |
try: | |
logger.info("Saving usage data...") | |
usage_tracker.save_data() | |
logger.info("Usage data saved.") | |
except Exception as e: | |
logger.error(f"Failed to save usage data during shutdown: {e}") | |
logger.info("Shutdown complete.") | |
# --- Main Execution Block --- | |
# HF Space Note: This block is mainly for local testing. | |
# HF Spaces usually run the app using `uvicorn main:app --host 0.0.0.0 --port 7860` (or similar) | |
# defined in the README metadata or a Procfile. | |
if __name__ == "__main__": | |
import uvicorn | |
logger.info("Starting server locally with uvicorn...") | |
# HF Space Note: Port 7860 is the default for HF Spaces. Host 0.0.0.0 is required. | |
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") | |