Spaces:
Running
Running
import os | |
import re | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, HTTPException, Request, Depends, Security, Query | |
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse, PlainTextResponse | |
from fastapi.security import APIKeyHeader | |
from pydantic import BaseModel | |
import httpx | |
from functools import lru_cache | |
from pathlib import Path | |
import json | |
import datetime | |
import time | |
import threading | |
from typing import Optional, Dict, List, Any, Generator | |
import asyncio | |
from starlette.status import HTTP_403_FORBIDDEN | |
import cloudscraper | |
from concurrent.futures import ThreadPoolExecutor | |
import uvloop | |
from fastapi.middleware.gzip import GZipMiddleware | |
from starlette.middleware.cors import CORSMiddleware | |
import contextlib | |
import requests | |
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | |
executor = ThreadPoolExecutor(max_workers=16) | |
load_dotenv() | |
api_key_header = APIKeyHeader(name="Authorization", auto_error=False) | |
from usage_tracker import UsageTracker | |
usage_tracker = UsageTracker() | |
app = FastAPI() | |
app.add_middleware(GZipMiddleware, minimum_size=1000) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def get_env_vars(): | |
""" | |
Loads and caches environment variables. This function is memoized | |
to avoid re-reading .env file on every call, improving performance. | |
""" | |
return { | |
'api_keys': os.getenv('API_KEYS', '').split(','), | |
'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'), | |
'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'), | |
'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'), | |
'secret_api_endpoint_4': os.getenv('SECRET_API_ENDPOINT_4', "https://text.pollinations.ai/openai"), | |
'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'), | |
'secret_api_endpoint_6': os.getenv('SECRET_API_ENDPOINT_6'), # New endpoint for Gemini | |
'mistral_api': os.getenv('MISTRAL_API', "https://api.mistral.ai"), | |
'mistral_key': os.getenv('MISTRAL_KEY'), | |
'gemini_key': os.getenv('GEMINI_KEY'), # Gemini API Key | |
'endpoint_origin': os.getenv('ENDPOINT_ORIGIN'), | |
'new_img': os.getenv('NEW_IMG') # For image generation API | |
} | |
# Define sets of models for different API endpoints for easier routing | |
mistral_models = { | |
"mistral-large-latest", "pixtral-large-latest", "mistral-moderation-latest", | |
"ministral-3b-latest", "ministral-8b-latest", "open-mistral-nemo", | |
"mistral-small-latest", "mistral-saba-latest", "codestral-latest" | |
} | |
pollinations_models = { | |
"openai", "openai-large", "openai-fast", "openai-xlarge", "openai-reasoning", | |
"qwen-coder", "llama", "mistral", "searchgpt", "deepseek", "claude-hybridspace", | |
"deepseek-r1", "deepseek-reasoner", "llamalight", "gemini", "gemini-thinking", | |
"hormoz", "phi", "phi-mini", "openai-audio", "llama-scaleway" | |
} | |
alternate_models = { | |
"o1", "llama-4-scout", "o4-mini", "sonar", "sonar-pro", "sonar-reasoning", | |
"sonar-reasoning-pro", "grok-3", "grok-3-fast", "r1-1776", "o3" | |
} | |
claude_3_models = { | |
"claude-3-7-sonnet", "claude-3-7-sonnet-thinking", "claude 3.5 haiku", | |
"claude 3.5 sonnet", "claude 3.5 haiku", "o3-mini-medium", "o3-mini-high", | |
"grok-3", "grok-3-thinking", "grok 2" | |
} | |
gemini_models = { | |
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-2.0-flash-lite-preview", | |
"gemini-2.0-flash", "gemini-2.0-flash-thinking", # aka Reasoning | |
"gemini-2.0-flash-preview-image-generation", "gemini-2.5-flash", | |
"gemini-2.5-pro-exp", "gemini-exp-1206" | |
} | |
supported_image_models = { | |
"Flux Pro Ultra", "grok-2-aurora", "Flux Pro", "Flux Pro Ultra Raw", | |
"Flux Dev", "Flux Schnell", "stable-diffusion-3-large-turbo", | |
"Flux Realism", "stable-diffusion-ultra", "dall-e-3", "sdxl-lightning-4step" | |
} | |
class Payload(BaseModel): | |
"""Pydantic model for chat completion requests.""" | |
model: str | |
messages: list | |
stream: bool = False | |
class ImageGenerationPayload(BaseModel): | |
"""Pydantic model for image generation requests.""" | |
model: str | |
prompt: str | |
size: str = "1024x1024" # Default size, assuming models support it | |
number: int = 1 | |
server_status = True # Global flag for server maintenance status | |
available_model_ids: List[str] = [] # List of all available model IDs | |
def get_async_client(): | |
"""Returns a memoized httpx.AsyncClient instance for making async HTTP requests.""" | |
return httpx.AsyncClient( | |
timeout=60.0, | |
limits=httpx.Limits(max_keepalive_connections=50, max_connections=200) | |
) | |
scraper_pool = [] | |
MAX_SCRAPERS = 20 | |
def get_scraper(): | |
"""Retrieves a cloudscraper instance from a pool for web scraping.""" | |
if not scraper_pool: | |
# Initialize the pool if it's empty (should be done at startup) | |
for _ in range(MAX_SCRAPERS): | |
scraper_pool.append(cloudscraper.create_scraper()) | |
# Simple round-robin selection from the pool | |
return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS] | |
async def verify_api_key( | |
request: Request, | |
api_key: str = Security(api_key_header) | |
) -> bool: | |
""" | |
Verifies the API key provided in the Authorization header. | |
Allows access without API key if the request comes from specific Hugging Face spaces. | |
""" | |
referer = request.headers.get("referer", "") | |
if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground", | |
"https://parthsadaria-lokiai.hf.space/image-playground")): | |
return True | |
if not api_key: | |
raise HTTPException( | |
status_code=HTTP_403_FORBIDDEN, | |
detail="No API key provided" | |
) | |
if api_key.startswith('Bearer '): | |
api_key = api_key[7:] | |
valid_api_keys = get_env_vars().get('api_keys', []) | |
if not valid_api_keys or valid_api_keys == ['']: | |
raise HTTPException( | |
status_code=HTTP_403_FORBIDDEN, | |
detail="API keys not configured on server" | |
) | |
if api_key not in set(valid_api_keys): | |
raise HTTPException( | |
status_code=HTTP_403_FORBIDDEN, | |
detail="Invalid API key" | |
) | |
return True | |
def load_models_data(): | |
"""Loads model data from 'models.json' and caches it.""" | |
try: | |
file_path = Path(__file__).parent / 'models.json' | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
except (FileNotFoundError, json.JSONDecodeError) as e: | |
print(f"Error loading models.json: {str(e)}") | |
return [] | |
async def get_models(): | |
"""Returns the list of available models.""" | |
models_data = load_models_data() | |
if not models_data: | |
raise HTTPException(status_code=500, detail="Error loading available models") | |
return models_data | |
async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True): | |
""" | |
Asynchronously generates a response using a search-based model. | |
Streams results if `stream` is True. | |
""" | |
queue = asyncio.Queue() | |
async def _fetch_search_data(): | |
"""Internal helper to fetch data from the search API and put into queue.""" | |
try: | |
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} | |
system_message = systemprompt or "Be Helpful and Friendly" | |
prompt = [{"role": "user", "content": query}] | |
prompt.insert(0, {"content": system_message, "role": "system"}) | |
payload = { | |
"is_vscode_extension": True, | |
"message_history": prompt, | |
"requested_model": "searchgpt", | |
"user_input": prompt[-1]["content"], | |
} | |
secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3'] | |
if not secret_api_endpoint_3: | |
await queue.put({"error": "Search API endpoint not configured"}) | |
return | |
async with httpx.AsyncClient(timeout=30.0) as client: | |
async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response: | |
if response.status_code != 200: | |
error_detail = await response.text() | |
await queue.put({"error": f"Search API returned status code {response.status_code}: {error_detail}"}) | |
return | |
buffer = "" | |
async for line in response.aiter_lines(): | |
if line.startswith("data: "): | |
try: | |
json_data = json.loads(line[6:]) | |
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "") | |
if content.strip(): | |
cleaned_response = { | |
"created": json_data.get("created"), | |
"id": json_data.get("id"), | |
"model": "searchgpt", | |
"object": "chat.completion", | |
"choices": [ | |
{ | |
"message": { | |
"content": content | |
} | |
} | |
] | |
} | |
await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content}) | |
except json.JSONDecodeError: | |
# If line is not valid JSON, treat it as raw text and pass through if it's the end of stream | |
if line.strip() == "[DONE]": | |
continue # This is usually handled by the aiter_lines loop finishing | |
print(f"Warning: Could not decode JSON from search API stream: {line}") | |
await queue.put({"error": f"Invalid JSON from search API: {line}"}) | |
break # Stop processing on bad JSON | |
await queue.put(None) # Signal end of stream | |
except Exception as e: | |
print(f"Error in _fetch_search_data: {e}") | |
await queue.put({"error": str(e)}) | |
await queue.put(None) | |
asyncio.create_task(_fetch_search_data()) | |
return queue | |
def read_html_file(file_path): | |
"""Reads content of an HTML file and caches it.""" | |
try: | |
with open(file_path, "r") as file: | |
return file.read() | |
except FileNotFoundError: | |
return None | |
# Static file routes for basic web assets | |
async def favicon(): | |
favicon_path = Path(__file__).parent / "favicon.ico" | |
return FileResponse(favicon_path, media_type="image/x-icon") | |
async def banner(): | |
banner_path = Path(__file__).parent / "banner.jpg" | |
return FileResponse(banner_path, media_type="image/jpeg") | |
async def ping(): | |
"""Simple health check endpoint.""" | |
return {"message": "pong", "response_time": "0.000000 seconds"} | |
async def root(): | |
"""Serves the main index.html file.""" | |
html_content = read_html_file("index.html") | |
if html_content is None: | |
raise HTTPException(status_code=404, detail="index.html not found") | |
return HTMLResponse(content=html_content) | |
async def script(): | |
"""Serves script.js.""" | |
html_content = read_html_file("script.js") | |
if html_content is None: | |
raise HTTPException(status_code=404, detail="script.js not found") | |
return HTMLResponse(content=html_content) | |
async def style(): | |
"""Serves style.css.""" | |
html_content = read_html_file("style.css") | |
if html_content is None: | |
raise HTTPException(status_code=404, detail="style.css not found") | |
return HTMLResponse(content=html_content) | |
async def dynamic_ai_page(request: Request): | |
""" | |
Generates a dynamic HTML page using an AI model based on user-agent and IP. | |
Note: The hardcoded API endpoint and bearer token should ideally be managed | |
more securely, perhaps via environment variables and proper authentication. | |
""" | |
user_agent = request.headers.get('user-agent', 'Unknown User') | |
client_ip = request.client.host if request.client else "Unknown IP" | |
location = f"IP: {client_ip}" | |
prompt = f""" | |
Generate a dynamic HTML page for a user with the following details: with name "LOKI.AI" | |
- User-Agent: {user_agent} | |
- Location: {location} | |
- Style: Cyberpunk, minimalist, or retro | |
Make sure the HTML is clean and includes a heading, also have cool animations a motivational message, and a cool background. | |
Wrap the generated HTML in triple backticks (```). | |
""" | |
payload = { | |
"model": "mistral-small-latest", | |
"messages": [{"role": "user", "content": prompt}] | |
} | |
# Using the local /chat/completions endpoint for internal model call | |
# This assumes the current server can proxy to Mistral. | |
# For production, consider direct calls if not proxying is needed. | |
headers = { | |
"Authorization": "Bearer playground" # Use a dedicated internal token if available | |
} | |
try: | |
# Use httpx.AsyncClient for making an async request | |
async with httpx.AsyncClient() as client: | |
response = await client.post( | |
f"http://localhost:7860/chat/completions", # Call self or internal API | |
json=payload, | |
headers=headers, | |
timeout=30.0 | |
) | |
response.raise_for_status() # Raise an exception for bad status codes | |
data = response.json() | |
html_content = None | |
if data and 'choices' in data and len(data['choices']) > 0: | |
message_content = data['choices'][0].get('message', {}).get('content', '') | |
# Extract content within triple backticks | |
match = re.search(r"```(?:html)?(.*?)```", message_content, re.DOTALL) | |
if match: | |
html_content = match.group(1).strip() | |
else: | |
# Fallback: if no backticks, assume the whole content is HTML | |
html_content = message_content.strip() | |
if not html_content: | |
raise HTTPException(status_code=500, detail="Failed to generate HTML content from AI.") | |
return HTMLResponse(content=html_content) | |
except httpx.RequestError as e: | |
print(f"HTTPX Request Error in /dynamo: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to connect to internal AI service: {e}") | |
except httpx.HTTPStatusError as e: | |
print(f"HTTPX Status Error in /dynamo: {e.response.status_code} - {e.response.text}") | |
raise HTTPException(status_code=e.response.status_code, detail=f"Internal AI service responded with error: {e.response.text}") | |
except Exception as e: | |
print(f"An unexpected error occurred in /dynamo: {e}") | |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}") | |
async def scrape_site(url: str = Query(..., description="URL to scrape")): | |
""" | |
Scrapes the content of a given URL using cloudscraper. | |
Uses await in front of get_scraper().get() for async execution. | |
""" | |
try: | |
# get_scraper() returns a synchronous scraper object, but we are running | |
# it in an async endpoint. For CPU-bound tasks like this, it's better | |
# to offload to a thread pool to not block the event loop. | |
# However, cloudscraper's get method is typically synchronous. | |
# If cloudscraper were truly async, we'd use await. | |
# For now, running in executor to prevent blocking. | |
loop = asyncio.get_running_loop() | |
response_text = await loop.run_in_executor( | |
executor, | |
lambda: get_scraper().get(url).text | |
) | |
if response_text and len(response_text.strip()) > 0: | |
return PlainTextResponse(response_text) | |
else: | |
raise HTTPException(status_code=500, detail="Scraping returned empty content.") | |
except Exception as e: | |
print(f"Cloudscraper failed: {e}") | |
raise HTTPException(status_code=500, detail=f"Cloudscraper failed: {e}") | |
async def playground(): | |
"""Serves the playground.html file.""" | |
html_content = read_html_file("playground.html") | |
if html_content is None: | |
raise HTTPException(status_code=404, detail="playground.html not found") | |
return HTMLResponse(content=html_content) | |
async def image_playground(): | |
"""Serves the image-playground.html file.""" | |
html_content = read_html_file("image-playground.html") | |
if html_content is None: | |
raise HTTPException(status_code=404, detail="image-playground.html not found") | |
return HTMLResponse(content=html_content) | |
GITHUB_BASE = "[https://raw.githubusercontent.com/Parthsadaria/Vetra/main](https://raw.githubusercontent.com/Parthsadaria/Vetra/main)" | |
FILES = { | |
"html": "index.html", | |
"css": "style.css", | |
"js": "script.js" | |
} | |
async def get_github_file(filename: str) -> Optional[str]: | |
"""Fetches a file from a specified GitHub raw URL.""" | |
url = f"{GITHUB_BASE}/{filename}" | |
async with httpx.AsyncClient() as client: | |
try: | |
res = await client.get(url, follow_redirects=True) | |
res.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx) | |
return res.text | |
except httpx.HTTPStatusError as e: | |
print(f"Error fetching {filename} from GitHub: {e.response.status_code} - {e.response.text}") | |
return None | |
except httpx.RequestError as e: | |
print(f"Request error fetching {filename} from GitHub: {e}") | |
return None | |
async def serve_vetra(): | |
""" | |
Serves a dynamic HTML page by fetching HTML, CSS, and JS from GitHub | |
and embedding them into a single HTML response. | |
""" | |
html = await get_github_file(FILES["html"]) | |
css = await get_github_file(FILES["css"]) | |
js = await get_github_file(FILES["js"]) | |
if not html: | |
raise HTTPException(status_code=404, detail="index.html not found on GitHub") | |
final_html = html.replace( | |
"</head>", | |
f"<style>{css or '/* CSS not found */'}</style></head>" | |
).replace( | |
"</body>", | |
f"<script>{js or '// JS not found'}</script></body>" | |
) | |
return HTMLResponse(content=final_html) | |
async def search_gpt(q: str, request: Request, stream: Optional[bool] = False, systemprompt: Optional[str] = None): | |
""" | |
Endpoint for search-based AI completion. | |
Records usage and streams results. | |
""" | |
if not q: | |
raise HTTPException(status_code=400, detail="Query parameter 'q' is required") | |
# Record usage for searchgpt endpoint | |
usage_tracker.record_request(request=request, model="searchgpt", endpoint="/searchgpt") | |
queue = await generate_search_async(q, systemprompt=systemprompt, stream=True) | |
if stream: | |
async def stream_generator(): | |
"""Generator for streaming search results.""" | |
collected_text = "" | |
while True: | |
item = await queue.get() | |
if item is None: | |
break | |
if "error" in item: | |
# Yield error as a data event so client can handle it gracefully | |
yield f"data: {json.dumps({'error': item['error']})}\n\n" | |
break | |
if "data" in item: | |
yield item["data"] | |
collected_text += item.get("text", "") | |
return StreamingResponse( | |
stream_generator(), | |
media_type="text/event-stream" | |
) | |
else: | |
# Non-streaming response: collect all chunks and return as JSON | |
collected_text = "" | |
while True: | |
item = await queue.get() | |
if item is None: | |
break | |
if "error" in item: | |
raise HTTPException(status_code=500, detail=item["error"]) | |
collected_text += item.get("text", "") | |
return JSONResponse(content={"response": collected_text}) | |
header_url = os.getenv('HEADER_URL') # This variable should be configured in .env | |
async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)): | |
""" | |
Proxies chat completion requests to various AI model endpoints based on the model specified in the payload. | |
Records usage and handles streaming responses. | |
""" | |
if not server_status: | |
raise HTTPException( | |
status_code=503, | |
detail="Server is under maintenance. Please try again later." | |
) | |
model_to_use = payload.model or "gpt-4o-mini" # Default model | |
# Validate if the requested model is available | |
if available_model_ids and model_to_use not in set(available_model_ids): | |
raise HTTPException( | |
status_code=400, | |
detail=f"Model '{model_to_use}' is not available. Check /models for the available model list." | |
) | |
# Record usage before making the external API call | |
usage_tracker.record_request(request=request, model=model_to_use, endpoint="/chat/completions") | |
payload_dict = payload.dict() | |
payload_dict["model"] = model_to_use # Ensure the payload has the resolved model name | |
stream_enabled = payload_dict.get("stream", True) # Default to streaming if not specified | |
env_vars = get_env_vars() | |
endpoint = None | |
custom_headers = {} | |
target_url_path = "/v1/chat/completions" # Default path for OpenAI-like APIs | |
# Determine the correct endpoint and headers based on the model | |
if model_to_use in mistral_models: | |
endpoint = env_vars['mistral_api'] | |
custom_headers = { | |
"Authorization": f"Bearer {env_vars['mistral_key']}" | |
} | |
elif model_to_use in pollinations_models: | |
endpoint = env_vars['secret_api_endpoint_4'] | |
custom_headers = {} # Pollinations.ai might not require auth | |
elif model_to_use in alternate_models: | |
endpoint = env_vars['secret_api_endpoint_2'] | |
custom_headers = {} | |
elif model_to_use in claude_3_models: | |
endpoint = env_vars['secret_api_endpoint_5'] | |
custom_headers = {} # Assuming no specific auth needed for this proxy | |
elif model_to_use in gemini_models: | |
endpoint = env_vars['secret_api_endpoint_6'] | |
if not endpoint: | |
raise HTTPException(status_code=500, detail="Gemini API endpoint (SECRET_API_ENDPOINT_6) not configured.") | |
if not env_vars['gemini_key']: | |
raise HTTPException(status_code=500, detail="GEMINI_KEY not configured for Gemini models.") | |
custom_headers = { | |
"Authorization": f"Bearer {env_vars['gemini_key']}" | |
} | |
target_url_path = "/chat/completions" # Gemini's specific path | |
else: | |
# Default fallback for other models (e.g., OpenAI compatible APIs) | |
endpoint = env_vars['secret_api_endpoint'] | |
custom_headers = { | |
"Origin": header_url, | |
"Priority": "u=1, i", | |
"Referer": header_url | |
} | |
if not endpoint: | |
raise HTTPException(status_code=500, detail=f"No API endpoint configured for model: {model_to_use}") | |
print(f"Proxying request for model '{model_to_use}' to endpoint: {endpoint}{target_url_path}") | |
async def real_time_stream_generator(): | |
"""Generator to stream responses from the upstream API.""" | |
try: | |
async with httpx.AsyncClient(timeout=60.0) as client: | |
# Stream the request to the upstream API | |
async with client.stream("POST", f"{endpoint}{target_url_path}", json=payload_dict, headers=custom_headers) as response: | |
# Handle non-2xx responses from the upstream API | |
if response.status_code >= 400: | |
error_messages = { | |
400: "Bad request. Verify input data.", | |
401: "Unauthorized. Invalid API key for upstream service.", | |
403: "Forbidden. You do not have access to this resource on upstream.", | |
404: "The requested resource was not found on upstream.", | |
422: "Unprocessable entity. Check your payload for upstream API.", | |
500: "Internal server error from upstream API." | |
} | |
detail_message = error_messages.get(response.status_code, f"Upstream error code: {response.status_code}") | |
# Attempt to read upstream error response body for more detail | |
try: | |
error_body = await response.aread() | |
error_json = json.loads(error_body.decode('utf-8')) | |
if 'error' in error_json and 'message' in error_json['error']: | |
detail_message += f" - Upstream detail: {error_json['error']['message']}" | |
elif 'detail' in error_json: | |
detail_message += f" - Upstream detail: {error_json['detail']}" | |
else: | |
detail_message += f" - Upstream raw: {error_body.decode('utf-8')[:200]}..." # Limit for logging | |
except (json.JSONDecodeError, UnicodeDecodeError): | |
detail_message += f" - Upstream raw: {error_body.decode('utf-8', errors='ignore')[:200]}..." | |
raise HTTPException(status_code=response.status_code, detail=detail_message) | |
# Yield each line from the upstream stream | |
async for line in response.aiter_lines(): | |
if line: | |
yield line + "\n" | |
except httpx.TimeoutException: | |
raise HTTPException(status_code=504, detail="Request to upstream AI service timed out.") | |
except httpx.RequestError as e: | |
raise HTTPException(status_code=502, detail=f"Failed to connect to upstream AI service: {str(e)}") | |
except Exception as e: | |
# Re-raise HTTPException if it's already one, otherwise wrap in a 500 | |
if isinstance(e, HTTPException): | |
raise e | |
print(f"An unexpected error occurred during chat completion proxy: {e}") | |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") | |
if stream_enabled: | |
return StreamingResponse( | |
real_time_stream_generator(), | |
media_type="text/event-stream", | |
headers={ | |
"Content-Type": "text/event-stream", | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"X-Accel-Buffering": "no" # Disable buffering for SSE | |
} | |
) | |
else: | |
# For non-streaming requests, collect all parts and return a single JSON response | |
response_content_lines = [] | |
async for line in real_time_stream_generator(): | |
response_content_lines.append(line) | |
full_response_text = "".join(response_content_lines) | |
# Parse the concatenated stream data. This often involves stripping "data: " prefix | |
# and combining JSON objects from each line. | |
parsed_data = [] | |
for line in full_response_text.splitlines(): | |
if line.startswith("data: "): | |
try: | |
parsed_data.append(json.loads(line[6:])) | |
except json.JSONDecodeError: | |
print(f"Warning: Could not decode JSON line in non-streaming response: {line}") | |
# Attempt to reconstruct a single coherent JSON response | |
# This logic might need refinement based on actual API response format for non-streaming | |
final_json_response = {} | |
if parsed_data: | |
# Example: For OpenAI-like API, you might want the last 'choices' part | |
# This is a simplification and might need adjustment for other APIs | |
if 'choices' in parsed_data[-1]: | |
final_json_response = parsed_data[-1] | |
else: | |
# Fallback: just return the list of parsed objects | |
final_json_response = {"response_parts": parsed_data} | |
if not final_json_response: | |
# If nothing was parsed, indicate an issue | |
raise HTTPException(status_code=500, detail="No valid JSON response received from upstream API for non-streaming request.") | |
return JSONResponse(content=final_json_response) | |
async def create_image(payload: ImageGenerationPayload, request: Request, authenticated: bool = Depends(verify_api_key)): | |
""" | |
Proxies image generation requests to a dedicated image generation API. | |
Records usage. | |
""" | |
if not server_status: | |
raise HTTPException( | |
status_code=503, | |
detail="Server is under maintenance. Please try again later." | |
) | |
if payload.model not in supported_image_models: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {', '.join(supported_image_models)}" | |
) | |
# Record usage for image generation endpoint | |
usage_tracker.record_request(request=request, model=payload.model, endpoint="/images/generations") | |
api_payload = { | |
"model": payload.model, | |
"prompt": payload.prompt, | |
"size": payload.size, | |
"n": payload.number # Often 'n' for number of images in APIs | |
} | |
target_api_url = get_env_vars().get('new_img') # Get the image API URL from env vars | |
if not target_api_url: | |
raise HTTPException(status_code=500, detail="Image generation API endpoint (NEW_IMG) not configured.") | |
try: | |
async with httpx.AsyncClient(timeout=60.0) as client: | |
response = await client.post(target_api_url, json=api_payload) | |
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) | |
return JSONResponse(content=response.json()) | |
except httpx.TimeoutException: | |
raise HTTPException(status_code=504, detail="Image generation request timed out.") | |
except httpx.RequestError as e: | |
raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}") | |
except httpx.HTTPStatusError as e: | |
error_detail = e.response.json().get("detail", f"Image generation failed with status code: {e.response.status_code}") | |
raise HTTPException(status_code=e.response.status_code, detail=error_detail) | |
except Exception as e: | |
print(f"An unexpected error occurred during image generation: {e}") | |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}") | |
async def get_usage_json(days: int = 7): | |
""" | |
Returns the raw usage data as JSON. | |
Can specify the number of days for the summary. | |
""" | |
return usage_tracker.get_usage_summary(days) | |
def generate_usage_html(usage_data: Dict[str, Any]): | |
""" | |
Generates an HTML page to display usage statistics. | |
Includes tables for model, API endpoint usage, daily usage, and recent requests. | |
Also includes placeholders for Chart.js to render graphs. | |
""" | |
# Prepare data for Chart.js | |
# Model Usage Chart Data | |
model_labels = list(usage_data['model_usage_period'].keys()) | |
model_counts = list(usage_data['model_usage_period'].values()) | |
# Endpoint Usage Chart Data | |
endpoint_labels = list(usage_data['endpoint_usage_period'].keys()) | |
endpoint_counts = list(usage_data['endpoint_usage_period'].values()) | |
# Daily Usage Chart Data | |
daily_dates = list(usage_data['daily_usage_period'].keys()) | |
daily_requests = [data['requests'] for data in usage_data['daily_usage_period'].values()] | |
daily_unique_ips = [data['unique_ips_count'] for data in usage_data['daily_usage_period'].values()] | |
# Format table rows for HTML | |
model_usage_all_time_rows = "\n".join([ | |
f""" | |
<tr> | |
<td>{model}</td> | |
<td>{stats['total_requests']}</td> | |
<td>{datetime.datetime.fromisoformat(stats['first_used']).strftime("%Y-%m-%d %H:%M")}</td> | |
<td>{datetime.datetime.fromisoformat(stats['last_used']).strftime("%Y-%m-%d %H:%M")}</td> | |
</tr> | |
""" for model, stats in usage_data['all_time_model_usage'].items() | |
]) | |
api_usage_all_time_rows = "\n".join([ | |
f""" | |
<tr> | |
<td>{endpoint}</td> | |
<td>{stats['total_requests']}</td> | |
<td>{datetime.datetime.fromisoformat(stats['first_used']).strftime("%Y-%m-%d %H:%M")}</td> | |
<td>{datetime.datetime.fromisoformat(stats['last_used']).strftime("%Y-%m-%d %H:%M")}</td> | |
</tr> | |
""" for endpoint, stats in usage_data['all_time_endpoint_usage'].items() | |
]) | |
daily_usage_table_rows = "\n".join([ | |
f""" | |
<tr> | |
<td>{date}</td> | |
<td>{data['requests']}</td> | |
<td>{data['unique_ips_count']}</td> | |
</tr> | |
""" for date, data in usage_data['daily_usage_period'].items() | |
]) | |
recent_requests_rows = "\n".join([ | |
f""" | |
<tr> | |
<td>{datetime.datetime.fromisoformat(req['timestamp']).strftime("%Y-%m-%d %H:%M:%S")}</td> | |
<td>{req['model']}</td> | |
<td>{req['endpoint']}</td> | |
<td>{req['ip_address']}</td> | |
<td>{req['user_agent']}</td> | |
</tr> | |
""" for req in usage_data['recent_requests'] | |
]) | |
html_content = f""" | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Lokiai AI - Usage Statistics</title> | |
<link href="[https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;700&display=swap](https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;700&display=swap)" rel="stylesheet"> | |
<script src="[https://cdn.jsdelivr.net/npm/chart.js](https://cdn.jsdelivr.net/npm/chart.js)"></script> | |
<style> | |
:root {{ | |
--bg-dark: #0f1011; | |
--bg-darker: #070708; | |
--text-primary: #e6e6e6; | |
--text-secondary: #8c8c8c; | |
--border-color: #2c2c2c; | |
--accent-color: #3a6ee0; | |
--accent-hover: #4a7ef0; | |
--chart-bg-light: rgba(58, 110, 224, 0.2); | |
--chart-border-light: #3a6ee0; | |
}} | |
body {{ | |
font-family: 'Inter', sans-serif; | |
background-color: var(--bg-dark); | |
color: var(--text-primary); | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 40px 20px; | |
line-height: 1.6; | |
}} | |
.logo {{ | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
margin-bottom: 30px; | |
}} | |
.logo h1 {{ | |
font-weight: 700; | |
font-size: 2.8em; | |
color: var(--text-primary); | |
margin-left: 15px; | |
}} | |
.logo img {{ | |
width: 70px; | |
height: 70px; | |
border-radius: 12px; | |
box-shadow: 0 5px 15px rgba(0,0,0,0.2); | |
}} | |
.container {{ | |
background-color: var(--bg-darker); | |
border-radius: 16px; | |
padding: 30px; | |
box-shadow: 0 20px 50px rgba(0,0,0,0.4); | |
border: 1px solid var(--border-color); | |
}} | |
h2, h3 {{ | |
color: var(--text-primary); | |
border-bottom: 2px solid var(--border-color); | |
padding-bottom: 12px; | |
margin-top: 40px; | |
margin-bottom: 25px; | |
font-weight: 600; | |
font-size: 1.8em; | |
}} | |
.summary-grid {{ | |
display: grid; | |
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); | |
gap: 20px; | |
margin-bottom: 30px; | |
}} | |
.summary-card {{ | |
background-color: var(--bg-dark); | |
border-radius: 10px; | |
padding: 20px; | |
text-align: center; | |
border: 1px solid var(--border-color); | |
box-shadow: 0 8px 20px rgba(0,0,0,0.2); | |
transition: transform 0.2s ease-in-out; | |
}} | |
.summary-card:hover {{ | |
transform: translateY(-5px); | |
}} | |
.summary-card h3 {{ | |
margin-top: 0; | |
font-size: 1.1em; | |
color: var(--text-secondary); | |
border-bottom: none; | |
padding-bottom: 0; | |
margin-bottom: 10px; | |
}} | |
.summary-card p {{ | |
font-size: 2.2em; | |
font-weight: 700; | |
color: var(--accent-color); | |
margin: 0; | |
}} | |
table {{ | |
width: 100%; | |
border-collapse: separate; | |
border-spacing: 0; | |
margin-bottom: 40px; | |
background-color: var(--bg-dark); | |
border-radius: 10px; | |
overflow: hidden; | |
box-shadow: 0 8px 20px rgba(0,0,0,0.2); | |
}} | |
th, td {{ | |
border: 1px solid var(--border-color); | |
padding: 15px; | |
text-align: left; | |
transition: background-color 0.3s ease; | |
}} | |
th {{ | |
background-color: #1a1a1a; | |
color: var(--text-primary); | |
font-weight: 600; | |
text-transform: uppercase; | |
font-size: 0.95em; | |
}} | |
tr:nth-child(even) {{ | |
background-color: rgba(255,255,255,0.03); | |
}} | |
tr:hover {{ | |
background-color: rgba(62,100,255,0.1); | |
}} | |
.chart-container {{ | |
background-color: var(--bg-dark); | |
border-radius: 10px; | |
padding: 20px; | |
margin-bottom: 40px; | |
border: 1px solid var(--border-color); | |
box-shadow: 0 8px 20px rgba(0,0,0,0.2); | |
max-height: 400px; /* Limit chart height */ | |
position: relative; /* For responsive canvas */ | |
}} | |
canvas {{ | |
max-width: 100% !important; | |
height: auto !important; | |
}} | |
@media (max-width: 768px) {{ | |
body {{ | |
padding: 20px 10px; | |
}} | |
.container {{ | |
padding: 20px; | |
}} | |
.logo h1 {{ | |
font-size: 2em; | |
}} | |
.summary-card p {{ | |
font-size: 1.8em; | |
}} | |
h2, h3 {{ | |
font-size: 1.5em; | |
}} | |
table {{ | |
font-size: 0.85em; | |
}} | |
th, td {{ | |
padding: 10px; | |
}} | |
}} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="logo"> | |
<img src="" alt="Lokiai AI Logo"> | |
<h1>Lokiai AI Usage</h1> | |
</div> | |
<div class="summary-grid"> | |
<div class="summary-card"> | |
<h3>Total Requests (All Time)</h3> | |
<p>{usage_data['total_requests']}</p> | |
</div> | |
<div class="summary-card"> | |
<h3>Unique IPs (All Time)</h3> | |
<p>{usage_data['unique_ips_total_count']}</p> | |
</div> | |
<div class="summary-card"> | |
<h3>Models Used (Last {days} Days)</h3> | |
<p>{len(usage_data['model_usage_period'])}</p> | |
</div> | |
<div class="summary-card"> | |
<h3>Endpoints Used (Last {days} Days)</h3> | |
<p>{len(usage_data['endpoint_usage_period'])}</p> | |
</div> | |
</div> | |
<h2>Daily Usage (Last {days} Days)</h2> | |
<div class="chart-container"> | |
<canvas id="dailyRequestsChart"></canvas> | |
</div> | |
<table> | |
<thead> | |
<tr> | |
<th>Date</th> | |
<th>Requests</th> | |
<th>Unique IPs</th> | |
</tr> | |
</thead> | |
<tbody> | |
{daily_usage_table_rows} | |
</tbody> | |
</table> | |
<h2>Model Usage (Last {days} Days)</h2> | |
<div class="chart-container"> | |
<canvas id="modelUsageChart"></canvas> | |
</div> | |
<h3>Model Usage (All Time Details)</h3> | |
<table> | |
<thead> | |
<tr> | |
<th>Model</th> | |
<th>Total Requests</th> | |
<th>First Used</th> | |
<th>Last Used</th> | |
</tr> | |
</thead> | |
<tbody> | |
{model_usage_all_time_rows} | |
</tbody> | |
</table> | |
<h2>API Endpoint Usage (Last {days} Days)</h2> | |
<div class="chart-container"> | |
<canvas id="endpointUsageChart"></canvas> | |
</div> | |
<h3>API Endpoint Usage (All Time Details)</h3> | |
<table> | |
<thead> | |
<tr> | |
<th>Endpoint</th> | |
<th>Total Requests</th> | |
<th>First Used</th> | |
<th>Last Used</th> | |
</tr> | |
</thead> | |
<tbody> | |
{api_usage_all_time_rows} | |
</tbody> | |
</table> | |
<h2>Recent Requests (Last 20)</h2> | |
<table> | |
<thead> | |
<tr> | |
<th>Timestamp</th> | |
<th>Model</th> | |
<th>Endpoint</th> | |
<th>IP Address</th> | |
<th>User Agent</th> | |
</tr> | |
</thead> | |
<tbody> | |
{recent_requests_rows} | |
</tbody> | |
</table> | |
</div> | |
<script> | |
// Chart.js data and rendering logic | |
const modelLabels = {json.dumps(model_labels)}; | |
const modelCounts = {json.dumps(model_counts)}; | |
const endpointLabels = {json.dumps(endpoint_labels)}; | |
const endpointCounts = {json.dumps(endpoint_counts)}; | |
const dailyDates = {json.dumps(daily_dates)}; | |
const dailyRequests = {json.dumps(daily_requests)}; | |
const dailyUniqueIps = {json.dumps(daily_unique_ips)}; | |
// Model Usage Chart (Bar Chart) | |
new Chart(document.getElementById('modelUsageChart'), {{ | |
type: 'bar', | |
data: {{ | |
labels: modelLabels, | |
datasets: [{{ | |
label: 'Requests', | |
data: modelCounts, | |
backgroundColor: 'var(--chart-bg-light)', | |
borderColor: 'var(--chart-border-light)', | |
borderWidth: 1, | |
borderRadius: 5, | |
}}] | |
}}, | |
options: {{ | |
responsive: true, | |
maintainAspectRatio: false, | |
plugins: {{ | |
legend: {{ | |
labels: {{ | |
color: 'var(--text-primary)' | |
}} | |
}}, | |
title: {{ | |
display: true, | |
text: 'Model Usage', | |
color: 'var(--text-primary)' | |
}} | |
}}, | |
scales: {{ | |
x: {{ | |
ticks: {{ | |
color: 'var(--text-secondary)' | |
}}, | |
grid: {{ | |
color: 'var(--border-color)' | |
}} | |
}}, | |
y: {{ | |
beginAtZero: true, | |
ticks: {{ | |
color: 'var(--text-secondary)' | |
}}, | |
grid: {{ | |
color: 'var(--border-color)' | |
}} | |
}} | |
}} | |
}} | |
}}); | |
// Endpoint Usage Chart (Doughnut Chart) | |
new Chart(document.getElementById('endpointUsageChart'), {{ | |
type: 'doughnut', | |
data: {{ | |
labels: endpointLabels, | |
datasets: [{{ | |
label: 'Requests', | |
data: endpointCounts, | |
backgroundColor: [ | |
'#3a6ee0', '#5b8bff', '#8dc4ff', '#b3d8ff', '#d0e8ff', | |
'#FF6384', '#36A2EB', '#FFCE56', '#4BC0C0', '#9966FF' | |
], | |
hoverOffset: 4 | |
}}] | |
}}, | |
options: {{ | |
responsive: true, | |
maintainAspectRatio: false, | |
plugins: {{ | |
legend: {{ | |
position: 'right', | |
labels: {{ | |
color: 'var(--text-primary)' | |
}} | |
}}, | |
title: {{ | |
display: true, | |
text: 'API Endpoint Usage', | |
color: 'var(--text-primary)' | |
}} | |
}} | |
}} | |
}}); | |
// Daily Requests Chart (Line Chart) | |
new Chart(document.getElementById('dailyRequestsChart'), {{ | |
type: 'line', | |
data: {{ | |
labels: dailyDates, | |
datasets: [ | |
{{ | |
label: 'Total Requests', | |
data: dailyRequests, | |
borderColor: 'var(--accent-color)', | |
backgroundColor: 'rgba(58, 110, 224, 0.1)', | |
fill: true, | |
tension: 0.3 | |
}}, | |
{{ | |
label: 'Unique IPs', | |
data: dailyUniqueIps, | |
borderColor: '#FFCE56', // A distinct color for unique IPs | |
backgroundColor: 'rgba(255, 206, 86, 0.1)', | |
fill: true, | |
tension: 0.3 | |
}} | |
] | |
}}, | |
options: {{ | |
responsive: true, | |
maintainAspectRatio: false, | |
plugins: {{ | |
legend: {{ | |
labels: {{ | |
color: 'var(--text-primary)' | |
}} | |
}}, | |
title: {{ | |
display: true, | |
text: 'Daily Requests and Unique IPs', | |
color: 'var(--text-primary)' | |
}} | |
}}, | |
scales: {{ | |
x: {{ | |
ticks: {{ | |
color: 'var(--text-secondary)' | |
}}, | |
grid: {{ | |
color: 'var(--border-color)' | |
}} | |
}}, | |
y: {{ | |
beginAtZero: true, | |
ticks: {{ | |
color: 'var(--text-secondary)' | |
}}, | |
grid: {{ | |
color: 'var(--border-color)' | |
}} | |
}} | |
}} | |
}} | |
}}); | |
</script> | |
</body> | |
</html> | |
""" | |
return html_content | |
async def usage_page(days: int = 7): | |
""" | |
Serves a detailed HTML page with usage statistics and charts. | |
The 'days' query parameter can be used to specify the reporting period for charts. | |
""" | |
usage_data = usage_tracker.get_usage_summary(days=days) | |
html_content = generate_usage_html(usage_data) | |
return HTMLResponse(content=html_content) | |
async def get_meme(): | |
""" | |
Fetches a random meme from meme-api.com and streams the image content. | |
Handles potential errors during fetching. | |
""" | |
try: | |
client = get_async_client() | |
response = await client.get("[https://meme-api.com/gimme](https://meme-api.com/gimme)") | |
response.raise_for_status() # Raise an exception for bad status codes | |
response_data = response.json() | |
meme_url = response_data.get("url") | |
if not meme_url: | |
raise HTTPException(status_code=404, detail="No meme URL found in response.") | |
# Stream the image content back to the client | |
image_response = await client.get(meme_url, follow_redirects=True) | |
image_response.raise_for_status() | |
async def stream_with_larger_chunks(): | |
"""Streams binary data in larger chunks for efficiency.""" | |
chunks = [] | |
size = 0 | |
# Define a larger chunk size for better streaming performance | |
chunk_size = 65536 # 64 KB | |
async for chunk in image_response.aiter_bytes(chunk_size=chunk_size): | |
chunks.append(chunk) | |
size += len(chunk) | |
if size >= chunk_size * 2: # Send chunks when accumulated size is significant | |
yield b''.join(chunks) | |
chunks = [] | |
size = 0 | |
if chunks: # Yield any remaining chunks | |
yield b''.join(chunks) | |
return StreamingResponse( | |
stream_with_larger_chunks(), | |
media_type=image_response.headers.get("content-type", "image/png"), # Fallback to png | |
headers={'Cache-Control': 'max-age=3600'} # Cache memes for 1 hour | |
) | |
except httpx.HTTPStatusError as e: | |
print(f"Error fetching meme from upstream: {e.response.status_code} - {e.response.text}") | |
raise HTTPException(status_code=e.response.status_code, detail=f"Failed to fetch meme: {e.response.text}") | |
except httpx.RequestError as e: | |
print(f"Request error fetching meme: {e}") | |
raise HTTPException(status_code=502, detail=f"Could not connect to meme service: {e}") | |
except Exception as e: | |
print(f"An unexpected error occurred while getting meme: {e}") | |
raise HTTPException(status_code=500, detail="Failed to retrieve meme due to an unexpected error.") | |
def load_model_ids(json_file_path: str) -> List[str]: | |
""" | |
Loads model IDs from a JSON file. | |
This helps in dynamically determining available models. | |
""" | |
try: | |
with open(json_file_path, 'r') as f: | |
models_data = json.load(f) | |
return [model['id'] for model in models_data if 'id' in model] | |
except Exception as e: | |
print(f"Error loading model IDs from {json_file_path}: {str(e)}") | |
return [] | |
async def startup_event(): | |
""" | |
Actions to perform on application startup: | |
- Load available model IDs. | |
- Initialize scraper pool. | |
- Check for missing environment variables and issue warnings. | |
""" | |
global available_model_ids | |
# Load models from a local models.json file first | |
available_model_ids = load_model_ids("models.json") | |
print(f"Loaded {len(available_model_ids)} model IDs from models.json") | |
# Extend with hardcoded model lists for various providers | |
available_model_ids.extend(list(pollinations_models)) | |
available_model_ids.extend(list(alternate_models)) | |
available_model_ids.extend(list(mistral_models)) | |
available_model_ids.extend(list(claude_3_models)) | |
available_model_ids.extend(list(gemini_models)) # Add Gemini models explicitly | |
# Remove duplicates and store as a set for faster lookups | |
available_model_ids = list(set(available_model_ids)) | |
print(f"Total unique available models after merging: {len(available_model_ids)}") | |
# Initialize scraper pool | |
for _ in range(MAX_SCRAPERS): | |
scraper_pool.append(cloudscraper.create_scraper()) | |
print(f"Initialized Cloudscraper pool with {MAX_SCRAPERS} instances.") | |
# Environment variable check for critical services | |
env_vars = get_env_vars() | |
missing_vars = [] | |
if not env_vars['api_keys'] or env_vars['api_keys'] == ['']: | |
missing_vars.append('API_KEYS') | |
if not env_vars['secret_api_endpoint']: | |
missing_vars.append('SECRET_API_ENDPOINT') | |
if not env_vars['secret_api_endpoint_2']: | |
missing_vars.append('SECRET_API_ENDPOINT_2') | |
if not env_vars['secret_api_endpoint_3']: | |
missing_vars.append('SECRET_API_ENDPOINT_3') | |
if not env_vars['secret_api_endpoint_4'] and any(model in pollinations_models for model in available_model_ids): | |
missing_vars.append('SECRET_API_ENDPOINT_4 (Pollinations.ai)') | |
if not env_vars['secret_api_endpoint_5'] and any(model in claude_3_models for model in available_model_ids): | |
missing_vars.append('SECRET_API_ENDPOINT_5 (Claude 3.x)') | |
if not env_vars['secret_api_endpoint_6'] and any(model in gemini_models for model in available_model_ids): | |
missing_vars.append('SECRET_API_ENDPOINT_6 (Gemini)') | |
if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids): | |
missing_vars.append('MISTRAL_API') | |
if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids): | |
missing_vars.append('MISTRAL_KEY') | |
if not env_vars['gemini_key'] and any(model in gemini_models for model in available_model_ids): | |
missing_vars.append('GEMINI_KEY') | |
if not env_vars['new_img'] and len(supported_image_models) > 0: | |
missing_vars.append('NEW_IMG (Image Generation)') | |
if missing_vars: | |
print(f"WARNING: The following critical environment variables are missing or empty: {', '.join(missing_vars)}") | |
print("Some server functionality (e.g., specific AI models, image generation) may be limited or unavailable.") | |
else: | |
print("All critical environment variables appear to be configured.") | |
print("Server started successfully!") | |
async def shutdown_event(): | |
""" | |
Actions to perform on application shutdown: | |
- Close HTTPX client. | |
- Clear scraper pool. | |
- Save usage data to disk. | |
""" | |
client = get_async_client() | |
await client.aclose() # Ensure the httpx client connection pool is closed | |
scraper_pool.clear() # Clear the scraper pool | |
usage_tracker.save_data() # Persist usage data on shutdown | |
print("Server shutdown complete!") | |
async def health_check(): | |
""" | |
Provides a health check endpoint, reporting server status and missing critical environment variables. | |
""" | |
env_vars = get_env_vars() | |
missing_critical_vars = [] | |
# Re-check critical environment variables for health status | |
if not env_vars['api_keys'] or env_vars['api_keys'] == ['']: | |
missing_critical_vars.append('API_KEYS') | |
if not env_vars['secret_api_endpoint']: | |
missing_critical_vars.append('SECRET_API_ENDPOINT') | |
if not env_vars['secret_api_endpoint_2']: | |
missing_critical_vars.append('SECRET_API_ENDPOINT_2') | |
if not env_vars['secret_api_endpoint_3']: | |
missing_critical_vars.append('SECRET_API_ENDPOINT_3') | |
# Check for specific service endpoints only if corresponding models are configured/supported | |
if not env_vars['secret_api_endpoint_4'] and any(model in pollinations_models for model in available_model_ids): | |
missing_critical_vars.append('SECRET_API_ENDPOINT_4 (Pollinations.ai)') | |
if not env_vars['secret_api_endpoint_5'] and any(model in claude_3_models for model in available_model_ids): | |
missing_critical_vars.append('SECRET_API_ENDPOINT_5 (Claude 3.x)') | |
if not env_vars['secret_api_endpoint_6'] and any(model in gemini_models for model in available_model_ids): | |
missing_critical_vars.append('SECRET_API_ENDPOINT_6 (Gemini)') | |
if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids): | |
missing_critical_vars.append('MISTRAL_API') | |
if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids): | |
missing_critical_vars.append('MISTRAL_KEY') | |
if not env_vars['gemini_key'] and any(model in gemini_models for model in available_model_ids): | |
missing_critical_vars.append('GEMINI_KEY') | |
if not env_vars['new_img'] and len(supported_image_models) > 0: | |
missing_critical_vars.append('NEW_IMG (Image Generation)') | |
health_status = { | |
"status": "healthy" if not missing_critical_vars else "unhealthy", | |
"missing_env_vars": missing_critical_vars, | |
"server_status": server_status, # Reports global server status flag | |
"message": "Everything's lit! π" if not missing_critical_vars else "Uh oh, some env vars are missing. π¬" | |
} | |
return JSONResponse(content=health_status) | |
if __name__ == "__main__": | |
import uvicorn | |
# When running directly, ensure startup_event is called to load models and check env vars | |
# uvicorn handles startup/shutdown events automatically when run with `uvicorn.run()` | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |