Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -3,7 +3,7 @@ from dotenv import load_dotenv
|
|
3 |
from fastapi import FastAPI, HTTPException, Request
|
4 |
from fastapi.responses import StreamingResponse, HTMLResponse
|
5 |
from pydantic import BaseModel
|
6 |
-
import
|
7 |
from functools import lru_cache
|
8 |
|
9 |
# Load environment variables from .env file
|
@@ -11,9 +11,6 @@ load_dotenv()
|
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
14 |
-
# Create a session for reusing the HTTP connection
|
15 |
-
session = requests.Session()
|
16 |
-
|
17 |
# Get API keys and secret endpoint from environment variables
|
18 |
api_keys_str = os.getenv('API_KEYS')
|
19 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
@@ -55,17 +52,18 @@ async def root():
|
|
55 |
|
56 |
# Cache function with lru_cache
|
57 |
@lru_cache(maxsize=1)
|
58 |
-
def get_cached_models():
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
65 |
|
66 |
@app.get("/models")
|
67 |
async def get_models():
|
68 |
-
return get_cached_models()
|
69 |
|
70 |
@app.post("/v1/chat/completions")
|
71 |
async def get_completion(payload: Payload, request: Request):
|
@@ -78,16 +76,17 @@ async def get_completion(payload: Payload, request: Request):
|
|
78 |
# Prepare the payload for streaming
|
79 |
payload_dict = {**payload.dict(), "stream": True}
|
80 |
|
81 |
-
# Define
|
82 |
-
def stream_generator():
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
91 |
|
92 |
# Return the streaming response
|
93 |
return StreamingResponse(stream_generator(), media_type="application/json")
|
|
|
3 |
from fastapi import FastAPI, HTTPException, Request
|
4 |
from fastapi.responses import StreamingResponse, HTMLResponse
|
5 |
from pydantic import BaseModel
|
6 |
+
import httpx
|
7 |
from functools import lru_cache
|
8 |
|
9 |
# Load environment variables from .env file
|
|
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
|
|
|
|
|
|
14 |
# Get API keys and secret endpoint from environment variables
|
15 |
api_keys_str = os.getenv('API_KEYS')
|
16 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
|
|
52 |
|
53 |
# Cache function with lru_cache
|
54 |
@lru_cache(maxsize=1)
|
55 |
+
async def get_cached_models():
|
56 |
+
async with httpx.AsyncClient() as client:
|
57 |
+
try:
|
58 |
+
response = await client.get(f"{secret_api_endpoint}/api/v1/models", timeout=3)
|
59 |
+
response.raise_for_status()
|
60 |
+
return response.json()
|
61 |
+
except httpx.RequestError as e:
|
62 |
+
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
|
63 |
|
64 |
@app.get("/models")
|
65 |
async def get_models():
|
66 |
+
return await get_cached_models()
|
67 |
|
68 |
@app.post("/v1/chat/completions")
|
69 |
async def get_completion(payload: Payload, request: Request):
|
|
|
76 |
# Prepare the payload for streaming
|
77 |
payload_dict = {**payload.dict(), "stream": True}
|
78 |
|
79 |
+
# Define an asynchronous generator to stream the response
|
80 |
+
async def stream_generator():
|
81 |
+
async with httpx.AsyncClient() as client:
|
82 |
+
try:
|
83 |
+
async with client.stream("POST", secret_api_endpoint, json=payload_dict, timeout=10) as response:
|
84 |
+
response.raise_for_status()
|
85 |
+
async for chunk in response.aiter_bytes(chunk_size=512): # Smaller chunks for faster response
|
86 |
+
if chunk:
|
87 |
+
yield chunk
|
88 |
+
except httpx.RequestError as e:
|
89 |
+
raise HTTPException(status_code=500, detail=f"Streaming failed: {e}")
|
90 |
|
91 |
# Return the streaming response
|
92 |
return StreamingResponse(stream_generator(), media_type="application/json")
|