File size: 3,374 Bytes
6a03bb0
 
 
 
 
 
 
 
 
 
 
 
 
 
0273b85
 
6a03bb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286c8c5
6a03bb0
 
 
 
 
 
 
 
 
 
 
 
286c8c5
6a03bb0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import json
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import faiss
import gradio as gr
from gradio import mount_gradio_app

# ------------------- Config ------------------- #
DATA_PATH = "./dataset/pentagon_core.json"
EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2"
QA_MODEL = "./models/bart-large-cnn"
DEVICE = "cuda" if os.environ.get("USE_CUDA") == "1" else "cpu"

# ------------------- Load Models ------------------- #
embedder = SentenceTransformer(EMBEDDING_MODEL)
qa_model = pipeline("text2text-generation", model=QA_MODEL, device=0 if DEVICE == "cuda" else -1)

# ------------------- Load Dataset + Index ------------------- #
if os.path.exists(DATA_PATH):
    with open(DATA_PATH, "r") as f:
        knowledge_base = json.load(f)
else:
    knowledge_base = []

texts = [item["content"] for item in knowledge_base]
embeddings = embedder.encode(texts, convert_to_tensor=True)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.cpu().detach().numpy())

# ------------------- FastAPI App ------------------- #
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # For development
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --------- Upload Endpoint --------- #
class UploadData(BaseModel):
    content: str

@app.post("/upload/")
def upload_knowledge(data: UploadData):
    global knowledge_base, index

    knowledge_base.append({"content": data.content})
    with open(DATA_PATH, "w") as f:
        json.dump(knowledge_base, f, indent=2)

    new_embedding = embedder.encode([data.content], convert_to_numpy=True)
    index.add(new_embedding)

    return {"message": "Data uploaded and indexed."}

# --------- Ask Endpoint --------- #
@app.get("/ask/")
def ask(question: str, top_k: int = 3):
    question_embedding = embedder.encode([question], convert_to_numpy=True)
    distances, indices = index.search(question_embedding, top_k)
    context = " ".join([knowledge_base[i]["content"] for i in indices[0]])

    prompt = (
        f"Context: {context}\n\n"
        f"Answer the following question based only on the above context:\n"
        f"{question}\n\nAnswer:"
    )
    output = qa_model(prompt, max_length=256, do_sample=False)[0]["generated_text"]

    return {
        "question": question,
        "context_used": context,
        "answer": output.strip()
    }

# --------- Gradio UI --------- #
def gradio_upload(file):
    if file is None:
        return "No file selected."

    try:
        content = str(file)  # Works for NamedString
        import requests

        base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860")
        response = requests.post(f"{base_url}/upload/", json={"content": content})

        if response.status_code == 200:
            return "βœ… Data successfully uploaded and indexed!"
        else:
            return f"❌ Failed: {response.text}"
    except Exception as e:
        return f"❌ Error: {str(e)}"


gr_app = gr.Interface(
    fn=gradio_upload,
    inputs=gr.File(label="Upload .txt or .json file"),
    outputs="text",
    title="Upload Knowledge",
)

# Mount Gradio at /ui
app = mount_gradio_app(app, gr_app, path="/ui")