Spaces:
Sleeping
Sleeping
import os | |
import json | |
import glob | |
import gradio as gr | |
from collections import defaultdict | |
# --- Configuration --- | |
# Base path where all dataset folders are located | |
BASE_DATA_DIRECTORY = "./" | |
# Names of the VLMs and their corresponding keys used in file names | |
VLM_MODELS = { | |
"GPT-4o": "4o", | |
"OpenAI o1": "o1", | |
"Gemini 2.5 Pro": "gemini", | |
"Qwen 2.5 VL": "qwen" | |
} | |
# Configuration for each dataset | |
DATASET_CONFIG = { | |
"AITW": { | |
"display_name": "AITW", | |
"base_dir": os.path.join(BASE_DATA_DIRECTORY, ""), # Base dir is the root for aitw | |
"json_patterns": ["aitw_{model_key}_dataset.json", "aitw_{model_key}_dataset1.json"], | |
"data_is_nested": True, # The JSON is a dict of episodes, which contain steps | |
}, | |
"Where2Place": { | |
"display_name": "Where2Place", | |
"base_dir": os.path.join(BASE_DATA_DIRECTORY, "where2place"), | |
"json_patterns": ["where2place_mcq_{model_key}.json"], | |
}, | |
"MONDAY": { | |
"display_name": "MONDAY", | |
"base_dir": os.path.join(BASE_DATA_DIRECTORY, "Monday"), | |
"json_patterns": ["monday_mcq_test_{model_key}.json", "monday_mcq_test_unseen_os_{model_key}.json"], | |
}, | |
"RoboVQA": { | |
"display_name": "RoboVQA", | |
"base_dir": os.path.join(BASE_DATA_DIRECTORY, "robovqa"), | |
"json_patterns": ["robovqa_final_dataset_{model_key}.json"], | |
} | |
} | |
# --- Data Loading and Processing --- | |
def load_data_for_dataset(dataset_key): | |
""" | |
Loads and structures data for a given dataset from its JSON files. | |
Returns a dictionary where keys are unique sample IDs and values are | |
dictionaries mapping VLM model keys to their specific data for that sample. | |
e.g., {'episode_123:step_0': {'4o': {...}, 'o1': {...}}, ...} | |
""" | |
if dataset_key not in DATASET_CONFIG: | |
return {} | |
config = DATASET_CONFIG[dataset_key] | |
unified_data = defaultdict(dict) | |
print(f"Loading data for dataset: {dataset_key}") | |
for display_name, model_key in VLM_MODELS.items(): | |
all_entries = [] | |
for pattern in config["json_patterns"]: | |
# Construct the full file path pattern | |
full_pattern = os.path.join(config["base_dir"], pattern.format(model_key=model_key)) | |
# Find all matching files | |
json_files = glob.glob(full_pattern) | |
for file_path in json_files: | |
print(f" - Reading file: {file_path}") | |
try: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
if isinstance(data, list): | |
all_entries.extend(data) | |
elif isinstance(data, dict): | |
# Handle AITW's nested structure | |
if config.get("data_is_nested"): | |
for episode_id, episode_data in data.items(): | |
for step in episode_data.get("steps", []): | |
# Add episode context to each step | |
step_with_context = step.copy() | |
step_with_context['episode_id'] = episode_id | |
step_with_context['episode_goal'] = episode_data.get('episode_goal') | |
all_entries.append(step_with_context) | |
except FileNotFoundError: | |
print(f" - WARNING: File not found: {file_path}") | |
except json.JSONDecodeError: | |
print(f" - WARNING: Could not decode JSON from: {file_path}") | |
# Process loaded entries and add to the unified dictionary | |
for i, entry in enumerate(all_entries): | |
sample_id = None | |
if dataset_key == "AITW": | |
sample_id = f"{entry.get('episode_id', 'unknown_ep')}:{entry.get('step_id', 'unknown_step')}" | |
elif dataset_key == "Where2Place": | |
sample_id = f"q_{entry.get('question_id', i)}" | |
elif dataset_key == "MONDAY": | |
sample_id = f"{entry.get('episode_id', 'unknown_ep')}:{entry.get('step_id', i)}" | |
elif dataset_key == "RoboVQA": | |
sample_id = f"{entry.get('episode_id', i)}" | |
if sample_id: | |
unified_data[sample_id][model_key] = entry | |
# Sort sample IDs for consistent ordering in the dropdown | |
sorted_unified_data = {k: unified_data[k] for k in sorted(unified_data.keys())} | |
print(f"Finished loading. Found {len(sorted_unified_data)} unique samples.") | |
return sorted_unified_data | |
def format_mcq_options(options, correct_index): | |
"""Formats MCQ options into a Markdown string, highlighting the correct one.""" | |
if not isinstance(options, list): | |
return "Options not available." | |
lines = [] | |
for i, option in enumerate(options): | |
# The correct answer in JSON can be 1-based or 0-based index. Check both. | |
is_correct = (i == correct_index) | |
prefix = "✅ **" if is_correct else "" | |
suffix = "**" if is_correct else "" | |
lines.append(f"- {prefix}{option}{suffix}") | |
return "\n".join(lines) | |
# --- Gradio UI Application --- | |
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo: | |
gr.Markdown("# VLM Comparative Benchmark Visualizer") | |
gr.Markdown("Select a dataset to load evaluation samples. The interface will display the same question/task evaluated across four different VLMs.") | |
# --- State Management --- | |
all_data_state = gr.State({}) | |
# --- UI Components --- | |
with gr.Row(): | |
dataset_selector = gr.Dropdown( | |
choices=list(DATASET_CONFIG.keys()), | |
label="1. Select a Dataset", | |
value="AITW" # Default value | |
) | |
sample_selector = gr.Dropdown( | |
label="2. Select a Sample / Episode Step", | |
interactive=True, | |
# Choices will be populated dynamically | |
) | |
shared_info_display = gr.Markdown(visible=False) # For goal, common question, etc. | |
with gr.Row(equal_height=False): | |
vlm_outputs = [] | |
for vlm_display_name, vlm_key in VLM_MODELS.items(): | |
with gr.Column(scale=1): | |
with gr.Blocks(): | |
gr.Markdown(f"### {vlm_display_name}") | |
media_display = gr.Image(label="Media", type="filepath", interactive=False, height=400) | |
info_display = gr.Markdown() | |
vlm_outputs.append((media_display, info_display)) | |
# --- UI Update Logic --- | |
def handle_dataset_selection(dataset_key): | |
""" | |
Triggered when a new dataset is selected. | |
Loads all data for that dataset and populates the sample selector. | |
""" | |
print(f"UI: Dataset selection changed to '{dataset_key}'") | |
if not dataset_key: | |
return { | |
all_data_state: {}, | |
sample_selector: gr.update(choices=[], value=None), | |
} | |
data = load_data_for_dataset(dataset_key) | |
sample_ids = list(data.keys()) | |
first_sample = sample_ids[0] if sample_ids else None | |
return { | |
all_data_state: data, | |
sample_selector: gr.update(choices=sample_ids, value=first_sample, visible=True), | |
} | |
def handle_sample_selection(dataset_key, sample_id, all_data): | |
""" | |
Triggered when a new sample is selected. | |
Updates the four columns with the data for that sample. | |
""" | |
print(f"UI: Sample selection changed to '{sample_id}'") | |
if not sample_id or not all_data: | |
# Create empty updates for all components if there's no selection | |
updates = [gr.update(visible=False)] + [gr.update(value=None, visible=False)] * len(vlm_outputs) * 2 | |
return dict(zip([shared_info_display] + [item for sublist in vlm_outputs for item in sublist], updates)) | |
sample_data_for_all_vlms = all_data.get(sample_id, {}) | |
# --- 1. Update Shared Information Display --- | |
shared_md_parts = [] | |
# Use data from the first available VLM to populate shared info | |
first_vlm_key = next(iter(VLM_MODELS.values())) | |
first_vlm_data = sample_data_for_all_vlms.get(first_vlm_key, {}) | |
if dataset_key == "AITW": | |
shared_md_parts.append(f"**Goal:** `{first_vlm_data.get('episode_goal', 'N/A')}`") | |
shared_md_parts.append(f"**Question:** *{first_vlm_data.get('questions', {}).get('question', 'N/A')}*") | |
elif dataset_key == "MONDAY": | |
shared_md_parts.append(f"**Goal:** `{first_vlm_data.get('goal', 'N/A')}`") | |
shared_md_parts.append(f"**OS:** {first_vlm_data.get('os', 'N/A')}") | |
elif dataset_key == "RoboVQA": | |
shared_md_parts.append(f"**Task Type:** {first_vlm_data.get('task_type', 'N/A')}") | |
# Where2Place has its question per-VLM, so no shared info needed. | |
shared_info_update = gr.update(value="\n\n".join(shared_md_parts), visible=bool(shared_md_parts)) | |
# --- 2. Update Each VLM Column --- | |
column_updates = [] | |
config = DATASET_CONFIG[dataset_key] | |
for vlm_display_name, vlm_key in VLM_MODELS.items(): | |
vlm_data = sample_data_for_all_vlms.get(vlm_key) | |
if not vlm_data: | |
column_updates.extend([gr.update(value=None, visible=True), gr.update(value="*Data not found for this sample.*")]) | |
continue | |
# Find image/media path | |
media_path = None | |
if dataset_key == "AITW": media_path = vlm_data.get('screenshot_path') | |
elif dataset_key == "Where2Place": media_path = vlm_data.get('marked_image_path') | |
elif dataset_key == "MONDAY": media_path = vlm_data.get('screenshot_path') | |
elif dataset_key == "RoboVQA": media_path = vlm_data.get('media_path') | |
# Construct absolute path if relative | |
absolute_media_path = None | |
if media_path: | |
# The AITW paths are absolute, others are relative. | |
if os.path.isabs(media_path): | |
absolute_media_path = media_path | |
else: | |
absolute_media_path = os.path.join(config['base_dir'], media_path) | |
# Build the markdown content for the info box | |
md_content = [] | |
if dataset_key == "AITW": | |
md_content.append(f"**Action History:**\n```\n{vlm_data.get('action_history', 'None')}\n```") | |
options = vlm_data.get('questions', {}).get('options') | |
answer_idx = vlm_data.get('questions', {}).get('correct_answer_index') | |
md_content.append(format_mcq_options(options, answer_idx)) | |
elif dataset_key == "Where2Place": | |
md_content.append(f"**Question:** *{vlm_data.get('question', 'N/A')}*") | |
options = vlm_data.get('options') | |
answer_idx = vlm_data.get('answer') | |
md_content.append(format_mcq_options(options, answer_idx)) | |
elif dataset_key == "MONDAY": | |
md_content.append(f"**Question:** *{vlm_data.get('current_question', 'N/A')}*") | |
md_content.append(f"**Action History:**\n```\n{vlm_data.get('action_history', 'None')}\n```") | |
options = vlm_data.get('options') | |
answer_idx = vlm_data.get('answer') | |
md_content.append(format_mcq_options(options, answer_idx)) | |
elif dataset_key == "RoboVQA": | |
md_content.append(f"**Question:** *{vlm_data.get('question', 'N/A')}*") | |
options = vlm_data.get('options') | |
answer_idx = vlm_data.get('answer') | |
md_content.append(format_mcq_options(options, answer_idx)) | |
image_update = gr.update(value=absolute_media_path if absolute_media_path and os.path.exists(absolute_media_path) else None, visible=True) | |
info_update = gr.update(value="\n\n".join(md_content)) | |
column_updates.extend([image_update, info_update]) | |
# Combine all updates into a single dictionary to return | |
output_components = [shared_info_display] + [item for sublist in vlm_outputs for item in sublist] | |
return dict(zip(output_components, [shared_info_update] + column_updates)) | |
# --- Event Listeners --- | |
# When the app loads, trigger the dataset selection change to load the default dataset | |
demo.load( | |
fn=handle_dataset_selection, | |
inputs=[dataset_selector], | |
outputs=[all_data_state, sample_selector] | |
) | |
# When the dataset is changed by the user | |
dataset_selector.change( | |
fn=handle_dataset_selection, | |
inputs=[dataset_selector], | |
outputs=[all_data_state, sample_selector] | |
) | |
# When a new sample is selected, trigger the main display update | |
# This also gets triggered automatically after the dataset selection changes the sample dropdown | |
sample_selector.change( | |
fn=handle_sample_selection, | |
inputs=[dataset_selector, sample_selector, all_data_state], | |
outputs=[shared_info_display] + [item for sublist in vlm_outputs for item in sublist] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True, debug=True, allowed_paths=["/n/fs/vision-mix/ag9604/visualizer/"]) |