maabedmohammed commited on
Commit
d3bbe45
·
verified ·
1 Parent(s): e5824a7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +221 -35
src/streamlit_app.py CHANGED
@@ -1,40 +1,226 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Streamlit application for Question Answering system.
4
+ Optimized for deployment on Hugging Face Spaces.
5
+ """
6
+
7
  import streamlit as st
8
+ import os
9
+ import time
10
+ import torch
11
+ import pandas as pd
12
+ import matplotlib.pyplot as plt
13
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
14
+ import json
15
 
16
+ # Page configuration
17
+ st.set_page_config(
18
+ page_title="Question Answering System",
19
+ page_icon="❓",
20
+ layout="wide"
21
+ )
22
 
23
+ # Constants
24
+ MODELS = {
25
+ "ELECTRA-small": "mrm8488/electra-small-finetuned-squadv1",
26
+ "ALBERT-base-v2": "twmkn9/albert-base-v2-squad2",
27
+ "DistilBERT-base": "distilbert-base-cased-distilled-squad"
28
+ }
29
 
30
+ # Cache for loaded models
31
+ @st.cache_resource
32
+ def load_model(model_name):
33
+ """Load model and tokenizer with caching"""
34
+ try:
35
+ model_path = MODELS[model_name]
36
+ qa_pipeline = pipeline("question-answering", model=model_path)
37
+ return qa_pipeline
38
+ except Exception as e:
39
+ st.error(f"Error loading model {model_name}: {e}")
40
+ return None
41
+
42
+ def answer_question(qa_pipeline, question, context):
43
+ """
44
+ Answer a question given a context using the QA pipeline
45
+ """
46
+ if not question or not context:
47
+ return None, 0, 0
48
+
49
+ # Measure inference time
50
+ start_time = time.time()
51
+
52
+ # Run model
53
+ result = qa_pipeline(question=question, context=context)
54
+
55
+ # Calculate inference time
56
+ inference_time = time.time() - start_time
57
+
58
+ return result["answer"], result["score"], inference_time
59
+
60
+ def highlight_answer(context, answer):
61
+ """Highlight the answer in the context with HTML"""
62
+ if not answer or not context:
63
+ return context
64
+
65
+ # Find the answer in the context (case insensitive)
66
+ lower_context = context.lower()
67
+ lower_answer = answer.lower()
68
+
69
+ if lower_answer in lower_context:
70
+ start_idx = lower_context.find(lower_answer)
71
+ end_idx = start_idx + len(lower_answer)
72
+
73
+ # Create HTML with highlighted answer
74
+ highlighted = (
75
+ context[:start_idx] +
76
+ f'<span style="background-color: #ffdd99; font-weight: bold;">{context[start_idx:end_idx]}</span>' +
77
+ context[end_idx:]
78
+ )
79
+ return highlighted
80
+
81
+ return context
82
+
83
+ def generate_comparison_chart(results_df):
84
+ """Generate a comparison chart for model results"""
85
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
86
+
87
+ # Sort models by score
88
+ results_df = results_df.sort_values('score', ascending=False)
89
+
90
+ # Plot scores
91
+ models = results_df['model_name']
92
+ scores = results_df['score']
93
+ ax1.barh(models, scores, color='skyblue')
94
+ ax1.set_xlabel('Confidence Score')
95
+ ax1.set_title('Model Confidence Scores')
96
+ ax1.grid(axis='x', linestyle='--', alpha=0.7)
97
+
98
+ # Plot inference times
99
+ inference_times = results_df['inference_time'].astype(float)
100
+ ax2.barh(models, inference_times, color='salmon')
101
+ ax2.set_xlabel('Inference Time (seconds)')
102
+ ax2.set_title('Model Inference Times')
103
+ ax2.grid(axis='x', linestyle='--', alpha=0.7)
104
+
105
+ plt.tight_layout()
106
+ return fig
107
+
108
+ def main():
109
+ # Title and description
110
+ st.title("Question Answering System")
111
+ st.markdown("""
112
+ This application answers questions based on the provided context using transformer-based models
113
+ fine-tuned on the SQuAD dataset. Enter a context paragraph and ask questions about it.
114
+ """)
115
+
116
+ # Initialize session state for storing results
117
+ if 'comparison_results' not in st.session_state:
118
+ st.session_state.comparison_results = None
119
+
120
+ # Layout
121
+ col1, col2 = st.columns([3, 1])
122
+
123
+ with col1:
124
+ # Context input
125
+ context = st.text_area(
126
+ "Context",
127
+ "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.",
128
+ height=200
129
+ )
130
+
131
+ # Question input
132
+ question = st.text_input("Question", "In what country is Normandy located?")
133
+
134
+ # Add a separator
135
+ st.markdown("---")
136
+
137
+ # Results section
138
+ st.subheader("Results")
139
+
140
+ if st.button("Compare All Models"):
141
+ progress_bar = st.progress(0)
142
+ results = []
143
+
144
+ # Process each model
145
+ for i, model_name in enumerate(MODELS.keys()):
146
+ status_text = st.empty()
147
+ status_text.text(f"Processing with {model_name}...")
148
+
149
+ # Load model
150
+ qa_pipeline = load_model(model_name)
151
+ if qa_pipeline is not None:
152
+ # Get answer
153
+ answer, score, inference_time = answer_question(qa_pipeline, question, context)
154
+
155
+ # Store results
156
+ results.append({
157
+ "model_name": model_name,
158
+ "answer": answer,
159
+ "score": score,
160
+ "inference_time": inference_time
161
+ })
162
+
163
+ # Update progress
164
+ progress_bar.progress((i + 1) / len(MODELS))
165
+
166
+ # Display results in a table
167
+ if results:
168
+ results_df = pd.DataFrame(results)
169
+ display_df = results_df.copy()
170
+ display_df["inference_time"] = display_df["inference_time"].apply(lambda x: f"{x:.4f} s")
171
+ display_df["score"] = display_df["score"].apply(lambda x: f"{x:.4f}")
172
+ st.table(display_df)
173
+
174
+ # Save results to session state for comparison chart
175
+ st.session_state.comparison_results = results_df
176
+
177
+ # Show comparison chart
178
+ st.subheader("Model Comparison")
179
+ comparison_chart = generate_comparison_chart(results_df)
180
+ st.pyplot(comparison_chart)
181
+
182
+ with col2:
183
+ # Model selection
184
+ st.subheader("Available Models")
185
+
186
+ selected_model = st.selectbox(
187
+ "Select a model",
188
+ list(MODELS.keys()),
189
+ key="model_selector"
190
+ )
191
+
192
+ # Load selected model and answer
193
+ if st.button("Answer Question"):
194
+ with st.spinner(f"Loading {selected_model}..."):
195
+ qa_pipeline = load_model(selected_model)
196
+
197
+ if qa_pipeline is not None:
198
+ with st.spinner("Generating answer..."):
199
+ answer, score, inference_time = answer_question(qa_pipeline, question, context)
200
+
201
+ st.success("Answer generated!")
202
+ st.markdown(f"**Model:** {selected_model}")
203
+ st.markdown(f"**Answer:** {answer}")
204
+ st.markdown(f"**Confidence:** {score:.4f}")
205
+ st.markdown(f"**Inference Time:** {inference_time:.4f} seconds")
206
+
207
+ # Highlight answer in context
208
+ st.subheader("Answer in Context")
209
+ highlighted_context = highlight_answer(context, answer)
210
+ st.markdown(highlighted_context, unsafe_allow_html=True)
211
+
212
+ # Advanced options
213
+ with st.expander("Model Information"):
214
+ st.markdown("""
215
+ **ELECTRA-small**
216
+ A smaller, efficient model with good performance and speed.
217
+
218
+ **ALBERT-base-v2**
219
+ Parameter-efficient model with strong performance.
220
+
221
+ **DistilBERT-base**
222
+ Distilled BERT model that's faster while maintaining accuracy.
223
+ """)
224
 
225
+ if __name__ == "__main__":
226
+ main()