acmc commited on
Commit
6f92421
·
verified ·
1 Parent(s): 2405ca0

Create streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +648 -0
streamlit_app.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import json
4
+ import os
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ from plotly.subplots import make_subplots
8
+ import numpy as np
9
+ from pathlib import Path
10
+ import glob
11
+ import requests
12
+ from io import StringIO
13
+ import zipfile
14
+ import tempfile
15
+ import shutil
16
+
17
+ # Set page config
18
+ st.set_page_config(
19
+ page_title="Attention Analysis Results Explorer",
20
+ page_icon="🔍",
21
+ layout="wide",
22
+ initial_sidebar_state="expanded"
23
+ )
24
+
25
+ # Custom CSS for better styling
26
+ st.markdown("""
27
+ <style>
28
+ .main-header {
29
+ font-size: 2.5rem;
30
+ font-weight: bold;
31
+ color: #1f77b4;
32
+ text-align: center;
33
+ margin-bottom: 2rem;
34
+ }
35
+ .section-header {
36
+ font-size: 1.5rem;
37
+ font-weight: bold;
38
+ color: #ff7f0e;
39
+ margin-top: 2rem;
40
+ margin-bottom: 1rem;
41
+ }
42
+ .metric-container {
43
+ background-color: #f0f2f6;
44
+ padding: 1rem;
45
+ border-radius: 0.5rem;
46
+ margin: 0.5rem 0;
47
+ }
48
+ .stSelectbox > div > div {
49
+ background-color: white;
50
+ }
51
+ </style>
52
+ """, unsafe_allow_html=True)
53
+
54
+ class AttentionResultsExplorer:
55
+ def __init__(self, github_repo="ACMCMC/attention", use_cache=True):
56
+ self.github_repo = github_repo
57
+ self.use_cache = use_cache
58
+ self.cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache"
59
+ self.base_path = self.cache_dir
60
+
61
+ # Initialize cache directory
62
+ if not self.cache_dir.exists():
63
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
64
+
65
+ # Download and cache data if needed
66
+ if not self._cache_exists() or not use_cache:
67
+ self._download_repository()
68
+
69
+ self.languages = self._get_available_languages()
70
+ self.relation_types = None
71
+
72
+ def _cache_exists(self):
73
+ """Check if cached data exists"""
74
+ return (self.cache_dir / "results_en").exists()
75
+
76
+ def _download_repository(self):
77
+ """Download repository data from GitHub"""
78
+ st.info("🔄 Downloading results data from GitHub... This may take a moment.")
79
+
80
+ # GitHub API to get the repository contents
81
+ api_url = f"https://api.github.com/repos/{self.github_repo}/contents"
82
+
83
+ try:
84
+ # Get list of result directories
85
+ response = requests.get(api_url)
86
+ response.raise_for_status()
87
+ contents = response.json()
88
+
89
+ result_dirs = [item['name'] for item in contents
90
+ if item['type'] == 'dir' and item['name'].startswith('results_')]
91
+
92
+ st.write(f"Found {len(result_dirs)} result directories: {', '.join(result_dirs)}")
93
+
94
+ # Download each result directory
95
+ progress_bar = st.progress(0)
96
+ for i, result_dir in enumerate(result_dirs):
97
+ st.write(f"Downloading {result_dir}...")
98
+ self._download_directory(result_dir)
99
+ progress_bar.progress((i + 1) / len(result_dirs))
100
+
101
+ st.success("✅ Download completed!")
102
+
103
+ except Exception as e:
104
+ st.error(f"❌ Error downloading repository: {str(e)}")
105
+ st.error("Please check the repository URL and your internet connection.")
106
+ raise
107
+
108
+ def _download_directory(self, dir_name, path=""):
109
+ """Recursively download a directory from GitHub"""
110
+ url = f"https://api.github.com/repos/{self.github_repo}/contents/{path}{dir_name}"
111
+
112
+ try:
113
+ response = requests.get(url)
114
+ response.raise_for_status()
115
+ contents = response.json()
116
+
117
+ local_dir = self.cache_dir / path / dir_name
118
+ local_dir.mkdir(parents=True, exist_ok=True)
119
+
120
+ for item in contents:
121
+ if item['type'] == 'file':
122
+ self._download_file(item, local_dir)
123
+ elif item['type'] == 'dir':
124
+ self._download_directory(item['name'], f"{path}{dir_name}/")
125
+
126
+ except Exception as e:
127
+ st.warning(f"Could not download {dir_name}: {str(e)}")
128
+
129
+ def _download_file(self, file_info, local_dir):
130
+ """Download a single file from GitHub"""
131
+ try:
132
+ # Download file content
133
+ response = requests.get(file_info['download_url'])
134
+ response.raise_for_status()
135
+
136
+ # Save to local cache
137
+ local_file = local_dir / file_info['name']
138
+
139
+ # Handle different file types
140
+ if file_info['name'].endswith(('.csv', '.json')):
141
+ with open(local_file, 'w', encoding='utf-8') as f:
142
+ f.write(response.text)
143
+ else: # Binary files like PDFs
144
+ with open(local_file, 'wb') as f:
145
+ f.write(response.content)
146
+
147
+ except Exception as e:
148
+ st.warning(f"Could not download file {file_info['name']}: {str(e)}")
149
+
150
+ def _get_available_languages(self):
151
+ """Get all available language directories"""
152
+ if not self.base_path.exists():
153
+ return []
154
+ result_dirs = [d.name for d in self.base_path.iterdir()
155
+ if d.is_dir() and d.name.startswith("results_")]
156
+ languages = [d.replace("results_", "") for d in result_dirs]
157
+ return sorted(languages)
158
+
159
+ def _get_experimental_configs(self, language):
160
+ """Get all experimental configurations for a language"""
161
+ lang_dir = self.base_path / f"results_{language}"
162
+ if not lang_dir.exists():
163
+ return []
164
+ configs = [d.name for d in lang_dir.iterdir() if d.is_dir()]
165
+ return sorted(configs)
166
+
167
+ def _get_models(self, language, config):
168
+ """Get all models for a language and configuration"""
169
+ config_dir = self.base_path / f"results_{language}" / config
170
+ if not config_dir.exists():
171
+ return []
172
+ models = [d.name for d in config_dir.iterdir() if d.is_dir()]
173
+ return sorted(models)
174
+
175
+ def _parse_config_name(self, config_name):
176
+ """Parse configuration name into readable format"""
177
+ parts = config_name.split('+')
178
+ config_dict = {}
179
+ for part in parts:
180
+ if '_' in part:
181
+ key, value = part.split('_', 1)
182
+ config_dict[key.replace('_', ' ').title()] = value
183
+ return config_dict
184
+
185
+ def _load_metadata(self, language, config, model):
186
+ """Load metadata for a specific combination"""
187
+ metadata_path = self.base_path / f"results_{language}" / config / model / "metadata" / "metadata.json"
188
+ if metadata_path.exists():
189
+ with open(metadata_path, 'r') as f:
190
+ return json.load(f)
191
+ return None
192
+
193
+ def _load_uas_scores(self, language, config, model):
194
+ """Load UAS scores data"""
195
+ uas_dir = self.base_path / f"results_{language}" / config / model / "uas_scores"
196
+ if not uas_dir.exists():
197
+ return {}
198
+
199
+ uas_data = {}
200
+ csv_files = list(uas_dir.glob("uas_*.csv"))
201
+
202
+ if csv_files:
203
+ progress_bar = st.progress(0)
204
+ status_text = st.empty()
205
+
206
+ for i, csv_file in enumerate(csv_files):
207
+ relation = csv_file.stem.replace("uas_", "")
208
+ status_text.text(f"Loading UAS data: {relation}")
209
+
210
+ try:
211
+ df = pd.read_csv(csv_file, index_col=0)
212
+ uas_data[relation] = df
213
+ except Exception as e:
214
+ st.warning(f"Could not load {csv_file.name}: {e}")
215
+
216
+ progress_bar.progress((i + 1) / len(csv_files))
217
+
218
+ progress_bar.empty()
219
+ status_text.empty()
220
+
221
+ return uas_data
222
+
223
+ def _load_head_matching(self, language, config, model):
224
+ """Load head matching data"""
225
+ heads_dir = self.base_path / f"results_{language}" / config / model / "number_of_heads_matching"
226
+ if not heads_dir.exists():
227
+ return {}
228
+
229
+ heads_data = {}
230
+ csv_files = list(heads_dir.glob("heads_matching_*.csv"))
231
+
232
+ if csv_files:
233
+ progress_bar = st.progress(0)
234
+ status_text = st.empty()
235
+
236
+ for i, csv_file in enumerate(csv_files):
237
+ relation = csv_file.stem.replace("heads_matching_", "").replace(f"_{model}", "")
238
+ status_text.text(f"Loading head matching data: {relation}")
239
+
240
+ try:
241
+ df = pd.read_csv(csv_file, index_col=0)
242
+ heads_data[relation] = df
243
+ except Exception as e:
244
+ st.warning(f"Could not load {csv_file.name}: {e}")
245
+
246
+ progress_bar.progress((i + 1) / len(csv_files))
247
+
248
+ progress_bar.empty()
249
+ status_text.empty()
250
+
251
+ return heads_data
252
+
253
+ def _load_variability(self, language, config, model):
254
+ """Load variability data"""
255
+ var_path = self.base_path / f"results_{language}" / config / model / "variability" / "variability_list.csv"
256
+ if var_path.exists():
257
+ try:
258
+ return pd.read_csv(var_path, index_col=0)
259
+ except Exception as e:
260
+ st.warning(f"Could not load variability data: {e}")
261
+ return None
262
+
263
+ def _get_available_figures(self, language, config, model):
264
+ """Get all available figure files"""
265
+ figures_dir = self.base_path / f"results_{language}" / config / model / "figures"
266
+ if not figures_dir.exists():
267
+ return []
268
+ return list(figures_dir.glob("*.pdf"))
269
+
270
+ def main():
271
+ # Title
272
+ st.markdown('<div class="main-header">🔍 Attention Analysis Results Explorer</div>', unsafe_allow_html=True)
273
+
274
+ # Sidebar for navigation
275
+ st.sidebar.title("🔧 Configuration")
276
+
277
+ # Cache management section
278
+ st.sidebar.markdown("### 📁 Data Management")
279
+
280
+ # Initialize explorer
281
+ use_cache = st.sidebar.checkbox("Use cached data", value=True,
282
+ help="Use previously downloaded data if available")
283
+
284
+ if st.sidebar.button("🔄 Refresh Data", help="Download fresh data from GitHub"):
285
+ # Clear cache and re-download
286
+ cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache"
287
+ if cache_dir.exists():
288
+ shutil.rmtree(cache_dir)
289
+ st.rerun()
290
+
291
+ # Show cache status
292
+ cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache"
293
+ if cache_dir.exists():
294
+ st.sidebar.success("✅ Data cached locally")
295
+ else:
296
+ st.sidebar.info("📥 Will download data from GitHub")
297
+
298
+ st.sidebar.markdown("---")
299
+
300
+ # Initialize explorer with error handling
301
+ try:
302
+ explorer = AttentionResultsExplorer(use_cache=use_cache)
303
+ except Exception as e:
304
+ st.error(f"❌ Failed to initialize data explorer: {str(e)}")
305
+ st.error("Please check your internet connection and try again.")
306
+ return
307
+
308
+ # Check if any languages are available
309
+ if not explorer.languages:
310
+ st.error("❌ No result data found. Please check the GitHub repository.")
311
+ return
312
+
313
+ # Language selection
314
+ selected_language = st.sidebar.selectbox(
315
+ "Select Language",
316
+ options=explorer.languages,
317
+ help="Choose the language dataset to explore"
318
+ )
319
+
320
+ # Get configurations for selected language
321
+ configs = explorer._get_experimental_configs(selected_language)
322
+ if not configs:
323
+ st.error(f"No configurations found for language: {selected_language}")
324
+ return
325
+
326
+ # Configuration selection
327
+ selected_config = st.sidebar.selectbox(
328
+ "Select Experimental Configuration",
329
+ options=configs,
330
+ help="Choose the experimental configuration"
331
+ )
332
+
333
+ # Parse and display configuration details
334
+ config_details = explorer._parse_config_name(selected_config)
335
+ st.sidebar.markdown("**Configuration Details:**")
336
+ for key, value in config_details.items():
337
+ st.sidebar.markdown(f"- **{key}**: {value}")
338
+
339
+ # Get models for selected language and config
340
+ models = explorer._get_models(selected_language, selected_config)
341
+ if not models:
342
+ st.error(f"No models found for {selected_language}/{selected_config}")
343
+ return
344
+
345
+ # Model selection
346
+ selected_model = st.sidebar.selectbox(
347
+ "Select Model",
348
+ options=models,
349
+ help="Choose the model to analyze"
350
+ )
351
+
352
+ # Main content area
353
+ tab1, tab2, tab3, tab4, tab5 = st.tabs([
354
+ "📊 Overview",
355
+ "🎯 UAS Scores",
356
+ "🧠 Head Matching",
357
+ "📈 Variability",
358
+ "🖼️ Figures"
359
+ ])
360
+
361
+ # Tab 1: Overview
362
+ with tab1:
363
+ st.markdown('<div class="section-header">Experiment Overview</div>', unsafe_allow_html=True)
364
+
365
+ # Load metadata
366
+ metadata = explorer._load_metadata(selected_language, selected_config, selected_model)
367
+ if metadata:
368
+ col1, col2, col3, col4 = st.columns(4)
369
+ with col1:
370
+ st.metric("Total Samples", metadata.get('total_number', 'N/A'))
371
+ with col2:
372
+ st.metric("Processed Correctly", metadata.get('number_processed_correctly', 'N/A'))
373
+ with col3:
374
+ st.metric("Errors", metadata.get('number_errored', 'N/A'))
375
+ with col4:
376
+ success_rate = (metadata.get('number_processed_correctly', 0) /
377
+ metadata.get('total_number', 1)) * 100 if metadata.get('total_number') else 0
378
+ st.metric("Success Rate", f"{success_rate:.1f}%")
379
+
380
+ st.markdown("**Random Seed:**", metadata.get('random_seed', 'N/A'))
381
+
382
+ if metadata.get('errored_phrases'):
383
+ st.markdown("**Errored Phrase IDs:**")
384
+ st.write(metadata['errored_phrases'])
385
+ else:
386
+ st.warning("No metadata available for this configuration.")
387
+
388
+ # Quick stats about available data
389
+ st.markdown('<div class="section-header">Available Data</div>', unsafe_allow_html=True)
390
+
391
+ uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model)
392
+ heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model)
393
+ variability_data = explorer._load_variability(selected_language, selected_config, selected_model)
394
+ figures = explorer._get_available_figures(selected_language, selected_config, selected_model)
395
+
396
+ col1, col2, col3, col4 = st.columns(4)
397
+ with col1:
398
+ st.metric("UAS Relations", len(uas_data))
399
+ with col2:
400
+ st.metric("Head Matching Relations", len(heads_data))
401
+ with col3:
402
+ st.metric("Variability Data", "✓" if variability_data is not None else "✗")
403
+ with col4:
404
+ st.metric("Figure Files", len(figures))
405
+
406
+ # Tab 2: UAS Scores
407
+ with tab2:
408
+ st.markdown('<div class="section-header">UAS (Unlabeled Attachment Score) Analysis</div>', unsafe_allow_html=True)
409
+
410
+ uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model)
411
+
412
+ if uas_data:
413
+ # Relation selection
414
+ selected_relation = st.selectbox(
415
+ "Select Dependency Relation",
416
+ options=list(uas_data.keys()),
417
+ help="Choose a dependency relation to visualize UAS scores"
418
+ )
419
+
420
+ if selected_relation and selected_relation in uas_data:
421
+ df = uas_data[selected_relation]
422
+
423
+ # Display the data table
424
+ st.markdown("**UAS Scores Matrix (Layer × Head)**")
425
+ st.dataframe(df, use_container_width=True)
426
+
427
+ # Create heatmap
428
+ fig = px.imshow(
429
+ df.values,
430
+ x=[f"Head {i}" for i in df.columns],
431
+ y=[f"Layer {i}" for i in df.index],
432
+ color_continuous_scale="Viridis",
433
+ title=f"UAS Scores Heatmap - {selected_relation}",
434
+ labels=dict(color="UAS Score")
435
+ )
436
+ fig.update_layout(height=600)
437
+ st.plotly_chart(fig, use_container_width=True)
438
+
439
+ # Statistics
440
+ st.markdown("**Statistics**")
441
+ col1, col2, col3, col4 = st.columns(4)
442
+ with col1:
443
+ st.metric("Max Score", f"{df.values.max():.4f}")
444
+ with col2:
445
+ st.metric("Min Score", f"{df.values.min():.4f}")
446
+ with col3:
447
+ st.metric("Mean Score", f"{df.values.mean():.4f}")
448
+ with col4:
449
+ st.metric("Std Dev", f"{df.values.std():.4f}")
450
+ else:
451
+ st.warning("No UAS score data available for this configuration.")
452
+
453
+ # Tab 3: Head Matching
454
+ with tab3:
455
+ st.markdown('<div class="section-header">Attention Head Matching Analysis</div>', unsafe_allow_html=True)
456
+
457
+ heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model)
458
+
459
+ if heads_data:
460
+ # Relation selection
461
+ selected_relation = st.selectbox(
462
+ "Select Dependency Relation",
463
+ options=list(heads_data.keys()),
464
+ help="Choose a dependency relation to visualize head matching patterns",
465
+ key="heads_relation"
466
+ )
467
+
468
+ if selected_relation and selected_relation in heads_data:
469
+ df = heads_data[selected_relation]
470
+
471
+ # Display the data table
472
+ st.markdown("**Head Matching Counts Matrix (Layer × Head)**")
473
+ st.dataframe(df, use_container_width=True)
474
+
475
+ # Create heatmap
476
+ fig = px.imshow(
477
+ df.values,
478
+ x=[f"Head {i}" for i in df.columns],
479
+ y=[f"Layer {i}" for i in df.index],
480
+ color_continuous_scale="Blues",
481
+ title=f"Head Matching Counts - {selected_relation}",
482
+ labels=dict(color="Match Count")
483
+ )
484
+ fig.update_layout(height=600)
485
+ st.plotly_chart(fig, use_container_width=True)
486
+
487
+ # Create bar chart of total matches per layer
488
+ layer_totals = df.sum(axis=1)
489
+ fig_bar = px.bar(
490
+ x=layer_totals.index,
491
+ y=layer_totals.values,
492
+ title=f"Total Matches per Layer - {selected_relation}",
493
+ labels={"x": "Layer", "y": "Total Matches"}
494
+ )
495
+ fig_bar.update_layout(height=400)
496
+ st.plotly_chart(fig_bar, use_container_width=True)
497
+
498
+ # Statistics
499
+ st.markdown("**Statistics**")
500
+ col1, col2, col3, col4 = st.columns(4)
501
+ with col1:
502
+ st.metric("Total Matches", int(df.values.sum()))
503
+ with col2:
504
+ st.metric("Max per Cell", int(df.values.max()))
505
+ with col3:
506
+ best_layer = layer_totals.idxmax()
507
+ st.metric("Best Layer", f"Layer {best_layer}")
508
+ with col4:
509
+ best_head_idx = np.unravel_index(df.values.argmax(), df.values.shape)
510
+ st.metric("Best Head", f"L{best_head_idx[0]}-H{best_head_idx[1]}")
511
+ else:
512
+ st.warning("No head matching data available for this configuration.")
513
+
514
+ # Tab 4: Variability
515
+ with tab4:
516
+ st.markdown('<div class="section-header">Attention Variability Analysis</div>', unsafe_allow_html=True)
517
+
518
+ variability_data = explorer._load_variability(selected_language, selected_config, selected_model)
519
+
520
+ if variability_data is not None:
521
+ # Display the data table
522
+ st.markdown("**Variability Matrix (Layer × Head)**")
523
+ st.dataframe(variability_data, use_container_width=True)
524
+
525
+ # Create heatmap
526
+ fig = px.imshow(
527
+ variability_data.values,
528
+ x=[f"Head {i}" for i in variability_data.columns],
529
+ y=[f"Layer {i}" for i in variability_data.index],
530
+ color_continuous_scale="Reds",
531
+ title="Attention Variability Heatmap",
532
+ labels=dict(color="Variability Score")
533
+ )
534
+ fig.update_layout(height=600)
535
+ st.plotly_chart(fig, use_container_width=True)
536
+
537
+ # Create line plot for variability trends
538
+ fig_line = go.Figure()
539
+ for col in variability_data.columns:
540
+ fig_line.add_trace(go.Scatter(
541
+ x=variability_data.index,
542
+ y=variability_data[col],
543
+ mode='lines+markers',
544
+ name=f'Head {col}',
545
+ line=dict(width=2)
546
+ ))
547
+
548
+ fig_line.update_layout(
549
+ title="Variability Trends Across Layers",
550
+ xaxis_title="Layer",
551
+ yaxis_title="Variability Score",
552
+ height=500
553
+ )
554
+ st.plotly_chart(fig_line, use_container_width=True)
555
+
556
+ # Statistics
557
+ st.markdown("**Statistics**")
558
+ col1, col2, col3, col4 = st.columns(4)
559
+ with col1:
560
+ st.metric("Max Variability", f"{variability_data.values.max():.4f}")
561
+ with col2:
562
+ st.metric("Min Variability", f"{variability_data.values.min():.4f}")
563
+ with col3:
564
+ st.metric("Mean Variability", f"{variability_data.values.mean():.4f}")
565
+ with col4:
566
+ most_variable_idx = np.unravel_index(variability_data.values.argmax(), variability_data.values.shape)
567
+ st.metric("Most Variable", f"L{most_variable_idx[0]}-H{most_variable_idx[1]}")
568
+ else:
569
+ st.warning("No variability data available for this configuration.")
570
+
571
+ # Tab 5: Figures
572
+ with tab5:
573
+ st.markdown('<div class="section-header">Generated Figures</div>', unsafe_allow_html=True)
574
+
575
+ figures = explorer._get_available_figures(selected_language, selected_config, selected_model)
576
+
577
+ if figures:
578
+ st.markdown(f"**Available Figures: {len(figures)}**")
579
+
580
+ # Group figures by relation type
581
+ figure_groups = {}
582
+ for fig_path in figures:
583
+ # Extract relation from filename
584
+ filename = fig_path.stem
585
+ relation = filename.replace("heads_matching_", "").replace(f"_{selected_model}", "")
586
+ if relation not in figure_groups:
587
+ figure_groups[relation] = []
588
+ figure_groups[relation].append(fig_path)
589
+
590
+ # Select relation to view
591
+ selected_fig_relation = st.selectbox(
592
+ "Select Relation for Figure View",
593
+ options=list(figure_groups.keys()),
594
+ help="Choose a dependency relation to view its figure"
595
+ )
596
+
597
+ if selected_fig_relation and selected_fig_relation in figure_groups:
598
+ fig_path = figure_groups[selected_fig_relation][0]
599
+
600
+ st.markdown(f"**Figure: {fig_path.name}**")
601
+ st.markdown(f"**Path:** `{fig_path}`")
602
+
603
+ # Note about PDF viewing
604
+ st.info(
605
+ "📄 PDF figures are available in the results directory. "
606
+ "Due to Streamlit limitations, PDF files cannot be displayed directly in the browser. "
607
+ "You can download or view them locally."
608
+ )
609
+
610
+ # Provide download link
611
+ try:
612
+ with open(fig_path, "rb") as file:
613
+ st.download_button(
614
+ label=f"📥 Download {fig_path.name}",
615
+ data=file.read(),
616
+ file_name=fig_path.name,
617
+ mime="application/pdf"
618
+ )
619
+ except Exception as e:
620
+ st.error(f"Could not load figure: {e}")
621
+
622
+ # List all available figures
623
+ st.markdown("**All Available Figures:**")
624
+ for relation, paths in figure_groups.items():
625
+ with st.expander(f"📊 {relation} ({len(paths)} files)"):
626
+ for path in paths:
627
+ st.markdown(f"- `{path.name}`")
628
+ else:
629
+ st.warning("No figures available for this configuration.")
630
+
631
+ # Footer
632
+ st.markdown("---")
633
+
634
+ # Data source information
635
+ col1, col2 = st.columns([2, 1])
636
+ with col1:
637
+ st.markdown(
638
+ "🔬 **Attention Analysis Results Explorer** | "
639
+ f"Currently viewing: {selected_language.upper()} - {selected_model} | "
640
+ "Built with Streamlit"
641
+ )
642
+ with col2:
643
+ st.markdown(
644
+ f"📊 **Data Source**: [GitHub Repository](https://github.com/{explorer.github_repo})"
645
+ )
646
+
647
+ if __name__ == "__main__":
648
+ main()