import streamlit as st import pandas as pd import json import os import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots import numpy as np from pathlib import Path import glob import requests from io import StringIO import zipfile import tempfile import shutil import time from datetime import datetime, timezone # Set page config st.set_page_config( page_title="Attention Analysis Results Explorer", page_icon="🔍", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for better styling st.markdown(""" """, unsafe_allow_html=True) class AttentionResultsExplorer: def __init__(self, github_repo="ACMCMC/attention", use_cache=True): self.github_repo = github_repo self.use_cache = use_cache self.cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache" self.base_path = self.cache_dir # Initialize cache directory if not self.cache_dir.exists(): self.cache_dir.mkdir(parents=True, exist_ok=True) # Get available languages from GitHub without downloading self.available_languages = self._get_available_languages_from_github() self.relation_types = None def _get_available_languages_from_github(self): """Get available languages from GitHub API without downloading""" api_url = f"https://api.github.com/repos/{self.github_repo}/contents" response = self._make_github_request(api_url, "available languages") if response is None: # Rate limit hit or other error, fallback to local cache return self._get_available_languages_local() try: contents = response.json() result_dirs = [item['name'] for item in contents if item['type'] == 'dir' and item['name'].startswith('results_')] languages = [d.replace("results_", "") for d in result_dirs] return sorted(languages) except Exception as e: st.warning(f"Could not parse language list from GitHub: {str(e)}") # Fallback to local cache if available return self._get_available_languages_local() def _get_available_languages_local(self): """Get available languages from local cache""" if not self.base_path.exists(): return [] result_dirs = [d.name for d in self.base_path.iterdir() if d.is_dir() and d.name.startswith("results_")] languages = [d.replace("results_", "") for d in result_dirs] return sorted(languages) def _ensure_specific_data_downloaded(self, language, config, model): """Download specific files for a language/config/model combination if not cached""" base_path = f"results_{language}/{config}/{model}" local_path = self.base_path / f"results_{language}" / config / model # Check if we already have this specific combination cached if local_path.exists() and self.use_cache: # Quick check if essential files exist metadata_path = local_path / "metadata" / "metadata.json" if metadata_path.exists(): return # Already have the data with st.spinner(f"📥 Downloading data for {language.upper()}/{config}/{model}..."): try: self._download_specific_model_data(language, config, model) st.success(f"✅ Downloaded {language.upper()}/{model} data!") except Exception as e: st.error(f"❌ Failed to download specific data: {str(e)}") raise def _download_specific_model_data(self, language, config, model): """Download only the specific model data needed""" base_remote_path = f"results_{language}/{config}/{model}" # List of essential directories to download for a model essential_dirs = ["metadata", "uas_scores", "number_of_heads_matching", "variability", "figures"] for dir_name in essential_dirs: remote_path = f"{base_remote_path}/{dir_name}" try: self._download_directory_targeted(dir_name, remote_path, language, config, model) except Exception as e: st.warning(f"Could not download {dir_name} for {model}: {str(e)}") def _download_directory_targeted(self, dir_name, remote_path, language, config, model): """Download a specific directory for a model""" api_url = f"https://api.github.com/repos/{self.github_repo}/contents/{remote_path}" response = self._make_github_request(api_url, f"directory {dir_name}", silent_404=True) if response is None: return # Rate limit, 404, or other error try: contents = response.json() # Create local directory local_dir = self.base_path / f"results_{language}" / config / model / dir_name local_dir.mkdir(parents=True, exist_ok=True) # Download all files in this directory for item in contents: if item['type'] == 'file': self._download_file(item, local_dir) except Exception as e: st.warning(f"Could not download directory {dir_name}: {str(e)}") def _get_available_configs_from_github(self, language): """Get available configurations for a language from GitHub""" api_url = f"https://api.github.com/repos/{self.github_repo}/contents/results_{language}" response = self._make_github_request(api_url, f"configurations for {language}") if response is None: return [] try: contents = response.json() configs = [item['name'] for item in contents if item['type'] == 'dir'] return sorted(configs) except Exception as e: st.warning(f"Could not parse configurations for {language}: {str(e)}") return [] def _discover_config_parameters(self, language=None): """Dynamically discover configuration parameters from available configs For performance optimization, we only inspect the first language since configurations are consistent across all languages and models. """ try: # Use first available language if none specified (optimization) if language is None: if not self.available_languages: return {} language = self.available_languages[0] available_configs = self._get_experimental_configs(language) if not available_configs: return {} # Parse all configurations to extract unique parameters all_params = set() param_values = {} for config in available_configs: params = self._parse_config_params(config) for param, value in params.items(): all_params.add(param) if param not in param_values: param_values[param] = set() param_values[param].add(value) # Convert sets to sorted lists for consistent UI return {param: sorted(list(values)) for param, values in param_values.items()} except Exception as e: st.warning(f"Could not discover configuration parameters: {str(e)}") return {} def _build_config_from_params(self, param_dict): """Build configuration string from parameter dictionary""" config_parts = [] for param, value in sorted(param_dict.items()): config_parts.append(f"{param}_{value}") return "+".join(config_parts) def _find_best_matching_config(self, language, target_params): """Find the configuration that best matches the target parameters""" available_configs = self._get_experimental_configs(language) best_match = None best_score = -1 for config in available_configs: config_params = self._parse_config_params(config) # Calculate match score score = 0 total_params = len(target_params) for param, target_value in target_params.items(): if param in config_params and config_params[param] == target_value: score += 1 # Prefer configs with exact parameter count if len(config_params) == total_params: score += 0.5 if score > best_score: best_score = score best_match = config return best_match, best_score == len(target_params) def _download_repository(self): """Download repository data from GitHub""" st.info("🔄 Downloading results data from GitHub... This may take a moment.") # GitHub API to get the repository contents api_url = f"https://api.github.com/repos/{self.github_repo}/contents" try: # Get list of result directories response = requests.get(api_url) response.raise_for_status() contents = response.json() result_dirs = [item['name'] for item in contents if item['type'] == 'dir' and item['name'].startswith('results_')] st.write(f"Found {len(result_dirs)} result directories: {', '.join(result_dirs)}") # Download each result directory progress_bar = st.progress(0) for i, result_dir in enumerate(result_dirs): st.write(f"Downloading {result_dir}...") self._download_directory(result_dir) progress_bar.progress((i + 1) / len(result_dirs)) st.success("✅ Download completed!") except Exception as e: st.error(f"❌ Error downloading repository: {str(e)}") st.error("Please check the repository URL and your internet connection.") raise def _parse_config_params(self, config_name): """Parse configuration parameters into a dictionary""" parts = config_name.split('+') params = {} for part in parts: if '_' in part: key_parts = part.split('_') if len(key_parts) >= 2: key = '_'.join(key_parts[:-1]) value = key_parts[-1] params[key] = value == 'True' return params def _download_directory(self, dir_name, path=""): """Recursively download a directory from GitHub""" url = f"https://api.github.com/repos/{self.github_repo}/contents/{path}{dir_name}" try: response = requests.get(url) response.raise_for_status() contents = response.json() local_dir = self.cache_dir / path / dir_name local_dir.mkdir(parents=True, exist_ok=True) for item in contents: if item['type'] == 'file': self._download_file(item, local_dir) elif item['type'] == 'dir': self._download_directory(item['name'], f"{path}{dir_name}/") except Exception as e: st.warning(f"Could not download {dir_name}: {str(e)}") def _download_file(self, file_info, local_dir): """Download a single file from GitHub""" try: # Use the rate limit handling for file downloads too file_response = self._make_github_request(file_info['download_url'], f"file {file_info['name']}") if file_response is None: return # Rate limit or other error # Save to local cache local_file = local_dir / file_info['name'] # Handle different file types if file_info['name'].endswith(('.csv', '.json')): with open(local_file, 'w', encoding='utf-8') as f: f.write(file_response.text) else: # Binary files like PDFs with open(local_file, 'wb') as f: f.write(file_response.content) except Exception as e: st.warning(f"Could not download file {file_info['name']}: {str(e)}") def _get_available_languages(self): """Get all available language directories""" return self.available_languages def _get_experimental_configs(self, language): """Get all experimental configurations for a language from GitHub API""" api_url = f"https://api.github.com/repos/{self.github_repo}/contents/results_{language}" response = self._make_github_request(api_url, f"experimental configs for {language}") if response is not None: try: contents = response.json() configs = [item['name'] for item in contents if item['type'] == 'dir'] return sorted(configs) except Exception as e: st.warning(f"Could not parse experimental configs for {language}: {str(e)}") # Fallback to local cache if available lang_dir = self.base_path / f"results_{language}" if lang_dir.exists(): configs = [d.name for d in lang_dir.iterdir() if d.is_dir()] return sorted(configs) return [] def _find_matching_config(self, language, target_params): """Find the first matching configuration from target parameters""" return self._find_best_matching_config(language, target_params) def _get_models(self, language, config): """Get all models for a language and configuration from GitHub API""" api_url = f"https://api.github.com/repos/{self.github_repo}/contents/results_{language}/{config}" response = self._make_github_request(api_url, f"models for {language}/{config}") if response is not None: try: contents = response.json() models = [item['name'] for item in contents if item['type'] == 'dir'] return sorted(models) except Exception as e: st.warning(f"Could not parse models for {language}/{config}: {str(e)}") # Fallback to local cache if available config_dir = self.base_path / f"results_{language}" / config if config_dir.exists(): models = [d.name for d in config_dir.iterdir() if d.is_dir()] return sorted(models) return [] def _parse_config_name(self, config_name): """Parse configuration name into readable format""" parts = config_name.split('+') config_dict = {} for part in parts: if '_' in part: key, value = part.split('_', 1) config_dict[key.replace('_', ' ').title()] = value return config_dict def _load_metadata(self, language, config, model): """Load metadata for a specific combination""" # Ensure we have the specific data downloaded self._ensure_specific_data_downloaded(language, config, model) metadata_path = self.base_path / f"results_{language}" / config / model / "metadata" / "metadata.json" if metadata_path.exists(): with open(metadata_path, 'r') as f: return json.load(f) return None def _load_uas_scores(self, language, config, model): """Load UAS scores data""" # Ensure we have the specific data downloaded self._ensure_specific_data_downloaded(language, config, model) uas_dir = self.base_path / f"results_{language}" / config / model / "uas_scores" if not uas_dir.exists(): return {} uas_data = {} csv_files = list(uas_dir.glob("uas_*.csv")) if csv_files: with st.spinner("Loading UAS scores data..."): progress_bar = st.progress(0) status_text = st.empty() for i, csv_file in enumerate(csv_files): relation = csv_file.stem.replace("uas_", "") status_text.text(f"Loading UAS data: {relation}") try: df = pd.read_csv(csv_file, index_col=0) uas_data[relation] = df except Exception as e: st.warning(f"Could not load {csv_file.name}: {e}") progress_bar.progress((i + 1) / len(csv_files)) time.sleep(0.01) # Small delay for smoother progress progress_bar.empty() status_text.empty() return uas_data def _load_head_matching(self, language, config, model): """Load head matching data""" # Ensure we have the specific data downloaded self._ensure_specific_data_downloaded(language, config, model) heads_dir = self.base_path / f"results_{language}" / config / model / "number_of_heads_matching" if not heads_dir.exists(): return {} heads_data = {} csv_files = list(heads_dir.glob("heads_matching_*.csv")) if csv_files: with st.spinner("Loading head matching data..."): progress_bar = st.progress(0) status_text = st.empty() for i, csv_file in enumerate(csv_files): relation = csv_file.stem.replace("heads_matching_", "").replace(f"_{model}", "") status_text.text(f"Loading head matching data: {relation}") try: df = pd.read_csv(csv_file, index_col=0) heads_data[relation] = df except Exception as e: st.warning(f"Could not load {csv_file.name}: {e}") progress_bar.progress((i + 1) / len(csv_files)) time.sleep(0.01) # Small delay for smoother progress progress_bar.empty() status_text.empty() return heads_data def _load_variability(self, language, config, model): """Load variability data""" # Ensure we have the specific data downloaded self._ensure_specific_data_downloaded(language, config, model) var_path = self.base_path / f"results_{language}" / config / model / "variability" / "variability_list.csv" if var_path.exists(): try: return pd.read_csv(var_path, index_col=0) except Exception as e: st.warning(f"Could not load variability data: {e}") return None def _get_available_figures(self, language, config, model): """Get all available figure files""" # Ensure we have the specific data downloaded self._ensure_specific_data_downloaded(language, config, model) figures_dir = self.base_path / f"results_{language}" / config / model / "figures" if not figures_dir.exists(): return [] return list(figures_dir.glob("*.pdf")) def _handle_rate_limit_error(self, response): """Handle GitHub API rate limit errors with detailed user feedback""" if response.status_code in (403, 429): # Check if it's a rate limit error if 'rate limit' in response.text.lower() or 'api rate limit' in response.text.lower(): # Extract rate limit information from headers remaining = response.headers.get('x-ratelimit-remaining', 'unknown') reset_timestamp = response.headers.get('x-ratelimit-reset') limit = response.headers.get('x-ratelimit-limit', 'unknown') # Calculate reset time reset_time_str = "unknown" if reset_timestamp: try: reset_time = datetime.fromtimestamp(int(reset_timestamp), tz=timezone.utc) reset_time_str = reset_time.strftime("%Y-%m-%d %H:%M:%S UTC") # Calculate time until reset now = datetime.now(timezone.utc) time_until_reset = reset_time - now minutes_until_reset = int(time_until_reset.total_seconds() / 60) if minutes_until_reset > 0: reset_time_str += f" (in {minutes_until_reset} minutes)" except (ValueError, TypeError): pass # Display comprehensive rate limit information st.error("🚫 **GitHub API Rate Limit Exceeded**") with st.expander("📊 Rate Limit Details", expanded=True): col1, col2 = st.columns(2) with col1: st.metric("Requests Remaining", remaining) st.metric("Rate Limit", limit) with col2: st.metric("Reset Time", reset_time_str) if reset_timestamp: try: reset_time = datetime.fromtimestamp(int(reset_timestamp), tz=timezone.utc) now = datetime.now(timezone.utc) time_until_reset = reset_time - now if time_until_reset.total_seconds() > 0: st.metric("Time Until Reset", f"{int(time_until_reset.total_seconds() / 60)} minutes") except (ValueError, TypeError): pass return True # Indicates rate limit error was handled return False # Not a rate limit error def _make_github_request(self, url, description="GitHub API request", silent_404=False): """Make a GitHub API request with rate limit handling""" try: # Add GitHub token if available headers = {} github_token = os.environ.get('GITHUB_TOKEN') if github_token: headers['Authorization'] = f'token {github_token}' response = requests.get(url, headers=headers) # Check for rate limit before raising for status if self._handle_rate_limit_error(response): return None # Rate limit handled, return None # Handle 404 errors silently if requested (for optional directories) if response.status_code == 404 and silent_404: return None response.raise_for_status() return response except requests.exceptions.RequestException as e: if hasattr(e, 'response') and e.response is not None: # Handle 404 silently if requested if e.response.status_code == 404 and silent_404: return None if not self._handle_rate_limit_error(e.response): st.warning(f"Request failed for {description}: {str(e)}") else: st.warning(f"Network error for {description}: {str(e)}") return None def main(): # Title st.markdown('
🔍 Attention Analysis Results Explorer
', unsafe_allow_html=True) # Sidebar for navigation st.sidebar.title("🔧 Configuration") # Cache management section st.sidebar.markdown("### 📁 Data Management") # Initialize explorer use_cache = st.sidebar.checkbox("Use cached data", value=True, help="Use previously downloaded data if available") if st.sidebar.button("🔄 Clear Cache", help="Clear all cached data"): # Clear cache and re-download cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache" if cache_dir.exists(): shutil.rmtree(cache_dir) st.sidebar.success("✅ Cache cleared!") st.rerun() # Show cache status cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache" if cache_dir.exists(): # Get more detailed cache information cached_items = [] for lang_dir in cache_dir.iterdir(): if lang_dir.is_dir() and lang_dir.name.startswith("results_"): lang = lang_dir.name.replace("results_", "") configs = [d.name for d in lang_dir.iterdir() if d.is_dir()] if configs: models_count = 0 for config_dir in lang_dir.iterdir(): if config_dir.is_dir(): models = [d.name for d in config_dir.iterdir() if d.is_dir()] models_count += len(models) cached_items.append(f"{lang} ({len(configs)} configs, {models_count} models)") if cached_items: st.sidebar.success("✅ **Cached Data:**") for item in cached_items[:3]: # Show first 3 st.sidebar.text(f"• {item}") if len(cached_items) > 3: st.sidebar.text(f"... and {len(cached_items) - 3} more") else: st.sidebar.info("📥 Cache exists but empty") else: st.sidebar.info("📥 No cached data") st.sidebar.markdown("---") # Initialize explorer with error handling try: with st.spinner("🔄 Initializing attention analysis explorer..."): explorer = AttentionResultsExplorer(use_cache=use_cache) except Exception as e: st.error(f"❌ Failed to initialize data explorer: {str(e)}") st.error("Please check your internet connection and try again.") # Show some debugging information with st.expander("🔍 Debugging Information"): st.code(f"Error details: {str(e)}") st.markdown("**Possible solutions:**") st.markdown("- Check your internet connection") st.markdown("- Try clearing the cache") st.markdown("- Wait a moment and refresh the page") return # Check if any languages are available if not explorer.available_languages: st.error("❌ No result data found. Please check the GitHub repository.") st.markdown("**Expected repository structure:**") st.markdown("- Repository should contain `results_*` directories") st.markdown("- Each directory should contain experimental configurations") return # Show success message st.sidebar.success(f"✅ Found {len(explorer.available_languages)} languages: {', '.join(explorer.available_languages)}") # Language selection selected_language = st.sidebar.selectbox( "Select Language", options=explorer.available_languages, help="Choose the language dataset to explore" ) st.sidebar.markdown("---") # Configuration selection with dynamic discovery st.sidebar.markdown("### ⚙️ Experimental Configuration") # Discover available configuration parameters (optimized to use first language only) with st.spinner("🔍 Discovering configuration options..."): config_parameters = explorer._discover_config_parameters() if not config_parameters: st.sidebar.error("❌ Could not discover configuration parameters") st.stop() # Show discovered parameters st.sidebar.success(f"✅ Found {len(config_parameters)} configuration parameters") st.sidebar.info("💡 Configuration options are consistent across all languages - using optimized discovery") # Create UI elements for each discovered parameter selected_params = {} for param_name, possible_values in config_parameters.items(): # Clean up parameter name for display display_name = param_name.replace('_', ' ').title() if len(possible_values) == 2 and set(possible_values) == {True, False}: # Boolean parameter - use checkbox default_value = False # Default to False for boolean params selected_params[param_name] = st.sidebar.checkbox( display_name, value=default_value, help=f"Parameter: {param_name}" ) else: # Multi-value parameter - use selectbox selected_params[param_name] = st.sidebar.selectbox( display_name, options=possible_values, help=f"Parameter: {param_name}" ) # Find the best matching configuration selected_config, config_exists = explorer._find_matching_config(selected_language, selected_params) # Show current configuration st.sidebar.markdown("**Selected Parameters:**") for param, value in selected_params.items(): emoji = "✅" if value else "❌" if isinstance(value, bool) else "🔹" st.sidebar.text(f"{emoji} {param}: {value}") st.sidebar.markdown("**Matched Configuration:**") st.sidebar.code(selected_config if selected_config else "No match found", language="text") # Show configuration status if config_exists: st.sidebar.success("✅ Exact configuration match found!") else: st.sidebar.warning("⚠️ Using best available match") st.sidebar.markdown("---") # Get models for selected language and config if not selected_config: st.error("❌ No valid configuration found") st.info("Please try different parameter combinations.") st.stop() models = explorer._get_models(selected_language, selected_config) if not models: st.warning(f"❌ No models found for {selected_language}/{selected_config}") st.info("This configuration may not exist for the selected language. Try adjusting the configuration parameters above.") st.stop() # Model selection selected_model = st.sidebar.selectbox( "Select Model", options=models, help="Choose the model to analyze" ) # Main content area tab1, tab2, tab3, tab4, tab5 = st.tabs([ "📊 Overview", "🎯 UAS Scores", "🧠 Head Matching", "📈 Variability", "🖼️ Figures" ]) # Tab 1: Overview with tab1: st.markdown('
Experiment Overview
', unsafe_allow_html=True) # Show current configuration in a friendly format st.markdown("### 🔧 Current Configuration") config_params = explorer._parse_config_params(selected_config) col1, col2 = st.columns(2) with col1: st.markdown("**Configuration Parameters:**") for param, value in config_params.items(): emoji = "✅" if value else "❌" if isinstance(value, bool) else "🔹" readable_param = param.replace('_', ' ').title() st.markdown(f"{emoji} **{readable_param}**: {value}") with col2: st.markdown("**Selected Parameters vs Actual:**") for param in selected_params: selected_val = selected_params[param] actual_val = config_params.get(param, "N/A") match_emoji = "✅" if selected_val == actual_val else "⚠️" st.markdown(f"{match_emoji} **{param}**: {selected_val} → {actual_val}") st.markdown("**Raw Configuration String:**") st.code(selected_config, language="text") st.markdown("---") # Load metadata metadata = explorer._load_metadata(selected_language, selected_config, selected_model) if metadata: st.markdown("### 📊 Experiment Statistics") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Samples", metadata.get('total_number', 'N/A')) with col2: st.metric("Processed Correctly", metadata.get('number_processed_correctly', 'N/A')) with col3: st.metric("Errors", metadata.get('number_errored', 'N/A')) with col4: success_rate = (metadata.get('number_processed_correctly', 0) / metadata.get('total_number', 1)) * 100 if metadata.get('total_number') else 0 st.metric("Success Rate", f"{success_rate:.1f}%") if metadata.get('random_seed'): st.markdown(f"**Random Seed:** {metadata.get('random_seed')}") if metadata.get('errored_phrases'): with st.expander("🔍 View Errored Phrase IDs"): st.write(metadata['errored_phrases']) else: st.warning("No metadata available for this configuration.") # Quick stats about available data st.markdown("---") st.markdown('
Available Data Summary
', unsafe_allow_html=True) # Show loading message since we're now loading on-demand with st.spinner("Loading data summary..."): uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model) heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model) variability_data = explorer._load_variability(selected_language, selected_config, selected_model) figures = explorer._get_available_figures(selected_language, selected_config, selected_model) col1, col2, col3, col4 = st.columns(4) with col1: st.metric("UAS Relations", len(uas_data)) with col2: st.metric("Head Matching Relations", len(heads_data)) with col3: st.metric("Variability Data", "✓" if variability_data is not None else "✗") with col4: st.metric("Figure Files", len(figures)) # Show what was just downloaded if uas_data or heads_data or variability_data is not None or figures: st.success(f"✅ Successfully loaded data for {selected_language.upper()}/{selected_model}") else: st.warning("⚠️ No data files found for this configuration") # Tab 2: UAS Scores with tab2: st.markdown('
UAS (Unlabeled Attachment Score) Analysis
', unsafe_allow_html=True) uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model) if uas_data: # Relation selection selected_relation = st.selectbox( "Select Dependency Relation", options=list(uas_data.keys()), help="Choose a dependency relation to visualize UAS scores" ) if selected_relation and selected_relation in uas_data: df = uas_data[selected_relation] # Display the data table st.markdown("**UAS Scores Matrix (Layer × Head)**") st.dataframe(df, use_container_width=True) # Create heatmap fig = px.imshow( df.values, x=[f"Head {i}" for i in df.columns], y=[f"Layer {i}" for i in df.index], color_continuous_scale="Viridis", title=f"UAS Scores Heatmap - {selected_relation}", labels=dict(color="UAS Score") ) fig.update_layout(height=600) st.plotly_chart(fig, use_container_width=True) # Statistics st.markdown("**Statistics**") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Max Score", f"{df.values.max():.4f}") with col2: st.metric("Min Score", f"{df.values.min():.4f}") with col3: st.metric("Mean Score", f"{df.values.mean():.4f}") with col4: st.metric("Std Dev", f"{df.values.std():.4f}") else: st.warning("No UAS score data available for this configuration.") # Tab 3: Head Matching with tab3: st.markdown('
Attention Head Matching Analysis
', unsafe_allow_html=True) heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model) if heads_data: # Relation selection selected_relation = st.selectbox( "Select Dependency Relation", options=list(heads_data.keys()), help="Choose a dependency relation to visualize head matching patterns", key="heads_relation" ) if selected_relation and selected_relation in heads_data: df = heads_data[selected_relation] # Display the data table st.markdown("**Head Matching Counts Matrix (Layer × Head)**") st.dataframe(df, use_container_width=True) # Create heatmap fig = px.imshow( df.values, x=[f"Head {i}" for i in df.columns], y=[f"Layer {i}" for i in df.index], color_continuous_scale="Blues", title=f"Head Matching Counts - {selected_relation}", labels=dict(color="Match Count") ) fig.update_layout(height=600) st.plotly_chart(fig, use_container_width=True) # Create bar chart of total matches per layer layer_totals = df.sum(axis=1) fig_bar = px.bar( x=layer_totals.index, y=layer_totals.values, title=f"Total Matches per Layer - {selected_relation}", labels={"x": "Layer", "y": "Total Matches"} ) fig_bar.update_layout(height=400) st.plotly_chart(fig_bar, use_container_width=True) # Statistics st.markdown("**Statistics**") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Matches", int(df.values.sum())) with col2: st.metric("Max per Cell", int(df.values.max())) with col3: best_layer = layer_totals.idxmax() st.metric("Best Layer", f"Layer {best_layer}") with col4: best_head_idx = np.unravel_index(df.values.argmax(), df.values.shape) st.metric("Best Head", f"L{best_head_idx[0]}-H{best_head_idx[1]}") else: st.warning("No head matching data available for this configuration.") # Tab 4: Variability with tab4: st.markdown('
Attention Variability Analysis
', unsafe_allow_html=True) variability_data = explorer._load_variability(selected_language, selected_config, selected_model) if variability_data is not None: # Display the data table st.markdown("**Variability Matrix (Layer × Head)**") st.dataframe(variability_data, use_container_width=True) # Create heatmap fig = px.imshow( variability_data.values, x=[f"Head {i}" for i in variability_data.columns], y=[f"Layer {i}" for i in variability_data.index], color_continuous_scale="Reds", title="Attention Variability Heatmap", labels=dict(color="Variability Score") ) fig.update_layout(height=600) st.plotly_chart(fig, use_container_width=True) # Create line plot for variability trends fig_line = go.Figure() for col in variability_data.columns: fig_line.add_trace(go.Scatter( x=variability_data.index, y=variability_data[col], mode='lines+markers', name=f'Head {col}', line=dict(width=2) )) fig_line.update_layout( title="Variability Trends Across Layers", xaxis_title="Layer", yaxis_title="Variability Score", height=500 ) st.plotly_chart(fig_line, use_container_width=True) # Statistics st.markdown("**Statistics**") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Max Variability", f"{variability_data.values.max():.4f}") with col2: st.metric("Min Variability", f"{variability_data.values.min():.4f}") with col3: st.metric("Mean Variability", f"{variability_data.values.mean():.4f}") with col4: most_variable_idx = np.unravel_index(variability_data.values.argmax(), variability_data.values.shape) st.metric("Most Variable", f"L{most_variable_idx[0]}-H{most_variable_idx[1]}") else: st.warning("No variability data available for this configuration.") # Tab 5: Figures with tab5: st.markdown('
Generated Figures
', unsafe_allow_html=True) figures = explorer._get_available_figures(selected_language, selected_config, selected_model) if figures: st.markdown(f"**Available Figures: {len(figures)}**") # Group figures by relation type figure_groups = {} for fig_path in figures: # Extract relation from filename filename = fig_path.stem relation = filename.replace("heads_matching_", "").replace(f"_{selected_model}", "") if relation not in figure_groups: figure_groups[relation] = [] figure_groups[relation].append(fig_path) # Select relation to view selected_fig_relation = st.selectbox( "Select Relation for Figure View", options=list(figure_groups.keys()), help="Choose a dependency relation to view its figure" ) if selected_fig_relation and selected_fig_relation in figure_groups: fig_path = figure_groups[selected_fig_relation][0] st.markdown(f"**Figure: {fig_path.name}**") st.markdown(f"**Path:** `{fig_path}`") # Note about PDF viewing st.info( "📄 PDF figures are available in the results directory. " "Due to Streamlit limitations, PDF files cannot be displayed directly in the browser. " "You can download or view them locally." ) # Provide download link try: with open(fig_path, "rb") as file: st.download_button( label=f"📥 Download {fig_path.name}", data=file.read(), file_name=fig_path.name, mime="application/pdf" ) except Exception as e: st.error(f"Could not load figure: {e}") # List all available figures st.markdown("**All Available Figures:**") for relation, paths in figure_groups.items(): with st.expander(f"📊 {relation} ({len(paths)} files)"): for path in paths: st.markdown(f"- `{path.name}`") else: st.warning("No figures available for this configuration.") # Footer st.markdown("---") # Data source information col1, col2 = st.columns([2, 1]) with col1: st.markdown( "🔬 **Attention Analysis Results Explorer** | " f"Currently viewing: {selected_language.upper()} - {selected_model} | " "Built with Streamlit" ) with col2: st.markdown( f"📊 **Data Source**: [GitHub Repository](https://github.com/{explorer.github_repo})" ) if __name__ == "__main__": main()