Spaces:
Sleeping
Sleeping
import os | |
import json | |
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import List | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import gradio as gr | |
from gradio import mount_gradio_app | |
# ------------------- Config ------------------- # | |
DATA_PATH = "./dataset/pentagon_core.json" | |
EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2" | |
QA_MODEL = "./models/bart-large-cnn" | |
DEVICE = "cuda" if os.environ.get("USE_CUDA") == "1" else "cpu" | |
# ------------------- Load Models ------------------- # | |
embedder = SentenceTransformer(EMBEDDING_MODEL) | |
qa_model = pipeline("text2text-generation", model=QA_MODEL, device=0 if DEVICE == "cuda" else -1) | |
# ------------------- Load Dataset + Index ------------------- # | |
if os.path.exists(DATA_PATH): | |
with open(DATA_PATH, "r") as f: | |
knowledge_base = json.load(f) | |
else: | |
knowledge_base = [] | |
texts = [item["content"] for item in knowledge_base] | |
embeddings = embedder.encode(texts, convert_to_tensor=True) | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings.cpu().detach().numpy()) | |
# ------------------- FastAPI App ------------------- # | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # For development | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --------- Upload Endpoint --------- # | |
class UploadData(BaseModel): | |
content: str | |
def upload_knowledge(data: UploadData): | |
global knowledge_base, index | |
knowledge_base.append({"content": data.content}) | |
with open(DATA_PATH, "w") as f: | |
json.dump(knowledge_base, f, indent=2) | |
new_embedding = embedder.encode([data.content], convert_to_numpy=True) | |
index.add(new_embedding) | |
return {"message": "Data uploaded and indexed."} | |
# --------- Ask Endpoint --------- # | |
def ask(question: str, top_k: int = 3): | |
question_embedding = embedder.encode([question], convert_to_numpy=True) | |
distances, indices = index.search(question_embedding, top_k) | |
context = " ".join([knowledge_base[i]["content"] for i in indices[0]]) | |
prompt = ( | |
f"Context: {context}\n\n" | |
f"Answer the following question based only on the above context:\n" | |
f"{question}\n\nAnswer:" | |
) | |
output = qa_model(prompt, max_length=256, do_sample=False)[0]["generated_text"] | |
return { | |
"question": question, | |
"context_used": context, | |
"answer": output.strip() | |
} | |
# --------- Gradio UI --------- # | |
def gradio_upload(file): | |
if file is None: | |
return "No file selected." | |
try: | |
content = str(file) # Works for NamedString | |
import requests | |
base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860") | |
response = requests.post(f"{base_url}/upload/", json={"content": content}) | |
if response.status_code == 200: | |
return "β Data successfully uploaded and indexed!" | |
else: | |
return f"β Failed: {response.text}" | |
except Exception as e: | |
return f"β Error: {str(e)}" | |
gr_app = gr.Interface( | |
fn=gradio_upload, | |
inputs=gr.File(label="Upload .txt or .json file"), | |
outputs="text", | |
title="Upload Knowledge", | |
) | |
# Mount Gradio at /ui | |
app = mount_gradio_app(app, gr_app, path="/ui") | |