damoojeje commited on
Commit
df365ca
·
verified ·
1 Parent(s): 0113010

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -125
app.py CHANGED
@@ -1,156 +1,177 @@
1
- # app.py
2
- # SmartManuals-AI: Hugging Face Space version
3
-
4
- import os, json, fitz, nltk, chromadb, io
5
- import torch
6
  from tqdm import tqdm
7
- from PIL import Image
8
  from docx import Document
9
- from sentence_transformers import SentenceTransformer, util
10
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
11
- from nltk.tokenize import sent_tokenize
12
  import pytesseract
 
 
 
 
 
13
  import gradio as gr
14
 
15
- # ----------------------
16
- # Configuration
17
- # ----------------------
18
  MANUALS_FOLDER = "./Manuals"
19
- CHUNKS_JSONL = "chunks.jsonl"
20
  CHROMA_PATH = "./chroma_store"
21
  COLLECTION_NAME = "manual_chunks"
22
  CHUNK_SIZE = 750
23
  CHUNK_OVERLAP = 100
24
- MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
25
- HF_TOKEN = os.getenv("HF_TOKEN")
26
-
27
- # ----------------------
28
- # Ensure punkt is downloaded
29
- # ----------------------
30
- nltk.download("punkt")
31
-
32
- # ----------------------
33
- # Utilities
34
- # ----------------------
35
- def extract_text_from_pdf(path):
36
- doc = fitz.open(path)
37
- text = ""
38
- for page in doc:
39
- t = page.get_text()
40
- if not t.strip():
41
- pix = page.get_pixmap(dpi=300)
42
- img = Image.open(io.BytesIO(pix.tobytes("png")))
43
- t = pytesseract.image_to_string(img)
44
- text += t + "\n"
45
- return text
46
-
47
- def extract_text_from_docx(path):
48
- doc = Document(path)
49
- return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
50
 
 
 
 
51
  def clean(text):
52
- return "\n".join([line.strip() for line in text.splitlines() if line.strip()])
 
53
 
54
  def split_sentences(text):
55
- return sent_tokenize(text)
56
 
57
- def chunk_sentences(sentences, max_tokens=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
58
- chunks, chunk, count = [], [], 0
59
- for s in sentences:
60
- words = s.split()
61
- if count + len(words) > max_tokens:
62
  chunks.append(" ".join(chunk))
63
- chunk = chunk[-overlap:] if overlap > 0 else []
64
- count = sum(len(x.split()) for x in chunk)
65
- chunk.append(s)
66
- count += len(words)
67
  if chunk:
68
  chunks.append(" ".join(chunk))
69
  return chunks
70
 
71
- def get_metadata(filename):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  name = filename.lower()
73
- return {
74
- "source_file": filename,
75
- "doc_type": "service manual" if "sm" in name else "owner manual" if "om" in name else "unknown",
76
- "model": "se3hd" if "se3hd" in name else "unknown"
77
- }
78
-
79
- # ----------------------
80
- # Embedding
81
- # ----------------------
 
 
 
 
 
 
 
 
 
 
 
82
  def embed_all():
83
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
84
  client = chromadb.PersistentClient(path=CHROMA_PATH)
85
- try:
86
  client.delete_collection(COLLECTION_NAME)
87
- except:
88
- pass
89
  collection = client.create_collection(COLLECTION_NAME)
90
- chunks, metadatas, ids = [], [], []
91
- files = os.listdir(MANUALS_FOLDER)
92
- idx = 0
93
- for file in tqdm(files):
94
- path = os.path.join(MANUALS_FOLDER, file)
95
- text = extract_text_from_pdf(path) if file.endswith(".pdf") else extract_text_from_docx(path)
96
- meta = get_metadata(file)
97
- sents = split_sentences(clean(text))
98
- for i, chunk in enumerate(chunk_sentences(sents)):
99
- chunks.append(chunk)
100
- ids.append(f"{file}::chunk_{i}")
101
- metadatas.append(meta)
102
- if len(chunks) >= 16:
103
- emb = embedder.encode(chunks).tolist()
104
- collection.add(documents=chunks, ids=ids, metadatas=metadatas, embeddings=emb)
105
- chunks, ids, metadatas = [], [], []
106
- if chunks:
107
- emb = embedder.encode(chunks).tolist()
108
- collection.add(documents=chunks, ids=ids, metadatas=metadatas, embeddings=emb)
109
- return collection, embedder
110
 
111
- # ----------------------
112
- # Model setup
113
- # ----------------------
114
- def load_model():
115
- device = 0 if torch.cuda.is_available() else -1
116
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
117
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=HF_TOKEN)
118
- return pipeline("text-generation", model=model, tokenizer=tokenizer, device=device, max_new_tokens=512)
119
-
120
- # ----------------------
121
- # RAG Pipeline
122
- # ----------------------
123
- def answer_query(question):
124
- results = db.query(query_texts=[question], n_results=5)
125
- context = "\n\n".join(results["documents"][0])
126
- prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
127
- You are a helpful assistant. Use the provided context to answer questions. If you don't know, say 'I don't know.'
128
- <context>
129
- {context}
130
- </context>
131
- <|start_header_id|>user<|end_header_id|>
132
- {question}<|start_header_id|>assistant<|end_header_id|>"""
133
- return llm(prompt)[0]["generated_text"].split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
134
-
135
- # ----------------------
136
- # UI
137
- # ----------------------
138
- with gr.Blocks() as demo:
139
- status = gr.Textbox(label="Status", value="Embedding manuals... Please wait.", interactive=False)
140
- question = gr.Textbox(label="Ask a Question")
141
- submit = gr.Button("🔍 Ask")
142
- answer = gr.Textbox(label="Answer", lines=8)
143
 
144
- def handle_query(q):
145
- return answer_query(q)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- submit.click(fn=handle_query, inputs=question, outputs=answer)
148
 
149
- # ----------------------
150
- # Startup
151
- # ----------------------
152
- status_text = "Embedding manuals and loading model..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  db, embedder = embed_all()
154
- llm = load_model()
155
- status_text = "Ready!"
 
 
 
 
 
 
 
 
 
156
  demo.launch()
 
1
+ import os
2
+ import json
3
+ import fitz # PyMuPDF
4
+ import re
 
5
  from tqdm import tqdm
 
6
  from docx import Document
7
+ from PIL import Image
 
 
8
  import pytesseract
9
+ import io
10
+ import torch
11
+ import chromadb
12
+ from sentence_transformers import SentenceTransformer, util
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  import gradio as gr
15
 
16
+ # ---------------------------
17
+ # 📁 Configuration
18
+ # ---------------------------
19
  MANUALS_FOLDER = "./Manuals"
 
20
  CHROMA_PATH = "./chroma_store"
21
  COLLECTION_NAME = "manual_chunks"
22
  CHUNK_SIZE = 750
23
  CHUNK_OVERLAP = 100
24
+ MAX_CONTEXT_CHUNKS = 3
25
+ HF_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
26
+ HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # ---------------------------
29
+ # 🧹 Helpers
30
+ # ---------------------------
31
  def clean(text):
32
+ lines = text.splitlines()
33
+ return "\n".join(line.strip() for line in lines if line.strip())
34
 
35
  def split_sentences(text):
36
+ return re.split(r'(?<=[.!?])\s+', text.strip())
37
 
38
+ def chunk_sentences(sentences, max_len=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
39
+ chunks, chunk, length = [], [], 0
40
+ for sent in sentences:
41
+ tokens = len(sent.split())
42
+ if length + tokens > max_len and chunk:
43
  chunks.append(" ".join(chunk))
44
+ chunk = chunk[-overlap:] if overlap else []
45
+ length = sum(len(s.split()) for s in chunk)
46
+ chunk.append(sent)
47
+ length += tokens
48
  if chunk:
49
  chunks.append(" ".join(chunk))
50
  return chunks
51
 
52
+ def extract_text_from_pdf(path):
53
+ doc = fitz.open(path)
54
+ full_text = []
55
+ for page in doc:
56
+ text = page.get_text().strip()
57
+ if not text:
58
+ try:
59
+ pix = page.get_pixmap(dpi=300)
60
+ img_data = pix.tobytes("png")
61
+ img = Image.open(io.BytesIO(img_data))
62
+ text = pytesseract.image_to_string(img).strip()
63
+ except Exception:
64
+ text = ""
65
+ full_text.append(text)
66
+ return "\n".join(full_text)
67
+
68
+ def extract_text_from_docx(path):
69
+ doc = Document(path)
70
+ return "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
71
+
72
+ def extract_metadata(filename):
73
  name = filename.lower()
74
+ model = next((m for m in ["se3hd", "se3", "se4", "symbio", "explore", "integrity x", "integrity sl", "everest", "engage", "inspire", "discover", "95t", "95x", "95c", "95r", "97c"] if m in name), "unknown")
75
+ if "om" in name or "owner" in name:
76
+ doc_type = "owner manual"
77
+ elif "sm" in name or "service" in name:
78
+ doc_type = "service manual"
79
+ elif "assembly" in name:
80
+ doc_type = "assembly instructions"
81
+ elif "alert" in name:
82
+ doc_type = "installer alert"
83
+ elif "parts" in name:
84
+ doc_type = "parts manual"
85
+ elif "bulletin" in name:
86
+ doc_type = "service bulletin"
87
+ else:
88
+ doc_type = "unknown"
89
+ return model, doc_type
90
+
91
+ # ---------------------------
92
+ # 🚀 Build ChromaDB at Startup
93
+ # ---------------------------
94
  def embed_all():
 
95
  client = chromadb.PersistentClient(path=CHROMA_PATH)
96
+ if COLLECTION_NAME in [c.name for c in client.list_collections()]:
97
  client.delete_collection(COLLECTION_NAME)
 
 
98
  collection = client.create_collection(COLLECTION_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
101
+ records = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ for fname in os.listdir(MANUALS_FOLDER):
104
+ path = os.path.join(MANUALS_FOLDER, fname)
105
+ if not fname.lower().endswith((".pdf", ".docx")):
106
+ continue
107
+ text = extract_text_from_pdf(path) if fname.endswith(".pdf") else extract_text_from_docx(path)
108
+ sents = split_sentences(clean(text))
109
+ chunks = chunk_sentences(sents)
110
+ model, doc_type = extract_metadata(fname)
111
+ for i, chunk in enumerate(chunks):
112
+ records.append({
113
+ "id": f"{fname}::chunk_{i+1}",
114
+ "text": chunk,
115
+ "metadata": {"source_file": fname, "model": model, "doc_type": doc_type}
116
+ })
117
+
118
+ for i in range(0, len(records), 16):
119
+ batch = records[i:i+16]
120
+ texts = [r["text"] for r in batch]
121
+ ids = [r["id"] for r in batch]
122
+ metas = [r["metadata"] for r in batch]
123
+ embeddings = embedder.encode(texts).tolist()
124
+ collection.add(documents=texts, ids=ids, metadatas=metas, embeddings=embeddings)
125
 
126
+ return collection, embedder
127
 
128
+ # ---------------------------
129
+ # 💬 Load HF Model
130
+ # ---------------------------
131
+ llm_pipe = None
132
+ if HF_TOKEN:
133
+ tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_TOKEN)
134
+ model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_TOKEN, torch_dtype=torch.float32)
135
+ llm_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
136
+
137
+ # ---------------------------
138
+ # 🔎 RAG Function
139
+ # ---------------------------
140
+ def run_query(question):
141
+ if not question.strip():
142
+ return "Please enter a question."
143
+ if not db or not embedder:
144
+ return "Chroma or embedder not ready."
145
+
146
+ q_embed = embedder.encode(question).tolist()
147
+ res = db.query(query_embeddings=[q_embed], n_results=MAX_CONTEXT_CHUNKS)
148
+ contexts = res["documents"][0]
149
+ prompt = """
150
+ You are a technical assistant.
151
+ Answer only using the context below.
152
+ Say 'I don't know' if not found.
153
+
154
+ """
155
+ context_text = "\n\n".join(contexts)
156
+ final_prompt = prompt + f"Context:\n{context_text}\n\nQuestion: {question}\nAnswer:"
157
+ if llm_pipe:
158
+ result = llm_pipe(final_prompt, max_new_tokens=300)[0]['generated_text']
159
+ return result.split("Answer:")[-1].strip()
160
+ return "Model not loaded."
161
+
162
+ # ---------------------------
163
+ # 🧠 Init embeddings once
164
+ # ---------------------------
165
  db, embedder = embed_all()
166
+
167
+ # ---------------------------
168
+ # 🎛️ Gradio Interface
169
+ # ---------------------------
170
+ with gr.Blocks() as demo:
171
+ gr.Markdown("# 🤖 SmartManuals-AI: Ask Technical Questions about Your Manuals")
172
+ question = gr.Textbox(placeholder="e.g. How do I reset the treadmill console?", label="Enter Question")
173
+ submit = gr.Button("Get Answer")
174
+ output = gr.Textbox(label="Answer")
175
+ submit.click(fn=run_query, inputs=question, outputs=output)
176
+
177
  demo.launch()