ParthSadaria commited on
Commit
a7b3e7b
·
verified ·
1 Parent(s): 160be92

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +66 -228
main.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import re
3
  from dotenv import load_dotenv
4
- from fastapi import FastAPI, HTTPException, Request, Depends, Security, APIRouter
5
  from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
6
  from fastapi.security import APIKeyHeader
7
  from pydantic import BaseModel
@@ -12,7 +12,7 @@ import json
12
  import datetime
13
  import time
14
  import threading
15
- from typing import Optional, Dict, List, Any, Generator, Set # Import Set
16
  import asyncio
17
  from starlette.status import HTTP_403_FORBIDDEN
18
  import cloudscraper
@@ -66,10 +66,7 @@ def get_env_vars():
66
  }
67
 
68
  # Configuration for models - use sets for faster lookups
69
- # IMPORTANT: These will be updated in memory. For persistence,
70
- # you would need to save these changes to a file (like models.json)
71
- # or a database and reload them on startup.
72
- mistral_models: Set[str] = {
73
  "mistral-large-latest",
74
  "pixtral-large-latest",
75
  "mistral-moderation-latest",
@@ -81,7 +78,7 @@ mistral_models: Set[str] = {
81
  "codestral-latest"
82
  }
83
 
84
- pollinations_models: Set[str] = {
85
  "openai",
86
  "openai-large",
87
  "openai-xlarge",
@@ -103,7 +100,7 @@ pollinations_models: Set[str] = {
103
  "openai-audio",
104
  "llama-scaleway"
105
  }
106
- alternate_models: Set[str] = {
107
  "o1",
108
  "llama-4-scout",
109
  "o4-mini",
@@ -117,7 +114,7 @@ alternate_models: Set[str] = {
117
  "o3"
118
  }
119
 
120
- claude_3_models: Set[str] = { # Models for the new endpoint
121
  "claude-3-7-sonnet",
122
  "claude-3-7-sonnet-thinking",
123
  "claude 3.5 haiku",
@@ -131,7 +128,7 @@ claude_3_models: Set[str] = { # Models for the new endpoint
131
  }
132
 
133
  # Supported image generation models
134
- supported_image_models: Set[str] = {
135
  "Flux Pro Ultra",
136
  "grok-2-aurora",
137
  "Flux Pro",
@@ -161,14 +158,10 @@ class ImageGenerationPayload(BaseModel):
161
  number: int
162
 
163
 
164
- # Pydantic model for updating models via admin endpoint
165
- class UpdateModelsPayload(BaseModel):
166
- provider: str # e.g., "mistral", "pollinations", "alternate", "claude_3", "image"
167
- models: List[str] # The new list of model IDs for the provider
168
 
169
  # Server status global variable
170
  server_status = True
171
- available_model_ids: List[str] = [] # This will be updated based on the sets
172
 
173
  # Create a reusable httpx client pool with connection pooling
174
  @lru_cache(maxsize=1)
@@ -197,20 +190,20 @@ async def verify_api_key(
197
  ) -> bool:
198
  # Allow bypass if the referer is from /playground or /image-playground
199
  referer = request.headers.get("referer", "")
200
- if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground",
201
  "https://parthsadaria-lokiai.hf.space/image-playground")):
202
  return True
203
-
204
  if not api_key:
205
  raise HTTPException(
206
  status_code=HTTP_403_FORBIDDEN,
207
  detail="No API key provided"
208
  )
209
-
210
  # Only clean if needed
211
  if api_key.startswith('Bearer '):
212
  api_key = api_key[7:] # Remove 'Bearer ' prefix
213
-
214
  # Get API keys from environment
215
  valid_api_keys = get_env_vars().get('api_keys', [])
216
  if not valid_api_keys or valid_api_keys == ['']:
@@ -218,14 +211,14 @@ async def verify_api_key(
218
  status_code=HTTP_403_FORBIDDEN,
219
  detail="API keys not configured on server"
220
  )
221
-
222
  # Fast check with set operation
223
  if api_key not in set(valid_api_keys):
224
  raise HTTPException(
225
  status_code=HTTP_403_FORBIDDEN,
226
  detail="Invalid API key"
227
  )
228
-
229
  return True
230
 
231
  # Pre-load and cache models.json
@@ -241,27 +234,10 @@ def load_models_data():
241
 
242
  # Async wrapper for models data
243
  async def get_models():
244
- # Combine models from all active sets for the /models endpoint
245
- all_models = list(mistral_models) + \
246
- list(pollinations_models) + \
247
- list(alternate_models) + \
248
- list(claude_3_models) + \
249
- list(supported_image_models) # Include image models
250
-
251
- # Fetch additional models from models.json if it exists and add them
252
- models_from_file = load_models_data()
253
- if models_from_file:
254
- # Assuming models.json contains a list of dicts with 'id'
255
- all_models.extend([model.get('id') for model in models_from_file if model.get('id')])
256
-
257
- # Remove duplicates and sort for a consistent list
258
- unique_models = sorted(list(set(all_models)))
259
-
260
- # Format as a list of dictionaries for compatibility with existing /models endpoint
261
- formatted_models = [{"id": model_id, "name": model_id} for model_id in unique_models]
262
-
263
- return formatted_models
264
-
265
 
266
  # Enhanced async streaming - now with real-time SSE support
267
  async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
@@ -391,47 +367,47 @@ async def dynamic_ai_page(request: Request):
391
  user_agent = request.headers.get('user-agent', 'Unknown User')
392
  client_ip = request.client.host
393
  location = f"IP: {client_ip}"
394
-
395
  prompt = f"""
396
- Generate a dynamic HTML page for a user with the following details: with name "LOKI.AI"
397
  - User-Agent: {user_agent}
398
  - Location: {location}
399
  - Style: Cyberpunk, minimalist, or retro
400
-
401
  Make sure the HTML is clean and includes a heading, also have cool animations a motivational message, and a cool background.
402
  Wrap the generated HTML in triple backticks (```).
403
  """
404
-
405
  payload = {
406
  "model": "mistral-small-latest",
407
  "messages": [{"role": "user", "content": prompt}]
408
  }
409
-
410
  headers = {
411
  "Authorization": "Bearer playground"
412
  }
413
-
414
- response = requests.post("[https://parthsadaria-lokiai.hf.space/chat/completions](https://parthsadaria-lokiai.hf.space/chat/completions)", json=payload, headers=headers)
415
  data = response.json()
416
-
417
  # Extract HTML from ``` blocks
418
  html_content = re.search(r"```(.*?)```", data['choices'][0]['message']['content'], re.DOTALL)
419
  if html_content:
420
  html_content = html_content.group(1).strip()
421
-
422
  # Remove the first word
423
  if html_content:
424
  html_content = ' '.join(html_content.split(' ')[1:])
425
-
426
  return HTMLResponse(content=html_content)
427
-
428
  @app.get("/playground", response_class=HTMLResponse)
429
  async def playground():
430
  html_content = read_html_file("playground.html")
431
  if html_content is None:
432
  return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
433
  return HTMLResponse(content=html_content)
434
-
435
  @app.get("/image-playground", response_class=HTMLResponse)
436
  async def playground():
437
  html_content = read_html_file("image-playground.html")
@@ -548,18 +524,12 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
548
  model_to_use = payload.model or "gpt-4o-mini"
549
 
550
  # Validate model availability - fast lookup with set
551
- # Check if the model is in any of the currently active sets
552
- if model_to_use not in mistral_models and \
553
- model_to_use not in pollinations_models and \
554
- model_to_use not in alternate_models and \
555
- model_to_use not in claude_3_models and \
556
- model_to_use not in supported_image_models: # Also check image models
557
  raise HTTPException(
558
  status_code=400,
559
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
560
  )
561
 
562
-
563
  # Log request without blocking
564
  asyncio.create_task(log_request(request, model_to_use))
565
  usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
@@ -589,11 +559,7 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
589
  elif model_to_use in claude_3_models: # Use the new endpoint
590
  endpoint = env_vars['secret_api_endpoint_5']
591
  custom_headers = {}
592
- # Add check for image models here if they use a different endpoint than /images/generations
593
- # elif model_to_use in supported_image_models:
594
- # endpoint = env_vars['YOUR_IMAGE_COMPLETIONS_ENDPOINT'] # Define a new env var if needed
595
- # custom_headers = {}
596
- else: # Default endpoint
597
  endpoint = env_vars['secret_api_endpoint']
598
  custom_headers = {
599
  "Origin": header_url,
@@ -607,36 +573,15 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
607
  async def real_time_stream_generator():
608
  try:
609
  async with httpx.AsyncClient(timeout=60.0) as client:
610
- # Adjust the endpoint path based on the provider if necessary
611
- # For example, Mistral uses /v1/chat/completions, Pollinations might use something else
612
- # Based on your existing code, it seems most use /v1/chat/completions,
613
- # but this is a point to verify with the actual provider APIs.
614
- api_path = "/v1/chat/completions"
615
- if model_to_use in mistral_models:
616
- api_path = "/v1/chat/completions" # Or the correct path for Mistral
617
- elif model_to_use in pollinations_models:
618
- api_path = "/v1/chat/completions" # Or the correct path for Pollinations
619
- elif model_to_use in alternate_models:
620
- api_path = "/v1/chat/completions" # Or the correct path for Alternate
621
- elif model_to_use in claude_3_models:
622
- api_path = "/v1/chat/completions" # Or the correct path for Claude 3
623
-
624
- async with client.stream("POST", f"{endpoint}{api_path}", json=payload_dict, headers=custom_headers) as response:
625
  if response.status_code >= 400:
626
  error_messages = {
627
  422: "Unprocessable entity. Check your payload.",
628
  400: "Bad request. Verify input data.",
629
  403: "Forbidden. You do not have access to this resource.",
630
  404: "The requested resource was not found.",
631
- 500: "Internal Server Error from upstream API."
632
  }
633
- detail = error_messages.get(response.status_code, f"Error code: {response.status_code} from upstream API.")
634
- try:
635
- # Attempt to get more detail from the upstream response body
636
- error_body = await response.aread()
637
- detail += f" Upstream response: {error_body.decode()}"
638
- except Exception:
639
- pass # Ignore errors reading the body
640
  raise HTTPException(status_code=response.status_code, detail=detail)
641
 
642
  # Stream the response in real-time with minimal buffering
@@ -671,32 +616,8 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
671
  async for chunk in real_time_stream_generator():
672
  response_content.append(chunk)
673
 
674
- # The collected content might be multiple SSE data chunks.
675
- # For non-streaming, we expect a single JSON object.
676
- # This part might need adjustment based on the *actual* non-streaming
677
- # response format of the upstream APIs. Assuming it's a single JSON:
678
- try:
679
- # Attempt to parse the full collected content as JSON
680
- # This assumes the upstream non-streaming response is a single JSON blob
681
- # without the 'data: ' prefix and multiple lines.
682
- full_response_text = "".join(response_content)
683
- # Remove potential 'data: ' prefixes if they exist even in non-stream
684
- full_response_text = re.sub(r'^data: ', '', full_response_text, flags=re.MULTILINE)
685
- # Remove empty lines
686
- full_response_text = "\n".join([line for line in full_response_text.splitlines() if line.strip()])
687
-
688
- # If the upstream API sends multiple JSON objects even in non-stream,
689
- # you might need to process them differently, e.g., concatenate content.
690
- # For now, assume a single JSON object is expected.
691
- json_response = json.loads(full_response_text)
692
- return JSONResponse(content=json_response)
693
- except json.JSONDecodeError:
694
- # If parsing fails, return the raw collected content and a server error
695
- print(f"Warning: Failed to decode JSON for non-streaming response. Raw content: {response_content}")
696
- raise HTTPException(status_code=500, detail="Failed to parse upstream API JSON response.")
697
- except Exception as e:
698
- print(f"Warning: Unexpected error processing non-streaming response: {e}")
699
- raise HTTPException(status_code=500, detail=f"An error occurred processing non-streaming response: {str(e)}")
700
 
701
 
702
  # New image generation endpoint
@@ -716,7 +637,7 @@ async def create_image(payload: ImageGenerationPayload, authenticated: bool = De
716
  if payload.model not in supported_image_models:
717
  raise HTTPException(
718
  status_code=400,
719
- detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {list(supported_image_models)}"
720
  )
721
 
722
  # Log the request
@@ -730,16 +651,12 @@ async def create_image(payload: ImageGenerationPayload, authenticated: bool = De
730
  "number": payload.number
731
  }
732
 
733
- # Target API endpoint for image generation
734
- target_api_url = os.getenv('NEW_IMG') # Ensure this env var is set
735
-
736
- if not target_api_url:
737
- raise HTTPException(status_code=500, detail="Image generation API endpoint (NEW_IMG) is not configured.")
738
-
739
 
740
  try:
741
  # Use a timeout for the image generation request
742
- async with httpx.AsyncClient(timeout=120.0) as client: # Increased timeout for image generation
743
  response = await client.post(target_api_url, json=api_payload)
744
 
745
  if response.status_code != 200:
@@ -757,6 +674,7 @@ async def create_image(payload: ImageGenerationPayload, authenticated: bool = De
757
  raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
758
 
759
 
 
760
  # Asynchronous logging function
761
  async def log_request(request, model):
762
  # Get minimal data for logging
@@ -920,7 +838,7 @@ def generate_usage_html(usage_data):
920
  <body>
921
  <div class="container">
922
  <div class="logo">
923
- <img src="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjAwIiBoZWlnaHQ9IjIwMCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBkPSJNMTAwIDM1TDUwIDkwaDEwMHoiIGZpbGw9IiMzYTZlZTAiLz4+PGNpcmNsZSBjeD0iMTAwIiBjeT0iMTQwIiByPSIzMCIgZmlsbD0iIzNhNmVlMCIvPjwvc3ZnPg==" alt="Lokai AI Logo">
924
  <h1>Lokiai AI</h1>
925
  </div>
926
 
@@ -1018,34 +936,33 @@ async def get_meme():
1018
  raise HTTPException(status_code=500, detail="Failed to retrieve meme")
1019
 
1020
  # Utility function for loading model IDs - optimized to run once at startup
1021
- # This function is now less critical as model availability is checked against sets
1022
- # However, it's still used to initially populate available_model_ids from models.json
1023
- def load_model_ids_from_file(json_file_path):
1024
  try:
1025
  with open(json_file_path, 'r') as f:
1026
  models_data = json.load(f)
1027
- # Extract 'id' from each model object
1028
  return [model['id'] for model in models_data if 'id' in model]
1029
  except Exception as e:
1030
- print(f"Error loading model IDs from {json_file_path}: {str(e)}")
1031
  return []
1032
 
1033
  @app.on_event("startup")
1034
  async def startup_event():
1035
  global available_model_ids
1036
- # Load initial models from models.json
1037
- available_model_ids = load_model_ids_from_file("models.json")
1038
 
1039
- # Add models from hardcoded sets to the available_model_ids list for the /models endpoint
1040
- # Note: The actual model availability check in /chat/completions uses the sets directly.
1041
  available_model_ids.extend(list(pollinations_models))
 
1042
  available_model_ids.extend(list(alternate_models))
 
1043
  available_model_ids.extend(list(mistral_models))
 
1044
  available_model_ids.extend(list(claude_3_models))
1045
- available_model_ids.extend(list(supported_image_models)) # Add image models
1046
 
1047
  available_model_ids = list(set(available_model_ids)) # Remove duplicates
1048
- print(f"Initial available models for /models endpoint: {len(available_model_ids)}")
1049
 
1050
  # Preload scrapers
1051
  for _ in range(MAX_SCRAPERS):
@@ -1067,14 +984,10 @@ async def startup_event():
1067
  missing_vars.append('SECRET_API_ENDPOINT_4')
1068
  if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
1069
  missing_vars.append('SECRET_API_ENDPOINT_5')
1070
- # Check Mistral keys only if Mistral models are defined in the sets
1071
- if mistral_models and (not env_vars.get('mistral_api') or not env_vars.get('mistral_key')):
1072
- if not env_vars.get('mistral_api'): missing_vars.append('MISTRAL_API')
1073
- if not env_vars.get('mistral_key'): missing_vars.append('MISTRAL_KEY')
1074
- # Check image endpoint only if image models are defined in the sets
1075
- if supported_image_models and not os.getenv('NEW_IMG'):
1076
- missing_vars.append('NEW_IMG')
1077
-
1078
 
1079
  if missing_vars:
1080
  print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
@@ -1096,6 +1009,7 @@ async def shutdown_event():
1096
 
1097
  print("Server shutdown complete!")
1098
 
 
1099
  # Health check endpoint
1100
  @app.get("/health")
1101
  async def health_check():
@@ -1104,25 +1018,22 @@ async def health_check():
1104
  missing_critical_vars = []
1105
 
1106
  # Check critical environment variables
1107
- if not env_vars.get('api_keys') or env_vars['api_keys'] == ['']:
1108
  missing_critical_vars.append('API_KEYS')
1109
- if not env_vars.get('secret_api_endpoint'):
1110
  missing_critical_vars.append('SECRET_API_ENDPOINT')
1111
- if not env_vars.get('secret_api_endpoint_2'):
1112
  missing_critical_vars.append('SECRET_API_ENDPOINT_2')
1113
- if not env_vars.get('secret_api_endpoint_3'):
1114
  missing_critical_vars.append('SECRET_API_ENDPOINT_3')
1115
- if not env_vars.get('secret_api_endpoint_4'):
1116
  missing_critical_vars.append('SECRET_API_ENDPOINT_4')
1117
- if not env_vars.get('secret_api_endpoint_5'): # Check the new endpoint
1118
  missing_critical_vars.append('SECRET_API_ENDPOINT_5')
1119
- if not env_vars.get('mistral_api'):
1120
  missing_critical_vars.append('MISTRAL_API')
1121
- if not env_vars.get('mistral_key'):
1122
  missing_critical_vars.append('MISTRAL_KEY')
1123
- if not os.getenv('NEW_IMG'):
1124
- missing_critical_vars.append('NEW_IMG')
1125
-
1126
 
1127
  health_status = {
1128
  "status": "healthy" if not missing_critical_vars else "unhealthy",
@@ -1132,79 +1043,6 @@ async def health_check():
1132
  }
1133
  return JSONResponse(content=health_status)
1134
 
1135
-
1136
- # --- Admin Endpoints ---
1137
- # Create a separate APIRouter for admin endpoints for better organization
1138
- admin_router = APIRouter(prefix="/admin", tags=["Admin"])
1139
-
1140
- @admin_router.post("/update_models", dependencies=[Depends(verify_api_key)])
1141
- async def update_provider_models(payload: UpdateModelsPayload):
1142
- """
1143
- Updates the list of available models for a specific provider.
1144
- Requires API key authentication.
1145
- """
1146
- global mistral_models, pollinations_models, alternate_models, claude_3_models, supported_image_models, available_model_ids
1147
-
1148
- provider = payload.provider.lower()
1149
- new_models_list = payload.models
1150
-
1151
- # Map provider names to the corresponding global sets
1152
- provider_model_sets = {
1153
- "mistral": mistral_models,
1154
- "pollinations": pollinations_models,
1155
- "alternate": alternate_models,
1156
- "claude_3": claude_3_models,
1157
- "image": supported_image_models # Use "image" as the provider name for image models
1158
- }
1159
-
1160
- if provider not in provider_model_sets:
1161
- raise HTTPException(status_code=400, detail=f"Unknown provider: {provider}. Valid providers are: {list(provider_model_sets.keys())}")
1162
-
1163
- # Update the models set for the specified provider
1164
- # Using set() ensures uniqueness and efficient lookup
1165
- provider_model_sets[provider].clear() # Clear existing models
1166
- provider_model_sets[provider].update(new_models_list) # Add new models
1167
-
1168
- # Rebuild the overall available_model_ids list for the /models endpoint
1169
- # This is important so the /models endpoint reflects the changes
1170
- available_model_ids = list(mistral_models) + \
1171
- list(pollinations_models) + \
1172
- list(alternate_models) + \
1173
- list(claude_3_models) + \
1174
- list(supported_image_models)
1175
- available_model_ids = list(set(available_model_ids)) # Remove duplicates
1176
-
1177
- print(f"Updated models for provider '{provider}'. New models: {list(provider_model_sets[provider])}")
1178
- print(f"Total available models for /models endpoint: {len(available_model_ids)}")
1179
-
1180
-
1181
- # TODO: Implement persistence (e.g., save to models.json)
1182
- # For example: save_models_to_file("models.json", provider_model_sets)
1183
- # And modify startup_event to load from this file.
1184
-
1185
- return {"message": f"Models updated successfully for provider: {provider}", "new_models_count": len(new_models_list)}
1186
-
1187
- @admin_router.get("/view_models", dependencies=[Depends(verify_api_key)])
1188
- async def view_current_models():
1189
- """
1190
- Returns the currently active model sets for all providers.
1191
- Requires API key authentication.
1192
- """
1193
- # Return the current state of the model sets
1194
- return {
1195
- "mistral": list(mistral_models),
1196
- "pollinations": list(pollinations_models),
1197
- "alternate": list(alternate_models),
1198
- "claude_3": list(claude_3_models),
1199
- "image": list(supported_image_models)
1200
- }
1201
-
1202
- # Add the admin router to the main app
1203
- app.include_router(admin_router)
1204
-
1205
-
1206
  if __name__ == "__main__":
1207
  import uvicorn
1208
- # Note: For production, consider using a process manager like Gunicorn
1209
- # with multiple workers for better performance and reliability.
1210
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import re
3
  from dotenv import load_dotenv
4
+ from fastapi import FastAPI, HTTPException, Request, Depends, Security
5
  from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
6
  from fastapi.security import APIKeyHeader
7
  from pydantic import BaseModel
 
12
  import datetime
13
  import time
14
  import threading
15
+ from typing import Optional, Dict, List, Any, Generator
16
  import asyncio
17
  from starlette.status import HTTP_403_FORBIDDEN
18
  import cloudscraper
 
66
  }
67
 
68
  # Configuration for models - use sets for faster lookups
69
+ mistral_models = {
 
 
 
70
  "mistral-large-latest",
71
  "pixtral-large-latest",
72
  "mistral-moderation-latest",
 
78
  "codestral-latest"
79
  }
80
 
81
+ pollinations_models = {
82
  "openai",
83
  "openai-large",
84
  "openai-xlarge",
 
100
  "openai-audio",
101
  "llama-scaleway"
102
  }
103
+ alternate_models = {
104
  "o1",
105
  "llama-4-scout",
106
  "o4-mini",
 
114
  "o3"
115
  }
116
 
117
+ claude_3_models = { # Models for the new endpoint
118
  "claude-3-7-sonnet",
119
  "claude-3-7-sonnet-thinking",
120
  "claude 3.5 haiku",
 
128
  }
129
 
130
  # Supported image generation models
131
+ supported_image_models = {
132
  "Flux Pro Ultra",
133
  "grok-2-aurora",
134
  "Flux Pro",
 
158
  number: int
159
 
160
 
 
 
 
 
161
 
162
  # Server status global variable
163
  server_status = True
164
+ available_model_ids: List[str] = []
165
 
166
  # Create a reusable httpx client pool with connection pooling
167
  @lru_cache(maxsize=1)
 
190
  ) -> bool:
191
  # Allow bypass if the referer is from /playground or /image-playground
192
  referer = request.headers.get("referer", "")
193
+ if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground",
194
  "https://parthsadaria-lokiai.hf.space/image-playground")):
195
  return True
196
+
197
  if not api_key:
198
  raise HTTPException(
199
  status_code=HTTP_403_FORBIDDEN,
200
  detail="No API key provided"
201
  )
202
+
203
  # Only clean if needed
204
  if api_key.startswith('Bearer '):
205
  api_key = api_key[7:] # Remove 'Bearer ' prefix
206
+
207
  # Get API keys from environment
208
  valid_api_keys = get_env_vars().get('api_keys', [])
209
  if not valid_api_keys or valid_api_keys == ['']:
 
211
  status_code=HTTP_403_FORBIDDEN,
212
  detail="API keys not configured on server"
213
  )
214
+
215
  # Fast check with set operation
216
  if api_key not in set(valid_api_keys):
217
  raise HTTPException(
218
  status_code=HTTP_403_FORBIDDEN,
219
  detail="Invalid API key"
220
  )
221
+
222
  return True
223
 
224
  # Pre-load and cache models.json
 
234
 
235
  # Async wrapper for models data
236
  async def get_models():
237
+ models_data = load_models_data()
238
+ if not models_data:
239
+ raise HTTPException(status_code=500, detail="Error loading available models")
240
+ return models_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  # Enhanced async streaming - now with real-time SSE support
243
  async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
 
367
  user_agent = request.headers.get('user-agent', 'Unknown User')
368
  client_ip = request.client.host
369
  location = f"IP: {client_ip}"
370
+
371
  prompt = f"""
372
+ Generate a dynamic HTML page for a user with the following details: with name "LOKI.AI"
373
  - User-Agent: {user_agent}
374
  - Location: {location}
375
  - Style: Cyberpunk, minimalist, or retro
376
+
377
  Make sure the HTML is clean and includes a heading, also have cool animations a motivational message, and a cool background.
378
  Wrap the generated HTML in triple backticks (```).
379
  """
380
+
381
  payload = {
382
  "model": "mistral-small-latest",
383
  "messages": [{"role": "user", "content": prompt}]
384
  }
385
+
386
  headers = {
387
  "Authorization": "Bearer playground"
388
  }
389
+
390
+ response = requests.post("https://parthsadaria-lokiai.hf.space/chat/completions", json=payload, headers=headers)
391
  data = response.json()
392
+
393
  # Extract HTML from ``` blocks
394
  html_content = re.search(r"```(.*?)```", data['choices'][0]['message']['content'], re.DOTALL)
395
  if html_content:
396
  html_content = html_content.group(1).strip()
397
+
398
  # Remove the first word
399
  if html_content:
400
  html_content = ' '.join(html_content.split(' ')[1:])
401
+
402
  return HTMLResponse(content=html_content)
403
+
404
  @app.get("/playground", response_class=HTMLResponse)
405
  async def playground():
406
  html_content = read_html_file("playground.html")
407
  if html_content is None:
408
  return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
409
  return HTMLResponse(content=html_content)
410
+
411
  @app.get("/image-playground", response_class=HTMLResponse)
412
  async def playground():
413
  html_content = read_html_file("image-playground.html")
 
524
  model_to_use = payload.model or "gpt-4o-mini"
525
 
526
  # Validate model availability - fast lookup with set
527
+ if available_model_ids and model_to_use not in set(available_model_ids):
 
 
 
 
 
528
  raise HTTPException(
529
  status_code=400,
530
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
531
  )
532
 
 
533
  # Log request without blocking
534
  asyncio.create_task(log_request(request, model_to_use))
535
  usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
 
559
  elif model_to_use in claude_3_models: # Use the new endpoint
560
  endpoint = env_vars['secret_api_endpoint_5']
561
  custom_headers = {}
562
+ else:
 
 
 
 
563
  endpoint = env_vars['secret_api_endpoint']
564
  custom_headers = {
565
  "Origin": header_url,
 
573
  async def real_time_stream_generator():
574
  try:
575
  async with httpx.AsyncClient(timeout=60.0) as client:
576
+ async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers) as response:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  if response.status_code >= 400:
578
  error_messages = {
579
  422: "Unprocessable entity. Check your payload.",
580
  400: "Bad request. Verify input data.",
581
  403: "Forbidden. You do not have access to this resource.",
582
  404: "The requested resource was not found.",
 
583
  }
584
+ detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
 
 
 
 
 
 
585
  raise HTTPException(status_code=response.status_code, detail=detail)
586
 
587
  # Stream the response in real-time with minimal buffering
 
616
  async for chunk in real_time_stream_generator():
617
  response_content.append(chunk)
618
 
619
+ return JSONResponse(content=json.loads(''.join(response_content)))
620
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
 
623
  # New image generation endpoint
 
637
  if payload.model not in supported_image_models:
638
  raise HTTPException(
639
  status_code=400,
640
+ detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {supported_image_models}"
641
  )
642
 
643
  # Log the request
 
651
  "number": payload.number
652
  }
653
 
654
+ # Target API endpoint
655
+ target_api_url = os.getenv('NEW_IMG')
 
 
 
 
656
 
657
  try:
658
  # Use a timeout for the image generation request
659
+ async with httpx.AsyncClient(timeout=60.0) as client:
660
  response = await client.post(target_api_url, json=api_payload)
661
 
662
  if response.status_code != 200:
 
674
  raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
675
 
676
 
677
+
678
  # Asynchronous logging function
679
  async def log_request(request, model):
680
  # Get minimal data for logging
 
838
  <body>
839
  <div class="container">
840
  <div class="logo">
841
+ <img src="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjAwIiBoZWlnaHQ9IjIwMCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBkPSJNMTAwIDM1TDUwIDkwaDEwMHoiIGZpbGw9IiMzYTZlZTAiLz48Y2lyY2xlIGN4PSIxMDAiIGN5PSIxNDAiIHI9IjMwIiBmaWxsPSIjM2E2ZWUwIi8+PC9zdmc+" alt="Lokai AI Logo">
842
  <h1>Lokiai AI</h1>
843
  </div>
844
 
 
936
  raise HTTPException(status_code=500, detail="Failed to retrieve meme")
937
 
938
  # Utility function for loading model IDs - optimized to run once at startup
939
+ def load_model_ids(json_file_path):
 
 
940
  try:
941
  with open(json_file_path, 'r') as f:
942
  models_data = json.load(f)
943
+ # Extract 'id' from each model object and use a set for fast lookups
944
  return [model['id'] for model in models_data if 'id' in model]
945
  except Exception as e:
946
+ print(f"Error loading model IDs: {str(e)}")
947
  return []
948
 
949
  @app.on_event("startup")
950
  async def startup_event():
951
  global available_model_ids
952
+ available_model_ids = load_model_ids("models.json")
953
+ print(f"Loaded {len(available_model_ids)} model IDs")
954
 
955
+ # Add all pollinations models to available_model_ids
 
956
  available_model_ids.extend(list(pollinations_models))
957
+ # Add alternate models to available_model_ids
958
  available_model_ids.extend(list(alternate_models))
959
+ # Add mistral models to available_model_ids
960
  available_model_ids.extend(list(mistral_models))
961
+ # Add claude models
962
  available_model_ids.extend(list(claude_3_models))
 
963
 
964
  available_model_ids = list(set(available_model_ids)) # Remove duplicates
965
+ print(f"Total available models: {len(available_model_ids)}")
966
 
967
  # Preload scrapers
968
  for _ in range(MAX_SCRAPERS):
 
984
  missing_vars.append('SECRET_API_ENDPOINT_4')
985
  if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
986
  missing_vars.append('SECRET_API_ENDPOINT_5')
987
+ if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids):
988
+ missing_vars.append('MISTRAL_API')
989
+ if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
990
+ missing_vars.append('MISTRAL_KEY')
 
 
 
 
991
 
992
  if missing_vars:
993
  print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
 
1009
 
1010
  print("Server shutdown complete!")
1011
 
1012
+ # Health check endpoint
1013
  # Health check endpoint
1014
  @app.get("/health")
1015
  async def health_check():
 
1018
  missing_critical_vars = []
1019
 
1020
  # Check critical environment variables
1021
+ if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
1022
  missing_critical_vars.append('API_KEYS')
1023
+ if not env_vars['secret_api_endpoint']:
1024
  missing_critical_vars.append('SECRET_API_ENDPOINT')
1025
+ if not env_vars['secret_api_endpoint_2']:
1026
  missing_critical_vars.append('SECRET_API_ENDPOINT_2')
1027
+ if not env_vars['secret_api_endpoint_3']:
1028
  missing_critical_vars.append('SECRET_API_ENDPOINT_3')
1029
+ if not env_vars['secret_api_endpoint_4']:
1030
  missing_critical_vars.append('SECRET_API_ENDPOINT_4')
1031
+ if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
1032
  missing_critical_vars.append('SECRET_API_ENDPOINT_5')
1033
+ if not env_vars['mistral_api']:
1034
  missing_critical_vars.append('MISTRAL_API')
1035
+ if not env_vars['mistral_key']:
1036
  missing_critical_vars.append('MISTRAL_KEY')
 
 
 
1037
 
1038
  health_status = {
1039
  "status": "healthy" if not missing_critical_vars else "unhealthy",
 
1043
  }
1044
  return JSONResponse(content=health_status)
1045
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1046
  if __name__ == "__main__":
1047
  import uvicorn
1048
+ uvicorn.run(app, host="0.0.0.0", port=7860)