SmartManuals-AI / app.py
damoojeje's picture
Update app.py
bc25066 verified
raw
history blame
5.21 kB
# app.py
# SmartManuals-AI: Hugging Face Space version
import os, json, fitz, nltk, chromadb, io
import torch
from tqdm import tqdm
from PIL import Image
from docx import Document
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from nltk.tokenize import sent_tokenize
import pytesseract
import gradio as gr
# ----------------------
# Configuration
# ----------------------
MANUALS_FOLDER = "./Manuals"
CHUNKS_JSONL = "chunks.jsonl"
CHROMA_PATH = "./chroma_store"
COLLECTION_NAME = "manual_chunks"
CHUNK_SIZE = 750
CHUNK_OVERLAP = 100
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
HF_TOKEN = os.getenv("HF_TOKEN")
# ----------------------
# Ensure punkt is downloaded
# ----------------------
nltk.download("punkt")
# ----------------------
# Utilities
# ----------------------
def extract_text_from_pdf(path):
doc = fitz.open(path)
text = ""
for page in doc:
t = page.get_text()
if not t.strip():
pix = page.get_pixmap(dpi=300)
img = Image.open(io.BytesIO(pix.tobytes("png")))
t = pytesseract.image_to_string(img)
text += t + "\n"
return text
def extract_text_from_docx(path):
doc = Document(path)
return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
def clean(text):
return "\n".join([line.strip() for line in text.splitlines() if line.strip()])
def split_sentences(text):
return sent_tokenize(text)
def chunk_sentences(sentences, max_tokens=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
chunks, chunk, count = [], [], 0
for s in sentences:
words = s.split()
if count + len(words) > max_tokens:
chunks.append(" ".join(chunk))
chunk = chunk[-overlap:] if overlap > 0 else []
count = sum(len(x.split()) for x in chunk)
chunk.append(s)
count += len(words)
if chunk:
chunks.append(" ".join(chunk))
return chunks
def get_metadata(filename):
name = filename.lower()
return {
"source_file": filename,
"doc_type": "service manual" if "sm" in name else "owner manual" if "om" in name else "unknown",
"model": "se3hd" if "se3hd" in name else "unknown"
}
# ----------------------
# Embedding
# ----------------------
def embed_all():
embedder = SentenceTransformer("all-MiniLM-L6-v2")
client = chromadb.PersistentClient(path=CHROMA_PATH)
try:
client.delete_collection(COLLECTION_NAME)
except:
pass
collection = client.create_collection(COLLECTION_NAME)
chunks, metadatas, ids = [], [], []
files = os.listdir(MANUALS_FOLDER)
idx = 0
for file in tqdm(files):
path = os.path.join(MANUALS_FOLDER, file)
text = extract_text_from_pdf(path) if file.endswith(".pdf") else extract_text_from_docx(path)
meta = get_metadata(file)
sents = split_sentences(clean(text))
for i, chunk in enumerate(chunk_sentences(sents)):
chunks.append(chunk)
ids.append(f"{file}::chunk_{i}")
metadatas.append(meta)
if len(chunks) >= 16:
emb = embedder.encode(chunks).tolist()
collection.add(documents=chunks, ids=ids, metadatas=metadatas, embeddings=emb)
chunks, ids, metadatas = [], [], []
if chunks:
emb = embedder.encode(chunks).tolist()
collection.add(documents=chunks, ids=ids, metadatas=metadatas, embeddings=emb)
return collection, embedder
# ----------------------
# Model setup
# ----------------------
def load_model():
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=HF_TOKEN)
return pipeline("text-generation", model=model, tokenizer=tokenizer, device=device, max_new_tokens=512)
# ----------------------
# RAG Pipeline
# ----------------------
def answer_query(question):
results = db.query(query_texts=[question], n_results=5)
context = "\n\n".join(results["documents"][0])
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant. Use the provided context to answer questions. If you don't know, say 'I don't know.'
<context>
{context}
</context>
<|start_header_id|>user<|end_header_id|>
{question}<|start_header_id|>assistant<|end_header_id|>"""
return llm(prompt)[0]["generated_text"].split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
# ----------------------
# UI
# ----------------------
with gr.Blocks() as demo:
status = gr.Textbox(label="Status", value="Embedding manuals... Please wait.", interactive=False)
question = gr.Textbox(label="Ask a Question")
submit = gr.Button("🔍 Ask")
answer = gr.Textbox(label="Answer", lines=8)
def handle_query(q):
return answer_query(q)
submit.click(fn=handle_query, inputs=question, outputs=answer)
# ----------------------
# Startup
# ----------------------
status_text = "Embedding manuals and loading model..."
db, embedder = embed_all()
llm = load_model()
status_text = "Ready!"
demo.launch()