lokiai / main.py
ParthSadaria's picture
Update main.py
4986fe4 verified
raw
history blame
3.7 kB
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel
import requests
from functools import lru_cache
# Load environment variables from .env file
load_dotenv()
app = FastAPI()
# Create a session for reusing the HTTP connection
session = requests.Session()
# Get API keys and secret endpoint from environment variables
api_keys_str = os.getenv('API_KEYS')
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
# Check if the endpoint is set in the environment
if not secret_api_endpoint:
raise HTTPException(status_code=500, detail="Secret API endpoint is not configured in environment variables.")
class Payload(BaseModel):
model: str
messages: list
@app.get("/", response_class=HTMLResponse)
async def root():
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Loki.AI API</title>
<style>
body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; background-color: #121212; color: white; }
h1 { color: #4CAF50; }
a { color: #BB86FC; text-decoration: none; }
a:hover { text-decoration: underline; }
</style>
</head>
<body>
<h1>Welcome to Loki.AI API!</h1>
<p>Created by Parth Sadaria</p>
<p>Check out the GitHub for more projects:</p>
<a href="https://github.com/ParthSadaria" target="_blank">github.com/ParthSadaria</a>
</body>
</html>
"""
return HTMLResponse(content=html_content)
# Cache function with lru_cache
@lru_cache(maxsize=1)
def get_cached_models():
try:
response = session.get(f"{secret_api_endpoint}/api/v1/models", timeout=3)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
@app.get("/models")
async def get_models():
return get_cached_models()
@app.post("/v1/chat/completions")
async def get_completion(payload: Payload, request: Request):
api_key = request.headers.get("Authorization")
# Validate API key
if api_key not in valid_api_keys:
raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
# Prepare the payload for streaming
payload_dict = {**payload.dict(), "stream": True}
# Define a generator to stream the response
def stream_generator():
try:
with session.post(secret_api_endpoint, json=payload_dict, stream=True, timeout=15) as response:
response.raise_for_status()
for chunk in response.iter_content(chunk_size=1024):
if chunk: # Only yield non-empty chunks
yield chunk
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=500, detail=f"Streaming failed: {e}")
# Return the streaming response
return StreamingResponse(stream_generator(), media_type="application/json")
# Log the API endpoints
@app.on_event("startup")
async def startup_event():
print("API endpoints:")
print("GET /")
print("GET /models")
print("POST /v1/chat/completions")
# Run the server with Uvicorn using the 'main' module
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)