Spaces:
Running
Running
Create streamlit_app.py
Browse files- streamlit_app.py +648 -0
streamlit_app.py
ADDED
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import plotly.express as px
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
from plotly.subplots import make_subplots
|
8 |
+
import numpy as np
|
9 |
+
from pathlib import Path
|
10 |
+
import glob
|
11 |
+
import requests
|
12 |
+
from io import StringIO
|
13 |
+
import zipfile
|
14 |
+
import tempfile
|
15 |
+
import shutil
|
16 |
+
|
17 |
+
# Set page config
|
18 |
+
st.set_page_config(
|
19 |
+
page_title="Attention Analysis Results Explorer",
|
20 |
+
page_icon="🔍",
|
21 |
+
layout="wide",
|
22 |
+
initial_sidebar_state="expanded"
|
23 |
+
)
|
24 |
+
|
25 |
+
# Custom CSS for better styling
|
26 |
+
st.markdown("""
|
27 |
+
<style>
|
28 |
+
.main-header {
|
29 |
+
font-size: 2.5rem;
|
30 |
+
font-weight: bold;
|
31 |
+
color: #1f77b4;
|
32 |
+
text-align: center;
|
33 |
+
margin-bottom: 2rem;
|
34 |
+
}
|
35 |
+
.section-header {
|
36 |
+
font-size: 1.5rem;
|
37 |
+
font-weight: bold;
|
38 |
+
color: #ff7f0e;
|
39 |
+
margin-top: 2rem;
|
40 |
+
margin-bottom: 1rem;
|
41 |
+
}
|
42 |
+
.metric-container {
|
43 |
+
background-color: #f0f2f6;
|
44 |
+
padding: 1rem;
|
45 |
+
border-radius: 0.5rem;
|
46 |
+
margin: 0.5rem 0;
|
47 |
+
}
|
48 |
+
.stSelectbox > div > div {
|
49 |
+
background-color: white;
|
50 |
+
}
|
51 |
+
</style>
|
52 |
+
""", unsafe_allow_html=True)
|
53 |
+
|
54 |
+
class AttentionResultsExplorer:
|
55 |
+
def __init__(self, github_repo="ACMCMC/attention", use_cache=True):
|
56 |
+
self.github_repo = github_repo
|
57 |
+
self.use_cache = use_cache
|
58 |
+
self.cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache"
|
59 |
+
self.base_path = self.cache_dir
|
60 |
+
|
61 |
+
# Initialize cache directory
|
62 |
+
if not self.cache_dir.exists():
|
63 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
64 |
+
|
65 |
+
# Download and cache data if needed
|
66 |
+
if not self._cache_exists() or not use_cache:
|
67 |
+
self._download_repository()
|
68 |
+
|
69 |
+
self.languages = self._get_available_languages()
|
70 |
+
self.relation_types = None
|
71 |
+
|
72 |
+
def _cache_exists(self):
|
73 |
+
"""Check if cached data exists"""
|
74 |
+
return (self.cache_dir / "results_en").exists()
|
75 |
+
|
76 |
+
def _download_repository(self):
|
77 |
+
"""Download repository data from GitHub"""
|
78 |
+
st.info("🔄 Downloading results data from GitHub... This may take a moment.")
|
79 |
+
|
80 |
+
# GitHub API to get the repository contents
|
81 |
+
api_url = f"https://api.github.com/repos/{self.github_repo}/contents"
|
82 |
+
|
83 |
+
try:
|
84 |
+
# Get list of result directories
|
85 |
+
response = requests.get(api_url)
|
86 |
+
response.raise_for_status()
|
87 |
+
contents = response.json()
|
88 |
+
|
89 |
+
result_dirs = [item['name'] for item in contents
|
90 |
+
if item['type'] == 'dir' and item['name'].startswith('results_')]
|
91 |
+
|
92 |
+
st.write(f"Found {len(result_dirs)} result directories: {', '.join(result_dirs)}")
|
93 |
+
|
94 |
+
# Download each result directory
|
95 |
+
progress_bar = st.progress(0)
|
96 |
+
for i, result_dir in enumerate(result_dirs):
|
97 |
+
st.write(f"Downloading {result_dir}...")
|
98 |
+
self._download_directory(result_dir)
|
99 |
+
progress_bar.progress((i + 1) / len(result_dirs))
|
100 |
+
|
101 |
+
st.success("✅ Download completed!")
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
st.error(f"❌ Error downloading repository: {str(e)}")
|
105 |
+
st.error("Please check the repository URL and your internet connection.")
|
106 |
+
raise
|
107 |
+
|
108 |
+
def _download_directory(self, dir_name, path=""):
|
109 |
+
"""Recursively download a directory from GitHub"""
|
110 |
+
url = f"https://api.github.com/repos/{self.github_repo}/contents/{path}{dir_name}"
|
111 |
+
|
112 |
+
try:
|
113 |
+
response = requests.get(url)
|
114 |
+
response.raise_for_status()
|
115 |
+
contents = response.json()
|
116 |
+
|
117 |
+
local_dir = self.cache_dir / path / dir_name
|
118 |
+
local_dir.mkdir(parents=True, exist_ok=True)
|
119 |
+
|
120 |
+
for item in contents:
|
121 |
+
if item['type'] == 'file':
|
122 |
+
self._download_file(item, local_dir)
|
123 |
+
elif item['type'] == 'dir':
|
124 |
+
self._download_directory(item['name'], f"{path}{dir_name}/")
|
125 |
+
|
126 |
+
except Exception as e:
|
127 |
+
st.warning(f"Could not download {dir_name}: {str(e)}")
|
128 |
+
|
129 |
+
def _download_file(self, file_info, local_dir):
|
130 |
+
"""Download a single file from GitHub"""
|
131 |
+
try:
|
132 |
+
# Download file content
|
133 |
+
response = requests.get(file_info['download_url'])
|
134 |
+
response.raise_for_status()
|
135 |
+
|
136 |
+
# Save to local cache
|
137 |
+
local_file = local_dir / file_info['name']
|
138 |
+
|
139 |
+
# Handle different file types
|
140 |
+
if file_info['name'].endswith(('.csv', '.json')):
|
141 |
+
with open(local_file, 'w', encoding='utf-8') as f:
|
142 |
+
f.write(response.text)
|
143 |
+
else: # Binary files like PDFs
|
144 |
+
with open(local_file, 'wb') as f:
|
145 |
+
f.write(response.content)
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
st.warning(f"Could not download file {file_info['name']}: {str(e)}")
|
149 |
+
|
150 |
+
def _get_available_languages(self):
|
151 |
+
"""Get all available language directories"""
|
152 |
+
if not self.base_path.exists():
|
153 |
+
return []
|
154 |
+
result_dirs = [d.name for d in self.base_path.iterdir()
|
155 |
+
if d.is_dir() and d.name.startswith("results_")]
|
156 |
+
languages = [d.replace("results_", "") for d in result_dirs]
|
157 |
+
return sorted(languages)
|
158 |
+
|
159 |
+
def _get_experimental_configs(self, language):
|
160 |
+
"""Get all experimental configurations for a language"""
|
161 |
+
lang_dir = self.base_path / f"results_{language}"
|
162 |
+
if not lang_dir.exists():
|
163 |
+
return []
|
164 |
+
configs = [d.name for d in lang_dir.iterdir() if d.is_dir()]
|
165 |
+
return sorted(configs)
|
166 |
+
|
167 |
+
def _get_models(self, language, config):
|
168 |
+
"""Get all models for a language and configuration"""
|
169 |
+
config_dir = self.base_path / f"results_{language}" / config
|
170 |
+
if not config_dir.exists():
|
171 |
+
return []
|
172 |
+
models = [d.name for d in config_dir.iterdir() if d.is_dir()]
|
173 |
+
return sorted(models)
|
174 |
+
|
175 |
+
def _parse_config_name(self, config_name):
|
176 |
+
"""Parse configuration name into readable format"""
|
177 |
+
parts = config_name.split('+')
|
178 |
+
config_dict = {}
|
179 |
+
for part in parts:
|
180 |
+
if '_' in part:
|
181 |
+
key, value = part.split('_', 1)
|
182 |
+
config_dict[key.replace('_', ' ').title()] = value
|
183 |
+
return config_dict
|
184 |
+
|
185 |
+
def _load_metadata(self, language, config, model):
|
186 |
+
"""Load metadata for a specific combination"""
|
187 |
+
metadata_path = self.base_path / f"results_{language}" / config / model / "metadata" / "metadata.json"
|
188 |
+
if metadata_path.exists():
|
189 |
+
with open(metadata_path, 'r') as f:
|
190 |
+
return json.load(f)
|
191 |
+
return None
|
192 |
+
|
193 |
+
def _load_uas_scores(self, language, config, model):
|
194 |
+
"""Load UAS scores data"""
|
195 |
+
uas_dir = self.base_path / f"results_{language}" / config / model / "uas_scores"
|
196 |
+
if not uas_dir.exists():
|
197 |
+
return {}
|
198 |
+
|
199 |
+
uas_data = {}
|
200 |
+
csv_files = list(uas_dir.glob("uas_*.csv"))
|
201 |
+
|
202 |
+
if csv_files:
|
203 |
+
progress_bar = st.progress(0)
|
204 |
+
status_text = st.empty()
|
205 |
+
|
206 |
+
for i, csv_file in enumerate(csv_files):
|
207 |
+
relation = csv_file.stem.replace("uas_", "")
|
208 |
+
status_text.text(f"Loading UAS data: {relation}")
|
209 |
+
|
210 |
+
try:
|
211 |
+
df = pd.read_csv(csv_file, index_col=0)
|
212 |
+
uas_data[relation] = df
|
213 |
+
except Exception as e:
|
214 |
+
st.warning(f"Could not load {csv_file.name}: {e}")
|
215 |
+
|
216 |
+
progress_bar.progress((i + 1) / len(csv_files))
|
217 |
+
|
218 |
+
progress_bar.empty()
|
219 |
+
status_text.empty()
|
220 |
+
|
221 |
+
return uas_data
|
222 |
+
|
223 |
+
def _load_head_matching(self, language, config, model):
|
224 |
+
"""Load head matching data"""
|
225 |
+
heads_dir = self.base_path / f"results_{language}" / config / model / "number_of_heads_matching"
|
226 |
+
if not heads_dir.exists():
|
227 |
+
return {}
|
228 |
+
|
229 |
+
heads_data = {}
|
230 |
+
csv_files = list(heads_dir.glob("heads_matching_*.csv"))
|
231 |
+
|
232 |
+
if csv_files:
|
233 |
+
progress_bar = st.progress(0)
|
234 |
+
status_text = st.empty()
|
235 |
+
|
236 |
+
for i, csv_file in enumerate(csv_files):
|
237 |
+
relation = csv_file.stem.replace("heads_matching_", "").replace(f"_{model}", "")
|
238 |
+
status_text.text(f"Loading head matching data: {relation}")
|
239 |
+
|
240 |
+
try:
|
241 |
+
df = pd.read_csv(csv_file, index_col=0)
|
242 |
+
heads_data[relation] = df
|
243 |
+
except Exception as e:
|
244 |
+
st.warning(f"Could not load {csv_file.name}: {e}")
|
245 |
+
|
246 |
+
progress_bar.progress((i + 1) / len(csv_files))
|
247 |
+
|
248 |
+
progress_bar.empty()
|
249 |
+
status_text.empty()
|
250 |
+
|
251 |
+
return heads_data
|
252 |
+
|
253 |
+
def _load_variability(self, language, config, model):
|
254 |
+
"""Load variability data"""
|
255 |
+
var_path = self.base_path / f"results_{language}" / config / model / "variability" / "variability_list.csv"
|
256 |
+
if var_path.exists():
|
257 |
+
try:
|
258 |
+
return pd.read_csv(var_path, index_col=0)
|
259 |
+
except Exception as e:
|
260 |
+
st.warning(f"Could not load variability data: {e}")
|
261 |
+
return None
|
262 |
+
|
263 |
+
def _get_available_figures(self, language, config, model):
|
264 |
+
"""Get all available figure files"""
|
265 |
+
figures_dir = self.base_path / f"results_{language}" / config / model / "figures"
|
266 |
+
if not figures_dir.exists():
|
267 |
+
return []
|
268 |
+
return list(figures_dir.glob("*.pdf"))
|
269 |
+
|
270 |
+
def main():
|
271 |
+
# Title
|
272 |
+
st.markdown('<div class="main-header">🔍 Attention Analysis Results Explorer</div>', unsafe_allow_html=True)
|
273 |
+
|
274 |
+
# Sidebar for navigation
|
275 |
+
st.sidebar.title("🔧 Configuration")
|
276 |
+
|
277 |
+
# Cache management section
|
278 |
+
st.sidebar.markdown("### 📁 Data Management")
|
279 |
+
|
280 |
+
# Initialize explorer
|
281 |
+
use_cache = st.sidebar.checkbox("Use cached data", value=True,
|
282 |
+
help="Use previously downloaded data if available")
|
283 |
+
|
284 |
+
if st.sidebar.button("🔄 Refresh Data", help="Download fresh data from GitHub"):
|
285 |
+
# Clear cache and re-download
|
286 |
+
cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache"
|
287 |
+
if cache_dir.exists():
|
288 |
+
shutil.rmtree(cache_dir)
|
289 |
+
st.rerun()
|
290 |
+
|
291 |
+
# Show cache status
|
292 |
+
cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache"
|
293 |
+
if cache_dir.exists():
|
294 |
+
st.sidebar.success("✅ Data cached locally")
|
295 |
+
else:
|
296 |
+
st.sidebar.info("📥 Will download data from GitHub")
|
297 |
+
|
298 |
+
st.sidebar.markdown("---")
|
299 |
+
|
300 |
+
# Initialize explorer with error handling
|
301 |
+
try:
|
302 |
+
explorer = AttentionResultsExplorer(use_cache=use_cache)
|
303 |
+
except Exception as e:
|
304 |
+
st.error(f"❌ Failed to initialize data explorer: {str(e)}")
|
305 |
+
st.error("Please check your internet connection and try again.")
|
306 |
+
return
|
307 |
+
|
308 |
+
# Check if any languages are available
|
309 |
+
if not explorer.languages:
|
310 |
+
st.error("❌ No result data found. Please check the GitHub repository.")
|
311 |
+
return
|
312 |
+
|
313 |
+
# Language selection
|
314 |
+
selected_language = st.sidebar.selectbox(
|
315 |
+
"Select Language",
|
316 |
+
options=explorer.languages,
|
317 |
+
help="Choose the language dataset to explore"
|
318 |
+
)
|
319 |
+
|
320 |
+
# Get configurations for selected language
|
321 |
+
configs = explorer._get_experimental_configs(selected_language)
|
322 |
+
if not configs:
|
323 |
+
st.error(f"No configurations found for language: {selected_language}")
|
324 |
+
return
|
325 |
+
|
326 |
+
# Configuration selection
|
327 |
+
selected_config = st.sidebar.selectbox(
|
328 |
+
"Select Experimental Configuration",
|
329 |
+
options=configs,
|
330 |
+
help="Choose the experimental configuration"
|
331 |
+
)
|
332 |
+
|
333 |
+
# Parse and display configuration details
|
334 |
+
config_details = explorer._parse_config_name(selected_config)
|
335 |
+
st.sidebar.markdown("**Configuration Details:**")
|
336 |
+
for key, value in config_details.items():
|
337 |
+
st.sidebar.markdown(f"- **{key}**: {value}")
|
338 |
+
|
339 |
+
# Get models for selected language and config
|
340 |
+
models = explorer._get_models(selected_language, selected_config)
|
341 |
+
if not models:
|
342 |
+
st.error(f"No models found for {selected_language}/{selected_config}")
|
343 |
+
return
|
344 |
+
|
345 |
+
# Model selection
|
346 |
+
selected_model = st.sidebar.selectbox(
|
347 |
+
"Select Model",
|
348 |
+
options=models,
|
349 |
+
help="Choose the model to analyze"
|
350 |
+
)
|
351 |
+
|
352 |
+
# Main content area
|
353 |
+
tab1, tab2, tab3, tab4, tab5 = st.tabs([
|
354 |
+
"📊 Overview",
|
355 |
+
"🎯 UAS Scores",
|
356 |
+
"🧠 Head Matching",
|
357 |
+
"📈 Variability",
|
358 |
+
"🖼️ Figures"
|
359 |
+
])
|
360 |
+
|
361 |
+
# Tab 1: Overview
|
362 |
+
with tab1:
|
363 |
+
st.markdown('<div class="section-header">Experiment Overview</div>', unsafe_allow_html=True)
|
364 |
+
|
365 |
+
# Load metadata
|
366 |
+
metadata = explorer._load_metadata(selected_language, selected_config, selected_model)
|
367 |
+
if metadata:
|
368 |
+
col1, col2, col3, col4 = st.columns(4)
|
369 |
+
with col1:
|
370 |
+
st.metric("Total Samples", metadata.get('total_number', 'N/A'))
|
371 |
+
with col2:
|
372 |
+
st.metric("Processed Correctly", metadata.get('number_processed_correctly', 'N/A'))
|
373 |
+
with col3:
|
374 |
+
st.metric("Errors", metadata.get('number_errored', 'N/A'))
|
375 |
+
with col4:
|
376 |
+
success_rate = (metadata.get('number_processed_correctly', 0) /
|
377 |
+
metadata.get('total_number', 1)) * 100 if metadata.get('total_number') else 0
|
378 |
+
st.metric("Success Rate", f"{success_rate:.1f}%")
|
379 |
+
|
380 |
+
st.markdown("**Random Seed:**", metadata.get('random_seed', 'N/A'))
|
381 |
+
|
382 |
+
if metadata.get('errored_phrases'):
|
383 |
+
st.markdown("**Errored Phrase IDs:**")
|
384 |
+
st.write(metadata['errored_phrases'])
|
385 |
+
else:
|
386 |
+
st.warning("No metadata available for this configuration.")
|
387 |
+
|
388 |
+
# Quick stats about available data
|
389 |
+
st.markdown('<div class="section-header">Available Data</div>', unsafe_allow_html=True)
|
390 |
+
|
391 |
+
uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model)
|
392 |
+
heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model)
|
393 |
+
variability_data = explorer._load_variability(selected_language, selected_config, selected_model)
|
394 |
+
figures = explorer._get_available_figures(selected_language, selected_config, selected_model)
|
395 |
+
|
396 |
+
col1, col2, col3, col4 = st.columns(4)
|
397 |
+
with col1:
|
398 |
+
st.metric("UAS Relations", len(uas_data))
|
399 |
+
with col2:
|
400 |
+
st.metric("Head Matching Relations", len(heads_data))
|
401 |
+
with col3:
|
402 |
+
st.metric("Variability Data", "✓" if variability_data is not None else "✗")
|
403 |
+
with col4:
|
404 |
+
st.metric("Figure Files", len(figures))
|
405 |
+
|
406 |
+
# Tab 2: UAS Scores
|
407 |
+
with tab2:
|
408 |
+
st.markdown('<div class="section-header">UAS (Unlabeled Attachment Score) Analysis</div>', unsafe_allow_html=True)
|
409 |
+
|
410 |
+
uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model)
|
411 |
+
|
412 |
+
if uas_data:
|
413 |
+
# Relation selection
|
414 |
+
selected_relation = st.selectbox(
|
415 |
+
"Select Dependency Relation",
|
416 |
+
options=list(uas_data.keys()),
|
417 |
+
help="Choose a dependency relation to visualize UAS scores"
|
418 |
+
)
|
419 |
+
|
420 |
+
if selected_relation and selected_relation in uas_data:
|
421 |
+
df = uas_data[selected_relation]
|
422 |
+
|
423 |
+
# Display the data table
|
424 |
+
st.markdown("**UAS Scores Matrix (Layer × Head)**")
|
425 |
+
st.dataframe(df, use_container_width=True)
|
426 |
+
|
427 |
+
# Create heatmap
|
428 |
+
fig = px.imshow(
|
429 |
+
df.values,
|
430 |
+
x=[f"Head {i}" for i in df.columns],
|
431 |
+
y=[f"Layer {i}" for i in df.index],
|
432 |
+
color_continuous_scale="Viridis",
|
433 |
+
title=f"UAS Scores Heatmap - {selected_relation}",
|
434 |
+
labels=dict(color="UAS Score")
|
435 |
+
)
|
436 |
+
fig.update_layout(height=600)
|
437 |
+
st.plotly_chart(fig, use_container_width=True)
|
438 |
+
|
439 |
+
# Statistics
|
440 |
+
st.markdown("**Statistics**")
|
441 |
+
col1, col2, col3, col4 = st.columns(4)
|
442 |
+
with col1:
|
443 |
+
st.metric("Max Score", f"{df.values.max():.4f}")
|
444 |
+
with col2:
|
445 |
+
st.metric("Min Score", f"{df.values.min():.4f}")
|
446 |
+
with col3:
|
447 |
+
st.metric("Mean Score", f"{df.values.mean():.4f}")
|
448 |
+
with col4:
|
449 |
+
st.metric("Std Dev", f"{df.values.std():.4f}")
|
450 |
+
else:
|
451 |
+
st.warning("No UAS score data available for this configuration.")
|
452 |
+
|
453 |
+
# Tab 3: Head Matching
|
454 |
+
with tab3:
|
455 |
+
st.markdown('<div class="section-header">Attention Head Matching Analysis</div>', unsafe_allow_html=True)
|
456 |
+
|
457 |
+
heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model)
|
458 |
+
|
459 |
+
if heads_data:
|
460 |
+
# Relation selection
|
461 |
+
selected_relation = st.selectbox(
|
462 |
+
"Select Dependency Relation",
|
463 |
+
options=list(heads_data.keys()),
|
464 |
+
help="Choose a dependency relation to visualize head matching patterns",
|
465 |
+
key="heads_relation"
|
466 |
+
)
|
467 |
+
|
468 |
+
if selected_relation and selected_relation in heads_data:
|
469 |
+
df = heads_data[selected_relation]
|
470 |
+
|
471 |
+
# Display the data table
|
472 |
+
st.markdown("**Head Matching Counts Matrix (Layer × Head)**")
|
473 |
+
st.dataframe(df, use_container_width=True)
|
474 |
+
|
475 |
+
# Create heatmap
|
476 |
+
fig = px.imshow(
|
477 |
+
df.values,
|
478 |
+
x=[f"Head {i}" for i in df.columns],
|
479 |
+
y=[f"Layer {i}" for i in df.index],
|
480 |
+
color_continuous_scale="Blues",
|
481 |
+
title=f"Head Matching Counts - {selected_relation}",
|
482 |
+
labels=dict(color="Match Count")
|
483 |
+
)
|
484 |
+
fig.update_layout(height=600)
|
485 |
+
st.plotly_chart(fig, use_container_width=True)
|
486 |
+
|
487 |
+
# Create bar chart of total matches per layer
|
488 |
+
layer_totals = df.sum(axis=1)
|
489 |
+
fig_bar = px.bar(
|
490 |
+
x=layer_totals.index,
|
491 |
+
y=layer_totals.values,
|
492 |
+
title=f"Total Matches per Layer - {selected_relation}",
|
493 |
+
labels={"x": "Layer", "y": "Total Matches"}
|
494 |
+
)
|
495 |
+
fig_bar.update_layout(height=400)
|
496 |
+
st.plotly_chart(fig_bar, use_container_width=True)
|
497 |
+
|
498 |
+
# Statistics
|
499 |
+
st.markdown("**Statistics**")
|
500 |
+
col1, col2, col3, col4 = st.columns(4)
|
501 |
+
with col1:
|
502 |
+
st.metric("Total Matches", int(df.values.sum()))
|
503 |
+
with col2:
|
504 |
+
st.metric("Max per Cell", int(df.values.max()))
|
505 |
+
with col3:
|
506 |
+
best_layer = layer_totals.idxmax()
|
507 |
+
st.metric("Best Layer", f"Layer {best_layer}")
|
508 |
+
with col4:
|
509 |
+
best_head_idx = np.unravel_index(df.values.argmax(), df.values.shape)
|
510 |
+
st.metric("Best Head", f"L{best_head_idx[0]}-H{best_head_idx[1]}")
|
511 |
+
else:
|
512 |
+
st.warning("No head matching data available for this configuration.")
|
513 |
+
|
514 |
+
# Tab 4: Variability
|
515 |
+
with tab4:
|
516 |
+
st.markdown('<div class="section-header">Attention Variability Analysis</div>', unsafe_allow_html=True)
|
517 |
+
|
518 |
+
variability_data = explorer._load_variability(selected_language, selected_config, selected_model)
|
519 |
+
|
520 |
+
if variability_data is not None:
|
521 |
+
# Display the data table
|
522 |
+
st.markdown("**Variability Matrix (Layer × Head)**")
|
523 |
+
st.dataframe(variability_data, use_container_width=True)
|
524 |
+
|
525 |
+
# Create heatmap
|
526 |
+
fig = px.imshow(
|
527 |
+
variability_data.values,
|
528 |
+
x=[f"Head {i}" for i in variability_data.columns],
|
529 |
+
y=[f"Layer {i}" for i in variability_data.index],
|
530 |
+
color_continuous_scale="Reds",
|
531 |
+
title="Attention Variability Heatmap",
|
532 |
+
labels=dict(color="Variability Score")
|
533 |
+
)
|
534 |
+
fig.update_layout(height=600)
|
535 |
+
st.plotly_chart(fig, use_container_width=True)
|
536 |
+
|
537 |
+
# Create line plot for variability trends
|
538 |
+
fig_line = go.Figure()
|
539 |
+
for col in variability_data.columns:
|
540 |
+
fig_line.add_trace(go.Scatter(
|
541 |
+
x=variability_data.index,
|
542 |
+
y=variability_data[col],
|
543 |
+
mode='lines+markers',
|
544 |
+
name=f'Head {col}',
|
545 |
+
line=dict(width=2)
|
546 |
+
))
|
547 |
+
|
548 |
+
fig_line.update_layout(
|
549 |
+
title="Variability Trends Across Layers",
|
550 |
+
xaxis_title="Layer",
|
551 |
+
yaxis_title="Variability Score",
|
552 |
+
height=500
|
553 |
+
)
|
554 |
+
st.plotly_chart(fig_line, use_container_width=True)
|
555 |
+
|
556 |
+
# Statistics
|
557 |
+
st.markdown("**Statistics**")
|
558 |
+
col1, col2, col3, col4 = st.columns(4)
|
559 |
+
with col1:
|
560 |
+
st.metric("Max Variability", f"{variability_data.values.max():.4f}")
|
561 |
+
with col2:
|
562 |
+
st.metric("Min Variability", f"{variability_data.values.min():.4f}")
|
563 |
+
with col3:
|
564 |
+
st.metric("Mean Variability", f"{variability_data.values.mean():.4f}")
|
565 |
+
with col4:
|
566 |
+
most_variable_idx = np.unravel_index(variability_data.values.argmax(), variability_data.values.shape)
|
567 |
+
st.metric("Most Variable", f"L{most_variable_idx[0]}-H{most_variable_idx[1]}")
|
568 |
+
else:
|
569 |
+
st.warning("No variability data available for this configuration.")
|
570 |
+
|
571 |
+
# Tab 5: Figures
|
572 |
+
with tab5:
|
573 |
+
st.markdown('<div class="section-header">Generated Figures</div>', unsafe_allow_html=True)
|
574 |
+
|
575 |
+
figures = explorer._get_available_figures(selected_language, selected_config, selected_model)
|
576 |
+
|
577 |
+
if figures:
|
578 |
+
st.markdown(f"**Available Figures: {len(figures)}**")
|
579 |
+
|
580 |
+
# Group figures by relation type
|
581 |
+
figure_groups = {}
|
582 |
+
for fig_path in figures:
|
583 |
+
# Extract relation from filename
|
584 |
+
filename = fig_path.stem
|
585 |
+
relation = filename.replace("heads_matching_", "").replace(f"_{selected_model}", "")
|
586 |
+
if relation not in figure_groups:
|
587 |
+
figure_groups[relation] = []
|
588 |
+
figure_groups[relation].append(fig_path)
|
589 |
+
|
590 |
+
# Select relation to view
|
591 |
+
selected_fig_relation = st.selectbox(
|
592 |
+
"Select Relation for Figure View",
|
593 |
+
options=list(figure_groups.keys()),
|
594 |
+
help="Choose a dependency relation to view its figure"
|
595 |
+
)
|
596 |
+
|
597 |
+
if selected_fig_relation and selected_fig_relation in figure_groups:
|
598 |
+
fig_path = figure_groups[selected_fig_relation][0]
|
599 |
+
|
600 |
+
st.markdown(f"**Figure: {fig_path.name}**")
|
601 |
+
st.markdown(f"**Path:** `{fig_path}`")
|
602 |
+
|
603 |
+
# Note about PDF viewing
|
604 |
+
st.info(
|
605 |
+
"📄 PDF figures are available in the results directory. "
|
606 |
+
"Due to Streamlit limitations, PDF files cannot be displayed directly in the browser. "
|
607 |
+
"You can download or view them locally."
|
608 |
+
)
|
609 |
+
|
610 |
+
# Provide download link
|
611 |
+
try:
|
612 |
+
with open(fig_path, "rb") as file:
|
613 |
+
st.download_button(
|
614 |
+
label=f"📥 Download {fig_path.name}",
|
615 |
+
data=file.read(),
|
616 |
+
file_name=fig_path.name,
|
617 |
+
mime="application/pdf"
|
618 |
+
)
|
619 |
+
except Exception as e:
|
620 |
+
st.error(f"Could not load figure: {e}")
|
621 |
+
|
622 |
+
# List all available figures
|
623 |
+
st.markdown("**All Available Figures:**")
|
624 |
+
for relation, paths in figure_groups.items():
|
625 |
+
with st.expander(f"📊 {relation} ({len(paths)} files)"):
|
626 |
+
for path in paths:
|
627 |
+
st.markdown(f"- `{path.name}`")
|
628 |
+
else:
|
629 |
+
st.warning("No figures available for this configuration.")
|
630 |
+
|
631 |
+
# Footer
|
632 |
+
st.markdown("---")
|
633 |
+
|
634 |
+
# Data source information
|
635 |
+
col1, col2 = st.columns([2, 1])
|
636 |
+
with col1:
|
637 |
+
st.markdown(
|
638 |
+
"🔬 **Attention Analysis Results Explorer** | "
|
639 |
+
f"Currently viewing: {selected_language.upper()} - {selected_model} | "
|
640 |
+
"Built with Streamlit"
|
641 |
+
)
|
642 |
+
with col2:
|
643 |
+
st.markdown(
|
644 |
+
f"📊 **Data Source**: [GitHub Repository](https://github.com/{explorer.github_repo})"
|
645 |
+
)
|
646 |
+
|
647 |
+
if __name__ == "__main__":
|
648 |
+
main()
|