damoojeje commited on
Commit
7b8bb00
·
verified ·
1 Parent(s): a857b53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -73
app.py CHANGED
@@ -31,7 +31,6 @@ CHUNK_OVERLAP = 100
31
  MAX_CONTEXT_CHUNKS = 3
32
  MODEL_ID = "ibm-granite/granite-vision-3.2-2b"
33
 
34
- # Device selection
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
 
37
  # ---------------- Text Helpers ----------------
@@ -86,74 +85,59 @@ def extract_docx_text(path):
86
 
87
  # ---------------- Embedding ----------------
88
  def embed_all():
89
- try:
90
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
91
- embedder.eval()
92
- except Exception as e:
93
- print("\u274c Failed to load SentenceTransformer:", e)
94
- return None, None
95
 
96
  try:
97
- client = chromadb.PersistentClient(path=CHROMA_PATH)
98
  client.delete_collection(COLLECTION_NAME)
99
- collection = client.get_or_create_collection(COLLECTION_NAME)
100
- except Exception as e:
101
- print("\u274c Failed to initialize ChromaDB:", e)
102
- return None, None
103
 
104
  docs, ids, metas = [], [], []
105
  print("\ud83d\udcc4 Processing manuals...")
106
 
107
- try:
108
- for fname in os.listdir(MANUALS_DIR):
109
- fpath = os.path.join(MANUALS_DIR, fname)
110
- if fname.lower().endswith(".pdf"):
111
- pages = extract_pdf_text(fpath)
112
- elif fname.lower().endswith(".docx"):
113
- pages = extract_docx_text(fpath)
114
- else:
115
- continue
116
-
117
- for path, page, text in pages:
118
- for i, chunk in enumerate(split_chunks(split_sentences(text))):
119
- chunk_id = f"{fname}::{page}::{i}"
120
- docs.append(chunk)
121
- ids.append(chunk_id)
122
- metas.append({"source": fname, "page": page})
123
-
124
- if len(docs) >= 16:
125
- embs = embedder.encode(docs).tolist()
126
- collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
127
- docs, ids, metas = [], [], []
128
-
129
- if docs:
130
- embs = embedder.encode(docs).tolist()
131
- collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
132
-
133
- print(f"\u2705 Embedded {len(ids)} chunks.")
134
- return collection, embedder
135
-
136
- except Exception as e:
137
- print("\u274c Error during embedding:", e)
138
- return None, None
139
 
140
  # ---------------- Model Setup ----------------
141
  def load_model():
142
- try:
143
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
144
- model = AutoModelForCausalLM.from_pretrained(
145
- MODEL_ID,
146
- token=HF_TOKEN,
147
- device_map="auto" if torch.cuda.is_available() else None,
148
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
149
- ).to(device)
150
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
151
- return pipe, tokenizer
152
- except Exception as e:
153
- print("\u274c Failed to load model:", e)
154
- return None, None
155
 
156
- # ---------------- QA Logic ----------------
157
  def ask_model(question, context, pipe, tokenizer):
158
  prompt = f"""Use only the following context to answer. If uncertain, say \"I don't know.\"
159
 
@@ -166,10 +150,12 @@ A:"""
166
  output = pipe(prompt, max_new_tokens=512)[0]["generated_text"]
167
  return output.split("A:")[-1].strip()
168
 
 
169
  def get_answer(question):
170
  if not all([embedder, db, model_pipe, model_tokenizer]):
171
- return "\u274c System not initialized. Check logs or try restarting the app."
172
  try:
 
173
  results = db.query(query_texts=[question], n_results=MAX_CONTEXT_CHUNKS)
174
  context = "\n\n".join(results["documents"][0])
175
  return ask_model(question, context, model_pipe, model_tokenizer)
@@ -184,31 +170,20 @@ with gr.Blocks() as demo:
184
  question = gr.Textbox(label="Ask your question")
185
  ask = gr.Button("Ask")
186
  answer = gr.Textbox(label="Answer", lines=8)
187
- status = gr.Markdown(visible=False)
188
 
189
- def wrapped_get_answer(q):
190
- ans = get_answer(q)
191
- return ans, "" # hide status after success
192
-
193
- ask.click(fn=wrapped_get_answer, inputs=question, outputs=[answer, status])
194
-
195
- # Show status on startup error
196
- if not all([embedder, db, model_pipe, model_tokenizer]):
197
- status.visible = True
198
- status.value = "\u26a0\ufe0f Initialization failed. Check logs or your HF_TOKEN."
199
 
200
- # Embed + Load Model at Startup
201
  try:
202
  db, embedder = embed_all()
203
  except Exception as e:
204
  print("\u274c Embedding failed:", e)
205
- db, embedder = None, None
206
 
207
  try:
208
  model_pipe, model_tokenizer = load_model()
209
  except Exception as e:
210
  print("\u274c Model loading failed:", e)
211
- model_pipe, model_tokenizer = None, None
212
 
213
  if __name__ == "__main__":
214
  demo.launch()
 
31
  MAX_CONTEXT_CHUNKS = 3
32
  MODEL_ID = "ibm-granite/granite-vision-3.2-2b"
33
 
 
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
 
36
  # ---------------- Text Helpers ----------------
 
85
 
86
  # ---------------- Embedding ----------------
87
  def embed_all():
88
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
89
+ embedder.eval()
90
+ client = chromadb.PersistentClient(path=CHROMA_PATH)
 
 
 
91
 
92
  try:
 
93
  client.delete_collection(COLLECTION_NAME)
94
+ except:
95
+ pass
96
+ collection = client.get_or_create_collection(COLLECTION_NAME)
 
97
 
98
  docs, ids, metas = [], [], []
99
  print("\ud83d\udcc4 Processing manuals...")
100
 
101
+ for fname in os.listdir(MANUALS_DIR):
102
+ fpath = os.path.join(MANUALS_DIR, fname)
103
+ if fname.lower().endswith(".pdf"):
104
+ pages = extract_pdf_text(fpath)
105
+ elif fname.lower().endswith(".docx"):
106
+ pages = extract_docx_text(fpath)
107
+ else:
108
+ continue
109
+
110
+ for path, page, text in pages:
111
+ for i, chunk in enumerate(split_chunks(split_sentences(text))):
112
+ chunk_id = f"{fname}::{page}::{i}"
113
+ docs.append(chunk)
114
+ ids.append(chunk_id)
115
+ metas.append({"source": fname, "page": page})
116
+
117
+ if len(docs) >= 16:
118
+ embs = embedder.encode(docs).tolist()
119
+ collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
120
+ docs, ids, metas = [], [], []
121
+
122
+ if docs:
123
+ embs = embedder.encode(docs).tolist()
124
+ collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
125
+
126
+ print(f"\u2705 Embedded {len(ids)} chunks.")
127
+ return collection, embedder
 
 
 
 
 
128
 
129
  # ---------------- Model Setup ----------------
130
  def load_model():
131
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
132
+ model = AutoModelForCausalLM.from_pretrained(
133
+ MODEL_ID,
134
+ token=HF_TOKEN,
135
+ device_map="auto" if torch.cuda.is_available() else None,
136
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
137
+ ).to(device)
138
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
139
+ return pipe, tokenizer
 
 
 
 
140
 
 
141
  def ask_model(question, context, pipe, tokenizer):
142
  prompt = f"""Use only the following context to answer. If uncertain, say \"I don't know.\"
143
 
 
150
  output = pipe(prompt, max_new_tokens=512)[0]["generated_text"]
151
  return output.split("A:")[-1].strip()
152
 
153
+ # ---------------- Query ----------------
154
  def get_answer(question):
155
  if not all([embedder, db, model_pipe, model_tokenizer]):
156
+ return "⚠️ The system is still initializing or failed to load. Please try again later."
157
  try:
158
+ query_emb = embedder.encode(question, convert_to_tensor=True)
159
  results = db.query(query_texts=[question], n_results=MAX_CONTEXT_CHUNKS)
160
  context = "\n\n".join(results["documents"][0])
161
  return ask_model(question, context, model_pipe, model_tokenizer)
 
170
  question = gr.Textbox(label="Ask your question")
171
  ask = gr.Button("Ask")
172
  answer = gr.Textbox(label="Answer", lines=8)
173
+ ask.click(fn=get_answer, inputs=question, outputs=answer)
174
 
175
+ # ---------------- Startup ----------------
176
+ embedder = db = model_pipe = model_tokenizer = None
 
 
 
 
 
 
 
 
177
 
 
178
  try:
179
  db, embedder = embed_all()
180
  except Exception as e:
181
  print("\u274c Embedding failed:", e)
 
182
 
183
  try:
184
  model_pipe, model_tokenizer = load_model()
185
  except Exception as e:
186
  print("\u274c Model loading failed:", e)
 
187
 
188
  if __name__ == "__main__":
189
  demo.launch()