Spaces:
Running
Running
import os | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.responses import StreamingResponse, HTMLResponse | |
from pydantic import BaseModel | |
import requests | |
from functools import lru_cache | |
# Load environment variables from .env file | |
load_dotenv() | |
app = FastAPI() | |
# Create a session for reusing the HTTP connection | |
session = requests.Session() | |
# Get API keys and secret endpoint from environment variables | |
api_keys_str = os.getenv('API_KEYS') | |
valid_api_keys = api_keys_str.split(',') if api_keys_str else [] | |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT') | |
# Check if the endpoint is set in the environment | |
if not secret_api_endpoint: | |
raise HTTPException(status_code=500, detail="Secret API endpoint is not configured in environment variables.") | |
class Payload(BaseModel): | |
model: str | |
messages: list | |
async def root(): | |
html_content = """ | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Loki.AI API</title> | |
<style> | |
body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; background-color: #121212; color: white; } | |
h1 { color: #4CAF50; } | |
a { color: #BB86FC; text-decoration: none; } | |
a:hover { text-decoration: underline; } | |
</style> | |
</head> | |
<body> | |
<h1>Welcome to Loki.AI API!</h1> | |
<p>Created by Parth Sadaria</p> | |
<p>Check out the GitHub for more projects:</p> | |
<a href="https://github.com/ParthSadaria" target="_blank">github.com/ParthSadaria</a> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html_content) | |
# Cache function with lru_cache | |
def get_cached_models(): | |
try: | |
response = session.get(f"{secret_api_endpoint}/api/v1/models", timeout=3) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
raise HTTPException(status_code=500, detail=f"Request failed: {e}") | |
async def get_models(): | |
return get_cached_models() | |
async def get_completion(payload: Payload, request: Request): | |
api_key = request.headers.get("Authorization") | |
# Validate API key | |
if api_key not in valid_api_keys: | |
raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)") | |
# Prepare the payload for streaming | |
payload_dict = {**payload.dict(), "stream": True} | |
# Define a generator to stream the response | |
def stream_generator(): | |
try: | |
with session.post(secret_api_endpoint, json=payload_dict, stream=True, timeout=15) as response: | |
response.raise_for_status() | |
for chunk in response.iter_content(chunk_size=1024): | |
if chunk: # Only yield non-empty chunks | |
yield chunk | |
except requests.exceptions.RequestException as e: | |
raise HTTPException(status_code=500, detail=f"Streaming failed: {e}") | |
# Return the streaming response | |
return StreamingResponse(stream_generator(), media_type="application/json") | |
# Log the API endpoints | |
async def startup_event(): | |
print("API endpoints:") | |
print("GET /") | |
print("GET /models") | |
print("POST /v1/chat/completions") | |
# Run the server with Uvicorn using the 'main' module | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |