Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -19,12 +19,13 @@ from concurrent.futures import ThreadPoolExecutor
|
|
19 |
import uvloop
|
20 |
from fastapi.middleware.gzip import GZipMiddleware
|
21 |
from starlette.middleware.cors import CORSMiddleware
|
|
|
22 |
|
23 |
# Enable uvloop for faster event loop
|
24 |
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
25 |
|
26 |
# Thread pool for CPU-bound operations
|
27 |
-
executor = ThreadPoolExecutor(max_workers=
|
28 |
|
29 |
# Load environment variables once at startup
|
30 |
load_dotenv()
|
@@ -59,7 +60,6 @@ def get_env_vars():
|
|
59 |
'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
|
60 |
'mistral_api': "https://api.mistral.ai",
|
61 |
'mistral_key': os.getenv('MISTRAL_KEY'),
|
62 |
-
'image_endpoint': os.getenv("IMAGE_ENDPOINT"),
|
63 |
'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
|
64 |
}
|
65 |
|
@@ -128,12 +128,12 @@ available_model_ids: List[str] = []
|
|
128 |
def get_async_client():
|
129 |
return httpx.AsyncClient(
|
130 |
timeout=60.0,
|
131 |
-
limits=httpx.Limits(max_keepalive_connections=
|
132 |
)
|
133 |
|
134 |
# Create a cloudscraper pool
|
135 |
scraper_pool = []
|
136 |
-
MAX_SCRAPERS =
|
137 |
|
138 |
def get_scraper():
|
139 |
if not scraper_pool:
|
@@ -197,83 +197,86 @@ async def get_models():
|
|
197 |
raise HTTPException(status_code=500, detail="Error loading available models")
|
198 |
return models_data
|
199 |
|
200 |
-
#
|
201 |
async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
|
202 |
-
|
|
|
203 |
|
204 |
-
def
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
|
|
258 |
}
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
result.append(f"data: {json.dumps(cleaned_response)}\n\n")
|
265 |
-
|
266 |
-
streaming_text += content
|
267 |
-
except json.JSONDecodeError:
|
268 |
-
continue
|
269 |
-
|
270 |
-
if not stream:
|
271 |
-
result.append(streaming_text)
|
272 |
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
-
#
|
276 |
-
return
|
277 |
|
278 |
# Cache for frequently accessed static files
|
279 |
@lru_cache(maxsize=10)
|
@@ -314,7 +317,7 @@ async def playground():
|
|
314 |
async def return_models():
|
315 |
return await get_models()
|
316 |
|
317 |
-
# Search routes
|
318 |
@app.get("/searchgpt")
|
319 |
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
|
320 |
if not q:
|
@@ -322,22 +325,44 @@ async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optio
|
|
322 |
|
323 |
usage_tracker.record_request(endpoint="/searchgpt")
|
324 |
|
325 |
-
|
326 |
|
327 |
if stream:
|
328 |
async def stream_generator():
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
return StreamingResponse(
|
333 |
stream_generator(),
|
334 |
media_type="text/event-stream"
|
335 |
)
|
336 |
else:
|
337 |
-
# For non-streaming, return
|
338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
-
#
|
341 |
@app.post("/chat/completions")
|
342 |
@app.post("/api/v1/chat/completions")
|
343 |
async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
|
@@ -364,7 +389,10 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
364 |
# Prepare payload
|
365 |
payload_dict = payload.dict()
|
366 |
payload_dict["model"] = model_to_use
|
367 |
-
|
|
|
|
|
|
|
368 |
# Get environment variables
|
369 |
env_vars = get_env_vars()
|
370 |
|
@@ -384,35 +412,13 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
384 |
endpoint = env_vars['secret_api_endpoint']
|
385 |
custom_headers = {}
|
386 |
|
387 |
-
print(f"Using endpoint: {endpoint}")
|
388 |
-
|
389 |
-
# Create a new scraper for each request to avoid potential blocking
|
390 |
-
scraper = cloudscraper.create_scraper(browser={
|
391 |
-
'browser': 'chrome',
|
392 |
-
'platform': 'windows',
|
393 |
-
'mobile': False
|
394 |
-
})
|
395 |
|
396 |
-
#
|
397 |
-
|
398 |
-
|
399 |
-
async def stream_generator_with_timeout(payload_dict):
|
400 |
try:
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
def request_with_timeout():
|
405 |
-
try:
|
406 |
-
# Send POST request with the correct headers and timeout
|
407 |
-
response = scraper.post(
|
408 |
-
f"{endpoint}/v1/chat/completions",
|
409 |
-
json=payload_dict,
|
410 |
-
headers=custom_headers,
|
411 |
-
stream=True,
|
412 |
-
timeout=TIMEOUT_SECONDS
|
413 |
-
)
|
414 |
-
|
415 |
-
# Handle response errors
|
416 |
if response.status_code >= 400:
|
417 |
error_messages = {
|
418 |
422: "Unprocessable entity. Check your payload.",
|
@@ -421,68 +427,42 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
421 |
404: "The requested resource was not found.",
|
422 |
}
|
423 |
detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
|
424 |
-
|
425 |
|
426 |
-
|
427 |
-
|
428 |
-
# Process the streaming response with timeout checks
|
429 |
-
for line in response.iter_lines():
|
430 |
-
# Check for cancellation
|
431 |
-
if cancel_event.is_set():
|
432 |
-
break
|
433 |
-
|
434 |
if line:
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
return {"error": "Request timed out or connection failed", "status_code": 504}
|
442 |
-
except Exception as e:
|
443 |
-
return {"error": str(e), "status_code": 500}
|
444 |
-
|
445 |
-
# Execute request in a ThreadPoolExecutor with a timeout
|
446 |
-
loop = asyncio.get_running_loop()
|
447 |
-
with ThreadPoolExecutor() as pool:
|
448 |
-
response_future = loop.run_in_executor(pool, request_with_timeout)
|
449 |
-
|
450 |
-
try:
|
451 |
-
# Wait for response with a timeout
|
452 |
-
response_data = await asyncio.wait_for(response_future, timeout=TIMEOUT_SECONDS)
|
453 |
-
|
454 |
-
# If there was an error, raise an HTTPException
|
455 |
-
if "error" in response_data:
|
456 |
-
raise HTTPException(
|
457 |
-
status_code=response_data.get("status_code", 500),
|
458 |
-
detail=response_data["error"]
|
459 |
-
)
|
460 |
-
|
461 |
-
# Stream the response lines
|
462 |
-
for line in response_data.get("lines", []):
|
463 |
-
yield line
|
464 |
-
|
465 |
-
except asyncio.TimeoutError:
|
466 |
-
# Cancel the ongoing request
|
467 |
-
cancel_event.set()
|
468 |
-
raise HTTPException(status_code=504, detail="Request timed out after 20 seconds")
|
469 |
-
|
470 |
except Exception as e:
|
471 |
if isinstance(e, HTTPException):
|
472 |
raise e
|
473 |
-
|
474 |
-
raise HTTPException(status_code=500, detail=f"An error occurred while processing your request: {str(e)}")
|
475 |
|
476 |
-
# Return streaming response with proper
|
477 |
-
|
478 |
return StreamingResponse(
|
479 |
-
|
480 |
-
media_type="
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
)
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
|
|
|
|
|
|
|
|
486 |
# Asynchronous logging function
|
487 |
async def log_request(request, model):
|
488 |
# Get minimal data for logging
|
@@ -490,160 +470,6 @@ async def log_request(request, model):
|
|
490 |
ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
|
491 |
print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
|
492 |
|
493 |
-
# Image generation endpoint - optimized to use connection pool
|
494 |
-
@app.api_route("/images/generations", methods=["GET", "POST"])
|
495 |
-
async def generate_image(
|
496 |
-
prompt: Optional[str] = None,
|
497 |
-
model: str = "flux",
|
498 |
-
seed: Optional[int] = None,
|
499 |
-
width: Optional[int] = None,
|
500 |
-
height: Optional[int] = None,
|
501 |
-
nologo: Optional[bool] = True,
|
502 |
-
private: Optional[bool] = None,
|
503 |
-
enhance: Optional[bool] = None,
|
504 |
-
request: Request = None,
|
505 |
-
authenticated: bool = Depends(verify_api_key)
|
506 |
-
):
|
507 |
-
# Validate the image endpoint
|
508 |
-
image_endpoint = get_env_vars()['image_endpoint']
|
509 |
-
if not image_endpoint:
|
510 |
-
raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
|
511 |
-
|
512 |
-
usage_tracker.record_request(endpoint="/images/generations")
|
513 |
-
|
514 |
-
# Handle GET and POST prompts
|
515 |
-
if request.method == "POST":
|
516 |
-
try:
|
517 |
-
body = await request.json()
|
518 |
-
prompt = body.get("prompt", "").strip()
|
519 |
-
if not prompt:
|
520 |
-
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
521 |
-
except Exception:
|
522 |
-
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
523 |
-
elif request.method == "GET":
|
524 |
-
if not prompt or not prompt.strip():
|
525 |
-
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
526 |
-
prompt = prompt.strip()
|
527 |
-
|
528 |
-
# Sanitize and encode the prompt
|
529 |
-
encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt')
|
530 |
-
|
531 |
-
# Construct the URL with the encoded prompt
|
532 |
-
base_url = image_endpoint.rstrip('/')
|
533 |
-
url = f"{base_url}/{encoded_prompt}"
|
534 |
-
|
535 |
-
# Prepare query parameters with validation
|
536 |
-
params = {}
|
537 |
-
if model and isinstance(model, str):
|
538 |
-
params['model'] = model
|
539 |
-
if seed is not None and isinstance(seed, int):
|
540 |
-
params['seed'] = seed
|
541 |
-
if width is not None and isinstance(width, int) and 64 <= width <= 2048:
|
542 |
-
params['width'] = width
|
543 |
-
if height is not None and isinstance(height, int) and 64 <= height <= 2048:
|
544 |
-
params['height'] = height
|
545 |
-
if nologo is not None:
|
546 |
-
params['nologo'] = str(nologo).lower()
|
547 |
-
if private is not None:
|
548 |
-
params['private'] = str(private).lower()
|
549 |
-
if enhance is not None:
|
550 |
-
params['enhance'] = str(enhance).lower()
|
551 |
-
|
552 |
-
try:
|
553 |
-
# Use the shared httpx client for connection pooling
|
554 |
-
client = get_async_client()
|
555 |
-
response = await client.get(url, params=params, follow_redirects=True)
|
556 |
-
|
557 |
-
# Check for various error conditions
|
558 |
-
if response.status_code != 200:
|
559 |
-
error_messages = {
|
560 |
-
404: "Image generation service not found",
|
561 |
-
400: "Invalid parameters provided to image service",
|
562 |
-
429: "Too many requests to image service",
|
563 |
-
}
|
564 |
-
detail = error_messages.get(response.status_code, f"Image generation failed with status code {response.status_code}")
|
565 |
-
raise HTTPException(status_code=response.status_code, detail=detail)
|
566 |
-
|
567 |
-
# Verify content type
|
568 |
-
content_type = response.headers.get('content-type', '')
|
569 |
-
if not content_type.startswith('image/'):
|
570 |
-
raise HTTPException(
|
571 |
-
status_code=500,
|
572 |
-
detail="Unexpected content type received from image service"
|
573 |
-
)
|
574 |
-
|
575 |
-
# Use larger chunks for streaming for better performance
|
576 |
-
async def stream_with_larger_chunks():
|
577 |
-
chunks = []
|
578 |
-
size = 0
|
579 |
-
async for chunk in response.aiter_bytes(chunk_size=16384): # Use 16KB chunks
|
580 |
-
chunks.append(chunk)
|
581 |
-
size += len(chunk)
|
582 |
-
|
583 |
-
if size >= 65536: # Yield every 64KB
|
584 |
-
yield b''.join(chunks)
|
585 |
-
chunks = []
|
586 |
-
size = 0
|
587 |
-
|
588 |
-
if chunks:
|
589 |
-
yield b''.join(chunks)
|
590 |
-
|
591 |
-
return StreamingResponse(
|
592 |
-
stream_with_larger_chunks(),
|
593 |
-
media_type=content_type,
|
594 |
-
headers={
|
595 |
-
'Cache-Control': 'no-cache, no-store, must-revalidate',
|
596 |
-
'Pragma': 'no-cache',
|
597 |
-
'Expires': '0'
|
598 |
-
}
|
599 |
-
)
|
600 |
-
|
601 |
-
except httpx.TimeoutException:
|
602 |
-
raise HTTPException(status_code=504, detail="Image generation request timed out")
|
603 |
-
except httpx.RequestError:
|
604 |
-
raise HTTPException(status_code=500, detail="Failed to contact image service")
|
605 |
-
except Exception:
|
606 |
-
raise HTTPException(status_code=500, detail="Unexpected error during image generation")
|
607 |
-
|
608 |
-
# Meme endpoint with optimized networking
|
609 |
-
@app.get("/meme")
|
610 |
-
async def get_meme():
|
611 |
-
try:
|
612 |
-
# Use the shared client for connection pooling
|
613 |
-
client = get_async_client()
|
614 |
-
response = await client.get("https://meme-api.com/gimme")
|
615 |
-
response_data = response.json()
|
616 |
-
|
617 |
-
meme_url = response_data.get("url")
|
618 |
-
if not meme_url:
|
619 |
-
raise HTTPException(status_code=404, detail="No meme found")
|
620 |
-
|
621 |
-
image_response = await client.get(meme_url, follow_redirects=True)
|
622 |
-
|
623 |
-
# Use larger chunks for streaming
|
624 |
-
async def stream_with_larger_chunks():
|
625 |
-
chunks = []
|
626 |
-
size = 0
|
627 |
-
async for chunk in image_response.aiter_bytes(chunk_size=16384):
|
628 |
-
chunks.append(chunk)
|
629 |
-
size += len(chunk)
|
630 |
-
|
631 |
-
if size >= 65536:
|
632 |
-
yield b''.join(chunks)
|
633 |
-
chunks = []
|
634 |
-
size = 0
|
635 |
-
|
636 |
-
if chunks:
|
637 |
-
yield b''.join(chunks)
|
638 |
-
|
639 |
-
return StreamingResponse(
|
640 |
-
stream_with_larger_chunks(),
|
641 |
-
media_type=image_response.headers.get("content-type", "image/png"),
|
642 |
-
headers={'Cache-Control': 'max-age=3600'} # Add caching
|
643 |
-
)
|
644 |
-
except Exception:
|
645 |
-
raise HTTPException(status_code=500, detail="Failed to retrieve meme")
|
646 |
-
|
647 |
# Cache usage statistics
|
648 |
@lru_cache(maxsize=10)
|
649 |
def get_usage_summary(days=7):
|
@@ -858,6 +684,45 @@ async def usage_page():
|
|
858 |
html_content = get_usage_page_html()
|
859 |
return HTMLResponse(content=html_content)
|
860 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
861 |
# Utility function for loading model IDs - optimized to run once at startup
|
862 |
def load_model_ids(json_file_path):
|
863 |
try:
|
@@ -877,8 +742,13 @@ async def startup_event():
|
|
877 |
|
878 |
# Add all pollinations models to available_model_ids
|
879 |
available_model_ids.extend(list(pollinations_models))
|
|
|
|
|
|
|
|
|
|
|
880 |
available_model_ids = list(set(available_model_ids)) # Remove duplicates
|
881 |
-
print(f"
|
882 |
|
883 |
# Preload scrapers
|
884 |
for _ in range(MAX_SCRAPERS):
|
@@ -900,8 +770,6 @@ async def startup_event():
|
|
900 |
missing_vars.append('MISTRAL_API')
|
901 |
if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
|
902 |
missing_vars.append('MISTRAL_KEY')
|
903 |
-
if not env_vars['image_endpoint']:
|
904 |
-
missing_vars.append('IMAGE_ENDPOINT')
|
905 |
|
906 |
if missing_vars:
|
907 |
print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
|
@@ -923,8 +791,7 @@ async def shutdown_event():
|
|
923 |
|
924 |
print("Server shutdown complete!")
|
925 |
|
926 |
-
#
|
927 |
-
|
928 |
# Health check endpoint
|
929 |
@app.get("/health")
|
930 |
async def health_check():
|
@@ -937,61 +804,25 @@ async def health_check():
|
|
937 |
missing_critical_vars.append('API_KEYS')
|
938 |
if not env_vars['secret_api_endpoint']:
|
939 |
missing_critical_vars.append('SECRET_API_ENDPOINT')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
940 |
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
status = "degraded"
|
947 |
-
|
948 |
-
return {
|
949 |
-
"status": status,
|
950 |
-
"timestamp": datetime.datetime.utcnow().isoformat(),
|
951 |
-
"uptime": time.time() - usage_tracker.start_time,
|
952 |
-
"models_loaded": models_loaded,
|
953 |
-
"model_count": len(available_model_ids),
|
954 |
-
"issues": {
|
955 |
-
"missing_env_vars": missing_critical_vars,
|
956 |
-
"models_available": models_loaded
|
957 |
-
}
|
958 |
}
|
|
|
959 |
|
960 |
-
# Error handlers
|
961 |
-
@app.exception_handler(HTTPException)
|
962 |
-
async def http_exception_handler(request, exc):
|
963 |
-
"""Format HTTP exceptions in a consistent way"""
|
964 |
-
return JSONResponse(
|
965 |
-
status_code=exc.status_code,
|
966 |
-
content={"error": exc.detail}
|
967 |
-
)
|
968 |
-
|
969 |
-
@app.exception_handler(Exception)
|
970 |
-
async def general_exception_handler(request, exc):
|
971 |
-
"""Handle unexpected exceptions gracefully"""
|
972 |
-
# Log the error for debugging
|
973 |
-
print(f"Unexpected error: {str(exc)}")
|
974 |
-
|
975 |
-
return JSONResponse(
|
976 |
-
status_code=500,
|
977 |
-
content={"error": "An unexpected error occurred. Please try again later."}
|
978 |
-
)
|
979 |
-
|
980 |
-
# Static files endpoint for serving CSS, JS, etc.
|
981 |
-
|
982 |
-
# Documentation
|
983 |
-
|
984 |
-
# Run the server when executed directly
|
985 |
if __name__ == "__main__":
|
986 |
import uvicorn
|
987 |
-
|
988 |
-
port = int(os.getenv("PORT", 7860))
|
989 |
-
|
990 |
-
print(f"Starting Lokiai AI server on port {port}")
|
991 |
-
uvicorn.run(
|
992 |
-
"main:app",
|
993 |
-
host="0.0.0.0",
|
994 |
-
port=port,
|
995 |
-
reload=False,
|
996 |
-
log_level="info"
|
997 |
-
)
|
|
|
19 |
import uvloop
|
20 |
from fastapi.middleware.gzip import GZipMiddleware
|
21 |
from starlette.middleware.cors import CORSMiddleware
|
22 |
+
import contextlib
|
23 |
|
24 |
# Enable uvloop for faster event loop
|
25 |
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
26 |
|
27 |
# Thread pool for CPU-bound operations
|
28 |
+
executor = ThreadPoolExecutor(max_workers=16) # Increased thread count for better parallelism
|
29 |
|
30 |
# Load environment variables once at startup
|
31 |
load_dotenv()
|
|
|
60 |
'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
|
61 |
'mistral_api': "https://api.mistral.ai",
|
62 |
'mistral_key': os.getenv('MISTRAL_KEY'),
|
|
|
63 |
'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
|
64 |
}
|
65 |
|
|
|
128 |
def get_async_client():
|
129 |
return httpx.AsyncClient(
|
130 |
timeout=60.0,
|
131 |
+
limits=httpx.Limits(max_keepalive_connections=50, max_connections=200) # Increased limits
|
132 |
)
|
133 |
|
134 |
# Create a cloudscraper pool
|
135 |
scraper_pool = []
|
136 |
+
MAX_SCRAPERS = 20 # Increased pool size
|
137 |
|
138 |
def get_scraper():
|
139 |
if not scraper_pool:
|
|
|
197 |
raise HTTPException(status_code=500, detail="Error loading available models")
|
198 |
return models_data
|
199 |
|
200 |
+
# Enhanced async streaming - now with real-time SSE support
|
201 |
async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
|
202 |
+
# Create a streaming response channel using asyncio.Queue
|
203 |
+
queue = asyncio.Queue()
|
204 |
|
205 |
+
async def _fetch_search_data():
|
206 |
+
try:
|
207 |
+
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
208 |
+
|
209 |
+
# Use the provided system prompt, or default to "Be Helpful and Friendly"
|
210 |
+
system_message = systemprompt or "Be Helpful and Friendly"
|
211 |
+
|
212 |
+
# Create the prompt history
|
213 |
+
prompt = [
|
214 |
+
{"role": "user", "content": query},
|
215 |
+
]
|
216 |
+
|
217 |
+
prompt.insert(0, {"content": system_message, "role": "system"})
|
218 |
+
|
219 |
+
# Prepare the payload for the API request
|
220 |
+
payload = {
|
221 |
+
"is_vscode_extension": True,
|
222 |
+
"message_history": prompt,
|
223 |
+
"requested_model": "Claude 3.7 Sonnet",
|
224 |
+
"user_input": prompt[-1]["content"],
|
225 |
+
}
|
226 |
+
|
227 |
+
# Get endpoint from environment
|
228 |
+
secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
|
229 |
+
if not secret_api_endpoint_3:
|
230 |
+
await queue.put({"error": "Search API endpoint not configured"})
|
231 |
+
return
|
232 |
+
|
233 |
+
# Use AsyncClient for better performance
|
234 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
235 |
+
async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response:
|
236 |
+
if response.status_code != 200:
|
237 |
+
await queue.put({"error": f"Search API returned status code {response.status_code}"})
|
238 |
+
return
|
239 |
+
|
240 |
+
# Process the streaming response in real-time
|
241 |
+
buffer = ""
|
242 |
+
async for line in response.aiter_lines():
|
243 |
+
if line.startswith("data: "):
|
244 |
+
try:
|
245 |
+
json_data = json.loads(line[6:])
|
246 |
+
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
247 |
+
|
248 |
+
if content.strip():
|
249 |
+
cleaned_response = {
|
250 |
+
"created": json_data.get("created"),
|
251 |
+
"id": json_data.get("id"),
|
252 |
+
"model": "searchgpt",
|
253 |
+
"object": "chat.completion",
|
254 |
+
"choices": [
|
255 |
+
{
|
256 |
+
"message": {
|
257 |
+
"content": content
|
258 |
+
}
|
259 |
+
}
|
260 |
+
]
|
261 |
}
|
262 |
+
|
263 |
+
# Send to queue immediately for streaming
|
264 |
+
await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content})
|
265 |
+
except json.JSONDecodeError:
|
266 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
+
# Signal completion
|
269 |
+
await queue.put(None)
|
270 |
+
|
271 |
+
except Exception as e:
|
272 |
+
await queue.put({"error": str(e)})
|
273 |
+
await queue.put(None)
|
274 |
+
|
275 |
+
# Start the fetch process
|
276 |
+
asyncio.create_task(_fetch_search_data())
|
277 |
|
278 |
+
# Return the queue for consumption
|
279 |
+
return queue
|
280 |
|
281 |
# Cache for frequently accessed static files
|
282 |
@lru_cache(maxsize=10)
|
|
|
317 |
async def return_models():
|
318 |
return await get_models()
|
319 |
|
320 |
+
# Search routes with enhanced real-time streaming
|
321 |
@app.get("/searchgpt")
|
322 |
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
|
323 |
if not q:
|
|
|
325 |
|
326 |
usage_tracker.record_request(endpoint="/searchgpt")
|
327 |
|
328 |
+
queue = await generate_search_async(q, systemprompt=systemprompt, stream=True)
|
329 |
|
330 |
if stream:
|
331 |
async def stream_generator():
|
332 |
+
collected_text = ""
|
333 |
+
while True:
|
334 |
+
item = await queue.get()
|
335 |
+
if item is None:
|
336 |
+
break
|
337 |
+
|
338 |
+
if "error" in item:
|
339 |
+
yield f"data: {json.dumps({'error': item['error']})}\n\n"
|
340 |
+
break
|
341 |
+
|
342 |
+
if "data" in item:
|
343 |
+
yield item["data"]
|
344 |
+
collected_text += item.get("text", "")
|
345 |
|
346 |
return StreamingResponse(
|
347 |
stream_generator(),
|
348 |
media_type="text/event-stream"
|
349 |
)
|
350 |
else:
|
351 |
+
# For non-streaming, collect all text and return at once
|
352 |
+
collected_text = ""
|
353 |
+
while True:
|
354 |
+
item = await queue.get()
|
355 |
+
if item is None:
|
356 |
+
break
|
357 |
+
|
358 |
+
if "error" in item:
|
359 |
+
raise HTTPException(status_code=500, detail=item["error"])
|
360 |
+
|
361 |
+
collected_text += item.get("text", "")
|
362 |
+
|
363 |
+
return JSONResponse(content={"response": collected_text})
|
364 |
|
365 |
+
# Enhanced streaming with direct SSE pass-through for real-time responses
|
366 |
@app.post("/chat/completions")
|
367 |
@app.post("/api/v1/chat/completions")
|
368 |
async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
|
|
|
389 |
# Prepare payload
|
390 |
payload_dict = payload.dict()
|
391 |
payload_dict["model"] = model_to_use
|
392 |
+
|
393 |
+
# Ensure stream is True for real-time streaming (can be overridden by client)
|
394 |
+
stream_enabled = payload_dict.get("stream", True)
|
395 |
+
|
396 |
# Get environment variables
|
397 |
env_vars = get_env_vars()
|
398 |
|
|
|
412 |
endpoint = env_vars['secret_api_endpoint']
|
413 |
custom_headers = {}
|
414 |
|
415 |
+
print(f"Using endpoint: {endpoint} for model: {model_to_use}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
+
# Improved real-time streaming handler
|
418 |
+
async def real_time_stream_generator():
|
|
|
|
|
419 |
try:
|
420 |
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
421 |
+
async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers) as response:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
if response.status_code >= 400:
|
423 |
error_messages = {
|
424 |
422: "Unprocessable entity. Check your payload.",
|
|
|
427 |
404: "The requested resource was not found.",
|
428 |
}
|
429 |
detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
|
430 |
+
raise HTTPException(status_code=response.status_code, detail=detail)
|
431 |
|
432 |
+
# Stream the response in real-time with minimal buffering
|
433 |
+
async for line in response.aiter_lines():
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
if line:
|
435 |
+
# Yield immediately for faster streaming
|
436 |
+
yield line + "\n"
|
437 |
+
except httpx.TimeoutException:
|
438 |
+
raise HTTPException(status_code=504, detail="Request timed out")
|
439 |
+
except httpx.RequestError as e:
|
440 |
+
raise HTTPException(status_code=502, detail=f"Failed to connect to upstream API: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
except Exception as e:
|
442 |
if isinstance(e, HTTPException):
|
443 |
raise e
|
444 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
|
|
445 |
|
446 |
+
# Return streaming response with proper headers
|
447 |
+
if stream_enabled:
|
448 |
return StreamingResponse(
|
449 |
+
real_time_stream_generator(),
|
450 |
+
media_type="text/event-stream",
|
451 |
+
headers={
|
452 |
+
"Content-Type": "text/event-stream",
|
453 |
+
"Cache-Control": "no-cache",
|
454 |
+
"Connection": "keep-alive",
|
455 |
+
"X-Accel-Buffering": "no" # Disable proxy buffering for Nginx
|
456 |
+
}
|
457 |
)
|
458 |
+
else:
|
459 |
+
# For non-streaming requests, collect the entire response
|
460 |
+
response_content = []
|
461 |
+
async for chunk in real_time_stream_generator():
|
462 |
+
response_content.append(chunk)
|
463 |
+
|
464 |
+
return JSONResponse(content=json.loads(''.join(response_content)))
|
465 |
+
|
466 |
# Asynchronous logging function
|
467 |
async def log_request(request, model):
|
468 |
# Get minimal data for logging
|
|
|
470 |
ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
|
471 |
print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
|
472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
# Cache usage statistics
|
474 |
@lru_cache(maxsize=10)
|
475 |
def get_usage_summary(days=7):
|
|
|
684 |
html_content = get_usage_page_html()
|
685 |
return HTMLResponse(content=html_content)
|
686 |
|
687 |
+
# Meme endpoint with optimized networking
|
688 |
+
@app.get("/meme")
|
689 |
+
async def get_meme():
|
690 |
+
try:
|
691 |
+
# Use the shared client for connection pooling
|
692 |
+
client = get_async_client()
|
693 |
+
response = await client.get("https://meme-api.com/gimme")
|
694 |
+
response_data = response.json()
|
695 |
+
|
696 |
+
meme_url = response_data.get("url")
|
697 |
+
if not meme_url:
|
698 |
+
raise HTTPException(status_code=404, detail="No meme found")
|
699 |
+
|
700 |
+
image_response = await client.get(meme_url, follow_redirects=True)
|
701 |
+
|
702 |
+
# Use larger chunks for streaming
|
703 |
+
async def stream_with_larger_chunks():
|
704 |
+
chunks = []
|
705 |
+
size = 0
|
706 |
+
async for chunk in image_response.aiter_bytes(chunk_size=16384):
|
707 |
+
chunks.append(chunk)
|
708 |
+
size += len(chunk)
|
709 |
+
|
710 |
+
if size >= 65536:
|
711 |
+
yield b''.join(chunks)
|
712 |
+
chunks = []
|
713 |
+
size = 0
|
714 |
+
|
715 |
+
if chunks:
|
716 |
+
yield b''.join(chunks)
|
717 |
+
|
718 |
+
return StreamingResponse(
|
719 |
+
stream_with_larger_chunks(),
|
720 |
+
media_type=image_response.headers.get("content-type", "image/png"),
|
721 |
+
headers={'Cache-Control': 'max-age=3600'} # Add caching
|
722 |
+
)
|
723 |
+
except Exception:
|
724 |
+
raise HTTPException(status_code=500, detail="Failed to retrieve meme")
|
725 |
+
|
726 |
# Utility function for loading model IDs - optimized to run once at startup
|
727 |
def load_model_ids(json_file_path):
|
728 |
try:
|
|
|
742 |
|
743 |
# Add all pollinations models to available_model_ids
|
744 |
available_model_ids.extend(list(pollinations_models))
|
745 |
+
# Add alternate models to available_model_ids
|
746 |
+
available_model_ids.extend(list(alternate_models))
|
747 |
+
# Add mistral models to available_model_ids
|
748 |
+
available_model_ids.extend(list(mistral_models))
|
749 |
+
|
750 |
available_model_ids = list(set(available_model_ids)) # Remove duplicates
|
751 |
+
print(f"Total available models: {len(available_model_ids)}")
|
752 |
|
753 |
# Preload scrapers
|
754 |
for _ in range(MAX_SCRAPERS):
|
|
|
770 |
missing_vars.append('MISTRAL_API')
|
771 |
if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
|
772 |
missing_vars.append('MISTRAL_KEY')
|
|
|
|
|
773 |
|
774 |
if missing_vars:
|
775 |
print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
|
|
|
791 |
|
792 |
print("Server shutdown complete!")
|
793 |
|
794 |
+
# Health check endpoint
|
|
|
795 |
# Health check endpoint
|
796 |
@app.get("/health")
|
797 |
async def health_check():
|
|
|
804 |
missing_critical_vars.append('API_KEYS')
|
805 |
if not env_vars['secret_api_endpoint']:
|
806 |
missing_critical_vars.append('SECRET_API_ENDPOINT')
|
807 |
+
if not env_vars['secret_api_endpoint_2']:
|
808 |
+
missing_critical_vars.append('SECRET_API_ENDPOINT_2')
|
809 |
+
if not env_vars['secret_api_endpoint_3']:
|
810 |
+
missing_critical_vars.append('SECRET_API_ENDPOINT_3')
|
811 |
+
if not env_vars['secret_api_endpoint_4']:
|
812 |
+
missing_critical_vars.append('SECRET_API_ENDPOINT_4')
|
813 |
+
if not env_vars['mistral_api']:
|
814 |
+
missing_critical_vars.append('MISTRAL_API')
|
815 |
+
if not env_vars['mistral_key']:
|
816 |
+
missing_critical_vars.append('MISTRAL_KEY')
|
817 |
|
818 |
+
health_status = {
|
819 |
+
"status": "healthy" if not missing_critical_vars else "unhealthy",
|
820 |
+
"missing_env_vars": missing_critical_vars,
|
821 |
+
"server_status": server_status,
|
822 |
+
"message": "Everything's lit! 🚀" if not missing_critical_vars else "Uh oh, some env vars are missing. 😬"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
823 |
}
|
824 |
+
return JSONResponse(content=health_status)
|
825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
826 |
if __name__ == "__main__":
|
827 |
import uvicorn
|
828 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|