Spaces:
Running
Running
File size: 4,669 Bytes
42cd5f6 dca3ac6 42cd5f6 dca3ac6 42cd5f6 dca3ac6 42cd5f6 dca3ac6 42cd5f6 dca3ac6 42cd5f6 dca3ac6 42cd5f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from engine import run_from_api_engine
from ingest import run_from_api_ingest
import uvicorn
import warnings
from typing import Annotated
import json
import argparse
from dotenv import load_dotenv
import box
import yaml
import os
from rich import print
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Load environment variables from .env file
load_dotenv()
# Import config vars
with open('config.yml', 'r', encoding='utf8') as ymlfile:
cfg = box.Box(yaml.safe_load(ymlfile))
# add asyncio to the pipeline
app = FastAPI(openapi_url="/api/v1/sparrow-llm/openapi.json", docs_url="/api/v1/sparrow-llm/docs")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True
)
@app.get("/")
def root():
return {"message": "Sparrow LLM API"}
@app.post("/api/v1/sparrow-llm/inference", tags=["LLM Inference"])
async def inference(
fields: Annotated[str, Form()],
agent: Annotated[str, Form()],
types: Annotated[str, Form()] = None,
keywords: Annotated[str, Form()] = None,
index_name: Annotated[str, Form()] = None,
options: Annotated[str, Form()] = None,
group_by_rows: Annotated[bool, Form()] = True,
update_targets: Annotated[bool, Form()] = True,
debug: Annotated[bool, Form()] = False,
sparrow_key: Annotated[str, Form()] = None,
file: UploadFile = File(None)
):
protected_access = cfg.PROTECTED_ACCESS
if protected_access:
# Retrieve all environment variables that start with 'SPARROW_KEY_'
sparrow_keys = {key: value for key, value in os.environ.items() if key.startswith('SPARROW_KEY_')}
# Check if the provided sparrow_key matches any of the environment variables
if sparrow_key not in sparrow_keys.values():
raise HTTPException(status_code=403, detail="Protected access. Agent not allowed.")
query = 'retrieve ' + fields
query_types = types
query_inputs_arr = [param.strip() for param in fields.split(',')] if query_types else []
query_types_arr = [param.strip() for param in query_types.split(',')] if query_types else []
keywords_arr = [param.strip() for param in keywords.split(',')] if keywords is not None else None
options_arr = [param.strip() for param in options.split(',')] if options is not None else None
if not query_types:
query = fields
try:
answer = await run_from_api_engine(agent, query_inputs_arr, query_types_arr, keywords_arr, query, index_name,
options_arr, file, group_by_rows, update_targets, debug)
except ValueError as e:
raise HTTPException(status_code=418, detail=str(e))
try:
if isinstance(answer, (str, bytes, bytearray)):
answer = json.loads(answer)
except json.JSONDecodeError as e:
raise HTTPException(status_code=418, detail=answer)
if debug:
print(f"\nJSON response:\n")
print(answer)
return {"message": answer}
@app.post("/api/v1/sparrow-llm/ingest", tags=["LLM Ingest"])
async def ingest(
agent: Annotated[str, Form()],
index_name: Annotated[str, Form()],
sparrow_key: Annotated[str, Form()] = None,
file: UploadFile = File()
):
protected_access = cfg.PROTECTED_ACCESS
if protected_access:
# Retrieve all environment variables that start with 'SPARROW_KEY_'
sparrow_keys = {key: value for key, value in os.environ.items() if key.startswith('SPARROW_KEY_')}
# Check if the provided sparrow_key matches any of the environment variables
if sparrow_key not in sparrow_keys.values():
raise HTTPException(status_code=403, detail="Protected access. Agent not allowed.")
try:
answer = await run_from_api_ingest(agent, index_name, file, False)
except ValueError as e:
raise HTTPException(status_code=418, detail=str(e))
if isinstance(answer, (str, bytes, bytearray)):
answer = json.loads(answer)
return {"message": answer}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run FastAPI App")
parser.add_argument("-p", "--port", type=int, default=8000, help="Port to run the FastAPI app on")
args = parser.parse_args()
uvicorn.run("api:app", host="0.0.0.0", port=args.port, reload=True)
# run the app with: python api.py --port 8000
# go to http://127.0.0.1:8000/api/v1/sparrow-llm/docs to see the Swagger UI
|