import os import json from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List from transformers import pipeline from sentence_transformers import SentenceTransformer import faiss import gradio as gr from gradio import mount_gradio_app # ------------------- Config ------------------- # DATA_PATH = "./dataset/pentagon_core.json" EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2" QA_MODEL = "./models/bart-large-cnn" DEVICE = "cuda" if os.environ.get("USE_CUDA") == "1" else "cpu" # ------------------- Load Models ------------------- # embedder = SentenceTransformer(EMBEDDING_MODEL) qa_model = pipeline("text2text-generation", model=QA_MODEL, device=0 if DEVICE == "cuda" else -1) # ------------------- Load Dataset + Index ------------------- # if os.path.exists(DATA_PATH): with open(DATA_PATH, "r") as f: knowledge_base = json.load(f) else: knowledge_base = [] texts = [item["content"] for item in knowledge_base] embeddings = embedder.encode(texts, convert_to_tensor=True) index = faiss.IndexFlatL2(embeddings.shape[1]) index.add(embeddings.cpu().detach().numpy()) # ------------------- FastAPI App ------------------- # app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # For development allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------- Upload Endpoint --------- # class UploadData(BaseModel): content: str @app.post("/upload/") def upload_knowledge(data: UploadData): global knowledge_base, index knowledge_base.append({"content": data.content}) with open(DATA_PATH, "w") as f: json.dump(knowledge_base, f, indent=2) new_embedding = embedder.encode([data.content], convert_to_numpy=True) index.add(new_embedding) return {"message": "Data uploaded and indexed."} # --------- Ask Endpoint --------- # @app.get("/ask/") def ask(question: str, top_k: int = 3): question_embedding = embedder.encode([question], convert_to_numpy=True) distances, indices = index.search(question_embedding, top_k) context = " ".join([knowledge_base[i]["content"] for i in indices[0]]) prompt = ( f"Context: {context}\n\n" f"Answer the following question based only on the above context:\n" f"{question}\n\nAnswer:" ) output = qa_model(prompt, max_length=256, do_sample=False)[0]["generated_text"] return { "question": question, "context_used": context, "answer": output.strip() } # --------- Gradio UI --------- # def gradio_upload(file): if file is None: return "No file selected." try: content = str(file) # Works for NamedString import requests base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860") response = requests.post(f"{base_url}/upload/", json={"content": content}) if response.status_code == 200: return "✅ Data successfully uploaded and indexed!" else: return f"❌ Failed: {response.text}" except Exception as e: return f"❌ Error: {str(e)}" gr_app = gr.Interface( fn=gradio_upload, inputs=gr.File(label="Upload .txt or .json file"), outputs="text", title="Upload Knowledge", ) # Mount Gradio at /ui app = mount_gradio_app(app, gr_app, path="/ui")