lokiai / main.py
ParthSadaria's picture
Update main.py
dbffb4e verified
import os
import re
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, Depends, Security, Query
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse, PlainTextResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
import httpx
from functools import lru_cache
from pathlib import Path
import json
import datetime
import time
import threading
from typing import Optional, Dict, List, Any, Generator
import asyncio
from starlette.status import HTTP_403_FORBIDDEN
import cloudscraper
from concurrent.futures import ThreadPoolExecutor
import uvloop
from fastapi.middleware.gzip import GZipMiddleware
from starlette.middleware.cors import CORSMiddleware
import contextlib
import requests
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
executor = ThreadPoolExecutor(max_workers=16)
load_dotenv()
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
from usage_tracker import UsageTracker
usage_tracker = UsageTracker()
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@lru_cache(maxsize=1)
def get_env_vars():
return {
'api_keys': os.getenv('API_KEYS', '').split(','),
'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'),
'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'),
'secret_api_endpoint_6': os.getenv('SECRET_API_ENDPOINT_6'), # New endpoint for Gemini
'mistral_api': "https://api.mistral.ai",
'mistral_key': os.getenv('MISTRAL_KEY'),
'gemini_key': os.getenv('GEMINI_KEY'), # Gemini API Key
'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
}
mistral_models = {
"mistral-large-latest",
"pixtral-large-latest",
"mistral-moderation-latest",
"ministral-3b-latest",
"ministral-8b-latest",
"open-mistral-nemo",
"mistral-small-latest",
"mistral-saba-latest",
"codestral-latest"
}
pollinations_models = {
"openai",
"openai-large",
"openai-fast",
"openai-xlarge",
"openai-reasoning",
"qwen-coder",
"llama",
"mistral",
"searchgpt",
"deepseek",
"claude-hybridspace",
"deepseek-r1",
"deepseek-reasoner",
"llamalight",
"gemini",
"gemini-thinking",
"hormoz",
"phi",
"phi-mini",
"openai-audio",
"llama-scaleway"
}
alternate_models = {
"o1",
"llama-4-scout",
"o4-mini",
"sonar",
"sonar-pro",
"sonar-reasoning",
"sonar-reasoning-pro",
"grok-3",
"grok-3-fast",
"r1-1776",
"o3"
}
claude_3_models = {
"claude-3-7-sonnet",
"claude-3-7-sonnet-thinking",
"claude 3.5 haiku",
"claude 3.5 sonnet",
"claude 3.5 haiku",
"o3-mini-medium",
"o3-mini-high",
"grok-3",
"grok-3-thinking",
"grok 2"
}
gemini_models = {
"gemini-1.5-pro",
"gemini-1.5-flash",
"gemini-2.0-flash-lite-preview",
"gemini-2.0-flash",
"gemini-2.0-flash-thinking", # aka Reasoning
"gemini-2.0-flash-preview-image-generation",
"gemini-2.5-flash",
"gemini-2.5-pro-exp",
"gemini-exp-1206"
}
supported_image_models = {
"Flux Pro Ultra",
"grok-2-aurora",
"Flux Pro",
"Flux Pro Ultra Raw",
"Flux Dev",
"Flux Schnell",
"stable-diffusion-3-large-turbo",
"Flux Realism",
"stable-diffusion-ultra",
"dall-e-3",
"sdxl-lightning-4step"
}
class Payload(BaseModel):
model: str
messages: list
stream: bool = False
class ImageGenerationPayload(BaseModel):
model: str
prompt: str
size: int
number: int
server_status = True
available_model_ids: List[str] = []
@lru_cache(maxsize=1)
def get_async_client():
return httpx.AsyncClient(
timeout=60.0,
limits=httpx.Limits(max_keepalive_connections=50, max_connections=200)
)
scraper_pool = []
MAX_SCRAPERS = 20
def get_scraper():
if not scraper_pool:
for _ in range(MAX_SCRAPERS):
scraper_pool.append(cloudscraper.create_scraper())
return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS]
async def verify_api_key(
request: Request,
api_key: str = Security(api_key_header)
) -> bool:
referer = request.headers.get("referer", "")
if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground",
"https://parthsadaria-lokiai.hf.space/image-playground")):
return True
if not api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="No API key provided"
)
if api_key.startswith('Bearer '):
api_key = api_key[7:]
valid_api_keys = get_env_vars().get('api_keys', [])
if not valid_api_keys or valid_api_keys == ['']:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="API keys not configured on server"
)
if api_key not in set(valid_api_keys):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid API key"
)
return True
@lru_cache(maxsize=1)
def load_models_data():
try:
file_path = Path(__file__).parent / 'models.json'
with open(file_path, 'r') as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading models.json: {str(e)}")
return []
async def get_models():
models_data = load_models_data()
if not models_data:
raise HTTPException(status_code=500, detail="Error loading available models")
return models_data
async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
queue = asyncio.Queue()
async def _fetch_search_data():
try:
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
system_message = systemprompt or "Be Helpful and Friendly"
prompt = [{"role": "user", "content": query}]
prompt.insert(0, {"content": system_message, "role": "system"})
payload = {
"is_vscode_extension": True,
"message_history": prompt,
"requested_model": "searchgpt",
"user_input": prompt[-1]["content"],
}
secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
if not secret_api_endpoint_3:
await queue.put({"error": "Search API endpoint not configured"})
return
async with httpx.AsyncClient(timeout=30.0) as client:
async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response:
if response.status_code != 200:
await queue.put({"error": f"Search API returned status code {response.status_code}"})
return
buffer = ""
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
json_data = json.loads(line[6:])
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content.strip():
cleaned_response = {
"created": json_data.get("created"),
"id": json_data.get("id"),
"model": "searchgpt",
"object": "chat.completion",
"choices": [
{
"message": {
"content": content
}
}
]
}
await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content})
except json.JSONDecodeError:
continue
await queue.put(None)
except Exception as e:
await queue.put({"error": str(e)})
await queue.put(None)
asyncio.create_task(_fetch_search_data())
return queue
@lru_cache(maxsize=10)
def read_html_file(file_path):
try:
with open(file_path, "r") as file:
return file.read()
except FileNotFoundError:
return None
@app.get("/favicon.ico")
async def favicon():
favicon_path = Path(__file__).parent / "favicon.ico"
return FileResponse(favicon_path, media_type="image/x-icon")
@app.get("/banner.jpg")
async def banner():
banner_path = Path(__file__).parent / "banner.jpg"
return FileResponse(banner_path, media_type="image/jpeg")
@app.get("/ping")
async def ping():
return {"message": "pong", "response_time": "0.000000 seconds"}
@app.get("/", response_class=HTMLResponse)
async def root():
html_content = read_html_file("index.html")
if html_content is None:
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
return HTMLResponse(content=html_content)
@app.get("/script.js", response_class=HTMLResponse)
async def script():
html_content = read_html_file("script.js")
if html_content is None:
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
return HTMLResponse(content=html_content)
@app.get("/style.css", response_class=HTMLResponse)
async def style():
html_content = read_html_file("style.css")
if html_content is None:
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
return HTMLResponse(content=html_content)
@app.get("/dynamo", response_class=HTMLResponse)
async def dynamic_ai_page(request: Request):
user_agent = request.headers.get('user-agent', 'Unknown User')
client_ip = request.client.host
location = f"IP: {client_ip}"
prompt = f"""
Generate a dynamic HTML page for a user with the following details: with name "LOKI.AI"
- User-Agent: {user_agent}
- Location: {location}
- Style: Cyberpunk, minimalist, or retro
Make sure the HTML is clean and includes a heading, also have cool animations a motivational message, and a cool background.
Wrap the generated HTML in triple backticks (```).
"""
payload = {
"model": "mistral-small-latest",
"messages": [{"role": "user", "content": prompt}]
}
headers = {
"Authorization": "Bearer playground"
}
response = requests.post("[https://parthsadaria-lokiai.hf.space/chat/completions](https://parthsadaria-lokiai.hf.space/chat/completions)", json=payload, headers=headers)
data = response.json()
html_content = re.search(r"```(.*?)```", data['choices'][0]['message']['content'], re.DOTALL)
if html_content:
html_content = html_content.group(1).strip()
if html_content:
html_content = ' '.join(html_content.split(' ')[1:])
return HTMLResponse(content=html_content)
@app.get("/scraper", response_class=PlainTextResponse)
def scrape_site(url: str = Query(..., description="URL to scrape")):
try:
scraper = cloudscraper.create_scraper()
response = scraper.get(url)
if response.status_code == 200 and len(response.text.strip()) > 0:
return response.text
except Exception as e:
print(f"Cloudscraper failed: {e}")
return "Cloudscraper failed."
@app.get("/playground", response_class=HTMLResponse)
async def playground():
html_content = read_html_file("playground.html")
if html_content is None:
return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
return HTMLResponse(content=html_content)
@app.get("/image-playground", response_class=HTMLResponse)
async def image_playground():
html_content = read_html_file("image-playground.html")
if html_content is None:
return HTMLResponse(content="<h1>image-playground.html not found</h1>", status_code=404)
return HTMLResponse(content=html_content)
GITHUB_BASE = "[https://raw.githubusercontent.com/Parthsadaria/Vetra/main](https://raw.githubusercontent.com/Parthsadaria/Vetra/main)"
FILES = {
"html": "index.html",
"css": "style.css",
"js": "script.js"
}
async def get_github_file(filename: str) -> str:
url = f"{GITHUB_BASE}/{filename}"
async with httpx.AsyncClient() as client:
res = await client.get(url)
return res.text if res.status_code == 200 else None
@app.get("/vetra", response_class=HTMLResponse)
async def serve_vetra():
html = await get_github_file(FILES["html"])
css = await get_github_file(FILES["css"])
js = await get_github_file(FILES["js"])
if not html:
return HTMLResponse(content="<h1>index.html not found on GitHub</h1>", status_code=404)
final_html = html.replace(
"</head>",
f"<style>{css or '/* CSS not found */'}</style></head>"
).replace(
"</body>",
f"<script>{js or '// JS not found'}</script></body>"
)
return HTMLResponse(content=final_html)
@app.get("/api/v1/models")
@app.get("/models")
async def return_models():
return await get_models()
@app.get("/searchgpt")
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
if not q:
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
usage_tracker.record_request(endpoint="/searchgpt")
queue = await generate_search_async(q, systemprompt=systemprompt, stream=True)
if stream:
async def stream_generator():
collected_text = ""
while True:
item = await queue.get()
if item is None:
break
if "error" in item:
yield f"data: {json.dumps({'error': item['error']})}\n\n"
break
if "data" in item:
yield item["data"]
collected_text += item.get("text", "")
return StreamingResponse(
stream_generator(),
media_type="text/event-stream"
)
else:
collected_text = ""
while True:
item = await queue.get()
if item is None:
break
if "error" in item:
raise HTTPException(status_code=500, detail=item["error"])
collected_text += item.get("text", "")
return JSONResponse(content={"response": collected_text})
header_url = os.getenv('HEADER_URL')
@app.post("/chat/completions")
@app.post("/api/v1/chat/completions")
async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
if not server_status:
return JSONResponse(
status_code=503,
content={"message": "Server is under maintenance. Please try again later."}
)
model_to_use = payload.model or "gpt-4o-mini"
if available_model_ids and model_to_use not in set(available_model_ids):
raise HTTPException(
status_code=400,
detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
)
asyncio.create_task(log_request(request, model_to_use))
usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
payload_dict = payload.dict()
payload_dict["model"] = model_to_use
stream_enabled = payload_dict.get("stream", True)
env_vars = get_env_vars()
target_url_path = "/v1/chat/completions" # Default path
if model_to_use in mistral_models:
endpoint = env_vars['mistral_api']
custom_headers = {
"Authorization": f"Bearer {env_vars['mistral_key']}"
}
elif model_to_use in pollinations_models:
endpoint = env_vars['secret_api_endpoint_4']
custom_headers = {}
elif model_to_use in alternate_models:
endpoint = env_vars['secret_api_endpoint_2']
custom_headers = {}
elif model_to_use in claude_3_models:
endpoint = env_vars['secret_api_endpoint_5']
custom_headers = {}
elif model_to_use in gemini_models: # Handle Gemini models
endpoint = env_vars['secret_api_endpoint_6']
if not endpoint:
raise HTTPException(status_code=500, detail="Gemini API endpoint not configured")
if not env_vars['gemini_key']:
raise HTTPException(status_code=500, detail="GEMINI_KEY not configured")
custom_headers = {
"Authorization": f"Bearer {env_vars['gemini_key']}"
}
target_url_path = "/chat/completions" # Use /chat/completions for Gemini
else:
endpoint = env_vars['secret_api_endpoint']
custom_headers = {
"Origin": header_url,
"Priority": "u=1, i",
"Referer": header_url
}
print(f"Using endpoint: {endpoint} with path: {target_url_path} for model: {model_to_use}")
async def real_time_stream_generator():
try:
async with httpx.AsyncClient(timeout=60.0) as client:
async with client.stream("POST", f"{endpoint}{target_url_path}", json=payload_dict, headers=custom_headers) as response:
if response.status_code >= 400:
error_messages = {
422: "Unprocessable entity. Check your payload.",
400: "Bad request. Verify input data.",
403: "Forbidden. You do not have access to this resource.",
404: "The requested resource was not found.",
}
detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
raise HTTPException(status_code=response.status_code, detail=detail)
async for line in response.aiter_lines():
if line:
yield line + "\n"
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Request timed out")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Failed to connect to upstream API: {str(e)}")
except Exception as e:
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
if stream_enabled:
return StreamingResponse(
real_time_stream_generator(),
media_type="text/event-stream",
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
else:
response_content = []
async for chunk in real_time_stream_generator():
response_content.append(chunk)
return JSONResponse(content=json.loads(''.join(response_content)))
@app.post("/images/generations")
async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)):
if not server_status:
return JSONResponse(
status_code=503,
content={"message": "Server is under maintenance. Please try again later."}
)
if payload.model not in supported_image_models:
raise HTTPException(
status_code=400,
detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {supported_image_models}"
)
usage_tracker.record_request(model=payload.model, endpoint="/images/generations")
api_payload = {
"model": payload.model,
"prompt": payload.prompt,
"size": payload.size,
"number": payload.number
}
target_api_url = os.getenv('NEW_IMG')
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(target_api_url, json=api_payload)
if response.status_code != 200:
error_detail = response.json().get("detail", f"Image generation failed with status code: {response.status_code}")
raise HTTPException(status_code=response.status_code, detail=error_detail)
return JSONResponse(content=response.json())
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Image generation request timed out.")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
async def log_request(request, model):
current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
ip_hash = hash(request.client.host) % 10000
print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
@lru_cache(maxsize=10)
def get_usage_summary(days=7):
return usage_tracker.get_usage_summary(days)
@app.get("/usage")
async def get_usage(days: int = 7):
return get_usage_summary(days)
def generate_usage_html(usage_data):
model_usage_rows = "\n".join([
f"""
<tr>
<td>{model}</td>
<td>{model_data['total_requests']}</td>
<td>{model_data['first_used']}</td>
<td>{model_data['last_used']}</td>
</tr>
""" for model, model_data in usage_data['models'].items()
])
api_usage_rows = "\n".join([
f"""
<tr>
<td>{endpoint}</td>
<td>{endpoint_data['total_requests']}</td>
<td>{endpoint_data['first_used']}</td>
<td>{endpoint_data['last_used']}</td>
</tr>
""" for endpoint, endpoint_data in usage_data['api_endpoints'].items()
])
daily_usage_rows = "\n".join([
"\n".join([
f"""
<tr>
<td>{date}</td>
<td>{entity}</td>
<td>{requests}</td>
</tr>
""" for entity, requests in date_data.items()
]) for date, date_data in usage_data['recent_daily_usage'].items()
])
html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Lokiai AI - Usage Statistics</title>
<link href="[https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap](https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap)" rel="stylesheet">
<style>
:root {{
--bg-dark: #0f1011;
--bg-darker: #070708;
--text-primary: #e6e6e6;
--text-secondary: #8c8c8c;
--border-color: #2c2c2c;
--accent-color: #3a6ee0;
--accent-hover: #4a7ef0;
}}
body {{
font-family: 'Inter', sans-serif;
background-color: var(--bg-dark);
color: var(--text-primary);
max-width: 1200px;
margin: 0 auto;
padding: 40px 20px;
line-height: 1.6;
}}
.logo {{
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 30px;
}}
.logo h1 {{
font-weight: 600;
font-size: 2.5em;
color: var(--text-primary);
margin-left: 15px;
}}
.logo img {{
width: 60px;
height: 60px;
border-radius: 10px;
}}
.container {{
background-color: var(--bg-darker);
border-radius: 12px;
padding: 30px;
box-shadow: 0 15px 40px rgba(0,0,0,0.3);
border: 1px solid var(--border-color);
}}
h2, h3 {{
color: var(--text-primary);
border-bottom: 2px solid var(--border-color);
padding-bottom: 10px;
font-weight: 500;
}}
.total-requests {{
background-color: var(--accent-color);
color: white;
text-align: center;
padding: 15px;
border-radius: 8px;
margin-bottom: 30px;
font-weight: 600;
letter-spacing: -0.5px;
}}
table {{
width: 100%;
border-collapse: separate;
border-spacing: 0;
margin-bottom: 30px;
background-color: var(--bg-dark);
border-radius: 8px;
overflow: hidden;
}}
th, td {{
border: 1px solid var(--border-color);
padding: 12px;
text-align: left;
transition: background-color 0.3s ease;
}}
th {{
background-color: #1e1e1e;
color: var(--text-primary);
font-weight: 600;
text-transform: uppercase;
font-size: 0.9em;
}}
tr:nth-child(even) {{
background-color: rgba(255,255,255,0.05);
}}
tr:hover {{
background-color: rgba(62,100,255,0.1);
}}
@media (max-width: 768px) {{
.container {{
padding: 15px;
}}
table {{
font-size: 0.9em;
}}
}}
</style>
</head>
<body>
<div class="container">
<div class="logo">
<img src="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjAwIiBoZWlnaHQ9IjIwMCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBkPSJNMTAwIDM1TDUwIDkwaDEwMHoiIGZpbGw9IiMzYTZlZTAiLz48Y2lyY2xlIGN4PSIxMDAiIGN5PSIxNDAiIHI9IjMwIiBmaWxsPSIjM2E2ZWUwIi8+PC9zdmc+" alt="Lokai AI Logo">
<h1>Lokiai AI</h1>
</div>
<div class="total-requests">
Total API Requests: {usage_data['total_requests']}
</div>
<h2>Model Usage</h2>
<table>
<tr>
<th>Model</th>
<th>Total Requests</th>
<th>First Used</th>
<th>Last Used</th>
</tr>
{model_usage_rows}
</table>
<h2>API Endpoint Usage</h2>
<table>
<tr>
<th>Endpoint</th>
<th>Total Requests</th>
<th>First Used</th>
<th>Last Used</th>
</tr>
{api_usage_rows}
</table>
<h2>Daily Usage (Last 7 Days)</h2>
<table>
<tr>
<th>Date</th>
<th>Entity</th>
<th>Requests</th>
</tr>
{daily_usage_rows}
</table>
</div>
</body>
</html>
"""
return html_content
@lru_cache(maxsize=1)
def get_usage_page_html():
usage_data = get_usage_summary()
return generate_usage_html(usage_data)
@app.get("/usage/page", response_class=HTMLResponse)
async def usage_page():
html_content = get_usage_page_html()
return HTMLResponse(content=html_content)
@app.get("/meme")
async def get_meme():
try:
client = get_async_client()
response = await client.get("[https://meme-api.com/gimme](https://meme-api.com/gimme)")
response_data = response.json()
meme_url = response_data.get("url")
if not meme_url:
raise HTTPException(status_code=404, detail="No meme found")
image_response = await client.get(meme_url, follow_redirects=True)
async def stream_with_larger_chunks():
chunks = []
size = 0
async for chunk in image_response.aiter_bytes(chunk_size=16384):
chunks.append(chunk)
size += len(chunk)
if size >= 65536:
yield b''.join(chunks)
chunks = []
size = 0
if chunks:
yield b''.join(chunks)
return StreamingResponse(
stream_with_larger_chunks(),
media_type=image_response.headers.get("content-type", "image/png"),
headers={'Cache-Control': 'max-age=3600'}
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to retrieve meme")
def load_model_ids(json_file_path):
try:
with open(json_file_path, 'r') as f:
models_data = json.load(f)
return [model['id'] for model in models_data if 'id' in model]
except Exception as e:
print(f"Error loading model IDs: {str(e)}")
return []
@app.on_event("startup")
async def startup_event():
global available_model_ids
available_model_ids = load_model_ids("models.json")
print(f"Loaded {len(available_model_ids)} model IDs")
available_model_ids.extend(list(pollinations_models))
available_model_ids.extend(list(alternate_models))
available_model_ids.extend(list(mistral_models))
available_model_ids.extend(list(claude_3_models))
available_model_ids.extend(list(gemini_models)) # Add Gemini models
available_model_ids = list(set(available_model_ids))
print(f"Total available models: {len(available_model_ids)}")
for _ in range(MAX_SCRAPERS):
scraper_pool.append(cloudscraper.create_scraper())
env_vars = get_env_vars()
missing_vars = []
if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
missing_vars.append('API_KEYS')
if not env_vars['secret_api_endpoint']:
missing_vars.append('SECRET_API_ENDPOINT')
if not env_vars['secret_api_endpoint_2']:
missing_vars.append('SECRET_API_ENDPOINT_2')
if not env_vars['secret_api_endpoint_3']:
missing_vars.append('SECRET_API_ENDPOINT_3')
if not env_vars['secret_api_endpoint_4']:
missing_vars.append('SECRET_API_ENDPOINT_4')
if not env_vars['secret_api_endpoint_5']:
missing_vars.append('SECRET_API_ENDPOINT_5')
if not env_vars['secret_api_endpoint_6']: # Check the new endpoint
missing_vars.append('SECRET_API_ENDPOINT_6')
if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids):
missing_vars.append('MISTRAL_API')
if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
missing_vars.append('MISTRAL_KEY')
if not env_vars['gemini_key'] and any(model in gemini_models for model in available_model_ids): # Check Gemini key
missing_vars.append('GEMINI_KEY')
if missing_vars:
print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
print("Some functionality may be limited.")
print("Server started successfully!")
@app.on_event("shutdown")
async def shutdown_event():
client = get_async_client()
await client.aclose()
scraper_pool.clear()
usage_tracker.save_data()
print("Server shutdown complete!")
@app.get("/health")
async def health_check():
env_vars = get_env_vars()
missing_critical_vars = []
if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
missing_critical_vars.append('API_KEYS')
if not env_vars['secret_api_endpoint']:
missing_critical_vars.append('SECRET_API_ENDPOINT')
if not env_vars['secret_api_endpoint_2']:
missing_critical_vars.append('SECRET_API_ENDPOINT_2')
if not env_vars['secret_api_endpoint_3']:
missing_critical_vars.append('SECRET_API_ENDPOINT_3')
if not env_vars['secret_api_endpoint_4']:
missing_critical_vars.append('SECRET_API_ENDPOINT_4')
if not env_vars['secret_api_endpoint_5']:
missing_critical_vars.append('SECRET_API_ENDPOINT_5')
if not env_vars['secret_api_endpoint_6']: # Check the new endpoint
missing_critical_vars.append('SECRET_API_ENDPOINT_6')
if not env_vars['mistral_api']:
missing_critical_vars.append('MISTRAL_API')
if not env_vars['mistral_key']:
missing_critical_vars.append('MISTRAL_KEY')
if not env_vars['gemini_key']: # Check Gemini key
missing_critical_vars.append('GEMINI_KEY')
health_status = {
"status": "healthy" if not missing_critical_vars else "unhealthy",
"missing_env_vars": missing_critical_vars,
"server_status": server_status,
"message": "Everything's lit! πŸš€" if not missing_critical_vars else "Uh oh, some env vars are missing. 😬"
}
return JSONResponse(content=health_status)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)