ParthSadaria commited on
Commit
8834a20
·
verified ·
1 Parent(s): 83ca19d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -22
main.py CHANGED
@@ -3,7 +3,7 @@ from dotenv import load_dotenv
3
  from fastapi import FastAPI, HTTPException, Request
4
  from fastapi.responses import StreamingResponse, HTMLResponse
5
  from pydantic import BaseModel
6
- import requests
7
  from functools import lru_cache
8
 
9
  # Load environment variables from .env file
@@ -11,9 +11,6 @@ load_dotenv()
11
 
12
  app = FastAPI()
13
 
14
- # Create a session for reusing the HTTP connection
15
- session = requests.Session()
16
-
17
  # Get API keys and secret endpoint from environment variables
18
  api_keys_str = os.getenv('API_KEYS')
19
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
@@ -55,17 +52,18 @@ async def root():
55
 
56
  # Cache function with lru_cache
57
  @lru_cache(maxsize=1)
58
- def get_cached_models():
59
- try:
60
- response = session.get(f"{secret_api_endpoint}/api/v1/models", timeout=3)
61
- response.raise_for_status()
62
- return response.json()
63
- except requests.exceptions.RequestException as e:
64
- raise HTTPException(status_code=500, detail=f"Request failed: {e}")
 
65
 
66
  @app.get("/models")
67
  async def get_models():
68
- return get_cached_models()
69
 
70
  @app.post("/v1/chat/completions")
71
  async def get_completion(payload: Payload, request: Request):
@@ -78,16 +76,17 @@ async def get_completion(payload: Payload, request: Request):
78
  # Prepare the payload for streaming
79
  payload_dict = {**payload.dict(), "stream": True}
80
 
81
- # Define a generator to stream the response
82
- def stream_generator():
83
- try:
84
- with session.post(secret_api_endpoint, json=payload_dict, stream=True, timeout=15) as response:
85
- response.raise_for_status()
86
- for chunk in response.iter_content(chunk_size=1024):
87
- if chunk: # Only yield non-empty chunks
88
- yield chunk
89
- except requests.exceptions.RequestException as e:
90
- raise HTTPException(status_code=500, detail=f"Streaming failed: {e}")
 
91
 
92
  # Return the streaming response
93
  return StreamingResponse(stream_generator(), media_type="application/json")
 
3
  from fastapi import FastAPI, HTTPException, Request
4
  from fastapi.responses import StreamingResponse, HTMLResponse
5
  from pydantic import BaseModel
6
+ import httpx
7
  from functools import lru_cache
8
 
9
  # Load environment variables from .env file
 
11
 
12
  app = FastAPI()
13
 
 
 
 
14
  # Get API keys and secret endpoint from environment variables
15
  api_keys_str = os.getenv('API_KEYS')
16
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
 
52
 
53
  # Cache function with lru_cache
54
  @lru_cache(maxsize=1)
55
+ async def get_cached_models():
56
+ async with httpx.AsyncClient() as client:
57
+ try:
58
+ response = await client.get(f"{secret_api_endpoint}/api/v1/models", timeout=3)
59
+ response.raise_for_status()
60
+ return response.json()
61
+ except httpx.RequestError as e:
62
+ raise HTTPException(status_code=500, detail=f"Request failed: {e}")
63
 
64
  @app.get("/models")
65
  async def get_models():
66
+ return await get_cached_models()
67
 
68
  @app.post("/v1/chat/completions")
69
  async def get_completion(payload: Payload, request: Request):
 
76
  # Prepare the payload for streaming
77
  payload_dict = {**payload.dict(), "stream": True}
78
 
79
+ # Define an asynchronous generator to stream the response
80
+ async def stream_generator():
81
+ async with httpx.AsyncClient() as client:
82
+ try:
83
+ async with client.stream("POST", secret_api_endpoint, json=payload_dict, timeout=10) as response:
84
+ response.raise_for_status()
85
+ async for chunk in response.aiter_bytes(chunk_size=512): # Smaller chunks for faster response
86
+ if chunk:
87
+ yield chunk
88
+ except httpx.RequestError as e:
89
+ raise HTTPException(status_code=500, detail=f"Streaming failed: {e}")
90
 
91
  # Return the streaming response
92
  return StreamingResponse(stream_generator(), media_type="application/json")