Test-api / app.py
syedMohib44
d
286c8c5
raw
history blame
3.37 kB
import os
import json
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import faiss
import gradio as gr
from gradio import mount_gradio_app
# ------------------- Config ------------------- #
DATA_PATH = "./dataset/pentagon_core.json"
EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2"
QA_MODEL = "./models/bart-large-cnn"
DEVICE = "cuda" if os.environ.get("USE_CUDA") == "1" else "cpu"
# ------------------- Load Models ------------------- #
embedder = SentenceTransformer(EMBEDDING_MODEL)
qa_model = pipeline("text2text-generation", model=QA_MODEL, device=0 if DEVICE == "cuda" else -1)
# ------------------- Load Dataset + Index ------------------- #
if os.path.exists(DATA_PATH):
with open(DATA_PATH, "r") as f:
knowledge_base = json.load(f)
else:
knowledge_base = []
texts = [item["content"] for item in knowledge_base]
embeddings = embedder.encode(texts, convert_to_tensor=True)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.cpu().detach().numpy())
# ------------------- FastAPI App ------------------- #
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For development
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --------- Upload Endpoint --------- #
class UploadData(BaseModel):
content: str
@app.post("/upload/")
def upload_knowledge(data: UploadData):
global knowledge_base, index
knowledge_base.append({"content": data.content})
with open(DATA_PATH, "w") as f:
json.dump(knowledge_base, f, indent=2)
new_embedding = embedder.encode([data.content], convert_to_numpy=True)
index.add(new_embedding)
return {"message": "Data uploaded and indexed."}
# --------- Ask Endpoint --------- #
@app.get("/ask/")
def ask(question: str, top_k: int = 3):
question_embedding = embedder.encode([question], convert_to_numpy=True)
distances, indices = index.search(question_embedding, top_k)
context = " ".join([knowledge_base[i]["content"] for i in indices[0]])
prompt = (
f"Context: {context}\n\n"
f"Answer the following question based only on the above context:\n"
f"{question}\n\nAnswer:"
)
output = qa_model(prompt, max_length=256, do_sample=False)[0]["generated_text"]
return {
"question": question,
"context_used": context,
"answer": output.strip()
}
# --------- Gradio UI --------- #
def gradio_upload(file):
if file is None:
return "No file selected."
try:
content = str(file) # Works for NamedString
import requests
base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860")
response = requests.post(f"{base_url}/upload/", json={"content": content})
if response.status_code == 200:
return "βœ… Data successfully uploaded and indexed!"
else:
return f"❌ Failed: {response.text}"
except Exception as e:
return f"❌ Error: {str(e)}"
gr_app = gr.Interface(
fn=gradio_upload,
inputs=gr.File(label="Upload .txt or .json file"),
outputs="text",
title="Upload Knowledge",
)
# Mount Gradio at /ui
app = mount_gradio_app(app, gr_app, path="/ui")