File size: 4,669 Bytes
42cd5f6
 
 
 
 
 
 
 
 
 
dca3ac6
 
42cd5f6
 
 
 
 
 
 
 
 
 
dca3ac6
 
 
42cd5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dca3ac6
42cd5f6
 
dca3ac6
 
 
 
 
 
 
 
 
 
42cd5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dca3ac6
42cd5f6
 
dca3ac6
 
 
 
 
 
 
 
 
42cd5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from engine import run_from_api_engine
from ingest import run_from_api_ingest
import uvicorn
import warnings
from typing import Annotated
import json
import argparse
from dotenv import load_dotenv
import box
import yaml
import os
from rich import print


warnings.filterwarnings("ignore", category=DeprecationWarning)


# Load environment variables from .env file
load_dotenv()

# Import config vars
with open('config.yml', 'r', encoding='utf8') as ymlfile:
    cfg = box.Box(yaml.safe_load(ymlfile))

# add asyncio to the pipeline

app = FastAPI(openapi_url="/api/v1/sparrow-llm/openapi.json", docs_url="/api/v1/sparrow-llm/docs")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
    allow_credentials=True
)


@app.get("/")
def root():
    return {"message": "Sparrow LLM API"}


@app.post("/api/v1/sparrow-llm/inference", tags=["LLM Inference"])
async def inference(
        fields: Annotated[str, Form()],
        agent: Annotated[str, Form()],
        types: Annotated[str, Form()] = None,
        keywords: Annotated[str, Form()] = None,
        index_name: Annotated[str, Form()] = None,
        options: Annotated[str, Form()] = None,
        group_by_rows: Annotated[bool, Form()] = True,
        update_targets: Annotated[bool, Form()] = True,
        debug: Annotated[bool, Form()] = False,
        sparrow_key: Annotated[str, Form()] = None,
        file: UploadFile = File(None)
        ):

    protected_access = cfg.PROTECTED_ACCESS
    if protected_access:
        # Retrieve all environment variables that start with 'SPARROW_KEY_'
        sparrow_keys = {key: value for key, value in os.environ.items() if key.startswith('SPARROW_KEY_')}

        # Check if the provided sparrow_key matches any of the environment variables
        if sparrow_key not in sparrow_keys.values():
            raise HTTPException(status_code=403, detail="Protected access. Agent not allowed.")

    query = 'retrieve ' + fields
    query_types = types

    query_inputs_arr = [param.strip() for param in fields.split(',')] if query_types else []
    query_types_arr = [param.strip() for param in query_types.split(',')] if query_types else []
    keywords_arr = [param.strip() for param in keywords.split(',')] if keywords is not None else None
    options_arr = [param.strip() for param in options.split(',')] if options is not None else None

    if not query_types:
        query = fields

    try:
        answer = await run_from_api_engine(agent, query_inputs_arr, query_types_arr, keywords_arr, query, index_name,
                                           options_arr, file, group_by_rows, update_targets, debug)
    except ValueError as e:
        raise HTTPException(status_code=418, detail=str(e))

    try:
        if isinstance(answer, (str, bytes, bytearray)):
            answer = json.loads(answer)
    except json.JSONDecodeError as e:
        raise HTTPException(status_code=418, detail=answer)

    if debug:
        print(f"\nJSON response:\n")
        print(answer)

    return {"message": answer}


@app.post("/api/v1/sparrow-llm/ingest", tags=["LLM Ingest"])
async def ingest(
        agent: Annotated[str, Form()],
        index_name: Annotated[str, Form()],
        sparrow_key: Annotated[str, Form()] = None,
        file: UploadFile = File()
        ):
    protected_access = cfg.PROTECTED_ACCESS
    if protected_access:
        # Retrieve all environment variables that start with 'SPARROW_KEY_'
        sparrow_keys = {key: value for key, value in os.environ.items() if key.startswith('SPARROW_KEY_')}

        # Check if the provided sparrow_key matches any of the environment variables
        if sparrow_key not in sparrow_keys.values():
            raise HTTPException(status_code=403, detail="Protected access. Agent not allowed.")

    try:
        answer = await run_from_api_ingest(agent, index_name, file, False)
    except ValueError as e:
        raise HTTPException(status_code=418, detail=str(e))

    if isinstance(answer, (str, bytes, bytearray)):
        answer = json.loads(answer)

    return {"message": answer}


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run FastAPI App")
    parser.add_argument("-p", "--port", type=int, default=8000, help="Port to run the FastAPI app on")
    args = parser.parse_args()

    uvicorn.run("api:app", host="0.0.0.0", port=args.port, reload=True)

# run the app with: python api.py --port 8000
# go to http://127.0.0.1:8000/api/v1/sparrow-llm/docs to see the Swagger UI