File size: 8,969 Bytes
729a6b2
6a03bb0
729a6b2
6a03bb0
729a6b2
 
 
 
 
 
e90608d
 
 
6a03bb0
729a6b2
 
 
 
 
0273b85
729a6b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a03bb0
 
729a6b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a03bb0
 
729a6b2
 
6a03bb0
729a6b2
 
 
6a03bb0
729a6b2
 
 
 
 
 
 
 
 
 
 
6a03bb0
729a6b2
 
 
 
 
 
 
 
 
 
 
6a03bb0
729a6b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a03bb0
 
729a6b2
6a03bb0
729a6b2
 
6a03bb0
 
729a6b2
 
 
 
 
e37090a
729a6b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e90608d
6a03bb0
 
 
 
 
 
 
 
 
 
81f4dce
46c8e59
81f4dce
46c8e59
81f4dce
6a03bb0
 
 
 
 
81f4dce
6a03bb0
 
 
286c8c5
6a03bb0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import torch
import json
import os
import faiss
import numpy as np
from pptx import Presentation
from fastapi import FastAPI, UploadFile, File
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
from io import BytesIO
import gradio as gr
from gradio import mount_gradio_app


# ---------------------------- #
# CONFIGURATION
# ---------------------------- #
MODEL_NAME = "./models/facebook-opt-1.3b"
SUMMARIZATION_MODEL = "./models/bart-large-cnn"
EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2"
DATA_DIRECTORY = "./dataset/"

# ---------------------------- #
# FUNCTION TO LOAD JSON FILES
# ---------------------------- #
def load_text_from_json(directory):
    text_data = set()  # Use set to remove duplicates

    for filename in os.listdir(directory):
        if filename.endswith(".json"):
            with open(os.path.join(directory, filename), "r", encoding="utf-8") as file:
                data = json.load(file)
                for entry in data.get("data", []):
                    question = entry.get("question", "").strip()
                    answer = entry.get("answer", "").strip()
                    if question and answer:
                        text_data.add(f"Q: {question} A: {answer}")

    return list(text_data)

# ---------------------------- #
# FUNCTION TO LOAD POWERPOINT FILES
# ---------------------------- #
def extract_text_from_pptx(file_path):
    prs = Presentation(file_path)
    text_data = []

    for slide in prs.slides:
        for shape in slide.shapes:
            if hasattr(shape, "text"):
                text_data.append(shape.text.strip())

    return " ".join(text_data)

def load_text_from_pptx(directory):
    text_data = set()

    for filename in os.listdir(directory):
        if filename.endswith(".pptx"):
            pptx_text = extract_text_from_pptx(os.path.join(directory, filename))
            text_data.add(pptx_text)

    return list(text_data)

# ---------------------------- #
# LOAD ALL TEXT DATA
# ---------------------------- #
all_text = load_text_from_json(DATA_DIRECTORY) + load_text_from_pptx(DATA_DIRECTORY)

# ---------------------------- #
# CHUNK DATA PROPERLY
# ---------------------------- #
CHUNK_SIZE = 500
chunks = set()

for text in all_text:
    sentences = text.split(". ")
    temp_chunk = ""

    for sentence in sentences:
        if len(temp_chunk) + len(sentence) < CHUNK_SIZE:
            temp_chunk += sentence + ". "
        else:
            chunks.add(temp_chunk.strip())  # Store chunk
            temp_chunk = sentence + ". "

    if temp_chunk:
        chunks.add(temp_chunk.strip())  # Store last chunk

chunks = list(chunks)  # Convert to list after deduplication

# ---------------------------- #
# EMBEDDING MODEL & FAISS VECTOR SEARCH
# ---------------------------- #
embedder = SentenceTransformer(EMBEDDING_MODEL, local_files_only=True)
chunk_embeddings = embedder.encode(chunks, convert_to_numpy=True)

# FAISS index
index = faiss.IndexFlatL2(chunk_embeddings.shape[1])
index.add(chunk_embeddings)

# ---------------------------- #
# LOAD LLM MODEL
# ---------------------------- #
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
)

# Summarization pipeline
summarizer = pipeline("summarization", model=SUMMARIZATION_MODEL)

# ---------------------------- #
# FASTAPI SETUP
# ---------------------------- #
app = FastAPI()

def retrieve_relevant_text(question, top_k=3):
    question_embedding = embedder.encode([question], convert_to_numpy=True)
    _, idxs = index.search(question_embedding, top_k)

    retrieved_texts = [chunks[idx] for idx in idxs[0]]
    
    # Filter out chunks that contain the same question
    filtered_chunks = [text for text in retrieved_texts if question.lower() not in text.lower()]
    unique_texts = list(set(filtered_chunks))

    context_text = " ".join(unique_texts)
    if len(context_text) > 1000:
        context_text = summarizer(context_text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]

    return context_text


@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
    global chunks, index, chunk_embeddings

    filename = file.filename
    content = await file.read()
    new_texts = []

    try:
        # -------------------- #
        # Process .json files
        # -------------------- #
        if filename.endswith(".json"):
            data = json.loads(content)
            for entry in data.get("data", []):
                question = entry.get("question", "").strip()
                answer = entry.get("answer", "").strip()
                if question and answer:
                    new_texts.append(f"Q: {question} A: {answer}")

        # -------------------- #
        # Process .pptx files
        # -------------------- #
        elif filename.endswith(".pptx"):
            prs = Presentation(BytesIO(content))
            ppt_text = []
            for slide in prs.slides:
                for shape in slide.shapes:
                    if hasattr(shape, "text"):
                        ppt_text.append(shape.text.strip())
            new_texts.append(" ".join(ppt_text))

        else:
            return {"error": "Unsupported file type. Use .json or .pptx"}

        # -------------------- #
        # Chunk and embed
        # -------------------- #
        new_chunks = set()
        for text in new_texts:
            sentences = text.split(". ")
            temp = ""
            for s in sentences:
                if len(temp) + len(s) < CHUNK_SIZE:
                    temp += s + ". "
                else:
                    new_chunks.add(temp.strip())
                    temp = s + ". "
            if temp:
                new_chunks.add(temp.strip())

        # Remove existing chunks (dedup)
        new_chunks = list(set(new_chunks) - set(chunks))

        if not new_chunks:
            return {"message": "No new unique chunks to add."}

        # Encode and update FAISS
        new_embeddings = embedder.encode(new_chunks, convert_to_numpy=True)
        index.add(new_embeddings)
        chunks.extend(new_chunks)

        return {
            "status": "success",
            "new_chunks_added": len(new_chunks),
            "total_chunks": len(chunks)
        }

    except Exception as e:
        return {"error": str(e)}
    

@app.get("/faq/")
def faq(question: str):
    """Answer user queries using retrieved knowledge."""
    retrieved_text = retrieve_relevant_text(question)

    prompt = (
        f"{retrieved_text.strip()}\n\n"
        f"Answer the following question based only on the above context:\n"
        f"{question.strip()}\n\n"
        f"Answer:"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to("cpu")

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_length=512,
            repetition_penalty=1.3,
            no_repeat_ngram_size=4,
            temperature=0.7,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )

    raw_answer = tokenizer.decode(output[0], skip_special_tokens=True)

    # ---------------------------- #
    # POST-PROCESSING CLEANUP
    # ---------------------------- #
    cleaned_answer = raw_answer

    # Remove the prompt (everything before final 'Answer:' keyword)
    if "Answer:" in cleaned_answer:
        cleaned_answer = cleaned_answer.split("Answer:")[-1]

    # Remove repeated question (case-insensitive)
    question_lower = question.strip().lower()
    cleaned_answer = cleaned_answer.strip()
    if cleaned_answer.lower().startswith(question_lower):
        cleaned_answer = cleaned_answer[len(question):].strip()

    # Final touch: remove context/prompt tokens if they leaked
    for token in ["Context:", "Question:", "Answer:"]:
        cleaned_answer = cleaned_answer.replace(token, "").strip()

    return {"answer": cleaned_answer}


# --------- Gradio UI --------- #
def gradio_upload(file):
    if file is None:
        return "No file selected."

    try:
        import requests

        base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860")

        # file is a NamedString β€” open it by its name
        with open(file.name, "rb") as f:
            files = {"file": (os.path.basename(file.name), f)}
            response = requests.post(f"{base_url}/upload/", files=files)

        if response.status_code == 200:
            return "βœ… Data successfully uploaded and indexed!"
        else:
            return f"❌ Failed: {response.text}"

    except Exception as e:
        return f"❌ Error: {str(e)}"


gr_app = gr.Interface(
    fn=gradio_upload,
    inputs=gr.File(label="Upload .txt or .json file"),
    outputs="text",
    title="Upload Knowledge",
)

# Mount Gradio at /ui
app = mount_gradio_app(app, gr_app, path="/ui")