File size: 3,786 Bytes
4986fe4
 
ef215d3
7c60ac5
378f2c3
8834a20
7c60ac5
378f2c3
4986fe4
 
 
378f2c3
 
4986fe4
 
 
 
 
 
 
 
378f2c3
 
 
 
 
7c60ac5
7ef5d89
7c60ac5
 
 
 
 
 
 
 
4986fe4
7c60ac5
4986fe4
7c60ac5
 
 
 
 
 
 
 
 
 
 
 
7ef5d89
7c60ac5
 
8834a20
 
 
 
 
 
 
 
7ef5d89
7c60ac5
 
8834a20
7c60ac5
045bd95
ef215d3
 
378f2c3
045bd95
378f2c3
045bd95
378f2c3
9d35223
 
045bd95
8834a20
 
 
 
 
 
 
 
 
 
 
045bd95
 
 
378f2c3
4116ca1
7ef5d89
 
 
 
 
045bd95
6a84e5c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel
import httpx
from functools import lru_cache

# Load environment variables from .env file
load_dotenv()

app = FastAPI()

# Get API keys and secret endpoint from environment variables
api_keys_str = os.getenv('API_KEYS')
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')

# Check if the endpoint is set in the environment
if not secret_api_endpoint:
    raise HTTPException(status_code=500, detail="Secret API endpoint is not configured in environment variables.")

class Payload(BaseModel):
    model: str
    messages: list

@app.get("/", response_class=HTMLResponse)
async def root():
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>Loki.AI API</title>
        <style>
            body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; background-color: #121212; color: white; }
            h1 { color: #4CAF50; }
            a { color: #BB86FC; text-decoration: none; }
            a:hover { text-decoration: underline; }
        </style>
    </head>
    <body>
        <h1>Welcome to Loki.AI API!</h1>
        <p>Created by Parth Sadaria</p>
        <p>Check out the GitHub for more projects:</p>
        <a href="https://github.com/ParthSadaria" target="_blank">github.com/ParthSadaria</a>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# Cache function with lru_cache
@lru_cache(maxsize=1)
async def get_cached_models():
    async with httpx.AsyncClient() as client:
        try:
            response = await client.get(f"{secret_api_endpoint}/api/v1/models", timeout=3)
            response.raise_for_status()
            return response.json()
        except httpx.RequestError as e:
            raise HTTPException(status_code=500, detail=f"Request failed: {e}")

@app.get("/models")
async def get_models():
    return await get_cached_models()

@app.post("/v1/chat/completions")
async def get_completion(payload: Payload, request: Request):
    api_key = request.headers.get("Authorization")
    
    # Validate API key
    if api_key not in valid_api_keys:
        raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
    
    # Prepare the payload for streaming
    payload_dict = {**payload.dict(), "stream": True}

    # Define an asynchronous generator to stream the response
    async def stream_generator():
        async with httpx.AsyncClient() as client:
            try:
                async with client.stream("POST", secret_api_endpoint, json=payload_dict, timeout=10) as response:
                    response.raise_for_status()
                    async for chunk in response.aiter_bytes(chunk_size=512):  # Smaller chunks for faster response
                        if chunk:
                            yield chunk
            except httpx.RequestError as e:
                raise HTTPException(status_code=500, detail=f"Streaming failed: {e}")

    # Return the streaming response
    return StreamingResponse(stream_generator(), media_type="application/json")

# Log the API endpoints
@app.on_event("startup")
async def startup_event():
    print("API endpoints:")
    print("GET /")
    print("GET /models")
    print("POST /v1/chat/completions")

# Run the server with Uvicorn using the 'main' module
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)