Spaces:
Running
Running
# app.py | |
# SmartManuals-AI: Hugging Face Space version | |
import os, json, fitz, nltk, chromadb, io | |
import torch | |
from tqdm import tqdm | |
from PIL import Image | |
from docx import Document | |
from sentence_transformers import SentenceTransformer, util | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from nltk.tokenize import sent_tokenize | |
import pytesseract | |
import gradio as gr | |
# ---------------------- | |
# Configuration | |
# ---------------------- | |
MANUALS_FOLDER = "./Manuals" | |
CHUNKS_JSONL = "chunks.jsonl" | |
CHROMA_PATH = "./chroma_store" | |
COLLECTION_NAME = "manual_chunks" | |
CHUNK_SIZE = 750 | |
CHUNK_OVERLAP = 100 | |
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# ---------------------- | |
# Ensure punkt is downloaded | |
# ---------------------- | |
nltk.download("punkt") | |
# ---------------------- | |
# Utilities | |
# ---------------------- | |
def extract_text_from_pdf(path): | |
doc = fitz.open(path) | |
text = "" | |
for page in doc: | |
t = page.get_text() | |
if not t.strip(): | |
pix = page.get_pixmap(dpi=300) | |
img = Image.open(io.BytesIO(pix.tobytes("png"))) | |
t = pytesseract.image_to_string(img) | |
text += t + "\n" | |
return text | |
def extract_text_from_docx(path): | |
doc = Document(path) | |
return "\n".join(p.text for p in doc.paragraphs if p.text.strip()) | |
def clean(text): | |
return "\n".join([line.strip() for line in text.splitlines() if line.strip()]) | |
def split_sentences(text): | |
return sent_tokenize(text) | |
def chunk_sentences(sentences, max_tokens=CHUNK_SIZE, overlap=CHUNK_OVERLAP): | |
chunks, chunk, count = [], [], 0 | |
for s in sentences: | |
words = s.split() | |
if count + len(words) > max_tokens: | |
chunks.append(" ".join(chunk)) | |
chunk = chunk[-overlap:] if overlap > 0 else [] | |
count = sum(len(x.split()) for x in chunk) | |
chunk.append(s) | |
count += len(words) | |
if chunk: | |
chunks.append(" ".join(chunk)) | |
return chunks | |
def get_metadata(filename): | |
name = filename.lower() | |
return { | |
"source_file": filename, | |
"doc_type": "service manual" if "sm" in name else "owner manual" if "om" in name else "unknown", | |
"model": "se3hd" if "se3hd" in name else "unknown" | |
} | |
# ---------------------- | |
# Embedding | |
# ---------------------- | |
def embed_all(): | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
client = chromadb.PersistentClient(path=CHROMA_PATH) | |
try: | |
client.delete_collection(COLLECTION_NAME) | |
except: | |
pass | |
collection = client.create_collection(COLLECTION_NAME) | |
chunks, metadatas, ids = [], [], [] | |
files = os.listdir(MANUALS_FOLDER) | |
idx = 0 | |
for file in tqdm(files): | |
path = os.path.join(MANUALS_FOLDER, file) | |
text = extract_text_from_pdf(path) if file.endswith(".pdf") else extract_text_from_docx(path) | |
meta = get_metadata(file) | |
sents = split_sentences(clean(text)) | |
for i, chunk in enumerate(chunk_sentences(sents)): | |
chunks.append(chunk) | |
ids.append(f"{file}::chunk_{i}") | |
metadatas.append(meta) | |
if len(chunks) >= 16: | |
emb = embedder.encode(chunks).tolist() | |
collection.add(documents=chunks, ids=ids, metadatas=metadatas, embeddings=emb) | |
chunks, ids, metadatas = [], [], [] | |
if chunks: | |
emb = embedder.encode(chunks).tolist() | |
collection.add(documents=chunks, ids=ids, metadatas=metadatas, embeddings=emb) | |
return collection, embedder | |
# ---------------------- | |
# Model setup | |
# ---------------------- | |
def load_model(): | |
device = 0 if torch.cuda.is_available() else -1 | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=HF_TOKEN) | |
return pipeline("text-generation", model=model, tokenizer=tokenizer, device=device, max_new_tokens=512) | |
# ---------------------- | |
# RAG Pipeline | |
# ---------------------- | |
def answer_query(question): | |
results = db.query(query_texts=[question], n_results=5) | |
context = "\n\n".join(results["documents"][0]) | |
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> | |
You are a helpful assistant. Use the provided context to answer questions. If you don't know, say 'I don't know.' | |
<context> | |
{context} | |
</context> | |
<|start_header_id|>user<|end_header_id|> | |
{question}<|start_header_id|>assistant<|end_header_id|>""" | |
return llm(prompt)[0]["generated_text"].split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip() | |
# ---------------------- | |
# UI | |
# ---------------------- | |
with gr.Blocks() as demo: | |
status = gr.Textbox(label="Status", value="Embedding manuals... Please wait.", interactive=False) | |
question = gr.Textbox(label="Ask a Question") | |
submit = gr.Button("🔍 Ask") | |
answer = gr.Textbox(label="Answer", lines=8) | |
def handle_query(q): | |
return answer_query(q) | |
submit.click(fn=handle_query, inputs=question, outputs=answer) | |
# ---------------------- | |
# Startup | |
# ---------------------- | |
status_text = "Embedding manuals and loading model..." | |
db, embedder = embed_all() | |
llm = load_model() | |
status_text = "Ready!" | |
demo.launch() | |