Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ import docx2txt
|
|
11 |
from PIL import Image
|
12 |
from io import BytesIO
|
13 |
from tqdm import tqdm
|
14 |
-
from transformers import
|
15 |
from sentence_transformers import SentenceTransformer, util
|
16 |
from nltk.tokenize import sent_tokenize
|
17 |
|
@@ -116,7 +116,7 @@ def embed_all():
|
|
116 |
ids.append(chunk_id)
|
117 |
metas.append({"source": fname, "page": page})
|
118 |
|
119 |
-
if len(docs) >=
|
120 |
embs = embedder.encode(docs).tolist()
|
121 |
collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
|
122 |
docs, ids, metas = [], [], []
|
@@ -135,20 +135,14 @@ def embed_all():
|
|
135 |
# ---------------- Model Setup ----------------
|
136 |
def load_model():
|
137 |
try:
|
138 |
-
|
139 |
-
model =
|
140 |
-
|
141 |
-
token=HF_TOKEN,
|
142 |
-
device_map="auto" if torch.cuda.is_available() else None,
|
143 |
-
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
144 |
-
).to(device)
|
145 |
-
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
|
146 |
-
return pipe, tokenizer
|
147 |
except Exception as e:
|
148 |
print("Model loading failed:", e)
|
149 |
return None, None
|
150 |
|
151 |
-
def ask_model(question, context,
|
152 |
prompt = f"""Use only the following context to answer. If uncertain, say \"I don't know.\"
|
153 |
|
154 |
<context>
|
@@ -157,18 +151,19 @@ def ask_model(question, context, pipe, tokenizer):
|
|
157 |
|
158 |
Q: {question}
|
159 |
A:"""
|
160 |
-
|
161 |
-
|
|
|
162 |
|
163 |
# ---------------- Query ----------------
|
164 |
def get_answer(question):
|
165 |
-
if not embedder or not db or not
|
166 |
return "System not ready. Try again after initialization."
|
167 |
try:
|
168 |
query_emb = embedder.encode(question, convert_to_tensor=True)
|
169 |
results = db.query(query_texts=[question], n_results=MAX_CONTEXT_CHUNKS)
|
170 |
context = "\n\n".join(results["documents"][0])
|
171 |
-
return ask_model(question, context,
|
172 |
except Exception as e:
|
173 |
print("Query error:", e)
|
174 |
return f"Error: {e}"
|
@@ -184,8 +179,8 @@ with gr.Blocks() as demo:
|
|
184 |
|
185 |
# Startup Initialization
|
186 |
embedder = None
|
187 |
-
|
188 |
-
|
189 |
|
190 |
try:
|
191 |
db, embedder = embed_all()
|
@@ -193,10 +188,10 @@ except Exception as e:
|
|
193 |
print("❌ Embedding failed:", e)
|
194 |
|
195 |
try:
|
196 |
-
|
197 |
except Exception as e:
|
198 |
print("❌ Model load failed:", e)
|
199 |
|
200 |
# Launch
|
201 |
if __name__ == "__main__":
|
202 |
-
demo.launch(share=
|
|
|
11 |
from PIL import Image
|
12 |
from io import BytesIO
|
13 |
from tqdm import tqdm
|
14 |
+
from transformers import AutoProcessor, AutoModelForVision2Seq
|
15 |
from sentence_transformers import SentenceTransformer, util
|
16 |
from nltk.tokenize import sent_tokenize
|
17 |
|
|
|
116 |
ids.append(chunk_id)
|
117 |
metas.append({"source": fname, "page": page})
|
118 |
|
119 |
+
if len(docs) >= 32: # Increased batch size for efficiency
|
120 |
embs = embedder.encode(docs).tolist()
|
121 |
collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs)
|
122 |
docs, ids, metas = [], [], []
|
|
|
135 |
# ---------------- Model Setup ----------------
|
136 |
def load_model():
|
137 |
try:
|
138 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
|
139 |
+
model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device)
|
140 |
+
return model, processor
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
except Exception as e:
|
142 |
print("Model loading failed:", e)
|
143 |
return None, None
|
144 |
|
145 |
+
def ask_model(question, context, model, processor):
|
146 |
prompt = f"""Use only the following context to answer. If uncertain, say \"I don't know.\"
|
147 |
|
148 |
<context>
|
|
|
151 |
|
152 |
Q: {question}
|
153 |
A:"""
|
154 |
+
inputs = processor(prompt, return_tensors="pt").to(device)
|
155 |
+
output = model.generate(**inputs)
|
156 |
+
return processor.decode(output[0], skip_special_tokens=True)
|
157 |
|
158 |
# ---------------- Query ----------------
|
159 |
def get_answer(question):
|
160 |
+
if not embedder or not db or not model:
|
161 |
return "System not ready. Try again after initialization."
|
162 |
try:
|
163 |
query_emb = embedder.encode(question, convert_to_tensor=True)
|
164 |
results = db.query(query_texts=[question], n_results=MAX_CONTEXT_CHUNKS)
|
165 |
context = "\n\n".join(results["documents"][0])
|
166 |
+
return ask_model(question, context, model, processor)
|
167 |
except Exception as e:
|
168 |
print("Query error:", e)
|
169 |
return f"Error: {e}"
|
|
|
179 |
|
180 |
# Startup Initialization
|
181 |
embedder = None
|
182 |
+
model = None
|
183 |
+
processor = None
|
184 |
|
185 |
try:
|
186 |
db, embedder = embed_all()
|
|
|
188 |
print("❌ Embedding failed:", e)
|
189 |
|
190 |
try:
|
191 |
+
model, processor = load_model()
|
192 |
except Exception as e:
|
193 |
print("❌ Model load failed:", e)
|
194 |
|
195 |
# Launch
|
196 |
if __name__ == "__main__":
|
197 |
+
demo.launch(share=False)
|