File size: 4,230 Bytes
4986fe4
d38c2eb
ef215d3
7c60ac5
378f2c3
8834a20
378f2c3
d38c2eb
4986fe4
378f2c3
 
b955cc1
4986fe4
 
 
b955cc1
4986fe4
b955cc1
 
 
 
 
 
378f2c3
 
 
 
b955cc1
378f2c3
7c60ac5
7ef5d89
7c60ac5
 
 
 
 
 
 
 
4986fe4
7c60ac5
4986fe4
7c60ac5
 
 
 
 
 
e2e24f9
7c60ac5
 
 
 
 
 
7ef5d89
a68045e
8834a20
 
e2e24f9
8834a20
 
 
 
7ef5d89
ac4bad0
3109050
 
 
ac4bad0
ef215d3
 
378f2c3
b955cc1
378f2c3
045bd95
378f2c3
b955cc1
 
 
 
 
045bd95
3109050
8834a20
 
b955cc1
8834a20
114ca84
 
3109050
 
 
 
 
 
 
 
 
378f2c3
7ef5d89
 
 
 
 
045bd95
6a84e5c
 
 
8e4491b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel
import httpx

load_dotenv()

app = FastAPI()

# Get API keys and secret endpoint from environment variables
api_keys_str = os.getenv('API_KEYS')
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')

# Validate if the main secret API endpoint is set
if not secret_api_endpoint or not secret_api_endpoint_2:
    raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")

# Define models that should use the secondary endpoint
alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}

class Payload(BaseModel):
    model: str
    messages: list
    stream: bool

@app.get("/", response_class=HTMLResponse)
async def root():
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>Loki.AI API</title>
        <style>
            body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; background-color: #121212; color: white; }
            h1 { color: #4CAF50; }
            a { color: #BB86FC; text-decoration: none; }
            a:hover { text-decoration: underline; }
        </style>
    </head>
    <body>
        <h1>Welcome to Loki.AI API!</h1>
        <p>Created by Parth Sadaria</p>
        <p>Go to /models for more info</p>
        <p>Check out the GitHub for more projects:</p>
        <a href="https://github.com/ParthSadaria" target="_blank">github.com/ParthSadaria</a>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

async def get_models():
    async with httpx.AsyncClient() as client:
        try:
            response = await client.get(f"{secret_api_endpoint}/v1/models", timeout=3)
            response.raise_for_status()
            return response.json()
        except httpx.RequestError as e:
            raise HTTPException(status_code=500, detail=f"Request failed: {e}")

@app.get("/models")
async def fetch_models():
    return await get_models()

@app.post("/chat/completions")
async def get_completion(payload: Payload, request: Request):
    api_key = request.headers.get("Authorization")
    
    # Validate API key
    if api_key not in valid_api_keys:
        raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
    
    # Determine which endpoint to use based on the model
    endpoint = secret_api_endpoint_2 if payload.model in alternate_models else secret_api_endpoint

    # Use the payload directly as it includes stream and other user data
    payload_dict = payload.dict()

    async def stream_generator(payload_dict):
        async with httpx.AsyncClient() as client:
            try:
                async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
                    response.raise_for_status()
                    async for line in response.aiter_lines():
                        if line:
                            yield f"{line}\n"
            except httpx.HTTPStatusError as status_err:
                raise HTTPException(status_code=status_err.response.status_code, detail=f"HTTP error: {status_err}")
            except httpx.RequestError as req_err:
                raise HTTPException(status_code=500, detail=f"Streaming failed: {req_err}")
            except Exception as e:
                raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
    
    return StreamingResponse(stream_generator(payload_dict), media_type="application/json")

@app.on_event("startup")
async def startup_event():
    print("API endpoints:")
    print("GET /")
    print("GET /models")
    print("POST /v1/chat/completions")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)