sparrow-ml / api.py
katanaml's picture
Added sparrow key support
dca3ac6
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from engine import run_from_api_engine
from ingest import run_from_api_ingest
import uvicorn
import warnings
from typing import Annotated
import json
import argparse
from dotenv import load_dotenv
import box
import yaml
import os
from rich import print
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Load environment variables from .env file
load_dotenv()
# Import config vars
with open('config.yml', 'r', encoding='utf8') as ymlfile:
cfg = box.Box(yaml.safe_load(ymlfile))
# add asyncio to the pipeline
app = FastAPI(openapi_url="/api/v1/sparrow-llm/openapi.json", docs_url="/api/v1/sparrow-llm/docs")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True
)
@app.get("/")
def root():
return {"message": "Sparrow LLM API"}
@app.post("/api/v1/sparrow-llm/inference", tags=["LLM Inference"])
async def inference(
fields: Annotated[str, Form()],
agent: Annotated[str, Form()],
types: Annotated[str, Form()] = None,
keywords: Annotated[str, Form()] = None,
index_name: Annotated[str, Form()] = None,
options: Annotated[str, Form()] = None,
group_by_rows: Annotated[bool, Form()] = True,
update_targets: Annotated[bool, Form()] = True,
debug: Annotated[bool, Form()] = False,
sparrow_key: Annotated[str, Form()] = None,
file: UploadFile = File(None)
):
protected_access = cfg.PROTECTED_ACCESS
if protected_access:
# Retrieve all environment variables that start with 'SPARROW_KEY_'
sparrow_keys = {key: value for key, value in os.environ.items() if key.startswith('SPARROW_KEY_')}
# Check if the provided sparrow_key matches any of the environment variables
if sparrow_key not in sparrow_keys.values():
raise HTTPException(status_code=403, detail="Protected access. Agent not allowed.")
query = 'retrieve ' + fields
query_types = types
query_inputs_arr = [param.strip() for param in fields.split(',')] if query_types else []
query_types_arr = [param.strip() for param in query_types.split(',')] if query_types else []
keywords_arr = [param.strip() for param in keywords.split(',')] if keywords is not None else None
options_arr = [param.strip() for param in options.split(',')] if options is not None else None
if not query_types:
query = fields
try:
answer = await run_from_api_engine(agent, query_inputs_arr, query_types_arr, keywords_arr, query, index_name,
options_arr, file, group_by_rows, update_targets, debug)
except ValueError as e:
raise HTTPException(status_code=418, detail=str(e))
try:
if isinstance(answer, (str, bytes, bytearray)):
answer = json.loads(answer)
except json.JSONDecodeError as e:
raise HTTPException(status_code=418, detail=answer)
if debug:
print(f"\nJSON response:\n")
print(answer)
return {"message": answer}
@app.post("/api/v1/sparrow-llm/ingest", tags=["LLM Ingest"])
async def ingest(
agent: Annotated[str, Form()],
index_name: Annotated[str, Form()],
sparrow_key: Annotated[str, Form()] = None,
file: UploadFile = File()
):
protected_access = cfg.PROTECTED_ACCESS
if protected_access:
# Retrieve all environment variables that start with 'SPARROW_KEY_'
sparrow_keys = {key: value for key, value in os.environ.items() if key.startswith('SPARROW_KEY_')}
# Check if the provided sparrow_key matches any of the environment variables
if sparrow_key not in sparrow_keys.values():
raise HTTPException(status_code=403, detail="Protected access. Agent not allowed.")
try:
answer = await run_from_api_ingest(agent, index_name, file, False)
except ValueError as e:
raise HTTPException(status_code=418, detail=str(e))
if isinstance(answer, (str, bytes, bytearray)):
answer = json.loads(answer)
return {"message": answer}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run FastAPI App")
parser.add_argument("-p", "--port", type=int, default=8000, help="Port to run the FastAPI app on")
args = parser.parse_args()
uvicorn.run("api:app", host="0.0.0.0", port=args.port, reload=True)
# run the app with: python api.py --port 8000
# go to http://127.0.0.1:8000/api/v1/sparrow-llm/docs to see the Swagger UI