Spaces:
Sleeping
Sleeping
File size: 8,969 Bytes
729a6b2 6a03bb0 729a6b2 6a03bb0 729a6b2 e90608d 6a03bb0 729a6b2 0273b85 729a6b2 6a03bb0 729a6b2 6a03bb0 729a6b2 6a03bb0 729a6b2 6a03bb0 729a6b2 6a03bb0 729a6b2 6a03bb0 729a6b2 6a03bb0 729a6b2 6a03bb0 729a6b2 6a03bb0 729a6b2 e37090a 729a6b2 e90608d 6a03bb0 81f4dce 46c8e59 81f4dce 46c8e59 81f4dce 6a03bb0 81f4dce 6a03bb0 286c8c5 6a03bb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
import torch
import json
import os
import faiss
import numpy as np
from pptx import Presentation
from fastapi import FastAPI, UploadFile, File
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
from io import BytesIO
import gradio as gr
from gradio import mount_gradio_app
# ---------------------------- #
# CONFIGURATION
# ---------------------------- #
MODEL_NAME = "./models/facebook-opt-1.3b"
SUMMARIZATION_MODEL = "./models/bart-large-cnn"
EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2"
DATA_DIRECTORY = "./dataset/"
# ---------------------------- #
# FUNCTION TO LOAD JSON FILES
# ---------------------------- #
def load_text_from_json(directory):
text_data = set() # Use set to remove duplicates
for filename in os.listdir(directory):
if filename.endswith(".json"):
with open(os.path.join(directory, filename), "r", encoding="utf-8") as file:
data = json.load(file)
for entry in data.get("data", []):
question = entry.get("question", "").strip()
answer = entry.get("answer", "").strip()
if question and answer:
text_data.add(f"Q: {question} A: {answer}")
return list(text_data)
# ---------------------------- #
# FUNCTION TO LOAD POWERPOINT FILES
# ---------------------------- #
def extract_text_from_pptx(file_path):
prs = Presentation(file_path)
text_data = []
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text_data.append(shape.text.strip())
return " ".join(text_data)
def load_text_from_pptx(directory):
text_data = set()
for filename in os.listdir(directory):
if filename.endswith(".pptx"):
pptx_text = extract_text_from_pptx(os.path.join(directory, filename))
text_data.add(pptx_text)
return list(text_data)
# ---------------------------- #
# LOAD ALL TEXT DATA
# ---------------------------- #
all_text = load_text_from_json(DATA_DIRECTORY) + load_text_from_pptx(DATA_DIRECTORY)
# ---------------------------- #
# CHUNK DATA PROPERLY
# ---------------------------- #
CHUNK_SIZE = 500
chunks = set()
for text in all_text:
sentences = text.split(". ")
temp_chunk = ""
for sentence in sentences:
if len(temp_chunk) + len(sentence) < CHUNK_SIZE:
temp_chunk += sentence + ". "
else:
chunks.add(temp_chunk.strip()) # Store chunk
temp_chunk = sentence + ". "
if temp_chunk:
chunks.add(temp_chunk.strip()) # Store last chunk
chunks = list(chunks) # Convert to list after deduplication
# ---------------------------- #
# EMBEDDING MODEL & FAISS VECTOR SEARCH
# ---------------------------- #
embedder = SentenceTransformer(EMBEDDING_MODEL, local_files_only=True)
chunk_embeddings = embedder.encode(chunks, convert_to_numpy=True)
# FAISS index
index = faiss.IndexFlatL2(chunk_embeddings.shape[1])
index.add(chunk_embeddings)
# ---------------------------- #
# LOAD LLM MODEL
# ---------------------------- #
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
)
# Summarization pipeline
summarizer = pipeline("summarization", model=SUMMARIZATION_MODEL)
# ---------------------------- #
# FASTAPI SETUP
# ---------------------------- #
app = FastAPI()
def retrieve_relevant_text(question, top_k=3):
question_embedding = embedder.encode([question], convert_to_numpy=True)
_, idxs = index.search(question_embedding, top_k)
retrieved_texts = [chunks[idx] for idx in idxs[0]]
# Filter out chunks that contain the same question
filtered_chunks = [text for text in retrieved_texts if question.lower() not in text.lower()]
unique_texts = list(set(filtered_chunks))
context_text = " ".join(unique_texts)
if len(context_text) > 1000:
context_text = summarizer(context_text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
return context_text
@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
global chunks, index, chunk_embeddings
filename = file.filename
content = await file.read()
new_texts = []
try:
# -------------------- #
# Process .json files
# -------------------- #
if filename.endswith(".json"):
data = json.loads(content)
for entry in data.get("data", []):
question = entry.get("question", "").strip()
answer = entry.get("answer", "").strip()
if question and answer:
new_texts.append(f"Q: {question} A: {answer}")
# -------------------- #
# Process .pptx files
# -------------------- #
elif filename.endswith(".pptx"):
prs = Presentation(BytesIO(content))
ppt_text = []
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
ppt_text.append(shape.text.strip())
new_texts.append(" ".join(ppt_text))
else:
return {"error": "Unsupported file type. Use .json or .pptx"}
# -------------------- #
# Chunk and embed
# -------------------- #
new_chunks = set()
for text in new_texts:
sentences = text.split(". ")
temp = ""
for s in sentences:
if len(temp) + len(s) < CHUNK_SIZE:
temp += s + ". "
else:
new_chunks.add(temp.strip())
temp = s + ". "
if temp:
new_chunks.add(temp.strip())
# Remove existing chunks (dedup)
new_chunks = list(set(new_chunks) - set(chunks))
if not new_chunks:
return {"message": "No new unique chunks to add."}
# Encode and update FAISS
new_embeddings = embedder.encode(new_chunks, convert_to_numpy=True)
index.add(new_embeddings)
chunks.extend(new_chunks)
return {
"status": "success",
"new_chunks_added": len(new_chunks),
"total_chunks": len(chunks)
}
except Exception as e:
return {"error": str(e)}
@app.get("/faq/")
def faq(question: str):
"""Answer user queries using retrieved knowledge."""
retrieved_text = retrieve_relevant_text(question)
prompt = (
f"{retrieved_text.strip()}\n\n"
f"Answer the following question based only on the above context:\n"
f"{question.strip()}\n\n"
f"Answer:"
)
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
with torch.no_grad():
output = model.generate(
**inputs,
max_length=512,
repetition_penalty=1.3,
no_repeat_ngram_size=4,
temperature=0.7,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
raw_answer = tokenizer.decode(output[0], skip_special_tokens=True)
# ---------------------------- #
# POST-PROCESSING CLEANUP
# ---------------------------- #
cleaned_answer = raw_answer
# Remove the prompt (everything before final 'Answer:' keyword)
if "Answer:" in cleaned_answer:
cleaned_answer = cleaned_answer.split("Answer:")[-1]
# Remove repeated question (case-insensitive)
question_lower = question.strip().lower()
cleaned_answer = cleaned_answer.strip()
if cleaned_answer.lower().startswith(question_lower):
cleaned_answer = cleaned_answer[len(question):].strip()
# Final touch: remove context/prompt tokens if they leaked
for token in ["Context:", "Question:", "Answer:"]:
cleaned_answer = cleaned_answer.replace(token, "").strip()
return {"answer": cleaned_answer}
# --------- Gradio UI --------- #
def gradio_upload(file):
if file is None:
return "No file selected."
try:
import requests
base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860")
# file is a NamedString β open it by its name
with open(file.name, "rb") as f:
files = {"file": (os.path.basename(file.name), f)}
response = requests.post(f"{base_url}/upload/", files=files)
if response.status_code == 200:
return "β
Data successfully uploaded and indexed!"
else:
return f"β Failed: {response.text}"
except Exception as e:
return f"β Error: {str(e)}"
gr_app = gr.Interface(
fn=gradio_upload,
inputs=gr.File(label="Upload .txt or .json file"),
outputs="text",
title="Upload Knowledge",
)
# Mount Gradio at /ui
app = mount_gradio_app(app, gr_app, path="/ui")
|