import torch import json import os import faiss import numpy as np from pptx import Presentation from fastapi import FastAPI, UploadFile, File from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from sentence_transformers import SentenceTransformer from io import BytesIO import gradio as gr from gradio import mount_gradio_app # ---------------------------- # # CONFIGURATION # ---------------------------- # MODEL_NAME = "./models/facebook-opt-1.3b" SUMMARIZATION_MODEL = "./models/bart-large-cnn" EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2" DATA_DIRECTORY = "./dataset/" # ---------------------------- # # FUNCTION TO LOAD JSON FILES # ---------------------------- # def load_text_from_json(directory): text_data = set() # Use set to remove duplicates for filename in os.listdir(directory): if filename.endswith(".json"): with open(os.path.join(directory, filename), "r", encoding="utf-8") as file: data = json.load(file) for entry in data.get("data", []): question = entry.get("question", "").strip() answer = entry.get("answer", "").strip() if question and answer: text_data.add(f"Q: {question} A: {answer}") return list(text_data) # ---------------------------- # # FUNCTION TO LOAD POWERPOINT FILES # ---------------------------- # def extract_text_from_pptx(file_path): prs = Presentation(file_path) text_data = [] for slide in prs.slides: for shape in slide.shapes: if hasattr(shape, "text"): text_data.append(shape.text.strip()) return " ".join(text_data) def load_text_from_pptx(directory): text_data = set() for filename in os.listdir(directory): if filename.endswith(".pptx"): pptx_text = extract_text_from_pptx(os.path.join(directory, filename)) text_data.add(pptx_text) return list(text_data) # ---------------------------- # # LOAD ALL TEXT DATA # ---------------------------- # all_text = load_text_from_json(DATA_DIRECTORY) + load_text_from_pptx(DATA_DIRECTORY) # ---------------------------- # # CHUNK DATA PROPERLY # ---------------------------- # CHUNK_SIZE = 500 chunks = set() for text in all_text: sentences = text.split(". ") temp_chunk = "" for sentence in sentences: if len(temp_chunk) + len(sentence) < CHUNK_SIZE: temp_chunk += sentence + ". " else: chunks.add(temp_chunk.strip()) # Store chunk temp_chunk = sentence + ". " if temp_chunk: chunks.add(temp_chunk.strip()) # Store last chunk chunks = list(chunks) # Convert to list after deduplication # ---------------------------- # # EMBEDDING MODEL & FAISS VECTOR SEARCH # ---------------------------- # embedder = SentenceTransformer(EMBEDDING_MODEL, local_files_only=True) chunk_embeddings = embedder.encode(chunks, convert_to_numpy=True) # FAISS index index = faiss.IndexFlatL2(chunk_embeddings.shape[1]) index.add(chunk_embeddings) # ---------------------------- # # LOAD LLM MODEL # ---------------------------- # tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu" ) # Summarization pipeline summarizer = pipeline("summarization", model=SUMMARIZATION_MODEL) # ---------------------------- # # FASTAPI SETUP # ---------------------------- # app = FastAPI() def retrieve_relevant_text(question, top_k=3): question_embedding = embedder.encode([question], convert_to_numpy=True) _, idxs = index.search(question_embedding, top_k) retrieved_texts = [chunks[idx] for idx in idxs[0]] # Filter out chunks that contain the same question filtered_chunks = [text for text in retrieved_texts if question.lower() not in text.lower()] unique_texts = list(set(filtered_chunks)) context_text = " ".join(unique_texts) if len(context_text) > 1000: context_text = summarizer(context_text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"] return context_text @app.post("/upload/") async def upload_file(file: UploadFile = File(...)): global chunks, index, chunk_embeddings filename = file.filename content = await file.read() new_texts = [] try: # -------------------- # # Process .json files # -------------------- # if filename.endswith(".json"): data = json.loads(content) for entry in data.get("data", []): question = entry.get("question", "").strip() answer = entry.get("answer", "").strip() if question and answer: new_texts.append(f"Q: {question} A: {answer}") # -------------------- # # Process .pptx files # -------------------- # elif filename.endswith(".pptx"): prs = Presentation(BytesIO(content)) ppt_text = [] for slide in prs.slides: for shape in slide.shapes: if hasattr(shape, "text"): ppt_text.append(shape.text.strip()) new_texts.append(" ".join(ppt_text)) else: return {"error": "Unsupported file type. Use .json or .pptx"} # -------------------- # # Chunk and embed # -------------------- # new_chunks = set() for text in new_texts: sentences = text.split(". ") temp = "" for s in sentences: if len(temp) + len(s) < CHUNK_SIZE: temp += s + ". " else: new_chunks.add(temp.strip()) temp = s + ". " if temp: new_chunks.add(temp.strip()) # Remove existing chunks (dedup) new_chunks = list(set(new_chunks) - set(chunks)) if not new_chunks: return {"message": "No new unique chunks to add."} # Encode and update FAISS new_embeddings = embedder.encode(new_chunks, convert_to_numpy=True) index.add(new_embeddings) chunks.extend(new_chunks) return { "status": "success", "new_chunks_added": len(new_chunks), "total_chunks": len(chunks) } except Exception as e: return {"error": str(e)} @app.get("/faq/") def faq(question: str): """Answer user queries using retrieved knowledge.""" retrieved_text = retrieve_relevant_text(question) prompt = ( f"{retrieved_text.strip()}\n\n" f"Answer the following question based only on the above context:\n" f"{question.strip()}\n\n" f"Answer:" ) inputs = tokenizer(prompt, return_tensors="pt").to("cpu") with torch.no_grad(): output = model.generate( **inputs, max_length=512, repetition_penalty=1.3, no_repeat_ngram_size=4, temperature=0.7, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) raw_answer = tokenizer.decode(output[0], skip_special_tokens=True) # ---------------------------- # # POST-PROCESSING CLEANUP # ---------------------------- # cleaned_answer = raw_answer # Remove the prompt (everything before final 'Answer:' keyword) if "Answer:" in cleaned_answer: cleaned_answer = cleaned_answer.split("Answer:")[-1] # Remove repeated question (case-insensitive) question_lower = question.strip().lower() cleaned_answer = cleaned_answer.strip() if cleaned_answer.lower().startswith(question_lower): cleaned_answer = cleaned_answer[len(question):].strip() # Final touch: remove context/prompt tokens if they leaked for token in ["Context:", "Question:", "Answer:"]: cleaned_answer = cleaned_answer.replace(token, "").strip() return {"answer": cleaned_answer} # --------- Gradio UI --------- # def gradio_upload(file): if file is None: return "No file selected." try: import requests base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860") # file is a NamedString — open it by its name with open(file.name, "rb") as f: files = {"file": (os.path.basename(file.name), f)} response = requests.post(f"{base_url}/upload/", files=files) if response.status_code == 200: return "✅ Data successfully uploaded and indexed!" else: return f"❌ Failed: {response.text}" except Exception as e: return f"❌ Error: {str(e)}" gr_app = gr.Interface( fn=gradio_upload, inputs=gr.File(label="Upload .txt or .json file"), outputs="text", title="Upload Knowledge", ) # Mount Gradio at /ui app = mount_gradio_app(app, gr_app, path="/ui")