damoojeje commited on
Commit
98c93fa
·
verified ·
1 Parent(s): 32e9a12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -11,7 +11,7 @@ import docx2txt
11
  from PIL import Image
12
  from io import BytesIO
13
  from tqdm import tqdm
14
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
15
  from sentence_transformers import SentenceTransformer, util
16
  from nltk.tokenize import sent_tokenize
17
 
@@ -116,7 +116,7 @@ def embed_all():
116
  ids.append(chunk_id)
117
  metas.append({"source": fname, "page": page})
118
 
119
- if len(docs) >= 16:
120
  embs = embedder.encode(docs).tolist()
121
  collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
122
  docs, ids, metas = [], [], []
@@ -135,20 +135,14 @@ def embed_all():
135
  # ---------------- Model Setup ----------------
136
  def load_model():
137
  try:
138
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
139
- model = AutoModelForCausalLM.from_pretrained(
140
- MODEL_ID,
141
- token=HF_TOKEN,
142
- device_map="auto" if torch.cuda.is_available() else None,
143
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
144
- ).to(device)
145
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
146
- return pipe, tokenizer
147
  except Exception as e:
148
  print("Model loading failed:", e)
149
  return None, None
150
 
151
- def ask_model(question, context, pipe, tokenizer):
152
  prompt = f"""Use only the following context to answer. If uncertain, say \"I don't know.\"
153
 
154
  <context>
@@ -157,18 +151,19 @@ def ask_model(question, context, pipe, tokenizer):
157
 
158
  Q: {question}
159
  A:"""
160
- output = pipe(prompt, max_new_tokens=512)[0]["generated_text"]
161
- return output.split("A:")[-1].strip()
 
162
 
163
  # ---------------- Query ----------------
164
  def get_answer(question):
165
- if not embedder or not db or not model_pipe:
166
  return "System not ready. Try again after initialization."
167
  try:
168
  query_emb = embedder.encode(question, convert_to_tensor=True)
169
  results = db.query(query_texts=[question], n_results=MAX_CONTEXT_CHUNKS)
170
  context = "\n\n".join(results["documents"][0])
171
- return ask_model(question, context, model_pipe, model_tokenizer)
172
  except Exception as e:
173
  print("Query error:", e)
174
  return f"Error: {e}"
@@ -184,8 +179,8 @@ with gr.Blocks() as demo:
184
 
185
  # Startup Initialization
186
  embedder = None
187
- model_pipe = None
188
- model_tokenizer = None
189
 
190
  try:
191
  db, embedder = embed_all()
@@ -193,10 +188,10 @@ except Exception as e:
193
  print("❌ Embedding failed:", e)
194
 
195
  try:
196
- model_pipe, model_tokenizer = load_model()
197
  except Exception as e:
198
  print("❌ Model load failed:", e)
199
 
200
  # Launch
201
  if __name__ == "__main__":
202
- demo.launch(share=True)
 
11
  from PIL import Image
12
  from io import BytesIO
13
  from tqdm import tqdm
14
+ from transformers import AutoProcessor, AutoModelForVision2Seq
15
  from sentence_transformers import SentenceTransformer, util
16
  from nltk.tokenize import sent_tokenize
17
 
 
116
  ids.append(chunk_id)
117
  metas.append({"source": fname, "page": page})
118
 
119
+ if len(docs) >= 32: # Increased batch size for efficiency
120
  embs = embedder.encode(docs).tolist()
121
  collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
122
  docs, ids, metas = [], [], []
 
135
  # ---------------- Model Setup ----------------
136
  def load_model():
137
  try:
138
+ processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
139
+ model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device)
140
+ return model, processor
 
 
 
 
 
 
141
  except Exception as e:
142
  print("Model loading failed:", e)
143
  return None, None
144
 
145
+ def ask_model(question, context, model, processor):
146
  prompt = f"""Use only the following context to answer. If uncertain, say \"I don't know.\"
147
 
148
  <context>
 
151
 
152
  Q: {question}
153
  A:"""
154
+ inputs = processor(prompt, return_tensors="pt").to(device)
155
+ output = model.generate(**inputs)
156
+ return processor.decode(output[0], skip_special_tokens=True)
157
 
158
  # ---------------- Query ----------------
159
  def get_answer(question):
160
+ if not embedder or not db or not model:
161
  return "System not ready. Try again after initialization."
162
  try:
163
  query_emb = embedder.encode(question, convert_to_tensor=True)
164
  results = db.query(query_texts=[question], n_results=MAX_CONTEXT_CHUNKS)
165
  context = "\n\n".join(results["documents"][0])
166
+ return ask_model(question, context, model, processor)
167
  except Exception as e:
168
  print("Query error:", e)
169
  return f"Error: {e}"
 
179
 
180
  # Startup Initialization
181
  embedder = None
182
+ model = None
183
+ processor = None
184
 
185
  try:
186
  db, embedder = embed_all()
 
188
  print("❌ Embedding failed:", e)
189
 
190
  try:
191
+ model, processor = load_model()
192
  except Exception as e:
193
  print("❌ Model load failed:", e)
194
 
195
  # Launch
196
  if __name__ == "__main__":
197
+ demo.launch(share=False)