VLM_Comparison / visualizer1.py
advaitgupta's picture
Update visualizer1.py
5c8ee13 verified
raw
history blame
13.6 kB
import os
import json
import glob
import gradio as gr
from collections import defaultdict
# --- Configuration ---
# Base path where all dataset folders are located
BASE_DATA_DIRECTORY = "./"
# Names of the VLMs and their corresponding keys used in file names
VLM_MODELS = {
"GPT-4o": "4o",
"OpenAI o1": "o1",
"Gemini 2.5 Pro": "gemini",
"Qwen 2.5 VL": "qwen"
}
# Configuration for each dataset
DATASET_CONFIG = {
"AITW": {
"display_name": "AITW",
"base_dir": os.path.join(BASE_DATA_DIRECTORY, ""), # Base dir is the root for aitw
"json_patterns": ["aitw_{model_key}_dataset.json", "aitw_{model_key}_dataset1.json"],
"data_is_nested": True, # The JSON is a dict of episodes, which contain steps
},
"Where2Place": {
"display_name": "Where2Place",
"base_dir": os.path.join(BASE_DATA_DIRECTORY, "where2place"),
"json_patterns": ["where2place_mcq_{model_key}.json"],
},
"MONDAY": {
"display_name": "MONDAY",
"base_dir": os.path.join(BASE_DATA_DIRECTORY, "Monday"),
"json_patterns": ["monday_mcq_test_{model_key}.json", "monday_mcq_test_unseen_os_{model_key}.json"],
},
"RoboVQA": {
"display_name": "RoboVQA",
"base_dir": os.path.join(BASE_DATA_DIRECTORY, "robovqa"),
"json_patterns": ["robovqa_final_dataset_{model_key}.json"],
}
}
# --- Data Loading and Processing ---
def load_data_for_dataset(dataset_key):
"""
Loads and structures data for a given dataset from its JSON files.
Returns a dictionary where keys are unique sample IDs and values are
dictionaries mapping VLM model keys to their specific data for that sample.
e.g., {'episode_123:step_0': {'4o': {...}, 'o1': {...}}, ...}
"""
if dataset_key not in DATASET_CONFIG:
return {}
config = DATASET_CONFIG[dataset_key]
unified_data = defaultdict(dict)
print(f"Loading data for dataset: {dataset_key}")
for display_name, model_key in VLM_MODELS.items():
all_entries = []
for pattern in config["json_patterns"]:
# Construct the full file path pattern
full_pattern = os.path.join(config["base_dir"], pattern.format(model_key=model_key))
# Find all matching files
json_files = glob.glob(full_pattern)
for file_path in json_files:
print(f" - Reading file: {file_path}")
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
all_entries.extend(data)
elif isinstance(data, dict):
# Handle AITW's nested structure
if config.get("data_is_nested"):
for episode_id, episode_data in data.items():
for step in episode_data.get("steps", []):
# Add episode context to each step
step_with_context = step.copy()
step_with_context['episode_id'] = episode_id
step_with_context['episode_goal'] = episode_data.get('episode_goal')
all_entries.append(step_with_context)
except FileNotFoundError:
print(f" - WARNING: File not found: {file_path}")
except json.JSONDecodeError:
print(f" - WARNING: Could not decode JSON from: {file_path}")
# Process loaded entries and add to the unified dictionary
for i, entry in enumerate(all_entries):
sample_id = None
if dataset_key == "AITW":
sample_id = f"{entry.get('episode_id', 'unknown_ep')}:{entry.get('step_id', 'unknown_step')}"
elif dataset_key == "Where2Place":
sample_id = f"q_{entry.get('question_id', i)}"
elif dataset_key == "MONDAY":
sample_id = f"{entry.get('episode_id', 'unknown_ep')}:{entry.get('step_id', i)}"
elif dataset_key == "RoboVQA":
sample_id = f"{entry.get('episode_id', i)}"
if sample_id:
unified_data[sample_id][model_key] = entry
# Sort sample IDs for consistent ordering in the dropdown
sorted_unified_data = {k: unified_data[k] for k in sorted(unified_data.keys())}
print(f"Finished loading. Found {len(sorted_unified_data)} unique samples.")
return sorted_unified_data
def format_mcq_options(options, correct_index):
"""Formats MCQ options into a Markdown string, highlighting the correct one."""
if not isinstance(options, list):
return "Options not available."
lines = []
for i, option in enumerate(options):
# The correct answer in JSON can be 1-based or 0-based index. Check both.
is_correct = (i == correct_index)
prefix = "✅ **" if is_correct else ""
suffix = "**" if is_correct else ""
lines.append(f"- {prefix}{option}{suffix}")
return "\n".join(lines)
# --- Gradio UI Application ---
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo:
gr.Markdown("# VLM Comparative Benchmark Visualizer")
gr.Markdown("Select a dataset to load evaluation samples. The interface will display the same question/task evaluated across four different VLMs.")
# --- State Management ---
all_data_state = gr.State({})
# --- UI Components ---
with gr.Row():
dataset_selector = gr.Dropdown(
choices=list(DATASET_CONFIG.keys()),
label="1. Select a Dataset",
value="AITW" # Default value
)
sample_selector = gr.Dropdown(
label="2. Select a Sample / Episode Step",
interactive=True,
# Choices will be populated dynamically
)
shared_info_display = gr.Markdown(visible=False) # For goal, common question, etc.
with gr.Row(equal_height=False):
vlm_outputs = []
for vlm_display_name, vlm_key in VLM_MODELS.items():
with gr.Column(scale=1):
with gr.Blocks():
gr.Markdown(f"### {vlm_display_name}")
media_display = gr.Image(label="Media", type="filepath", interactive=False, height=400)
info_display = gr.Markdown()
vlm_outputs.append((media_display, info_display))
# --- UI Update Logic ---
def handle_dataset_selection(dataset_key):
"""
Triggered when a new dataset is selected.
Loads all data for that dataset and populates the sample selector.
"""
print(f"UI: Dataset selection changed to '{dataset_key}'")
if not dataset_key:
return {
all_data_state: {},
sample_selector: gr.update(choices=[], value=None),
}
data = load_data_for_dataset(dataset_key)
sample_ids = list(data.keys())
first_sample = sample_ids[0] if sample_ids else None
return {
all_data_state: data,
sample_selector: gr.update(choices=sample_ids, value=first_sample, visible=True),
}
def handle_sample_selection(dataset_key, sample_id, all_data):
"""
Triggered when a new sample is selected.
Updates the four columns with the data for that sample.
"""
print(f"UI: Sample selection changed to '{sample_id}'")
if not sample_id or not all_data:
# Create empty updates for all components if there's no selection
updates = [gr.update(visible=False)] + [gr.update(value=None, visible=False)] * len(vlm_outputs) * 2
return dict(zip([shared_info_display] + [item for sublist in vlm_outputs for item in sublist], updates))
sample_data_for_all_vlms = all_data.get(sample_id, {})
# --- 1. Update Shared Information Display ---
shared_md_parts = []
# Use data from the first available VLM to populate shared info
first_vlm_key = next(iter(VLM_MODELS.values()))
first_vlm_data = sample_data_for_all_vlms.get(first_vlm_key, {})
if dataset_key == "AITW":
shared_md_parts.append(f"**Goal:** `{first_vlm_data.get('episode_goal', 'N/A')}`")
shared_md_parts.append(f"**Question:** *{first_vlm_data.get('questions', {}).get('question', 'N/A')}*")
elif dataset_key == "MONDAY":
shared_md_parts.append(f"**Goal:** `{first_vlm_data.get('goal', 'N/A')}`")
shared_md_parts.append(f"**OS:** {first_vlm_data.get('os', 'N/A')}")
elif dataset_key == "RoboVQA":
shared_md_parts.append(f"**Task Type:** {first_vlm_data.get('task_type', 'N/A')}")
# Where2Place has its question per-VLM, so no shared info needed.
shared_info_update = gr.update(value="\n\n".join(shared_md_parts), visible=bool(shared_md_parts))
# --- 2. Update Each VLM Column ---
column_updates = []
config = DATASET_CONFIG[dataset_key]
for vlm_display_name, vlm_key in VLM_MODELS.items():
vlm_data = sample_data_for_all_vlms.get(vlm_key)
if not vlm_data:
column_updates.extend([gr.update(value=None, visible=True), gr.update(value="*Data not found for this sample.*")])
continue
# Find image/media path
media_path = None
if dataset_key == "AITW": media_path = vlm_data.get('screenshot_path')
elif dataset_key == "Where2Place": media_path = vlm_data.get('marked_image_path')
elif dataset_key == "MONDAY": media_path = vlm_data.get('screenshot_path')
elif dataset_key == "RoboVQA": media_path = vlm_data.get('media_path')
# Construct absolute path if relative
absolute_media_path = None
if media_path:
# The AITW paths are absolute, others are relative.
if os.path.isabs(media_path):
absolute_media_path = media_path
else:
absolute_media_path = os.path.join(config['base_dir'], media_path)
# Build the markdown content for the info box
md_content = []
if dataset_key == "AITW":
md_content.append(f"**Action History:**\n```\n{vlm_data.get('action_history', 'None')}\n```")
options = vlm_data.get('questions', {}).get('options')
answer_idx = vlm_data.get('questions', {}).get('correct_answer_index')
md_content.append(format_mcq_options(options, answer_idx))
elif dataset_key == "Where2Place":
md_content.append(f"**Question:** *{vlm_data.get('question', 'N/A')}*")
options = vlm_data.get('options')
answer_idx = vlm_data.get('answer')
md_content.append(format_mcq_options(options, answer_idx))
elif dataset_key == "MONDAY":
md_content.append(f"**Question:** *{vlm_data.get('current_question', 'N/A')}*")
md_content.append(f"**Action History:**\n```\n{vlm_data.get('action_history', 'None')}\n```")
options = vlm_data.get('options')
answer_idx = vlm_data.get('answer')
md_content.append(format_mcq_options(options, answer_idx))
elif dataset_key == "RoboVQA":
md_content.append(f"**Question:** *{vlm_data.get('question', 'N/A')}*")
options = vlm_data.get('options')
answer_idx = vlm_data.get('answer')
md_content.append(format_mcq_options(options, answer_idx))
image_update = gr.update(value=absolute_media_path if absolute_media_path and os.path.exists(absolute_media_path) else None, visible=True)
info_update = gr.update(value="\n\n".join(md_content))
column_updates.extend([image_update, info_update])
# Combine all updates into a single dictionary to return
output_components = [shared_info_display] + [item for sublist in vlm_outputs for item in sublist]
return dict(zip(output_components, [shared_info_update] + column_updates))
# --- Event Listeners ---
# When the app loads, trigger the dataset selection change to load the default dataset
demo.load(
fn=handle_dataset_selection,
inputs=[dataset_selector],
outputs=[all_data_state, sample_selector]
)
# When the dataset is changed by the user
dataset_selector.change(
fn=handle_dataset_selection,
inputs=[dataset_selector],
outputs=[all_data_state, sample_selector]
)
# When a new sample is selected, trigger the main display update
# This also gets triggered automatically after the dataset selection changes the sample dropdown
sample_selector.change(
fn=handle_sample_selection,
inputs=[dataset_selector, sample_selector, all_data_state],
outputs=[shared_info_display] + [item for sublist in vlm_outputs for item in sublist]
)
if __name__ == "__main__":
demo.launch(share=True, debug=True, allowed_paths=["/n/fs/vision-mix/ag9604/visualizer/"])