Spaces:
Sleeping
Sleeping
File size: 13,578 Bytes
8718432 5c8ee13 8718432 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 |
import os
import json
import glob
import gradio as gr
from collections import defaultdict
# --- Configuration ---
# Base path where all dataset folders are located
BASE_DATA_DIRECTORY = "./"
# Names of the VLMs and their corresponding keys used in file names
VLM_MODELS = {
"GPT-4o": "4o",
"OpenAI o1": "o1",
"Gemini 2.5 Pro": "gemini",
"Qwen 2.5 VL": "qwen"
}
# Configuration for each dataset
DATASET_CONFIG = {
"AITW": {
"display_name": "AITW",
"base_dir": os.path.join(BASE_DATA_DIRECTORY, ""), # Base dir is the root for aitw
"json_patterns": ["aitw_{model_key}_dataset.json", "aitw_{model_key}_dataset1.json"],
"data_is_nested": True, # The JSON is a dict of episodes, which contain steps
},
"Where2Place": {
"display_name": "Where2Place",
"base_dir": os.path.join(BASE_DATA_DIRECTORY, "where2place"),
"json_patterns": ["where2place_mcq_{model_key}.json"],
},
"MONDAY": {
"display_name": "MONDAY",
"base_dir": os.path.join(BASE_DATA_DIRECTORY, "Monday"),
"json_patterns": ["monday_mcq_test_{model_key}.json", "monday_mcq_test_unseen_os_{model_key}.json"],
},
"RoboVQA": {
"display_name": "RoboVQA",
"base_dir": os.path.join(BASE_DATA_DIRECTORY, "robovqa"),
"json_patterns": ["robovqa_final_dataset_{model_key}.json"],
}
}
# --- Data Loading and Processing ---
def load_data_for_dataset(dataset_key):
"""
Loads and structures data for a given dataset from its JSON files.
Returns a dictionary where keys are unique sample IDs and values are
dictionaries mapping VLM model keys to their specific data for that sample.
e.g., {'episode_123:step_0': {'4o': {...}, 'o1': {...}}, ...}
"""
if dataset_key not in DATASET_CONFIG:
return {}
config = DATASET_CONFIG[dataset_key]
unified_data = defaultdict(dict)
print(f"Loading data for dataset: {dataset_key}")
for display_name, model_key in VLM_MODELS.items():
all_entries = []
for pattern in config["json_patterns"]:
# Construct the full file path pattern
full_pattern = os.path.join(config["base_dir"], pattern.format(model_key=model_key))
# Find all matching files
json_files = glob.glob(full_pattern)
for file_path in json_files:
print(f" - Reading file: {file_path}")
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
all_entries.extend(data)
elif isinstance(data, dict):
# Handle AITW's nested structure
if config.get("data_is_nested"):
for episode_id, episode_data in data.items():
for step in episode_data.get("steps", []):
# Add episode context to each step
step_with_context = step.copy()
step_with_context['episode_id'] = episode_id
step_with_context['episode_goal'] = episode_data.get('episode_goal')
all_entries.append(step_with_context)
except FileNotFoundError:
print(f" - WARNING: File not found: {file_path}")
except json.JSONDecodeError:
print(f" - WARNING: Could not decode JSON from: {file_path}")
# Process loaded entries and add to the unified dictionary
for i, entry in enumerate(all_entries):
sample_id = None
if dataset_key == "AITW":
sample_id = f"{entry.get('episode_id', 'unknown_ep')}:{entry.get('step_id', 'unknown_step')}"
elif dataset_key == "Where2Place":
sample_id = f"q_{entry.get('question_id', i)}"
elif dataset_key == "MONDAY":
sample_id = f"{entry.get('episode_id', 'unknown_ep')}:{entry.get('step_id', i)}"
elif dataset_key == "RoboVQA":
sample_id = f"{entry.get('episode_id', i)}"
if sample_id:
unified_data[sample_id][model_key] = entry
# Sort sample IDs for consistent ordering in the dropdown
sorted_unified_data = {k: unified_data[k] for k in sorted(unified_data.keys())}
print(f"Finished loading. Found {len(sorted_unified_data)} unique samples.")
return sorted_unified_data
def format_mcq_options(options, correct_index):
"""Formats MCQ options into a Markdown string, highlighting the correct one."""
if not isinstance(options, list):
return "Options not available."
lines = []
for i, option in enumerate(options):
# The correct answer in JSON can be 1-based or 0-based index. Check both.
is_correct = (i == correct_index)
prefix = "✅ **" if is_correct else ""
suffix = "**" if is_correct else ""
lines.append(f"- {prefix}{option}{suffix}")
return "\n".join(lines)
# --- Gradio UI Application ---
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo:
gr.Markdown("# VLM Comparative Benchmark Visualizer")
gr.Markdown("Select a dataset to load evaluation samples. The interface will display the same question/task evaluated across four different VLMs.")
# --- State Management ---
all_data_state = gr.State({})
# --- UI Components ---
with gr.Row():
dataset_selector = gr.Dropdown(
choices=list(DATASET_CONFIG.keys()),
label="1. Select a Dataset",
value="AITW" # Default value
)
sample_selector = gr.Dropdown(
label="2. Select a Sample / Episode Step",
interactive=True,
# Choices will be populated dynamically
)
shared_info_display = gr.Markdown(visible=False) # For goal, common question, etc.
with gr.Row(equal_height=False):
vlm_outputs = []
for vlm_display_name, vlm_key in VLM_MODELS.items():
with gr.Column(scale=1):
with gr.Blocks():
gr.Markdown(f"### {vlm_display_name}")
media_display = gr.Image(label="Media", type="filepath", interactive=False, height=400)
info_display = gr.Markdown()
vlm_outputs.append((media_display, info_display))
# --- UI Update Logic ---
def handle_dataset_selection(dataset_key):
"""
Triggered when a new dataset is selected.
Loads all data for that dataset and populates the sample selector.
"""
print(f"UI: Dataset selection changed to '{dataset_key}'")
if not dataset_key:
return {
all_data_state: {},
sample_selector: gr.update(choices=[], value=None),
}
data = load_data_for_dataset(dataset_key)
sample_ids = list(data.keys())
first_sample = sample_ids[0] if sample_ids else None
return {
all_data_state: data,
sample_selector: gr.update(choices=sample_ids, value=first_sample, visible=True),
}
def handle_sample_selection(dataset_key, sample_id, all_data):
"""
Triggered when a new sample is selected.
Updates the four columns with the data for that sample.
"""
print(f"UI: Sample selection changed to '{sample_id}'")
if not sample_id or not all_data:
# Create empty updates for all components if there's no selection
updates = [gr.update(visible=False)] + [gr.update(value=None, visible=False)] * len(vlm_outputs) * 2
return dict(zip([shared_info_display] + [item for sublist in vlm_outputs for item in sublist], updates))
sample_data_for_all_vlms = all_data.get(sample_id, {})
# --- 1. Update Shared Information Display ---
shared_md_parts = []
# Use data from the first available VLM to populate shared info
first_vlm_key = next(iter(VLM_MODELS.values()))
first_vlm_data = sample_data_for_all_vlms.get(first_vlm_key, {})
if dataset_key == "AITW":
shared_md_parts.append(f"**Goal:** `{first_vlm_data.get('episode_goal', 'N/A')}`")
shared_md_parts.append(f"**Question:** *{first_vlm_data.get('questions', {}).get('question', 'N/A')}*")
elif dataset_key == "MONDAY":
shared_md_parts.append(f"**Goal:** `{first_vlm_data.get('goal', 'N/A')}`")
shared_md_parts.append(f"**OS:** {first_vlm_data.get('os', 'N/A')}")
elif dataset_key == "RoboVQA":
shared_md_parts.append(f"**Task Type:** {first_vlm_data.get('task_type', 'N/A')}")
# Where2Place has its question per-VLM, so no shared info needed.
shared_info_update = gr.update(value="\n\n".join(shared_md_parts), visible=bool(shared_md_parts))
# --- 2. Update Each VLM Column ---
column_updates = []
config = DATASET_CONFIG[dataset_key]
for vlm_display_name, vlm_key in VLM_MODELS.items():
vlm_data = sample_data_for_all_vlms.get(vlm_key)
if not vlm_data:
column_updates.extend([gr.update(value=None, visible=True), gr.update(value="*Data not found for this sample.*")])
continue
# Find image/media path
media_path = None
if dataset_key == "AITW": media_path = vlm_data.get('screenshot_path')
elif dataset_key == "Where2Place": media_path = vlm_data.get('marked_image_path')
elif dataset_key == "MONDAY": media_path = vlm_data.get('screenshot_path')
elif dataset_key == "RoboVQA": media_path = vlm_data.get('media_path')
# Construct absolute path if relative
absolute_media_path = None
if media_path:
# The AITW paths are absolute, others are relative.
if os.path.isabs(media_path):
absolute_media_path = media_path
else:
absolute_media_path = os.path.join(config['base_dir'], media_path)
# Build the markdown content for the info box
md_content = []
if dataset_key == "AITW":
md_content.append(f"**Action History:**\n```\n{vlm_data.get('action_history', 'None')}\n```")
options = vlm_data.get('questions', {}).get('options')
answer_idx = vlm_data.get('questions', {}).get('correct_answer_index')
md_content.append(format_mcq_options(options, answer_idx))
elif dataset_key == "Where2Place":
md_content.append(f"**Question:** *{vlm_data.get('question', 'N/A')}*")
options = vlm_data.get('options')
answer_idx = vlm_data.get('answer')
md_content.append(format_mcq_options(options, answer_idx))
elif dataset_key == "MONDAY":
md_content.append(f"**Question:** *{vlm_data.get('current_question', 'N/A')}*")
md_content.append(f"**Action History:**\n```\n{vlm_data.get('action_history', 'None')}\n```")
options = vlm_data.get('options')
answer_idx = vlm_data.get('answer')
md_content.append(format_mcq_options(options, answer_idx))
elif dataset_key == "RoboVQA":
md_content.append(f"**Question:** *{vlm_data.get('question', 'N/A')}*")
options = vlm_data.get('options')
answer_idx = vlm_data.get('answer')
md_content.append(format_mcq_options(options, answer_idx))
image_update = gr.update(value=absolute_media_path if absolute_media_path and os.path.exists(absolute_media_path) else None, visible=True)
info_update = gr.update(value="\n\n".join(md_content))
column_updates.extend([image_update, info_update])
# Combine all updates into a single dictionary to return
output_components = [shared_info_display] + [item for sublist in vlm_outputs for item in sublist]
return dict(zip(output_components, [shared_info_update] + column_updates))
# --- Event Listeners ---
# When the app loads, trigger the dataset selection change to load the default dataset
demo.load(
fn=handle_dataset_selection,
inputs=[dataset_selector],
outputs=[all_data_state, sample_selector]
)
# When the dataset is changed by the user
dataset_selector.change(
fn=handle_dataset_selection,
inputs=[dataset_selector],
outputs=[all_data_state, sample_selector]
)
# When a new sample is selected, trigger the main display update
# This also gets triggered automatically after the dataset selection changes the sample dropdown
sample_selector.change(
fn=handle_sample_selection,
inputs=[dataset_selector, sample_selector, all_data_state],
outputs=[shared_info_display] + [item for sublist in vlm_outputs for item in sublist]
)
if __name__ == "__main__":
demo.launch(share=True, debug=True, allowed_paths=["/n/fs/vision-mix/ag9604/visualizer/"]) |