ParthSadaria commited on
Commit
4bcb2c2
·
verified ·
1 Parent(s): 4af9da3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +208 -377
main.py CHANGED
@@ -19,12 +19,13 @@ from concurrent.futures import ThreadPoolExecutor
19
  import uvloop
20
  from fastapi.middleware.gzip import GZipMiddleware
21
  from starlette.middleware.cors import CORSMiddleware
 
22
 
23
  # Enable uvloop for faster event loop
24
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
25
 
26
  # Thread pool for CPU-bound operations
27
- executor = ThreadPoolExecutor(max_workers=8)
28
 
29
  # Load environment variables once at startup
30
  load_dotenv()
@@ -59,7 +60,6 @@ def get_env_vars():
59
  'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
60
  'mistral_api': "https://api.mistral.ai",
61
  'mistral_key': os.getenv('MISTRAL_KEY'),
62
- 'image_endpoint': os.getenv("IMAGE_ENDPOINT"),
63
  'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
64
  }
65
 
@@ -128,12 +128,12 @@ available_model_ids: List[str] = []
128
  def get_async_client():
129
  return httpx.AsyncClient(
130
  timeout=60.0,
131
- limits=httpx.Limits(max_keepalive_connections=20, max_connections=100)
132
  )
133
 
134
  # Create a cloudscraper pool
135
  scraper_pool = []
136
- MAX_SCRAPERS = 10
137
 
138
  def get_scraper():
139
  if not scraper_pool:
@@ -197,83 +197,86 @@ async def get_models():
197
  raise HTTPException(status_code=500, detail="Error loading available models")
198
  return models_data
199
 
200
- # Searcher function with optimized streaming - moved to a separate thread
201
  async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
202
- loop = asyncio.get_running_loop()
 
203
 
204
- def _generate_search():
205
- headers = {"User-Agent": ""}
206
-
207
- # Use the provided system prompt, or default to "Be Helpful and Friendly"
208
- system_message = systemprompt or "Be Helpful and Friendly"
209
-
210
- # Create the prompt history with the user query and system message
211
- prompt = [
212
- {"role": "user", "content": query},
213
- ]
214
-
215
- prompt.insert(0, {"content": system_message, "role": "system"})
216
-
217
- # Prepare the payload for the API request
218
- payload = {
219
- "is_vscode_extension": True,
220
- "message_history": prompt,
221
- "requested_model": "Claude 3.7 Sonnet",
222
- "user_input": prompt[-1]["content"],
223
- }
224
-
225
- # Get endpoint from environment
226
- secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
227
- if not secret_api_endpoint_3:
228
- raise ValueError("Search API endpoint not configured")
229
-
230
- # Send the request to the chat endpoint using a scraper from the pool
231
- response = get_scraper().post(
232
- secret_api_endpoint_3,
233
- headers=headers,
234
- json=payload,
235
- stream=True
236
- )
237
-
238
- result = []
239
- streaming_text = ""
240
-
241
- # Process the streaming response
242
- for value in response.iter_lines(decode_unicode=True):
243
- if value.startswith("data: "):
244
- try:
245
- json_modified_value = json.loads(value[6:])
246
- content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "")
247
-
248
- if content.strip(): # Only process non-empty content
249
- cleaned_response = {
250
- "created": json_modified_value.get("created"),
251
- "id": json_modified_value.get("id"),
252
- "model": "searchgpt",
253
- "object": "chat.completion",
254
- "choices": [
255
- {
256
- "message": {
257
- "content": content
 
 
258
  }
259
- }
260
- ]
261
- }
262
-
263
- if stream:
264
- result.append(f"data: {json.dumps(cleaned_response)}\n\n")
265
-
266
- streaming_text += content
267
- except json.JSONDecodeError:
268
- continue
269
-
270
- if not stream:
271
- result.append(streaming_text)
272
 
273
- return result
 
 
 
 
 
 
 
 
274
 
275
- # Run in thread pool to avoid blocking the event loop
276
- return await loop.run_in_executor(executor, _generate_search)
277
 
278
  # Cache for frequently accessed static files
279
  @lru_cache(maxsize=10)
@@ -314,7 +317,7 @@ async def playground():
314
  async def return_models():
315
  return await get_models()
316
 
317
- # Search routes
318
  @app.get("/searchgpt")
319
  async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
320
  if not q:
@@ -322,22 +325,44 @@ async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optio
322
 
323
  usage_tracker.record_request(endpoint="/searchgpt")
324
 
325
- result = await generate_search_async(q, systemprompt=systemprompt, stream=stream)
326
 
327
  if stream:
328
  async def stream_generator():
329
- for chunk in result:
330
- yield chunk
 
 
 
 
 
 
 
 
 
 
 
331
 
332
  return StreamingResponse(
333
  stream_generator(),
334
  media_type="text/event-stream"
335
  )
336
  else:
337
- # For non-streaming, return the collected text
338
- return JSONResponse(content={"response": result[0] if result else ""})
 
 
 
 
 
 
 
 
 
 
 
339
 
340
- # Chat completion endpoint
341
  @app.post("/chat/completions")
342
  @app.post("/api/v1/chat/completions")
343
  async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
@@ -364,7 +389,10 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
364
  # Prepare payload
365
  payload_dict = payload.dict()
366
  payload_dict["model"] = model_to_use
367
-
 
 
 
368
  # Get environment variables
369
  env_vars = get_env_vars()
370
 
@@ -384,35 +412,13 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
384
  endpoint = env_vars['secret_api_endpoint']
385
  custom_headers = {}
386
 
387
- print(f"Using endpoint: {endpoint}")
388
-
389
- # Create a new scraper for each request to avoid potential blocking
390
- scraper = cloudscraper.create_scraper(browser={
391
- 'browser': 'chrome',
392
- 'platform': 'windows',
393
- 'mobile': False
394
- })
395
 
396
- # Set a timeout for the entire request handling
397
- TIMEOUT_SECONDS = 20
398
-
399
- async def stream_generator_with_timeout(payload_dict):
400
  try:
401
- # Create a thread-safe event for cancellation
402
- cancel_event = threading.Event()
403
-
404
- def request_with_timeout():
405
- try:
406
- # Send POST request with the correct headers and timeout
407
- response = scraper.post(
408
- f"{endpoint}/v1/chat/completions",
409
- json=payload_dict,
410
- headers=custom_headers,
411
- stream=True,
412
- timeout=TIMEOUT_SECONDS
413
- )
414
-
415
- # Handle response errors
416
  if response.status_code >= 400:
417
  error_messages = {
418
  422: "Unprocessable entity. Check your payload.",
@@ -421,68 +427,42 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
421
  404: "The requested resource was not found.",
422
  }
423
  detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
424
- return {"error": detail, "status_code": response.status_code}
425
 
426
- result = []
427
-
428
- # Process the streaming response with timeout checks
429
- for line in response.iter_lines():
430
- # Check for cancellation
431
- if cancel_event.is_set():
432
- break
433
-
434
  if line:
435
- decoded = line.decode('utf-8') + "\n"
436
- result.append(decoded)
437
-
438
- return {"lines": result}
439
-
440
- except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
441
- return {"error": "Request timed out or connection failed", "status_code": 504}
442
- except Exception as e:
443
- return {"error": str(e), "status_code": 500}
444
-
445
- # Execute request in a ThreadPoolExecutor with a timeout
446
- loop = asyncio.get_running_loop()
447
- with ThreadPoolExecutor() as pool:
448
- response_future = loop.run_in_executor(pool, request_with_timeout)
449
-
450
- try:
451
- # Wait for response with a timeout
452
- response_data = await asyncio.wait_for(response_future, timeout=TIMEOUT_SECONDS)
453
-
454
- # If there was an error, raise an HTTPException
455
- if "error" in response_data:
456
- raise HTTPException(
457
- status_code=response_data.get("status_code", 500),
458
- detail=response_data["error"]
459
- )
460
-
461
- # Stream the response lines
462
- for line in response_data.get("lines", []):
463
- yield line
464
-
465
- except asyncio.TimeoutError:
466
- # Cancel the ongoing request
467
- cancel_event.set()
468
- raise HTTPException(status_code=504, detail="Request timed out after 20 seconds")
469
-
470
  except Exception as e:
471
  if isinstance(e, HTTPException):
472
  raise e
473
- # Use a generic error message that doesn't expose internal details
474
- raise HTTPException(status_code=500, detail=f"An error occurred while processing your request: {str(e)}")
475
 
476
- # Return streaming response with proper timeout handling
477
- try:
478
  return StreamingResponse(
479
- stream_generator_with_timeout(payload_dict),
480
- media_type="application/json"
 
 
 
 
 
 
481
  )
482
- except Exception as e:
483
- if isinstance(e, HTTPException):
484
- raise e
485
- raise HTTPException(status_code=500, detail=f"Failed to initialize streaming response: {str(e)}")
 
 
 
 
486
  # Asynchronous logging function
487
  async def log_request(request, model):
488
  # Get minimal data for logging
@@ -490,160 +470,6 @@ async def log_request(request, model):
490
  ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
491
  print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
492
 
493
- # Image generation endpoint - optimized to use connection pool
494
- @app.api_route("/images/generations", methods=["GET", "POST"])
495
- async def generate_image(
496
- prompt: Optional[str] = None,
497
- model: str = "flux",
498
- seed: Optional[int] = None,
499
- width: Optional[int] = None,
500
- height: Optional[int] = None,
501
- nologo: Optional[bool] = True,
502
- private: Optional[bool] = None,
503
- enhance: Optional[bool] = None,
504
- request: Request = None,
505
- authenticated: bool = Depends(verify_api_key)
506
- ):
507
- # Validate the image endpoint
508
- image_endpoint = get_env_vars()['image_endpoint']
509
- if not image_endpoint:
510
- raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
511
-
512
- usage_tracker.record_request(endpoint="/images/generations")
513
-
514
- # Handle GET and POST prompts
515
- if request.method == "POST":
516
- try:
517
- body = await request.json()
518
- prompt = body.get("prompt", "").strip()
519
- if not prompt:
520
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
521
- except Exception:
522
- raise HTTPException(status_code=400, detail="Invalid JSON payload")
523
- elif request.method == "GET":
524
- if not prompt or not prompt.strip():
525
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
526
- prompt = prompt.strip()
527
-
528
- # Sanitize and encode the prompt
529
- encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt')
530
-
531
- # Construct the URL with the encoded prompt
532
- base_url = image_endpoint.rstrip('/')
533
- url = f"{base_url}/{encoded_prompt}"
534
-
535
- # Prepare query parameters with validation
536
- params = {}
537
- if model and isinstance(model, str):
538
- params['model'] = model
539
- if seed is not None and isinstance(seed, int):
540
- params['seed'] = seed
541
- if width is not None and isinstance(width, int) and 64 <= width <= 2048:
542
- params['width'] = width
543
- if height is not None and isinstance(height, int) and 64 <= height <= 2048:
544
- params['height'] = height
545
- if nologo is not None:
546
- params['nologo'] = str(nologo).lower()
547
- if private is not None:
548
- params['private'] = str(private).lower()
549
- if enhance is not None:
550
- params['enhance'] = str(enhance).lower()
551
-
552
- try:
553
- # Use the shared httpx client for connection pooling
554
- client = get_async_client()
555
- response = await client.get(url, params=params, follow_redirects=True)
556
-
557
- # Check for various error conditions
558
- if response.status_code != 200:
559
- error_messages = {
560
- 404: "Image generation service not found",
561
- 400: "Invalid parameters provided to image service",
562
- 429: "Too many requests to image service",
563
- }
564
- detail = error_messages.get(response.status_code, f"Image generation failed with status code {response.status_code}")
565
- raise HTTPException(status_code=response.status_code, detail=detail)
566
-
567
- # Verify content type
568
- content_type = response.headers.get('content-type', '')
569
- if not content_type.startswith('image/'):
570
- raise HTTPException(
571
- status_code=500,
572
- detail="Unexpected content type received from image service"
573
- )
574
-
575
- # Use larger chunks for streaming for better performance
576
- async def stream_with_larger_chunks():
577
- chunks = []
578
- size = 0
579
- async for chunk in response.aiter_bytes(chunk_size=16384): # Use 16KB chunks
580
- chunks.append(chunk)
581
- size += len(chunk)
582
-
583
- if size >= 65536: # Yield every 64KB
584
- yield b''.join(chunks)
585
- chunks = []
586
- size = 0
587
-
588
- if chunks:
589
- yield b''.join(chunks)
590
-
591
- return StreamingResponse(
592
- stream_with_larger_chunks(),
593
- media_type=content_type,
594
- headers={
595
- 'Cache-Control': 'no-cache, no-store, must-revalidate',
596
- 'Pragma': 'no-cache',
597
- 'Expires': '0'
598
- }
599
- )
600
-
601
- except httpx.TimeoutException:
602
- raise HTTPException(status_code=504, detail="Image generation request timed out")
603
- except httpx.RequestError:
604
- raise HTTPException(status_code=500, detail="Failed to contact image service")
605
- except Exception:
606
- raise HTTPException(status_code=500, detail="Unexpected error during image generation")
607
-
608
- # Meme endpoint with optimized networking
609
- @app.get("/meme")
610
- async def get_meme():
611
- try:
612
- # Use the shared client for connection pooling
613
- client = get_async_client()
614
- response = await client.get("https://meme-api.com/gimme")
615
- response_data = response.json()
616
-
617
- meme_url = response_data.get("url")
618
- if not meme_url:
619
- raise HTTPException(status_code=404, detail="No meme found")
620
-
621
- image_response = await client.get(meme_url, follow_redirects=True)
622
-
623
- # Use larger chunks for streaming
624
- async def stream_with_larger_chunks():
625
- chunks = []
626
- size = 0
627
- async for chunk in image_response.aiter_bytes(chunk_size=16384):
628
- chunks.append(chunk)
629
- size += len(chunk)
630
-
631
- if size >= 65536:
632
- yield b''.join(chunks)
633
- chunks = []
634
- size = 0
635
-
636
- if chunks:
637
- yield b''.join(chunks)
638
-
639
- return StreamingResponse(
640
- stream_with_larger_chunks(),
641
- media_type=image_response.headers.get("content-type", "image/png"),
642
- headers={'Cache-Control': 'max-age=3600'} # Add caching
643
- )
644
- except Exception:
645
- raise HTTPException(status_code=500, detail="Failed to retrieve meme")
646
-
647
  # Cache usage statistics
648
  @lru_cache(maxsize=10)
649
  def get_usage_summary(days=7):
@@ -858,6 +684,45 @@ async def usage_page():
858
  html_content = get_usage_page_html()
859
  return HTMLResponse(content=html_content)
860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
  # Utility function for loading model IDs - optimized to run once at startup
862
  def load_model_ids(json_file_path):
863
  try:
@@ -877,8 +742,13 @@ async def startup_event():
877
 
878
  # Add all pollinations models to available_model_ids
879
  available_model_ids.extend(list(pollinations_models))
 
 
 
 
 
880
  available_model_ids = list(set(available_model_ids)) # Remove duplicates
881
- print(f"Added Pollinations models. Total available models: {len(available_model_ids)}")
882
 
883
  # Preload scrapers
884
  for _ in range(MAX_SCRAPERS):
@@ -900,8 +770,6 @@ async def startup_event():
900
  missing_vars.append('MISTRAL_API')
901
  if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
902
  missing_vars.append('MISTRAL_KEY')
903
- if not env_vars['image_endpoint']:
904
- missing_vars.append('IMAGE_ENDPOINT')
905
 
906
  if missing_vars:
907
  print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
@@ -923,8 +791,7 @@ async def shutdown_event():
923
 
924
  print("Server shutdown complete!")
925
 
926
- # Server maintenance endpoint
927
-
928
  # Health check endpoint
929
  @app.get("/health")
930
  async def health_check():
@@ -937,61 +804,25 @@ async def health_check():
937
  missing_critical_vars.append('API_KEYS')
938
  if not env_vars['secret_api_endpoint']:
939
  missing_critical_vars.append('SECRET_API_ENDPOINT')
 
 
 
 
 
 
 
 
 
 
940
 
941
- # Check if models are loaded
942
- models_loaded = len(available_model_ids) > 0
943
-
944
- status = "healthy"
945
- if missing_critical_vars or not models_loaded:
946
- status = "degraded"
947
-
948
- return {
949
- "status": status,
950
- "timestamp": datetime.datetime.utcnow().isoformat(),
951
- "uptime": time.time() - usage_tracker.start_time,
952
- "models_loaded": models_loaded,
953
- "model_count": len(available_model_ids),
954
- "issues": {
955
- "missing_env_vars": missing_critical_vars,
956
- "models_available": models_loaded
957
- }
958
  }
 
959
 
960
- # Error handlers
961
- @app.exception_handler(HTTPException)
962
- async def http_exception_handler(request, exc):
963
- """Format HTTP exceptions in a consistent way"""
964
- return JSONResponse(
965
- status_code=exc.status_code,
966
- content={"error": exc.detail}
967
- )
968
-
969
- @app.exception_handler(Exception)
970
- async def general_exception_handler(request, exc):
971
- """Handle unexpected exceptions gracefully"""
972
- # Log the error for debugging
973
- print(f"Unexpected error: {str(exc)}")
974
-
975
- return JSONResponse(
976
- status_code=500,
977
- content={"error": "An unexpected error occurred. Please try again later."}
978
- )
979
-
980
- # Static files endpoint for serving CSS, JS, etc.
981
-
982
- # Documentation
983
-
984
- # Run the server when executed directly
985
  if __name__ == "__main__":
986
  import uvicorn
987
-
988
- port = int(os.getenv("PORT", 7860))
989
-
990
- print(f"Starting Lokiai AI server on port {port}")
991
- uvicorn.run(
992
- "main:app",
993
- host="0.0.0.0",
994
- port=port,
995
- reload=False,
996
- log_level="info"
997
- )
 
19
  import uvloop
20
  from fastapi.middleware.gzip import GZipMiddleware
21
  from starlette.middleware.cors import CORSMiddleware
22
+ import contextlib
23
 
24
  # Enable uvloop for faster event loop
25
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
26
 
27
  # Thread pool for CPU-bound operations
28
+ executor = ThreadPoolExecutor(max_workers=16) # Increased thread count for better parallelism
29
 
30
  # Load environment variables once at startup
31
  load_dotenv()
 
60
  'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
61
  'mistral_api': "https://api.mistral.ai",
62
  'mistral_key': os.getenv('MISTRAL_KEY'),
 
63
  'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
64
  }
65
 
 
128
  def get_async_client():
129
  return httpx.AsyncClient(
130
  timeout=60.0,
131
+ limits=httpx.Limits(max_keepalive_connections=50, max_connections=200) # Increased limits
132
  )
133
 
134
  # Create a cloudscraper pool
135
  scraper_pool = []
136
+ MAX_SCRAPERS = 20 # Increased pool size
137
 
138
  def get_scraper():
139
  if not scraper_pool:
 
197
  raise HTTPException(status_code=500, detail="Error loading available models")
198
  return models_data
199
 
200
+ # Enhanced async streaming - now with real-time SSE support
201
  async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
202
+ # Create a streaming response channel using asyncio.Queue
203
+ queue = asyncio.Queue()
204
 
205
+ async def _fetch_search_data():
206
+ try:
207
+ headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
208
+
209
+ # Use the provided system prompt, or default to "Be Helpful and Friendly"
210
+ system_message = systemprompt or "Be Helpful and Friendly"
211
+
212
+ # Create the prompt history
213
+ prompt = [
214
+ {"role": "user", "content": query},
215
+ ]
216
+
217
+ prompt.insert(0, {"content": system_message, "role": "system"})
218
+
219
+ # Prepare the payload for the API request
220
+ payload = {
221
+ "is_vscode_extension": True,
222
+ "message_history": prompt,
223
+ "requested_model": "Claude 3.7 Sonnet",
224
+ "user_input": prompt[-1]["content"],
225
+ }
226
+
227
+ # Get endpoint from environment
228
+ secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
229
+ if not secret_api_endpoint_3:
230
+ await queue.put({"error": "Search API endpoint not configured"})
231
+ return
232
+
233
+ # Use AsyncClient for better performance
234
+ async with httpx.AsyncClient(timeout=30.0) as client:
235
+ async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response:
236
+ if response.status_code != 200:
237
+ await queue.put({"error": f"Search API returned status code {response.status_code}"})
238
+ return
239
+
240
+ # Process the streaming response in real-time
241
+ buffer = ""
242
+ async for line in response.aiter_lines():
243
+ if line.startswith("data: "):
244
+ try:
245
+ json_data = json.loads(line[6:])
246
+ content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
247
+
248
+ if content.strip():
249
+ cleaned_response = {
250
+ "created": json_data.get("created"),
251
+ "id": json_data.get("id"),
252
+ "model": "searchgpt",
253
+ "object": "chat.completion",
254
+ "choices": [
255
+ {
256
+ "message": {
257
+ "content": content
258
+ }
259
+ }
260
+ ]
261
  }
262
+
263
+ # Send to queue immediately for streaming
264
+ await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content})
265
+ except json.JSONDecodeError:
266
+ continue
 
 
 
 
 
 
 
 
267
 
268
+ # Signal completion
269
+ await queue.put(None)
270
+
271
+ except Exception as e:
272
+ await queue.put({"error": str(e)})
273
+ await queue.put(None)
274
+
275
+ # Start the fetch process
276
+ asyncio.create_task(_fetch_search_data())
277
 
278
+ # Return the queue for consumption
279
+ return queue
280
 
281
  # Cache for frequently accessed static files
282
  @lru_cache(maxsize=10)
 
317
  async def return_models():
318
  return await get_models()
319
 
320
+ # Search routes with enhanced real-time streaming
321
  @app.get("/searchgpt")
322
  async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
323
  if not q:
 
325
 
326
  usage_tracker.record_request(endpoint="/searchgpt")
327
 
328
+ queue = await generate_search_async(q, systemprompt=systemprompt, stream=True)
329
 
330
  if stream:
331
  async def stream_generator():
332
+ collected_text = ""
333
+ while True:
334
+ item = await queue.get()
335
+ if item is None:
336
+ break
337
+
338
+ if "error" in item:
339
+ yield f"data: {json.dumps({'error': item['error']})}\n\n"
340
+ break
341
+
342
+ if "data" in item:
343
+ yield item["data"]
344
+ collected_text += item.get("text", "")
345
 
346
  return StreamingResponse(
347
  stream_generator(),
348
  media_type="text/event-stream"
349
  )
350
  else:
351
+ # For non-streaming, collect all text and return at once
352
+ collected_text = ""
353
+ while True:
354
+ item = await queue.get()
355
+ if item is None:
356
+ break
357
+
358
+ if "error" in item:
359
+ raise HTTPException(status_code=500, detail=item["error"])
360
+
361
+ collected_text += item.get("text", "")
362
+
363
+ return JSONResponse(content={"response": collected_text})
364
 
365
+ # Enhanced streaming with direct SSE pass-through for real-time responses
366
  @app.post("/chat/completions")
367
  @app.post("/api/v1/chat/completions")
368
  async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
 
389
  # Prepare payload
390
  payload_dict = payload.dict()
391
  payload_dict["model"] = model_to_use
392
+
393
+ # Ensure stream is True for real-time streaming (can be overridden by client)
394
+ stream_enabled = payload_dict.get("stream", True)
395
+
396
  # Get environment variables
397
  env_vars = get_env_vars()
398
 
 
412
  endpoint = env_vars['secret_api_endpoint']
413
  custom_headers = {}
414
 
415
+ print(f"Using endpoint: {endpoint} for model: {model_to_use}")
 
 
 
 
 
 
 
416
 
417
+ # Improved real-time streaming handler
418
+ async def real_time_stream_generator():
 
 
419
  try:
420
+ async with httpx.AsyncClient(timeout=60.0) as client:
421
+ async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers) as response:
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  if response.status_code >= 400:
423
  error_messages = {
424
  422: "Unprocessable entity. Check your payload.",
 
427
  404: "The requested resource was not found.",
428
  }
429
  detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
430
+ raise HTTPException(status_code=response.status_code, detail=detail)
431
 
432
+ # Stream the response in real-time with minimal buffering
433
+ async for line in response.aiter_lines():
 
 
 
 
 
 
434
  if line:
435
+ # Yield immediately for faster streaming
436
+ yield line + "\n"
437
+ except httpx.TimeoutException:
438
+ raise HTTPException(status_code=504, detail="Request timed out")
439
+ except httpx.RequestError as e:
440
+ raise HTTPException(status_code=502, detail=f"Failed to connect to upstream API: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  except Exception as e:
442
  if isinstance(e, HTTPException):
443
  raise e
444
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
445
 
446
+ # Return streaming response with proper headers
447
+ if stream_enabled:
448
  return StreamingResponse(
449
+ real_time_stream_generator(),
450
+ media_type="text/event-stream",
451
+ headers={
452
+ "Content-Type": "text/event-stream",
453
+ "Cache-Control": "no-cache",
454
+ "Connection": "keep-alive",
455
+ "X-Accel-Buffering": "no" # Disable proxy buffering for Nginx
456
+ }
457
  )
458
+ else:
459
+ # For non-streaming requests, collect the entire response
460
+ response_content = []
461
+ async for chunk in real_time_stream_generator():
462
+ response_content.append(chunk)
463
+
464
+ return JSONResponse(content=json.loads(''.join(response_content)))
465
+
466
  # Asynchronous logging function
467
  async def log_request(request, model):
468
  # Get minimal data for logging
 
470
  ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
471
  print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  # Cache usage statistics
474
  @lru_cache(maxsize=10)
475
  def get_usage_summary(days=7):
 
684
  html_content = get_usage_page_html()
685
  return HTMLResponse(content=html_content)
686
 
687
+ # Meme endpoint with optimized networking
688
+ @app.get("/meme")
689
+ async def get_meme():
690
+ try:
691
+ # Use the shared client for connection pooling
692
+ client = get_async_client()
693
+ response = await client.get("https://meme-api.com/gimme")
694
+ response_data = response.json()
695
+
696
+ meme_url = response_data.get("url")
697
+ if not meme_url:
698
+ raise HTTPException(status_code=404, detail="No meme found")
699
+
700
+ image_response = await client.get(meme_url, follow_redirects=True)
701
+
702
+ # Use larger chunks for streaming
703
+ async def stream_with_larger_chunks():
704
+ chunks = []
705
+ size = 0
706
+ async for chunk in image_response.aiter_bytes(chunk_size=16384):
707
+ chunks.append(chunk)
708
+ size += len(chunk)
709
+
710
+ if size >= 65536:
711
+ yield b''.join(chunks)
712
+ chunks = []
713
+ size = 0
714
+
715
+ if chunks:
716
+ yield b''.join(chunks)
717
+
718
+ return StreamingResponse(
719
+ stream_with_larger_chunks(),
720
+ media_type=image_response.headers.get("content-type", "image/png"),
721
+ headers={'Cache-Control': 'max-age=3600'} # Add caching
722
+ )
723
+ except Exception:
724
+ raise HTTPException(status_code=500, detail="Failed to retrieve meme")
725
+
726
  # Utility function for loading model IDs - optimized to run once at startup
727
  def load_model_ids(json_file_path):
728
  try:
 
742
 
743
  # Add all pollinations models to available_model_ids
744
  available_model_ids.extend(list(pollinations_models))
745
+ # Add alternate models to available_model_ids
746
+ available_model_ids.extend(list(alternate_models))
747
+ # Add mistral models to available_model_ids
748
+ available_model_ids.extend(list(mistral_models))
749
+
750
  available_model_ids = list(set(available_model_ids)) # Remove duplicates
751
+ print(f"Total available models: {len(available_model_ids)}")
752
 
753
  # Preload scrapers
754
  for _ in range(MAX_SCRAPERS):
 
770
  missing_vars.append('MISTRAL_API')
771
  if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
772
  missing_vars.append('MISTRAL_KEY')
 
 
773
 
774
  if missing_vars:
775
  print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
 
791
 
792
  print("Server shutdown complete!")
793
 
794
+ # Health check endpoint
 
795
  # Health check endpoint
796
  @app.get("/health")
797
  async def health_check():
 
804
  missing_critical_vars.append('API_KEYS')
805
  if not env_vars['secret_api_endpoint']:
806
  missing_critical_vars.append('SECRET_API_ENDPOINT')
807
+ if not env_vars['secret_api_endpoint_2']:
808
+ missing_critical_vars.append('SECRET_API_ENDPOINT_2')
809
+ if not env_vars['secret_api_endpoint_3']:
810
+ missing_critical_vars.append('SECRET_API_ENDPOINT_3')
811
+ if not env_vars['secret_api_endpoint_4']:
812
+ missing_critical_vars.append('SECRET_API_ENDPOINT_4')
813
+ if not env_vars['mistral_api']:
814
+ missing_critical_vars.append('MISTRAL_API')
815
+ if not env_vars['mistral_key']:
816
+ missing_critical_vars.append('MISTRAL_KEY')
817
 
818
+ health_status = {
819
+ "status": "healthy" if not missing_critical_vars else "unhealthy",
820
+ "missing_env_vars": missing_critical_vars,
821
+ "server_status": server_status,
822
+ "message": "Everything's lit! 🚀" if not missing_critical_vars else "Uh oh, some env vars are missing. 😬"
 
 
 
 
 
 
 
 
 
 
 
 
823
  }
824
+ return JSONResponse(content=health_status)
825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826
  if __name__ == "__main__":
827
  import uvicorn
828
+ uvicorn.run(app, host="0.0.0.0", port=7860)