Update src/streamlit_app.py
Browse files- src/streamlit_app.py +221 -35
src/streamlit_app.py
CHANGED
@@ -1,40 +1,226 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
indices = np.linspace(0, 1, num_points)
|
20 |
-
theta = 2 * np.pi * num_turns * indices
|
21 |
-
radius = indices
|
22 |
-
|
23 |
-
x = radius * np.cos(theta)
|
24 |
-
y = radius * np.sin(theta)
|
25 |
-
|
26 |
-
df = pd.DataFrame({
|
27 |
-
"x": x,
|
28 |
-
"y": y,
|
29 |
-
"idx": indices,
|
30 |
-
"rand": np.random.randn(num_points),
|
31 |
-
})
|
32 |
-
|
33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
34 |
-
.mark_point(filled=True)
|
35 |
-
.encode(
|
36 |
-
x=alt.X("x", axis=None),
|
37 |
-
y=alt.Y("y", axis=None),
|
38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
40 |
-
))
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Streamlit application for Question Answering system.
|
4 |
+
Optimized for deployment on Hugging Face Spaces.
|
5 |
+
"""
|
6 |
+
|
7 |
import streamlit as st
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import torch
|
11 |
+
import pandas as pd
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
|
14 |
+
import json
|
15 |
|
16 |
+
# Page configuration
|
17 |
+
st.set_page_config(
|
18 |
+
page_title="Question Answering System",
|
19 |
+
page_icon="❓",
|
20 |
+
layout="wide"
|
21 |
+
)
|
22 |
|
23 |
+
# Constants
|
24 |
+
MODELS = {
|
25 |
+
"ELECTRA-small": "mrm8488/electra-small-finetuned-squadv1",
|
26 |
+
"ALBERT-base-v2": "twmkn9/albert-base-v2-squad2",
|
27 |
+
"DistilBERT-base": "distilbert-base-cased-distilled-squad"
|
28 |
+
}
|
29 |
|
30 |
+
# Cache for loaded models
|
31 |
+
@st.cache_resource
|
32 |
+
def load_model(model_name):
|
33 |
+
"""Load model and tokenizer with caching"""
|
34 |
+
try:
|
35 |
+
model_path = MODELS[model_name]
|
36 |
+
qa_pipeline = pipeline("question-answering", model=model_path)
|
37 |
+
return qa_pipeline
|
38 |
+
except Exception as e:
|
39 |
+
st.error(f"Error loading model {model_name}: {e}")
|
40 |
+
return None
|
41 |
+
|
42 |
+
def answer_question(qa_pipeline, question, context):
|
43 |
+
"""
|
44 |
+
Answer a question given a context using the QA pipeline
|
45 |
+
"""
|
46 |
+
if not question or not context:
|
47 |
+
return None, 0, 0
|
48 |
+
|
49 |
+
# Measure inference time
|
50 |
+
start_time = time.time()
|
51 |
+
|
52 |
+
# Run model
|
53 |
+
result = qa_pipeline(question=question, context=context)
|
54 |
+
|
55 |
+
# Calculate inference time
|
56 |
+
inference_time = time.time() - start_time
|
57 |
+
|
58 |
+
return result["answer"], result["score"], inference_time
|
59 |
+
|
60 |
+
def highlight_answer(context, answer):
|
61 |
+
"""Highlight the answer in the context with HTML"""
|
62 |
+
if not answer or not context:
|
63 |
+
return context
|
64 |
+
|
65 |
+
# Find the answer in the context (case insensitive)
|
66 |
+
lower_context = context.lower()
|
67 |
+
lower_answer = answer.lower()
|
68 |
+
|
69 |
+
if lower_answer in lower_context:
|
70 |
+
start_idx = lower_context.find(lower_answer)
|
71 |
+
end_idx = start_idx + len(lower_answer)
|
72 |
+
|
73 |
+
# Create HTML with highlighted answer
|
74 |
+
highlighted = (
|
75 |
+
context[:start_idx] +
|
76 |
+
f'<span style="background-color: #ffdd99; font-weight: bold;">{context[start_idx:end_idx]}</span>' +
|
77 |
+
context[end_idx:]
|
78 |
+
)
|
79 |
+
return highlighted
|
80 |
+
|
81 |
+
return context
|
82 |
+
|
83 |
+
def generate_comparison_chart(results_df):
|
84 |
+
"""Generate a comparison chart for model results"""
|
85 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
86 |
+
|
87 |
+
# Sort models by score
|
88 |
+
results_df = results_df.sort_values('score', ascending=False)
|
89 |
+
|
90 |
+
# Plot scores
|
91 |
+
models = results_df['model_name']
|
92 |
+
scores = results_df['score']
|
93 |
+
ax1.barh(models, scores, color='skyblue')
|
94 |
+
ax1.set_xlabel('Confidence Score')
|
95 |
+
ax1.set_title('Model Confidence Scores')
|
96 |
+
ax1.grid(axis='x', linestyle='--', alpha=0.7)
|
97 |
+
|
98 |
+
# Plot inference times
|
99 |
+
inference_times = results_df['inference_time'].astype(float)
|
100 |
+
ax2.barh(models, inference_times, color='salmon')
|
101 |
+
ax2.set_xlabel('Inference Time (seconds)')
|
102 |
+
ax2.set_title('Model Inference Times')
|
103 |
+
ax2.grid(axis='x', linestyle='--', alpha=0.7)
|
104 |
+
|
105 |
+
plt.tight_layout()
|
106 |
+
return fig
|
107 |
+
|
108 |
+
def main():
|
109 |
+
# Title and description
|
110 |
+
st.title("Question Answering System")
|
111 |
+
st.markdown("""
|
112 |
+
This application answers questions based on the provided context using transformer-based models
|
113 |
+
fine-tuned on the SQuAD dataset. Enter a context paragraph and ask questions about it.
|
114 |
+
""")
|
115 |
+
|
116 |
+
# Initialize session state for storing results
|
117 |
+
if 'comparison_results' not in st.session_state:
|
118 |
+
st.session_state.comparison_results = None
|
119 |
+
|
120 |
+
# Layout
|
121 |
+
col1, col2 = st.columns([3, 1])
|
122 |
+
|
123 |
+
with col1:
|
124 |
+
# Context input
|
125 |
+
context = st.text_area(
|
126 |
+
"Context",
|
127 |
+
"The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.",
|
128 |
+
height=200
|
129 |
+
)
|
130 |
+
|
131 |
+
# Question input
|
132 |
+
question = st.text_input("Question", "In what country is Normandy located?")
|
133 |
+
|
134 |
+
# Add a separator
|
135 |
+
st.markdown("---")
|
136 |
+
|
137 |
+
# Results section
|
138 |
+
st.subheader("Results")
|
139 |
+
|
140 |
+
if st.button("Compare All Models"):
|
141 |
+
progress_bar = st.progress(0)
|
142 |
+
results = []
|
143 |
+
|
144 |
+
# Process each model
|
145 |
+
for i, model_name in enumerate(MODELS.keys()):
|
146 |
+
status_text = st.empty()
|
147 |
+
status_text.text(f"Processing with {model_name}...")
|
148 |
+
|
149 |
+
# Load model
|
150 |
+
qa_pipeline = load_model(model_name)
|
151 |
+
if qa_pipeline is not None:
|
152 |
+
# Get answer
|
153 |
+
answer, score, inference_time = answer_question(qa_pipeline, question, context)
|
154 |
+
|
155 |
+
# Store results
|
156 |
+
results.append({
|
157 |
+
"model_name": model_name,
|
158 |
+
"answer": answer,
|
159 |
+
"score": score,
|
160 |
+
"inference_time": inference_time
|
161 |
+
})
|
162 |
+
|
163 |
+
# Update progress
|
164 |
+
progress_bar.progress((i + 1) / len(MODELS))
|
165 |
+
|
166 |
+
# Display results in a table
|
167 |
+
if results:
|
168 |
+
results_df = pd.DataFrame(results)
|
169 |
+
display_df = results_df.copy()
|
170 |
+
display_df["inference_time"] = display_df["inference_time"].apply(lambda x: f"{x:.4f} s")
|
171 |
+
display_df["score"] = display_df["score"].apply(lambda x: f"{x:.4f}")
|
172 |
+
st.table(display_df)
|
173 |
+
|
174 |
+
# Save results to session state for comparison chart
|
175 |
+
st.session_state.comparison_results = results_df
|
176 |
+
|
177 |
+
# Show comparison chart
|
178 |
+
st.subheader("Model Comparison")
|
179 |
+
comparison_chart = generate_comparison_chart(results_df)
|
180 |
+
st.pyplot(comparison_chart)
|
181 |
+
|
182 |
+
with col2:
|
183 |
+
# Model selection
|
184 |
+
st.subheader("Available Models")
|
185 |
+
|
186 |
+
selected_model = st.selectbox(
|
187 |
+
"Select a model",
|
188 |
+
list(MODELS.keys()),
|
189 |
+
key="model_selector"
|
190 |
+
)
|
191 |
+
|
192 |
+
# Load selected model and answer
|
193 |
+
if st.button("Answer Question"):
|
194 |
+
with st.spinner(f"Loading {selected_model}..."):
|
195 |
+
qa_pipeline = load_model(selected_model)
|
196 |
+
|
197 |
+
if qa_pipeline is not None:
|
198 |
+
with st.spinner("Generating answer..."):
|
199 |
+
answer, score, inference_time = answer_question(qa_pipeline, question, context)
|
200 |
+
|
201 |
+
st.success("Answer generated!")
|
202 |
+
st.markdown(f"**Model:** {selected_model}")
|
203 |
+
st.markdown(f"**Answer:** {answer}")
|
204 |
+
st.markdown(f"**Confidence:** {score:.4f}")
|
205 |
+
st.markdown(f"**Inference Time:** {inference_time:.4f} seconds")
|
206 |
+
|
207 |
+
# Highlight answer in context
|
208 |
+
st.subheader("Answer in Context")
|
209 |
+
highlighted_context = highlight_answer(context, answer)
|
210 |
+
st.markdown(highlighted_context, unsafe_allow_html=True)
|
211 |
+
|
212 |
+
# Advanced options
|
213 |
+
with st.expander("Model Information"):
|
214 |
+
st.markdown("""
|
215 |
+
**ELECTRA-small**
|
216 |
+
A smaller, efficient model with good performance and speed.
|
217 |
+
|
218 |
+
**ALBERT-base-v2**
|
219 |
+
Parameter-efficient model with strong performance.
|
220 |
+
|
221 |
+
**DistilBERT-base**
|
222 |
+
Distilled BERT model that's faster while maintaining accuracy.
|
223 |
+
""")
|
224 |
|
225 |
+
if __name__ == "__main__":
|
226 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|