Spaces:
Sleeping
Sleeping
File size: 3,374 Bytes
6a03bb0 0273b85 6a03bb0 286c8c5 6a03bb0 286c8c5 6a03bb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import json
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import faiss
import gradio as gr
from gradio import mount_gradio_app
# ------------------- Config ------------------- #
DATA_PATH = "./dataset/pentagon_core.json"
EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2"
QA_MODEL = "./models/bart-large-cnn"
DEVICE = "cuda" if os.environ.get("USE_CUDA") == "1" else "cpu"
# ------------------- Load Models ------------------- #
embedder = SentenceTransformer(EMBEDDING_MODEL)
qa_model = pipeline("text2text-generation", model=QA_MODEL, device=0 if DEVICE == "cuda" else -1)
# ------------------- Load Dataset + Index ------------------- #
if os.path.exists(DATA_PATH):
with open(DATA_PATH, "r") as f:
knowledge_base = json.load(f)
else:
knowledge_base = []
texts = [item["content"] for item in knowledge_base]
embeddings = embedder.encode(texts, convert_to_tensor=True)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.cpu().detach().numpy())
# ------------------- FastAPI App ------------------- #
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For development
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --------- Upload Endpoint --------- #
class UploadData(BaseModel):
content: str
@app.post("/upload/")
def upload_knowledge(data: UploadData):
global knowledge_base, index
knowledge_base.append({"content": data.content})
with open(DATA_PATH, "w") as f:
json.dump(knowledge_base, f, indent=2)
new_embedding = embedder.encode([data.content], convert_to_numpy=True)
index.add(new_embedding)
return {"message": "Data uploaded and indexed."}
# --------- Ask Endpoint --------- #
@app.get("/ask/")
def ask(question: str, top_k: int = 3):
question_embedding = embedder.encode([question], convert_to_numpy=True)
distances, indices = index.search(question_embedding, top_k)
context = " ".join([knowledge_base[i]["content"] for i in indices[0]])
prompt = (
f"Context: {context}\n\n"
f"Answer the following question based only on the above context:\n"
f"{question}\n\nAnswer:"
)
output = qa_model(prompt, max_length=256, do_sample=False)[0]["generated_text"]
return {
"question": question,
"context_used": context,
"answer": output.strip()
}
# --------- Gradio UI --------- #
def gradio_upload(file):
if file is None:
return "No file selected."
try:
content = str(file) # Works for NamedString
import requests
base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860")
response = requests.post(f"{base_url}/upload/", json={"content": content})
if response.status_code == 200:
return "β
Data successfully uploaded and indexed!"
else:
return f"β Failed: {response.text}"
except Exception as e:
return f"β Error: {str(e)}"
gr_app = gr.Interface(
fn=gradio_upload,
inputs=gr.File(label="Upload .txt or .json file"),
outputs="text",
title="Upload Knowledge",
)
# Mount Gradio at /ui
app = mount_gradio_app(app, gr_app, path="/ui")
|