Spaces:
Sleeping
Sleeping
import torch | |
import json | |
import os | |
import faiss | |
import numpy as np | |
from pptx import Presentation | |
from fastapi import FastAPI, UploadFile, File | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from sentence_transformers import SentenceTransformer | |
from io import BytesIO | |
# ---------------------------- # | |
# CONFIGURATION | |
# ---------------------------- # | |
MODEL_NAME = "./models/facebook-opt-1.3b" | |
SUMMARIZATION_MODEL = "./models/bart-large-cnn" | |
EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2" | |
DATA_DIRECTORY = "./dataset/" | |
# ---------------------------- # | |
# FUNCTION TO LOAD JSON FILES | |
# ---------------------------- # | |
def load_text_from_json(directory): | |
text_data = set() # Use set to remove duplicates | |
for filename in os.listdir(directory): | |
if filename.endswith(".json"): | |
with open(os.path.join(directory, filename), "r", encoding="utf-8") as file: | |
data = json.load(file) | |
for entry in data.get("data", []): | |
question = entry.get("question", "").strip() | |
answer = entry.get("answer", "").strip() | |
if question and answer: | |
text_data.add(f"Q: {question} A: {answer}") | |
return list(text_data) | |
# ---------------------------- # | |
# FUNCTION TO LOAD POWERPOINT FILES | |
# ---------------------------- # | |
def extract_text_from_pptx(file_path): | |
prs = Presentation(file_path) | |
text_data = [] | |
for slide in prs.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text_data.append(shape.text.strip()) | |
return " ".join(text_data) | |
def load_text_from_pptx(directory): | |
text_data = set() | |
for filename in os.listdir(directory): | |
if filename.endswith(".pptx"): | |
pptx_text = extract_text_from_pptx(os.path.join(directory, filename)) | |
text_data.add(pptx_text) | |
return list(text_data) | |
# ---------------------------- # | |
# LOAD ALL TEXT DATA | |
# ---------------------------- # | |
all_text = load_text_from_json(DATA_DIRECTORY) + load_text_from_pptx(DATA_DIRECTORY) | |
# ---------------------------- # | |
# CHUNK DATA PROPERLY | |
# ---------------------------- # | |
CHUNK_SIZE = 500 | |
chunks = set() | |
for text in all_text: | |
sentences = text.split(". ") | |
temp_chunk = "" | |
for sentence in sentences: | |
if len(temp_chunk) + len(sentence) < CHUNK_SIZE: | |
temp_chunk += sentence + ". " | |
else: | |
chunks.add(temp_chunk.strip()) # Store chunk | |
temp_chunk = sentence + ". " | |
if temp_chunk: | |
chunks.add(temp_chunk.strip()) # Store last chunk | |
chunks = list(chunks) # Convert to list after deduplication | |
# ---------------------------- # | |
# EMBEDDING MODEL & FAISS VECTOR SEARCH | |
# ---------------------------- # | |
embedder = SentenceTransformer(EMBEDDING_MODEL, local_files_only=True) | |
chunk_embeddings = embedder.encode(chunks, convert_to_numpy=True) | |
# FAISS index | |
index = faiss.IndexFlatL2(chunk_embeddings.shape[1]) | |
index.add(chunk_embeddings) | |
# ---------------------------- # | |
# LOAD LLM MODEL | |
# ---------------------------- # | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu" | |
) | |
# Summarization pipeline | |
summarizer = pipeline("summarization", model=SUMMARIZATION_MODEL) | |
# ---------------------------- # | |
# FASTAPI SETUP | |
# ---------------------------- # | |
app = FastAPI() | |
def retrieve_relevant_text(question, top_k=3): | |
question_embedding = embedder.encode([question], convert_to_numpy=True) | |
_, idxs = index.search(question_embedding, top_k) | |
retrieved_texts = [chunks[idx] for idx in idxs[0]] | |
# Filter out chunks that contain the same question | |
filtered_chunks = [text for text in retrieved_texts if question.lower() not in text.lower()] | |
unique_texts = list(set(filtered_chunks)) | |
context_text = " ".join(unique_texts) | |
if len(context_text) > 1000: | |
context_text = summarizer(context_text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"] | |
return context_text | |
async def upload_file(file: UploadFile = File(...)): | |
global chunks, index, chunk_embeddings | |
filename = file.filename | |
content = await file.read() | |
new_texts = [] | |
try: | |
# -------------------- # | |
# Process .json files | |
# -------------------- # | |
if filename.endswith(".json"): | |
data = json.loads(content) | |
for entry in data.get("data", []): | |
question = entry.get("question", "").strip() | |
answer = entry.get("answer", "").strip() | |
if question and answer: | |
new_texts.append(f"Q: {question} A: {answer}") | |
# -------------------- # | |
# Process .pptx files | |
# -------------------- # | |
elif filename.endswith(".pptx"): | |
prs = Presentation(BytesIO(content)) | |
ppt_text = [] | |
for slide in prs.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
ppt_text.append(shape.text.strip()) | |
new_texts.append(" ".join(ppt_text)) | |
else: | |
return {"error": "Unsupported file type. Use .json or .pptx"} | |
# -------------------- # | |
# Chunk and embed | |
# -------------------- # | |
new_chunks = set() | |
for text in new_texts: | |
sentences = text.split(". ") | |
temp = "" | |
for s in sentences: | |
if len(temp) + len(s) < CHUNK_SIZE: | |
temp += s + ". " | |
else: | |
new_chunks.add(temp.strip()) | |
temp = s + ". " | |
if temp: | |
new_chunks.add(temp.strip()) | |
# Remove existing chunks (dedup) | |
new_chunks = list(set(new_chunks) - set(chunks)) | |
if not new_chunks: | |
return {"message": "No new unique chunks to add."} | |
# Encode and update FAISS | |
new_embeddings = embedder.encode(new_chunks, convert_to_numpy=True) | |
index.add(new_embeddings) | |
chunks.extend(new_chunks) | |
return { | |
"status": "success", | |
"new_chunks_added": len(new_chunks), | |
"total_chunks": len(chunks) | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
def faq(question: str): | |
"""Answer user queries using retrieved knowledge.""" | |
retrieved_text = retrieve_relevant_text(question) | |
prompt = ( | |
f"{retrieved_text.strip()}\n\n" | |
f"Answer the following question based only on the above context:\n" | |
f"{question.strip()}\n\n" | |
f"Answer:" | |
) | |
inputs = tokenizer(prompt, return_tensors="pt").to("cpu") | |
with torch.no_grad(): | |
output = model.generate( | |
**inputs, | |
max_length=200, | |
repetition_penalty=1.3, | |
no_repeat_ngram_size=4, | |
temperature=0.7, | |
do_sample=False, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
raw_answer = tokenizer.decode(output[0], skip_special_tokens=True) | |
# ---------------------------- # | |
# POST-PROCESSING CLEANUP | |
# ---------------------------- # | |
cleaned_answer = raw_answer | |
# Remove the prompt (everything before final 'Answer:' keyword) | |
if "Answer:" in cleaned_answer: | |
cleaned_answer = cleaned_answer.split("Answer:")[-1] | |
# Remove repeated question (case-insensitive) | |
question_lower = question.strip().lower() | |
cleaned_answer = cleaned_answer.strip() | |
if cleaned_answer.lower().startswith(question_lower): | |
cleaned_answer = cleaned_answer[len(question):].strip() | |
# Final touch: remove context/prompt tokens if they leaked | |
for token in ["Context:", "Question:", "Answer:"]: | |
cleaned_answer = cleaned_answer.replace(token, "").strip() | |
return {"answer": cleaned_answer} | |
# --------- Gradio UI --------- # | |
def gradio_upload(file): | |
if file is None: | |
return "No file selected." | |
try: | |
content = str(file) # Works for NamedString | |
import requests | |
base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860") | |
response = requests.post(f"{base_url}/upload/", json={"content": content}) | |
if response.status_code == 200: | |
return "β Data successfully uploaded and indexed!" | |
else: | |
return f"β Failed: {response.text}" | |
except Exception as e: | |
return f"β Error: {str(e)}" | |
gr_app = gr.Interface( | |
fn=gradio_upload, | |
inputs=gr.File(label="Upload .txt or .json file"), | |
outputs="text", | |
title="Upload Knowledge", | |
) | |
# Mount Gradio at /ui | |
app = mount_gradio_app(app, gr_app, path="/ui") | |