damoojeje commited on
Commit
a857b53
·
verified ·
1 Parent(s): ad0baa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -76
app.py CHANGED
@@ -8,19 +8,20 @@ import torch
8
  import nltk
9
  import traceback
10
  import docx2txt
11
- import logging
12
  from PIL import Image
13
  from io import BytesIO
14
  from tqdm import tqdm
15
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
16
  from sentence_transformers import SentenceTransformer, util
17
- from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktTrainer
18
 
19
- # ---------------- Logger Setup ----------------
20
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
21
- logger = logging.getLogger("SmartManuals")
 
 
22
 
23
- # ---------------- Config ----------------
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  MANUALS_DIR = "Manuals"
26
  CHROMA_PATH = "chroma_store"
@@ -30,25 +31,18 @@ CHUNK_OVERLAP = 100
30
  MAX_CONTEXT_CHUNKS = 3
31
  MODEL_ID = "ibm-granite/granite-vision-3.2-2b"
32
 
 
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
- # ---------------- Sentence Tokenizer (Persistent) ----------------
36
- try:
37
- nltk.data.find("tokenizers/punkt")
38
- except LookupError:
39
- nltk.download("punkt")
40
-
41
- tokenizer_punkt = PunktSentenceTokenizer()
42
-
43
  # ---------------- Text Helpers ----------------
44
  def clean(text):
45
  return "\n".join([line.strip() for line in text.splitlines() if line.strip()])
46
 
47
  def split_sentences(text):
48
  try:
49
- return tokenizer_punkt.tokenize(text)
50
- except Exception as e:
51
- logger.warning("Tokenizer fallback: simple split. Reason: %s", e)
52
  return text.split(". ")
53
 
54
  def split_chunks(sentences, max_tokens=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
@@ -80,71 +74,86 @@ def extract_pdf_text(path):
80
  text = pytesseract.image_to_string(img)
81
  chunks.append((path, i + 1, clean(text)))
82
  except Exception as e:
83
- logger.error("PDF read error [%s]: %s", path, e)
84
  return chunks
85
 
86
  def extract_docx_text(path):
87
  try:
88
  return [(path, 1, clean(docx2txt.process(path)))]
89
  except Exception as e:
90
- logger.error("DOCX read error [%s]: %s", path, e)
91
  return []
92
 
93
  # ---------------- Embedding ----------------
94
  def embed_all():
95
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
96
- embedder.eval()
97
- client = chromadb.PersistentClient(path=CHROMA_PATH)
 
 
 
98
 
99
  try:
 
100
  client.delete_collection(COLLECTION_NAME)
101
- except:
102
- pass
103
- collection = client.get_or_create_collection(COLLECTION_NAME)
 
104
 
105
  docs, ids, metas = [], [], []
106
- logger.info("📄 Processing manuals...")
107
-
108
- for fname in os.listdir(MANUALS_DIR):
109
- fpath = os.path.join(MANUALS_DIR, fname)
110
- if fname.lower().endswith(".pdf"):
111
- pages = extract_pdf_text(fpath)
112
- elif fname.lower().endswith(".docx"):
113
- pages = extract_docx_text(fpath)
114
- else:
115
- continue
116
-
117
- for path, page, text in pages:
118
- for i, chunk in enumerate(split_chunks(split_sentences(text))):
119
- chunk_id = f"{fname}::{page}::{i}"
120
- docs.append(chunk)
121
- ids.append(chunk_id)
122
- metas.append({"source": fname, "page": page})
123
-
124
- if len(docs) >= 16:
125
- embs = embedder.encode(docs).tolist()
126
- collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
127
- docs, ids, metas = [], [], []
128
-
129
- if docs:
130
- embs = embedder.encode(docs).tolist()
131
- collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
132
-
133
- logger.info("✅ Embedded %d chunks.", len(ids))
134
- return collection, embedder
 
 
 
 
 
135
 
136
  # ---------------- Model Setup ----------------
137
  def load_model():
138
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
139
- model = AutoModelForCausalLM.from_pretrained(
140
- MODEL_ID,
141
- token=HF_TOKEN,
142
- device_map="auto" if torch.cuda.is_available() else None,
143
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
144
- ).to(device)
145
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
146
- return pipe, tokenizer
 
 
 
 
147
 
 
148
  def ask_model(question, context, pipe, tokenizer):
149
  prompt = f"""Use only the following context to answer. If uncertain, say \"I don't know.\"
150
 
@@ -157,37 +166,48 @@ A:"""
157
  output = pipe(prompt, max_new_tokens=512)[0]["generated_text"]
158
  return output.split("A:")[-1].strip()
159
 
160
- # ---------------- Query ----------------
161
  def get_answer(question):
 
 
162
  try:
163
- query_emb = embedder.encode(question, convert_to_tensor=True)
164
  results = db.query(query_texts=[question], n_results=MAX_CONTEXT_CHUNKS)
165
  context = "\n\n".join(results["documents"][0])
166
- source_info = "\n\n".join([
167
- f"📄 Source: {m.get('source', 'N/A')} (Page {m.get('page', 'N/A')})" for m in results["metadatas"][0]
168
- ])
169
- answer = ask_model(question, context, model_pipe, model_tokenizer)
170
- return f"{answer}\n\n---\n{source_info}"
171
  except Exception as e:
172
- logger.error(" Query error: %s", e)
173
  return f"Error: {e}"
174
 
175
  # ---------------- UI ----------------
176
  with gr.Blocks() as demo:
177
- gr.Markdown("## 🤖 SmartManuals-AI (Granite 3.2-2B)")
178
  with gr.Row():
179
  question = gr.Textbox(label="Ask your question")
180
  ask = gr.Button("Ask")
181
- answer = gr.Textbox(label="Answer", lines=10)
182
- ask.click(fn=get_answer, inputs=question, outputs=answer)
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  # Embed + Load Model at Startup
185
  try:
186
  db, embedder = embed_all()
187
- model_pipe, model_tokenizer = load_model()
188
  except Exception as e:
189
- logger.exception(" Startup failure: %s", e)
190
  db, embedder = None, None
 
 
 
 
 
191
  model_pipe, model_tokenizer = None, None
192
 
193
  if __name__ == "__main__":
 
8
  import nltk
9
  import traceback
10
  import docx2txt
 
11
  from PIL import Image
12
  from io import BytesIO
13
  from tqdm import tqdm
14
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
15
  from sentence_transformers import SentenceTransformer, util
16
+ from nltk.tokenize import sent_tokenize
17
 
18
+ # Ensure punkt is downloaded
19
+ try:
20
+ nltk.data.find("tokenizers/punkt")
21
+ except LookupError:
22
+ nltk.download("punkt")
23
 
24
+ # Configuration
25
  HF_TOKEN = os.getenv("HF_TOKEN")
26
  MANUALS_DIR = "Manuals"
27
  CHROMA_PATH = "chroma_store"
 
31
  MAX_CONTEXT_CHUNKS = 3
32
  MODEL_ID = "ibm-granite/granite-vision-3.2-2b"
33
 
34
+ # Device selection
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
 
 
 
 
 
 
 
 
 
37
  # ---------------- Text Helpers ----------------
38
  def clean(text):
39
  return "\n".join([line.strip() for line in text.splitlines() if line.strip()])
40
 
41
  def split_sentences(text):
42
  try:
43
+ return sent_tokenize(text)
44
+ except:
45
+ print("\u26a0\ufe0f Tokenizer fallback: simple split.")
46
  return text.split(". ")
47
 
48
  def split_chunks(sentences, max_tokens=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
 
74
  text = pytesseract.image_to_string(img)
75
  chunks.append((path, i + 1, clean(text)))
76
  except Exception as e:
77
+ print("\u274c PDF read error:", path, e)
78
  return chunks
79
 
80
  def extract_docx_text(path):
81
  try:
82
  return [(path, 1, clean(docx2txt.process(path)))]
83
  except Exception as e:
84
+ print("\u274c DOCX read error:", path, e)
85
  return []
86
 
87
  # ---------------- Embedding ----------------
88
  def embed_all():
89
+ try:
90
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
91
+ embedder.eval()
92
+ except Exception as e:
93
+ print("\u274c Failed to load SentenceTransformer:", e)
94
+ return None, None
95
 
96
  try:
97
+ client = chromadb.PersistentClient(path=CHROMA_PATH)
98
  client.delete_collection(COLLECTION_NAME)
99
+ collection = client.get_or_create_collection(COLLECTION_NAME)
100
+ except Exception as e:
101
+ print("\u274c Failed to initialize ChromaDB:", e)
102
+ return None, None
103
 
104
  docs, ids, metas = [], [], []
105
+ print("\ud83d\udcc4 Processing manuals...")
106
+
107
+ try:
108
+ for fname in os.listdir(MANUALS_DIR):
109
+ fpath = os.path.join(MANUALS_DIR, fname)
110
+ if fname.lower().endswith(".pdf"):
111
+ pages = extract_pdf_text(fpath)
112
+ elif fname.lower().endswith(".docx"):
113
+ pages = extract_docx_text(fpath)
114
+ else:
115
+ continue
116
+
117
+ for path, page, text in pages:
118
+ for i, chunk in enumerate(split_chunks(split_sentences(text))):
119
+ chunk_id = f"{fname}::{page}::{i}"
120
+ docs.append(chunk)
121
+ ids.append(chunk_id)
122
+ metas.append({"source": fname, "page": page})
123
+
124
+ if len(docs) >= 16:
125
+ embs = embedder.encode(docs).tolist()
126
+ collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
127
+ docs, ids, metas = [], [], []
128
+
129
+ if docs:
130
+ embs = embedder.encode(docs).tolist()
131
+ collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
132
+
133
+ print(f"\u2705 Embedded {len(ids)} chunks.")
134
+ return collection, embedder
135
+
136
+ except Exception as e:
137
+ print("\u274c Error during embedding:", e)
138
+ return None, None
139
 
140
  # ---------------- Model Setup ----------------
141
  def load_model():
142
+ try:
143
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
144
+ model = AutoModelForCausalLM.from_pretrained(
145
+ MODEL_ID,
146
+ token=HF_TOKEN,
147
+ device_map="auto" if torch.cuda.is_available() else None,
148
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
149
+ ).to(device)
150
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
151
+ return pipe, tokenizer
152
+ except Exception as e:
153
+ print("\u274c Failed to load model:", e)
154
+ return None, None
155
 
156
+ # ---------------- QA Logic ----------------
157
  def ask_model(question, context, pipe, tokenizer):
158
  prompt = f"""Use only the following context to answer. If uncertain, say \"I don't know.\"
159
 
 
166
  output = pipe(prompt, max_new_tokens=512)[0]["generated_text"]
167
  return output.split("A:")[-1].strip()
168
 
 
169
  def get_answer(question):
170
+ if not all([embedder, db, model_pipe, model_tokenizer]):
171
+ return "\u274c System not initialized. Check logs or try restarting the app."
172
  try:
 
173
  results = db.query(query_texts=[question], n_results=MAX_CONTEXT_CHUNKS)
174
  context = "\n\n".join(results["documents"][0])
175
+ return ask_model(question, context, model_pipe, model_tokenizer)
 
 
 
 
176
  except Exception as e:
177
+ print("\u274c Query error:", e)
178
  return f"Error: {e}"
179
 
180
  # ---------------- UI ----------------
181
  with gr.Blocks() as demo:
182
+ gr.Markdown("## \ud83e\udd16 SmartManuals-AI (Granite 3.2-2B)")
183
  with gr.Row():
184
  question = gr.Textbox(label="Ask your question")
185
  ask = gr.Button("Ask")
186
+ answer = gr.Textbox(label="Answer", lines=8)
187
+ status = gr.Markdown(visible=False)
188
+
189
+ def wrapped_get_answer(q):
190
+ ans = get_answer(q)
191
+ return ans, "" # hide status after success
192
+
193
+ ask.click(fn=wrapped_get_answer, inputs=question, outputs=[answer, status])
194
+
195
+ # Show status on startup error
196
+ if not all([embedder, db, model_pipe, model_tokenizer]):
197
+ status.visible = True
198
+ status.value = "\u26a0\ufe0f Initialization failed. Check logs or your HF_TOKEN."
199
 
200
  # Embed + Load Model at Startup
201
  try:
202
  db, embedder = embed_all()
 
203
  except Exception as e:
204
+ print("\u274c Embedding failed:", e)
205
  db, embedder = None, None
206
+
207
+ try:
208
+ model_pipe, model_tokenizer = load_model()
209
+ except Exception as e:
210
+ print("\u274c Model loading failed:", e)
211
  model_pipe, model_tokenizer = None, None
212
 
213
  if __name__ == "__main__":