ParthSadaria commited on
Commit
3b8f2de
·
verified ·
1 Parent(s): 0a0ab04

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +81 -181
main.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import re
3
  from dotenv import load_dotenv
4
  from fastapi import FastAPI, HTTPException, Request, Depends, Security, Query
5
- from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse, PlainTextResponse
6
  from fastapi.security import APIKeyHeader
7
  from pydantic import BaseModel
8
  import httpx
@@ -22,25 +22,20 @@ from fastapi.middleware.gzip import GZipMiddleware
22
  from starlette.middleware.cors import CORSMiddleware
23
  import contextlib
24
  import requests
25
- # Enable uvloop for faster event loop
26
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
27
 
28
- # Thread pool for CPU-bound operations
29
- executor = ThreadPoolExecutor(max_workers=16) # Increased thread count for better parallelism
30
 
31
- # Load environment variables once at startup
32
  load_dotenv()
33
 
34
- # API key security scheme
35
  api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
36
 
37
- # Initialize usage tracker
38
  from usage_tracker import UsageTracker
39
  usage_tracker = UsageTracker()
40
 
41
  app = FastAPI()
42
 
43
- # Add middleware for compression and CORS
44
  app.add_middleware(GZipMiddleware, minimum_size=1000)
45
  app.add_middleware(
46
  CORSMiddleware,
@@ -50,7 +45,6 @@ app.add_middleware(
50
  allow_headers=["*"],
51
  )
52
 
53
- # Environment variables (cached)
54
  @lru_cache(maxsize=1)
55
  def get_env_vars():
56
  return {
@@ -59,13 +53,14 @@ def get_env_vars():
59
  'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
60
  'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'),
61
  'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
62
- 'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'), # Added new endpoint
 
63
  'mistral_api': "https://api.mistral.ai",
64
  'mistral_key': os.getenv('MISTRAL_KEY'),
 
65
  'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
66
  }
67
 
68
- # Configuration for models - use sets for faster lookups
69
  mistral_models = {
70
  "mistral-large-latest",
71
  "pixtral-large-latest",
@@ -115,7 +110,7 @@ alternate_models = {
115
  "o3"
116
  }
117
 
118
- claude_3_models = { # Models for the new endpoint
119
  "claude-3-7-sonnet",
120
  "claude-3-7-sonnet-thinking",
121
  "claude 3.5 haiku",
@@ -128,7 +123,19 @@ claude_3_models = { # Models for the new endpoint
128
  "grok 2"
129
  }
130
 
131
- # Supported image generation models
 
 
 
 
 
 
 
 
 
 
 
 
132
  supported_image_models = {
133
  "Flux Pro Ultra",
134
  "grok-2-aurora",
@@ -143,86 +150,70 @@ supported_image_models = {
143
  "sdxl-lightning-4step"
144
  }
145
 
146
-
147
- # Request payload model
148
  class Payload(BaseModel):
149
  model: str
150
  messages: list
151
  stream: bool = False
152
 
153
-
154
- # Image generation payload model
155
  class ImageGenerationPayload(BaseModel):
156
  model: str
157
  prompt: str
158
  size: int
159
  number: int
160
 
161
-
162
-
163
- # Server status global variable
164
  server_status = True
165
  available_model_ids: List[str] = []
166
 
167
- # Create a reusable httpx client pool with connection pooling
168
  @lru_cache(maxsize=1)
169
  def get_async_client():
170
  return httpx.AsyncClient(
171
  timeout=60.0,
172
- limits=httpx.Limits(max_keepalive_connections=50, max_connections=200) # Increased limits
173
  )
174
 
175
- # Create a cloudscraper pool
176
  scraper_pool = []
177
- MAX_SCRAPERS = 20 # Increased pool size
178
-
179
 
180
  def get_scraper():
181
  if not scraper_pool:
182
  for _ in range(MAX_SCRAPERS):
183
  scraper_pool.append(cloudscraper.create_scraper())
184
 
185
- return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS] # Simple round-robin
186
 
187
- # API key validation - optimized to avoid string operations when possible
188
  async def verify_api_key(
189
  request: Request,
190
  api_key: str = Security(api_key_header)
191
  ) -> bool:
192
- # Allow bypass if the referer is from /playground or /image-playground
193
  referer = request.headers.get("referer", "")
194
- if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground",
195
  "https://parthsadaria-lokiai.hf.space/image-playground")):
196
  return True
197
-
198
  if not api_key:
199
  raise HTTPException(
200
  status_code=HTTP_403_FORBIDDEN,
201
  detail="No API key provided"
202
  )
203
-
204
- # Only clean if needed
205
  if api_key.startswith('Bearer '):
206
- api_key = api_key[7:] # Remove 'Bearer ' prefix
207
-
208
- # Get API keys from environment
209
  valid_api_keys = get_env_vars().get('api_keys', [])
210
  if not valid_api_keys or valid_api_keys == ['']:
211
  raise HTTPException(
212
  status_code=HTTP_403_FORBIDDEN,
213
  detail="API keys not configured on server"
214
  )
215
-
216
- # Fast check with set operation
217
  if api_key not in set(valid_api_keys):
218
  raise HTTPException(
219
  status_code=HTTP_403_FORBIDDEN,
220
  detail="Invalid API key"
221
  )
222
-
223
  return True
224
 
225
- # Pre-load and cache models.json
226
  @lru_cache(maxsize=1)
227
  def load_models_data():
228
  try:
@@ -233,61 +224,44 @@ def load_models_data():
233
  print(f"Error loading models.json: {str(e)}")
234
  return []
235
 
236
- # Async wrapper for models data
237
  async def get_models():
238
  models_data = load_models_data()
239
  if not models_data:
240
  raise HTTPException(status_code=500, detail="Error loading available models")
241
  return models_data
242
 
243
- # Enhanced async streaming - now with real-time SSE support
244
  async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
245
- # Create a streaming response channel using asyncio.Queue
246
  queue = asyncio.Queue()
247
 
248
  async def _fetch_search_data():
249
  try:
250
  headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
251
-
252
- # Use the provided system prompt, or default to "Be Helpful and Friendly"
253
  system_message = systemprompt or "Be Helpful and Friendly"
254
-
255
- # Create the prompt history
256
- prompt = [
257
- {"role": "user", "content": query},
258
- ]
259
-
260
  prompt.insert(0, {"content": system_message, "role": "system"})
261
-
262
- # Prepare the payload for the API request
263
  payload = {
264
  "is_vscode_extension": True,
265
  "message_history": prompt,
266
  "requested_model": "searchgpt",
267
  "user_input": prompt[-1]["content"],
268
  }
269
-
270
- # Get endpoint from environment
271
  secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
272
  if not secret_api_endpoint_3:
273
  await queue.put({"error": "Search API endpoint not configured"})
274
  return
275
 
276
- # Use AsyncClient for better performance
277
  async with httpx.AsyncClient(timeout=30.0) as client:
278
  async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response:
279
  if response.status_code != 200:
280
  await queue.put({"error": f"Search API returned status code {response.status_code}"})
281
  return
282
 
283
- # Process the streaming response in real-time
284
  buffer = ""
285
  async for line in response.aiter_lines():
286
  if line.startswith("data: "):
287
  try:
288
  json_data = json.loads(line[6:])
289
  content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
290
-
291
  if content.strip():
292
  cleaned_response = {
293
  "created": json_data.get("created"),
@@ -302,26 +276,17 @@ async def generate_search_async(query: str, systemprompt: Optional[str] = None,
302
  }
303
  ]
304
  }
305
-
306
- # Send to queue immediately for streaming
307
  await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content})
308
  except json.JSONDecodeError:
309
  continue
310
-
311
- # Signal completion
312
  await queue.put(None)
313
-
314
  except Exception as e:
315
  await queue.put({"error": str(e)})
316
  await queue.put(None)
317
 
318
- # Start the fetch process
319
  asyncio.create_task(_fetch_search_data())
320
-
321
- # Return the queue for consumption
322
  return queue
323
 
324
- # Cache for frequently accessed static files
325
  @lru_cache(maxsize=10)
326
  def read_html_file(file_path):
327
  try:
@@ -330,16 +295,15 @@ def read_html_file(file_path):
330
  except FileNotFoundError:
331
  return None
332
 
333
- # Basic routes
334
  @app.get("/favicon.ico")
335
  async def favicon():
336
  favicon_path = Path(__file__).parent / "favicon.ico"
337
  return FileResponse(favicon_path, media_type="image/x-icon")
338
 
339
  @app.get("/banner.jpg")
340
- async def favicon():
341
- favicon_path = Path(__file__).parent / "banner.jpg"
342
- return FileResponse(favicon_path, media_type="image/x-icon")
343
 
344
  @app.get("/ping")
345
  async def ping():
@@ -351,64 +315,61 @@ async def root():
351
  if html_content is None:
352
  return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
353
  return HTMLResponse(content=html_content)
 
354
  @app.get("/script.js", response_class=HTMLResponse)
355
- async def root():
356
  html_content = read_html_file("script.js")
357
  if html_content is None:
358
  return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
359
  return HTMLResponse(content=html_content)
 
360
  @app.get("/style.css", response_class=HTMLResponse)
361
- async def root():
362
  html_content = read_html_file("style.css")
363
  if html_content is None:
364
  return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
365
  return HTMLResponse(content=html_content)
 
366
  @app.get("/dynamo", response_class=HTMLResponse)
367
  async def dynamic_ai_page(request: Request):
368
  user_agent = request.headers.get('user-agent', 'Unknown User')
369
  client_ip = request.client.host
370
  location = f"IP: {client_ip}"
371
-
372
  prompt = f"""
373
- Generate a dynamic HTML page for a user with the following details: with name "LOKI.AI"
374
  - User-Agent: {user_agent}
375
  - Location: {location}
376
  - Style: Cyberpunk, minimalist, or retro
377
-
378
  Make sure the HTML is clean and includes a heading, also have cool animations a motivational message, and a cool background.
379
  Wrap the generated HTML in triple backticks (```).
380
  """
381
-
382
  payload = {
383
  "model": "mistral-small-latest",
384
  "messages": [{"role": "user", "content": prompt}]
385
  }
386
-
387
  headers = {
388
  "Authorization": "Bearer playground"
389
  }
390
-
391
- response = requests.post("https://parthsadaria-lokiai.hf.space/chat/completions", json=payload, headers=headers)
392
  data = response.json()
393
-
394
- # Extract HTML from ``` blocks
395
  html_content = re.search(r"```(.*?)```", data['choices'][0]['message']['content'], re.DOTALL)
396
  if html_content:
397
  html_content = html_content.group(1).strip()
398
-
399
- # Remove the first word
400
  if html_content:
401
  html_content = ' '.join(html_content.split(' ')[1:])
402
-
403
- return HTMLResponse(content=html_content)
404
 
405
-
406
- ######################################
407
 
408
  @app.get("/scraper", response_class=PlainTextResponse)
409
  def scrape_site(url: str = Query(..., description="URL to scrape")):
410
  try:
411
- # Try cloudscraper first
412
  scraper = cloudscraper.create_scraper()
413
  response = scraper.get(url)
414
  if response.status_code == 200 and len(response.text.strip()) > 0:
@@ -417,28 +378,21 @@ def scrape_site(url: str = Query(..., description="URL to scrape")):
417
  print(f"Cloudscraper failed: {e}")
418
  return "Cloudscraper failed."
419
 
420
-
421
- #######################################
422
-
423
  @app.get("/playground", response_class=HTMLResponse)
424
  async def playground():
425
  html_content = read_html_file("playground.html")
426
  if html_content is None:
427
  return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
428
  return HTMLResponse(content=html_content)
429
-
430
  @app.get("/image-playground", response_class=HTMLResponse)
431
- async def playground():
432
  html_content = read_html_file("image-playground.html")
433
  if html_content is None:
434
  return HTMLResponse(content="<h1>image-playground.html not found</h1>", status_code=404)
435
  return HTMLResponse(content=html_content)
436
 
437
-
438
-
439
-
440
- # VETRA
441
- GITHUB_BASE = "https://raw.githubusercontent.com/Parthsadaria/Vetra/main"
442
 
443
  FILES = {
444
  "html": "index.html",
@@ -471,17 +425,11 @@ async def serve_vetra():
471
 
472
  return HTMLResponse(content=final_html)
473
 
474
-
475
-
476
-
477
-
478
- # Model routes
479
  @app.get("/api/v1/models")
480
  @app.get("/models")
481
  async def return_models():
482
  return await get_models()
483
 
484
- # Search routes with enhanced real-time streaming
485
  @app.get("/searchgpt")
486
  async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
487
  if not q:
@@ -512,7 +460,6 @@ async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optio
512
  media_type="text/event-stream"
513
  )
514
  else:
515
- # For non-streaming, collect all text and return at once
516
  collected_text = ""
517
  while True:
518
  item = await queue.get()
@@ -526,14 +473,10 @@ async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optio
526
 
527
  return JSONResponse(content={"response": collected_text})
528
 
529
-
530
-
531
- # Enhanced streaming with direct SSE pass-through for real-time responses
532
  header_url = os.getenv('HEADER_URL')
533
  @app.post("/chat/completions")
534
  @app.post("/api/v1/chat/completions")
535
  async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
536
- # Check server status
537
  if not server_status:
538
  return JSONResponse(
539
  status_code=503,
@@ -542,28 +485,22 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
542
 
543
  model_to_use = payload.model or "gpt-4o-mini"
544
 
545
- # Validate model availability - fast lookup with set
546
  if available_model_ids and model_to_use not in set(available_model_ids):
547
  raise HTTPException(
548
  status_code=400,
549
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
550
  )
551
 
552
- # Log request without blocking
553
  asyncio.create_task(log_request(request, model_to_use))
554
  usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
555
 
556
- # Prepare payload
557
  payload_dict = payload.dict()
558
  payload_dict["model"] = model_to_use
559
 
560
- # Ensure stream is True for real-time streaming (can be overridden by client)
561
  stream_enabled = payload_dict.get("stream", True)
562
 
563
- # Get environment variables
564
  env_vars = get_env_vars()
565
 
566
- # Select the appropriate endpoint (fast lookup with sets)
567
  if model_to_use in mistral_models:
568
  endpoint = env_vars['mistral_api']
569
  custom_headers = {
@@ -575,9 +512,18 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
575
  elif model_to_use in alternate_models:
576
  endpoint = env_vars['secret_api_endpoint_2']
577
  custom_headers = {}
578
- elif model_to_use in claude_3_models: # Use the new endpoint
579
  endpoint = env_vars['secret_api_endpoint_5']
580
  custom_headers = {}
 
 
 
 
 
 
 
 
 
581
  else:
582
  endpoint = env_vars['secret_api_endpoint']
583
  custom_headers = {
@@ -588,7 +534,6 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
588
 
589
  print(f"Using endpoint: {endpoint} for model: {model_to_use}")
590
 
591
- # Improved real-time streaming handler
592
  async def real_time_stream_generator():
593
  try:
594
  async with httpx.AsyncClient(timeout=60.0) as client:
@@ -603,10 +548,8 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
603
  detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
604
  raise HTTPException(status_code=response.status_code, detail=detail)
605
 
606
- # Stream the response in real-time with minimal buffering
607
  async for line in response.aiter_lines():
608
  if line:
609
- # Yield immediately for faster streaming
610
  yield line + "\n"
611
  except httpx.TimeoutException:
612
  raise HTTPException(status_code=504, detail="Request timed out")
@@ -617,7 +560,6 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
617
  raise e
618
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
619
 
620
- # Return streaming response with proper headers
621
  if stream_enabled:
622
  return StreamingResponse(
623
  real_time_stream_generator(),
@@ -626,43 +568,31 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
626
  "Content-Type": "text/event-stream",
627
  "Cache-Control": "no-cache",
628
  "Connection": "keep-alive",
629
- "X-Accel-Buffering": "no" # Disable proxy buffering for Nginx
630
  }
631
  )
632
  else:
633
- # For non-streaming requests, collect the entire response
634
  response_content = []
635
  async for chunk in real_time_stream_generator():
636
  response_content.append(chunk)
637
-
638
  return JSONResponse(content=json.loads(''.join(response_content)))
639
 
640
-
641
-
642
- # New image generation endpoint
643
  @app.post("/images/generations")
644
  async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)):
645
- """
646
- Endpoint for generating images based on a text prompt.
647
- """
648
- # Check server status
649
  if not server_status:
650
  return JSONResponse(
651
  status_code=503,
652
  content={"message": "Server is under maintenance. Please try again later."}
653
  )
654
 
655
- # Validate model
656
  if payload.model not in supported_image_models:
657
  raise HTTPException(
658
  status_code=400,
659
- detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {supported_image_models}"
660
  )
661
 
662
- # Log the request
663
  usage_tracker.record_request(model=payload.model, endpoint="/images/generations")
664
 
665
- # Prepare the payload for the external API
666
  api_payload = {
667
  "model": payload.model,
668
  "prompt": payload.prompt,
@@ -670,11 +600,9 @@ async def create_image(payload: ImageGenerationPayload, authenticated: bool = De
670
  "number": payload.number
671
  }
672
 
673
- # Target API endpoint
674
  target_api_url = os.getenv('NEW_IMG')
675
 
676
  try:
677
- # Use a timeout for the image generation request
678
  async with httpx.AsyncClient(timeout=60.0) as client:
679
  response = await client.post(target_api_url, json=api_payload)
680
 
@@ -682,7 +610,6 @@ async def create_image(payload: ImageGenerationPayload, authenticated: bool = De
682
  error_detail = response.json().get("detail", f"Image generation failed with status code: {response.status_code}")
683
  raise HTTPException(status_code=response.status_code, detail=error_detail)
684
 
685
- # Return the response from the external API
686
  return JSONResponse(content=response.json())
687
 
688
  except httpx.TimeoutException:
@@ -692,28 +619,20 @@ async def create_image(payload: ImageGenerationPayload, authenticated: bool = De
692
  except Exception as e:
693
  raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
694
 
695
-
696
-
697
- # Asynchronous logging function
698
  async def log_request(request, model):
699
- # Get minimal data for logging
700
  current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
701
- ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
702
  print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
703
 
704
- # Cache usage statistics
705
  @lru_cache(maxsize=10)
706
  def get_usage_summary(days=7):
707
  return usage_tracker.get_usage_summary(days)
708
 
709
  @app.get("/usage")
710
  async def get_usage(days: int = 7):
711
- """Retrieve usage statistics"""
712
  return get_usage_summary(days)
713
 
714
- # Generate HTML for usage page
715
  def generate_usage_html(usage_data):
716
- # Model Usage Table Rows
717
  model_usage_rows = "\n".join([
718
  f"""
719
  <tr>
@@ -725,7 +644,6 @@ def generate_usage_html(usage_data):
725
  """ for model, model_data in usage_data['models'].items()
726
  ])
727
 
728
- # API Endpoint Usage Table Rows
729
  api_usage_rows = "\n".join([
730
  f"""
731
  <tr>
@@ -737,7 +655,6 @@ def generate_usage_html(usage_data):
737
  """ for endpoint, endpoint_data in usage_data['api_endpoints'].items()
738
  ])
739
 
740
- # Daily Usage Table Rows
741
  daily_usage_rows = "\n".join([
742
  "\n".join([
743
  f"""
@@ -756,7 +673,7 @@ def generate_usage_html(usage_data):
756
  <head>
757
  <meta charset="UTF-8">
758
  <title>Lokiai AI - Usage Statistics</title>
759
- <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet">
760
  <style>
761
  :root {{
762
  --bg-dark: #0f1011;
@@ -902,7 +819,6 @@ def generate_usage_html(usage_data):
902
  """
903
  return html_content
904
 
905
- # Cache the usage page HTML
906
  @lru_cache(maxsize=1)
907
  def get_usage_page_html():
908
  usage_data = get_usage_summary()
@@ -910,18 +826,14 @@ def get_usage_page_html():
910
 
911
  @app.get("/usage/page", response_class=HTMLResponse)
912
  async def usage_page():
913
- """Serve an HTML page showing usage statistics"""
914
- # Use cached HTML if available, regenerate if not
915
  html_content = get_usage_page_html()
916
  return HTMLResponse(content=html_content)
917
 
918
- # Meme endpoint with optimized networking
919
  @app.get("/meme")
920
  async def get_meme():
921
  try:
922
- # Use the shared client for connection pooling
923
  client = get_async_client()
924
- response = await client.get("https://meme-api.com/gimme")
925
  response_data = response.json()
926
 
927
  meme_url = response_data.get("url")
@@ -930,36 +842,31 @@ async def get_meme():
930
 
931
  image_response = await client.get(meme_url, follow_redirects=True)
932
 
933
- # Use larger chunks for streaming
934
  async def stream_with_larger_chunks():
935
  chunks = []
936
  size = 0
937
  async for chunk in image_response.aiter_bytes(chunk_size=16384):
938
  chunks.append(chunk)
939
  size += len(chunk)
940
-
941
  if size >= 65536:
942
  yield b''.join(chunks)
943
  chunks = []
944
  size = 0
945
-
946
  if chunks:
947
  yield b''.join(chunks)
948
 
949
  return StreamingResponse(
950
  stream_with_larger_chunks(),
951
  media_type=image_response.headers.get("content-type", "image/png"),
952
- headers={'Cache-Control': 'max-age=3600'} # Add caching
953
  )
954
  except Exception:
955
  raise HTTPException(status_code=500, detail="Failed to retrieve meme")
956
 
957
- # Utility function for loading model IDs - optimized to run once at startup
958
  def load_model_ids(json_file_path):
959
  try:
960
  with open(json_file_path, 'r') as f:
961
  models_data = json.load(f)
962
- # Extract 'id' from each model object and use a set for fast lookups
963
  return [model['id'] for model in models_data if 'id' in model]
964
  except Exception as e:
965
  print(f"Error loading model IDs: {str(e)}")
@@ -971,23 +878,18 @@ async def startup_event():
971
  available_model_ids = load_model_ids("models.json")
972
  print(f"Loaded {len(available_model_ids)} model IDs")
973
 
974
- # Add all pollinations models to available_model_ids
975
  available_model_ids.extend(list(pollinations_models))
976
- # Add alternate models to available_model_ids
977
  available_model_ids.extend(list(alternate_models))
978
- # Add mistral models to available_model_ids
979
  available_model_ids.extend(list(mistral_models))
980
- # Add claude models
981
  available_model_ids.extend(list(claude_3_models))
 
982
 
983
- available_model_ids = list(set(available_model_ids)) # Remove duplicates
984
  print(f"Total available models: {len(available_model_ids)}")
985
 
986
- # Preload scrapers
987
  for _ in range(MAX_SCRAPERS):
988
  scraper_pool.append(cloudscraper.create_scraper())
989
 
990
- # Validate critical environment variables
991
  env_vars = get_env_vars()
992
  missing_vars = []
993
 
@@ -1001,12 +903,16 @@ async def startup_event():
1001
  missing_vars.append('SECRET_API_ENDPOINT_3')
1002
  if not env_vars['secret_api_endpoint_4']:
1003
  missing_vars.append('SECRET_API_ENDPOINT_4')
1004
- if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
1005
  missing_vars.append('SECRET_API_ENDPOINT_5')
 
 
1006
  if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids):
1007
  missing_vars.append('MISTRAL_API')
1008
  if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
1009
  missing_vars.append('MISTRAL_KEY')
 
 
1010
 
1011
  if missing_vars:
1012
  print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
@@ -1016,27 +922,17 @@ async def startup_event():
1016
 
1017
  @app.on_event("shutdown")
1018
  async def shutdown_event():
1019
- # Close the httpx client
1020
  client = get_async_client()
1021
  await client.aclose()
1022
-
1023
- # Clear scraper pool
1024
  scraper_pool.clear()
1025
-
1026
- # Persist usage data
1027
  usage_tracker.save_data()
1028
-
1029
  print("Server shutdown complete!")
1030
 
1031
- # Health check endpoint
1032
- # Health check endpoint
1033
  @app.get("/health")
1034
  async def health_check():
1035
- """Health check endpoint for monitoring"""
1036
  env_vars = get_env_vars()
1037
  missing_critical_vars = []
1038
 
1039
- # Check critical environment variables
1040
  if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
1041
  missing_critical_vars.append('API_KEYS')
1042
  if not env_vars['secret_api_endpoint']:
@@ -1047,12 +943,16 @@ async def health_check():
1047
  missing_critical_vars.append('SECRET_API_ENDPOINT_3')
1048
  if not env_vars['secret_api_endpoint_4']:
1049
  missing_critical_vars.append('SECRET_API_ENDPOINT_4')
1050
- if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
1051
  missing_critical_vars.append('SECRET_API_ENDPOINT_5')
 
 
1052
  if not env_vars['mistral_api']:
1053
  missing_critical_vars.append('MISTRAL_API')
1054
  if not env_vars['mistral_key']:
1055
  missing_critical_vars.append('MISTRAL_KEY')
 
 
1056
 
1057
  health_status = {
1058
  "status": "healthy" if not missing_critical_vars else "unhealthy",
@@ -1064,4 +964,4 @@ async def health_check():
1064
 
1065
  if __name__ == "__main__":
1066
  import uvicorn
1067
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  import re
3
  from dotenv import load_dotenv
4
  from fastapi import FastAPI, HTTPException, Request, Depends, Security, Query
5
+ from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse, PlainTextResponse
6
  from fastapi.security import APIKeyHeader
7
  from pydantic import BaseModel
8
  import httpx
 
22
  from starlette.middleware.cors import CORSMiddleware
23
  import contextlib
24
  import requests
25
+
26
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
27
 
28
+ executor = ThreadPoolExecutor(max_workers=16)
 
29
 
 
30
  load_dotenv()
31
 
 
32
  api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
33
 
 
34
  from usage_tracker import UsageTracker
35
  usage_tracker = UsageTracker()
36
 
37
  app = FastAPI()
38
 
 
39
  app.add_middleware(GZipMiddleware, minimum_size=1000)
40
  app.add_middleware(
41
  CORSMiddleware,
 
45
  allow_headers=["*"],
46
  )
47
 
 
48
  @lru_cache(maxsize=1)
49
  def get_env_vars():
50
  return {
 
53
  'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
54
  'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'),
55
  'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
56
+ 'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'),
57
+ 'secret_api_endpoint_6': os.getenv('SECRET_API_ENDPOINT_6'), # New endpoint for Gemini
58
  'mistral_api': "https://api.mistral.ai",
59
  'mistral_key': os.getenv('MISTRAL_KEY'),
60
+ 'gemini_key': os.getenv('GEMINI_KEY'), # Gemini API Key
61
  'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
62
  }
63
 
 
64
  mistral_models = {
65
  "mistral-large-latest",
66
  "pixtral-large-latest",
 
110
  "o3"
111
  }
112
 
113
+ claude_3_models = {
114
  "claude-3-7-sonnet",
115
  "claude-3-7-sonnet-thinking",
116
  "claude 3.5 haiku",
 
123
  "grok 2"
124
  }
125
 
126
+ gemini_models = {
127
+ "gemini-1.5-pro",
128
+ "gemini-1.5-flash",
129
+ "gemini-2.0-flash-lite-preview",
130
+ "gemini-2.0-flash",
131
+ "gemini-2.0-flash-thinking", # aka Reasoning
132
+ "gemini-2.0-flash-preview-image-generation",
133
+ "gemini-2.5-flash",
134
+ "gemini-2.5-pro-exp",
135
+ "gemini-exp-1206"
136
+ }
137
+
138
+
139
  supported_image_models = {
140
  "Flux Pro Ultra",
141
  "grok-2-aurora",
 
150
  "sdxl-lightning-4step"
151
  }
152
 
 
 
153
  class Payload(BaseModel):
154
  model: str
155
  messages: list
156
  stream: bool = False
157
 
 
 
158
  class ImageGenerationPayload(BaseModel):
159
  model: str
160
  prompt: str
161
  size: int
162
  number: int
163
 
 
 
 
164
  server_status = True
165
  available_model_ids: List[str] = []
166
 
 
167
  @lru_cache(maxsize=1)
168
  def get_async_client():
169
  return httpx.AsyncClient(
170
  timeout=60.0,
171
+ limits=httpx.Limits(max_keepalive_connections=50, max_connections=200)
172
  )
173
 
 
174
  scraper_pool = []
175
+ MAX_SCRAPERS = 20
 
176
 
177
  def get_scraper():
178
  if not scraper_pool:
179
  for _ in range(MAX_SCRAPERS):
180
  scraper_pool.append(cloudscraper.create_scraper())
181
 
182
+ return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS]
183
 
 
184
  async def verify_api_key(
185
  request: Request,
186
  api_key: str = Security(api_key_header)
187
  ) -> bool:
 
188
  referer = request.headers.get("referer", "")
189
+ if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground",
190
  "https://parthsadaria-lokiai.hf.space/image-playground")):
191
  return True
192
+
193
  if not api_key:
194
  raise HTTPException(
195
  status_code=HTTP_403_FORBIDDEN,
196
  detail="No API key provided"
197
  )
198
+
 
199
  if api_key.startswith('Bearer '):
200
+ api_key = api_key[7:]
201
+
 
202
  valid_api_keys = get_env_vars().get('api_keys', [])
203
  if not valid_api_keys or valid_api_keys == ['']:
204
  raise HTTPException(
205
  status_code=HTTP_403_FORBIDDEN,
206
  detail="API keys not configured on server"
207
  )
208
+
 
209
  if api_key not in set(valid_api_keys):
210
  raise HTTPException(
211
  status_code=HTTP_403_FORBIDDEN,
212
  detail="Invalid API key"
213
  )
214
+
215
  return True
216
 
 
217
  @lru_cache(maxsize=1)
218
  def load_models_data():
219
  try:
 
224
  print(f"Error loading models.json: {str(e)}")
225
  return []
226
 
 
227
  async def get_models():
228
  models_data = load_models_data()
229
  if not models_data:
230
  raise HTTPException(status_code=500, detail="Error loading available models")
231
  return models_data
232
 
 
233
  async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
 
234
  queue = asyncio.Queue()
235
 
236
  async def _fetch_search_data():
237
  try:
238
  headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
 
 
239
  system_message = systemprompt or "Be Helpful and Friendly"
240
+ prompt = [{"role": "user", "content": query}]
 
 
 
 
 
241
  prompt.insert(0, {"content": system_message, "role": "system"})
 
 
242
  payload = {
243
  "is_vscode_extension": True,
244
  "message_history": prompt,
245
  "requested_model": "searchgpt",
246
  "user_input": prompt[-1]["content"],
247
  }
 
 
248
  secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
249
  if not secret_api_endpoint_3:
250
  await queue.put({"error": "Search API endpoint not configured"})
251
  return
252
 
 
253
  async with httpx.AsyncClient(timeout=30.0) as client:
254
  async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response:
255
  if response.status_code != 200:
256
  await queue.put({"error": f"Search API returned status code {response.status_code}"})
257
  return
258
 
 
259
  buffer = ""
260
  async for line in response.aiter_lines():
261
  if line.startswith("data: "):
262
  try:
263
  json_data = json.loads(line[6:])
264
  content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
 
265
  if content.strip():
266
  cleaned_response = {
267
  "created": json_data.get("created"),
 
276
  }
277
  ]
278
  }
 
 
279
  await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content})
280
  except json.JSONDecodeError:
281
  continue
 
 
282
  await queue.put(None)
 
283
  except Exception as e:
284
  await queue.put({"error": str(e)})
285
  await queue.put(None)
286
 
 
287
  asyncio.create_task(_fetch_search_data())
 
 
288
  return queue
289
 
 
290
  @lru_cache(maxsize=10)
291
  def read_html_file(file_path):
292
  try:
 
295
  except FileNotFoundError:
296
  return None
297
 
 
298
  @app.get("/favicon.ico")
299
  async def favicon():
300
  favicon_path = Path(__file__).parent / "favicon.ico"
301
  return FileResponse(favicon_path, media_type="image/x-icon")
302
 
303
  @app.get("/banner.jpg")
304
+ async def banner():
305
+ banner_path = Path(__file__).parent / "banner.jpg"
306
+ return FileResponse(banner_path, media_type="image/jpeg")
307
 
308
  @app.get("/ping")
309
  async def ping():
 
315
  if html_content is None:
316
  return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
317
  return HTMLResponse(content=html_content)
318
+
319
  @app.get("/script.js", response_class=HTMLResponse)
320
+ async def script():
321
  html_content = read_html_file("script.js")
322
  if html_content is None:
323
  return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
324
  return HTMLResponse(content=html_content)
325
+
326
  @app.get("/style.css", response_class=HTMLResponse)
327
+ async def style():
328
  html_content = read_html_file("style.css")
329
  if html_content is None:
330
  return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
331
  return HTMLResponse(content=html_content)
332
+
333
  @app.get("/dynamo", response_class=HTMLResponse)
334
  async def dynamic_ai_page(request: Request):
335
  user_agent = request.headers.get('user-agent', 'Unknown User')
336
  client_ip = request.client.host
337
  location = f"IP: {client_ip}"
338
+
339
  prompt = f"""
340
+ Generate a dynamic HTML page for a user with the following details: with name "LOKI.AI"
341
  - User-Agent: {user_agent}
342
  - Location: {location}
343
  - Style: Cyberpunk, minimalist, or retro
344
+
345
  Make sure the HTML is clean and includes a heading, also have cool animations a motivational message, and a cool background.
346
  Wrap the generated HTML in triple backticks (```).
347
  """
348
+
349
  payload = {
350
  "model": "mistral-small-latest",
351
  "messages": [{"role": "user", "content": prompt}]
352
  }
353
+
354
  headers = {
355
  "Authorization": "Bearer playground"
356
  }
357
+
358
+ response = requests.post("[https://parthsadaria-lokiai.hf.space/chat/completions](https://parthsadaria-lokiai.hf.space/chat/completions)", json=payload, headers=headers)
359
  data = response.json()
360
+
 
361
  html_content = re.search(r"```(.*?)```", data['choices'][0]['message']['content'], re.DOTALL)
362
  if html_content:
363
  html_content = html_content.group(1).strip()
364
+
 
365
  if html_content:
366
  html_content = ' '.join(html_content.split(' ')[1:])
 
 
367
 
368
+ return HTMLResponse(content=html_content)
 
369
 
370
  @app.get("/scraper", response_class=PlainTextResponse)
371
  def scrape_site(url: str = Query(..., description="URL to scrape")):
372
  try:
 
373
  scraper = cloudscraper.create_scraper()
374
  response = scraper.get(url)
375
  if response.status_code == 200 and len(response.text.strip()) > 0:
 
378
  print(f"Cloudscraper failed: {e}")
379
  return "Cloudscraper failed."
380
 
 
 
 
381
  @app.get("/playground", response_class=HTMLResponse)
382
  async def playground():
383
  html_content = read_html_file("playground.html")
384
  if html_content is None:
385
  return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
386
  return HTMLResponse(content=html_content)
387
+
388
  @app.get("/image-playground", response_class=HTMLResponse)
389
+ async def image_playground():
390
  html_content = read_html_file("image-playground.html")
391
  if html_content is None:
392
  return HTMLResponse(content="<h1>image-playground.html not found</h1>", status_code=404)
393
  return HTMLResponse(content=html_content)
394
 
395
+ GITHUB_BASE = "[https://raw.githubusercontent.com/Parthsadaria/Vetra/main](https://raw.githubusercontent.com/Parthsadaria/Vetra/main)"
 
 
 
 
396
 
397
  FILES = {
398
  "html": "index.html",
 
425
 
426
  return HTMLResponse(content=final_html)
427
 
 
 
 
 
 
428
  @app.get("/api/v1/models")
429
  @app.get("/models")
430
  async def return_models():
431
  return await get_models()
432
 
 
433
  @app.get("/searchgpt")
434
  async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
435
  if not q:
 
460
  media_type="text/event-stream"
461
  )
462
  else:
 
463
  collected_text = ""
464
  while True:
465
  item = await queue.get()
 
473
 
474
  return JSONResponse(content={"response": collected_text})
475
 
 
 
 
476
  header_url = os.getenv('HEADER_URL')
477
  @app.post("/chat/completions")
478
  @app.post("/api/v1/chat/completions")
479
  async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
 
480
  if not server_status:
481
  return JSONResponse(
482
  status_code=503,
 
485
 
486
  model_to_use = payload.model or "gpt-4o-mini"
487
 
 
488
  if available_model_ids and model_to_use not in set(available_model_ids):
489
  raise HTTPException(
490
  status_code=400,
491
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
492
  )
493
 
 
494
  asyncio.create_task(log_request(request, model_to_use))
495
  usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
496
 
 
497
  payload_dict = payload.dict()
498
  payload_dict["model"] = model_to_use
499
 
 
500
  stream_enabled = payload_dict.get("stream", True)
501
 
 
502
  env_vars = get_env_vars()
503
 
 
504
  if model_to_use in mistral_models:
505
  endpoint = env_vars['mistral_api']
506
  custom_headers = {
 
512
  elif model_to_use in alternate_models:
513
  endpoint = env_vars['secret_api_endpoint_2']
514
  custom_headers = {}
515
+ elif model_to_use in claude_3_models:
516
  endpoint = env_vars['secret_api_endpoint_5']
517
  custom_headers = {}
518
+ elif model_to_use in gemini_models: # Handle Gemini models
519
+ endpoint = env_vars['secret_api_endpoint_6']
520
+ if not endpoint:
521
+ raise HTTPException(status_code=500, detail="Gemini API endpoint not configured")
522
+ if not env_vars['gemini_key']:
523
+ raise HTTPException(status_code=500, detail="GEMINI_KEY not configured")
524
+ custom_headers = {
525
+ "Authorization": f"Bearer {env_vars['gemini_key']}"
526
+ }
527
  else:
528
  endpoint = env_vars['secret_api_endpoint']
529
  custom_headers = {
 
534
 
535
  print(f"Using endpoint: {endpoint} for model: {model_to_use}")
536
 
 
537
  async def real_time_stream_generator():
538
  try:
539
  async with httpx.AsyncClient(timeout=60.0) as client:
 
548
  detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
549
  raise HTTPException(status_code=response.status_code, detail=detail)
550
 
 
551
  async for line in response.aiter_lines():
552
  if line:
 
553
  yield line + "\n"
554
  except httpx.TimeoutException:
555
  raise HTTPException(status_code=504, detail="Request timed out")
 
560
  raise e
561
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
562
 
 
563
  if stream_enabled:
564
  return StreamingResponse(
565
  real_time_stream_generator(),
 
568
  "Content-Type": "text/event-stream",
569
  "Cache-Control": "no-cache",
570
  "Connection": "keep-alive",
571
+ "X-Accel-Buffering": "no"
572
  }
573
  )
574
  else:
 
575
  response_content = []
576
  async for chunk in real_time_stream_generator():
577
  response_content.append(chunk)
 
578
  return JSONResponse(content=json.loads(''.join(response_content)))
579
 
 
 
 
580
  @app.post("/images/generations")
581
  async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)):
 
 
 
 
582
  if not server_status:
583
  return JSONResponse(
584
  status_code=503,
585
  content={"message": "Server is under maintenance. Please try again later."}
586
  )
587
 
 
588
  if payload.model not in supported_image_models:
589
  raise HTTPException(
590
  status_code=400,
591
+ detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {supported_image_models}"
592
  )
593
 
 
594
  usage_tracker.record_request(model=payload.model, endpoint="/images/generations")
595
 
 
596
  api_payload = {
597
  "model": payload.model,
598
  "prompt": payload.prompt,
 
600
  "number": payload.number
601
  }
602
 
 
603
  target_api_url = os.getenv('NEW_IMG')
604
 
605
  try:
 
606
  async with httpx.AsyncClient(timeout=60.0) as client:
607
  response = await client.post(target_api_url, json=api_payload)
608
 
 
610
  error_detail = response.json().get("detail", f"Image generation failed with status code: {response.status_code}")
611
  raise HTTPException(status_code=response.status_code, detail=error_detail)
612
 
 
613
  return JSONResponse(content=response.json())
614
 
615
  except httpx.TimeoutException:
 
619
  except Exception as e:
620
  raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
621
 
 
 
 
622
  async def log_request(request, model):
 
623
  current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
624
+ ip_hash = hash(request.client.host) % 10000
625
  print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
626
 
 
627
  @lru_cache(maxsize=10)
628
  def get_usage_summary(days=7):
629
  return usage_tracker.get_usage_summary(days)
630
 
631
  @app.get("/usage")
632
  async def get_usage(days: int = 7):
 
633
  return get_usage_summary(days)
634
 
 
635
  def generate_usage_html(usage_data):
 
636
  model_usage_rows = "\n".join([
637
  f"""
638
  <tr>
 
644
  """ for model, model_data in usage_data['models'].items()
645
  ])
646
 
 
647
  api_usage_rows = "\n".join([
648
  f"""
649
  <tr>
 
655
  """ for endpoint, endpoint_data in usage_data['api_endpoints'].items()
656
  ])
657
 
 
658
  daily_usage_rows = "\n".join([
659
  "\n".join([
660
  f"""
 
673
  <head>
674
  <meta charset="UTF-8">
675
  <title>Lokiai AI - Usage Statistics</title>
676
+ <link href="[https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap](https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap)" rel="stylesheet">
677
  <style>
678
  :root {{
679
  --bg-dark: #0f1011;
 
819
  """
820
  return html_content
821
 
 
822
  @lru_cache(maxsize=1)
823
  def get_usage_page_html():
824
  usage_data = get_usage_summary()
 
826
 
827
  @app.get("/usage/page", response_class=HTMLResponse)
828
  async def usage_page():
 
 
829
  html_content = get_usage_page_html()
830
  return HTMLResponse(content=html_content)
831
 
 
832
  @app.get("/meme")
833
  async def get_meme():
834
  try:
 
835
  client = get_async_client()
836
+ response = await client.get("[https://meme-api.com/gimme](https://meme-api.com/gimme)")
837
  response_data = response.json()
838
 
839
  meme_url = response_data.get("url")
 
842
 
843
  image_response = await client.get(meme_url, follow_redirects=True)
844
 
 
845
  async def stream_with_larger_chunks():
846
  chunks = []
847
  size = 0
848
  async for chunk in image_response.aiter_bytes(chunk_size=16384):
849
  chunks.append(chunk)
850
  size += len(chunk)
 
851
  if size >= 65536:
852
  yield b''.join(chunks)
853
  chunks = []
854
  size = 0
 
855
  if chunks:
856
  yield b''.join(chunks)
857
 
858
  return StreamingResponse(
859
  stream_with_larger_chunks(),
860
  media_type=image_response.headers.get("content-type", "image/png"),
861
+ headers={'Cache-Control': 'max-age=3600'}
862
  )
863
  except Exception:
864
  raise HTTPException(status_code=500, detail="Failed to retrieve meme")
865
 
 
866
  def load_model_ids(json_file_path):
867
  try:
868
  with open(json_file_path, 'r') as f:
869
  models_data = json.load(f)
 
870
  return [model['id'] for model in models_data if 'id' in model]
871
  except Exception as e:
872
  print(f"Error loading model IDs: {str(e)}")
 
878
  available_model_ids = load_model_ids("models.json")
879
  print(f"Loaded {len(available_model_ids)} model IDs")
880
 
 
881
  available_model_ids.extend(list(pollinations_models))
 
882
  available_model_ids.extend(list(alternate_models))
 
883
  available_model_ids.extend(list(mistral_models))
 
884
  available_model_ids.extend(list(claude_3_models))
885
+ available_model_ids.extend(list(gemini_models)) # Add Gemini models
886
 
887
+ available_model_ids = list(set(available_model_ids))
888
  print(f"Total available models: {len(available_model_ids)}")
889
 
 
890
  for _ in range(MAX_SCRAPERS):
891
  scraper_pool.append(cloudscraper.create_scraper())
892
 
 
893
  env_vars = get_env_vars()
894
  missing_vars = []
895
 
 
903
  missing_vars.append('SECRET_API_ENDPOINT_3')
904
  if not env_vars['secret_api_endpoint_4']:
905
  missing_vars.append('SECRET_API_ENDPOINT_4')
906
+ if not env_vars['secret_api_endpoint_5']:
907
  missing_vars.append('SECRET_API_ENDPOINT_5')
908
+ if not env_vars['secret_api_endpoint_6']: # Check the new endpoint
909
+ missing_vars.append('SECRET_API_ENDPOINT_6')
910
  if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids):
911
  missing_vars.append('MISTRAL_API')
912
  if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
913
  missing_vars.append('MISTRAL_KEY')
914
+ if not env_vars['gemini_key'] and any(model in gemini_models for model in available_model_ids): # Check Gemini key
915
+ missing_vars.append('GEMINI_KEY')
916
 
917
  if missing_vars:
918
  print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
 
922
 
923
  @app.on_event("shutdown")
924
  async def shutdown_event():
 
925
  client = get_async_client()
926
  await client.aclose()
 
 
927
  scraper_pool.clear()
 
 
928
  usage_tracker.save_data()
 
929
  print("Server shutdown complete!")
930
 
 
 
931
  @app.get("/health")
932
  async def health_check():
 
933
  env_vars = get_env_vars()
934
  missing_critical_vars = []
935
 
 
936
  if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
937
  missing_critical_vars.append('API_KEYS')
938
  if not env_vars['secret_api_endpoint']:
 
943
  missing_critical_vars.append('SECRET_API_ENDPOINT_3')
944
  if not env_vars['secret_api_endpoint_4']:
945
  missing_critical_vars.append('SECRET_API_ENDPOINT_4')
946
+ if not env_vars['secret_api_endpoint_5']:
947
  missing_critical_vars.append('SECRET_API_ENDPOINT_5')
948
+ if not env_vars['secret_api_endpoint_6']: # Check the new endpoint
949
+ missing_critical_vars.append('SECRET_API_ENDPOINT_6')
950
  if not env_vars['mistral_api']:
951
  missing_critical_vars.append('MISTRAL_API')
952
  if not env_vars['mistral_key']:
953
  missing_critical_vars.append('MISTRAL_KEY')
954
+ if not env_vars['gemini_key']: # Check Gemini key
955
+ missing_critical_vars.append('GEMINI_KEY')
956
 
957
  health_status = {
958
  "status": "healthy" if not missing_critical_vars else "unhealthy",
 
964
 
965
  if __name__ == "__main__":
966
  import uvicorn
967
+ uvicorn.run(app, host="0.0.0.0", port=7860)