Update app.py
Browse files
app.py
CHANGED
@@ -14,11 +14,11 @@ from sentence_transformers import SentenceTransformer
|
|
14 |
from transformers import pipeline
|
15 |
from sklearn.metrics.pairwise import cosine_similarity
|
16 |
|
17 |
-
# Load models
|
18 |
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
|
19 |
qa_pipeline = pipeline("question-answering", model="deepset/tinyroberta-squad2")
|
20 |
|
21 |
-
# Globals
|
22 |
all_chunks = []
|
23 |
chunk_sources = []
|
24 |
chunk_embeddings = None
|
@@ -82,11 +82,29 @@ def answer_question(question):
|
|
82 |
|
83 |
q_emb = embed_model.encode([question], convert_to_numpy=True)
|
84 |
sims = cosine_similarity(q_emb, chunk_embeddings)[0]
|
85 |
-
top_k_idx = sims.argsort()[::-1][:3]
|
86 |
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
if not context.strip():
|
89 |
-
return "No
|
90 |
|
91 |
try:
|
92 |
result = qa_pipeline(question=question, context=context)
|
@@ -94,8 +112,8 @@ def answer_question(question):
|
|
94 |
except Exception:
|
95 |
return "Error generating answer from the model."
|
96 |
|
97 |
-
sources = "\n".join(set(chunk_sources[i] for i in
|
98 |
-
confidence = np.mean([sims[i] for i in
|
99 |
return f"**Answer:** {answer}\n\n**Sources:**\n{sources}\n\n**Confidence:** {confidence:.2f}%"
|
100 |
|
101 |
with gr.Blocks() as demo:
|
|
|
14 |
from transformers import pipeline
|
15 |
from sklearn.metrics.pairwise import cosine_similarity
|
16 |
|
17 |
+
# Load models
|
18 |
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
|
19 |
qa_pipeline = pipeline("question-answering", model="deepset/tinyroberta-squad2")
|
20 |
|
21 |
+
# Globals
|
22 |
all_chunks = []
|
23 |
chunk_sources = []
|
24 |
chunk_embeddings = None
|
|
|
82 |
|
83 |
q_emb = embed_model.encode([question], convert_to_numpy=True)
|
84 |
sims = cosine_similarity(q_emb, chunk_embeddings)[0]
|
|
|
85 |
|
86 |
+
threshold = 0.5 # similarity threshold to filter relevant chunks
|
87 |
+
above_thresh_idx = [i for i, sim in enumerate(sims) if sim > threshold]
|
88 |
+
|
89 |
+
if not above_thresh_idx:
|
90 |
+
return "No relevant content found in the PDFs for your question."
|
91 |
+
|
92 |
+
# Sort by similarity descending
|
93 |
+
above_thresh_idx.sort(key=lambda i: sims[i], reverse=True)
|
94 |
+
|
95 |
+
max_context_chars = 2000
|
96 |
+
context_chunks = []
|
97 |
+
total_chars = 0
|
98 |
+
for i in above_thresh_idx:
|
99 |
+
chunk_len = len(all_chunks[i])
|
100 |
+
if total_chars + chunk_len > max_context_chars:
|
101 |
+
break
|
102 |
+
context_chunks.append(all_chunks[i])
|
103 |
+
total_chars += chunk_len
|
104 |
+
|
105 |
+
context = "\n\n".join(context_chunks)
|
106 |
if not context.strip():
|
107 |
+
return "No sufficient content to answer the question."
|
108 |
|
109 |
try:
|
110 |
result = qa_pipeline(question=question, context=context)
|
|
|
112 |
except Exception:
|
113 |
return "Error generating answer from the model."
|
114 |
|
115 |
+
sources = "\n".join(set(chunk_sources[i] for i in above_thresh_idx[:len(context_chunks)]))
|
116 |
+
confidence = np.mean([sims[i] for i in above_thresh_idx[:len(context_chunks)]]) * 100
|
117 |
return f"**Answer:** {answer}\n\n**Sources:**\n{sources}\n\n**Confidence:** {confidence:.2f}%"
|
118 |
|
119 |
with gr.Blocks() as demo:
|