Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import json | |
import os | |
import random | |
from base64 import b64decode | |
from io import BytesIO | |
import matplotlib.patches as patches | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from torch.utils.data import Dataset | |
def parse_args(): | |
"""Parse command line arguments for the visualization script. | |
Returns: | |
argparse.Namespace: Parsed command line arguments containing: | |
- img_tsv (str): Path to image TSV file | |
- ann_tsv (str): Path to annotation TSV file | |
- ann_lineidx (str): Path to annotation lineidx file | |
- idx (int): Index of the sample to visualize | |
- output (str): Output path for visualization image | |
""" | |
parser = argparse.ArgumentParser( | |
description="Visualize human reference data with reasoning process" | |
) | |
parser.add_argument( | |
"--img_tsv", | |
type=str, | |
default="IDEA-Research/HumanRef-CoT-45k/humanref_cot.images.tsv", | |
help="Path to image TSV file", | |
) | |
parser.add_argument( | |
"--ann_tsv", | |
type=str, | |
default="IDEA-Research/HumanRef-CoT-45k/humanref_cot.annotations.tsv", | |
help="Path to annotation TSV file", | |
) | |
parser.add_argument( | |
"--ann_lineidx", | |
type=str, | |
default="IDEA-Research/HumanRef-CoT-45k/humanref_cot.annotations.tsv.lineidx", | |
help="Path to annotation lineidx file", | |
) | |
parser.add_argument( | |
"--num_vis", type=int, default=50, help="number of data to visualize" | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="vis/", | |
help="Output path for visualization", | |
) | |
return parser.parse_args() | |
class TSVDataset(Dataset): | |
"""Dataset class for loading images and annotations from TSV files. | |
This dataset class handles loading of images and annotations from TSV format files, | |
where images are stored as base64 encoded strings and annotations are stored as JSON. | |
Args: | |
img_tsv_file (str): Path to the TSV file containing images | |
ann_tsv_file (str): Path to the TSV file containing annotations | |
ann_lineidx_file (str): Path to the line index file for annotations | |
Attributes: | |
data (list): List of line indices for annotations | |
img_handle (file): File handle for image TSV file | |
ann_handle (file): File handle for annotation TSV file | |
img_tsv_file (str): Path to image TSV file | |
ann_tsv_file (str): Path to annotation TSV file | |
""" | |
def __init__(self, img_tsv_file: str, ann_tsv_file: str, ann_lineidx_file: str): | |
super(TSVDataset, self).__init__() | |
self.data = [] | |
f = open(ann_lineidx_file) | |
for line in f: | |
self.data.append(int(line.strip())) | |
# shuffle(self.data) | |
random.shuffle(self.data) | |
self.img_handle = None | |
self.ann_handle = None | |
self.img_tsv_file = img_tsv_file | |
self.ann_tsv_file = ann_tsv_file | |
def __len__(self): | |
"""Get the total number of samples in the dataset. | |
Returns: | |
int: Number of samples in the dataset | |
""" | |
return len(self.data) | |
def __getitem__(self, idx): | |
"""Get a sample from the dataset. | |
Args: | |
idx (int): Index of the sample to retrieve | |
Returns: | |
tuple: (image, data_dict) where: | |
- image (PIL.Image): RGB image | |
- data_dict (dict): Dictionary containing: | |
- gt_boxes (list): List of bounding boxes [x0, y0, x1, y1] | |
- region_map (dict): Mapping from referring expressions to box indices | |
- think (str): Reasoning process text | |
""" | |
ann_line_idx = self.data[idx] | |
if self.ann_handle is None: | |
self.ann_handle = open(self.ann_tsv_file) | |
self.ann_handle.seek(ann_line_idx) | |
img_line_idx, ann = self.ann_handle.readline().strip().split("\t") | |
img_line_idx = int(img_line_idx) | |
if self.img_handle is None: | |
self.img_handle = open(self.img_tsv_file) | |
self.img_handle.seek(img_line_idx) | |
img = self.img_handle.readline().strip().split("\t")[1] | |
if img.startswith("b'"): | |
img = img[1:-1] | |
img = BytesIO(b64decode(img)) | |
image = Image.open(img).convert("RGB") | |
data_dict = json.loads(ann) | |
return image, data_dict | |
def visualize(image, data_dict, output_path="visualization.png"): | |
"""Visualize an image with bounding boxes and reasoning process. | |
This function creates a visualization with two panels: | |
- Left panel: Original image with ground truth boxes (red) and answer boxes (green) | |
- Right panel: Reasoning process text | |
Args: | |
image (PIL.Image): Input image to visualize | |
data_dict (dict): Dictionary containing: | |
- gt_boxes (list): List of bounding boxes [x0, y0, w, h] | |
- region_map (dict): Mapping from referring expressions to box indices | |
- think (str): Reasoning process text | |
output_path (str, optional): Path to save the visualization. Defaults to "visualization.png". | |
""" | |
# Create figure with two subplots side by side | |
plt.rcParams["figure.dpi"] = 300 | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) | |
# Display image on the left subplot | |
ax1.imshow(image) | |
# Plot all ground truth boxes in red with indices | |
gt_boxes = data_dict.get("gt_boxes", []) | |
for idx, box in enumerate(gt_boxes): | |
x0, y0, width, height = box | |
# Create rectangle patch | |
rect = patches.Rectangle( | |
(x0, y0), width, height, linewidth=2, edgecolor="red", facecolor="none" | |
) | |
ax1.add_patch(rect) | |
# Add index number | |
ax1.text( | |
x0, | |
y0 - 5, | |
str(idx), | |
color="red", | |
fontsize=12, | |
bbox=dict(facecolor="white", alpha=0.7), | |
) | |
# Plot answer boxes from region_map in green | |
region_map = data_dict.get("region_map", {}) | |
for referring_exp, answer_indices in region_map.items(): | |
# Display referring expression at the top of the image | |
ax1.text( | |
10, | |
30, | |
referring_exp, | |
color="blue", | |
fontsize=12, | |
bbox=dict(facecolor="white", alpha=0.7), | |
) | |
# Plot answer boxes in green | |
for idx in answer_indices: | |
if idx < len(gt_boxes): | |
box = gt_boxes[idx] | |
x0, y0, width, height = box | |
# Create rectangle patch for answer box | |
rect = patches.Rectangle( | |
(x0, y0), | |
width, | |
height, | |
linewidth=3, | |
edgecolor="green", | |
facecolor="none", | |
) | |
ax1.add_patch(rect) | |
# Remove axis ticks from image | |
ax1.set_xticks([]) | |
ax1.set_yticks([]) | |
ax1.set_title("Image with Bounding Boxes") | |
# Display reasoning text on the right subplot | |
ax2.text(0.05, 0.95, data_dict.get("think", ""), wrap=True, fontsize=12, va="top") | |
ax2.set_xticks([]) | |
ax2.set_yticks([]) | |
ax2.set_title("Reasoning Process") | |
# Adjust layout and display | |
plt.tight_layout() | |
plt.savefig(output_path, dpi=300) | |
if __name__ == "__main__": | |
import argparse | |
# Parse arguments | |
args = parse_args() | |
# Initialize dataset | |
dataset = TSVDataset(args.img_tsv, args.ann_tsv, args.ann_lineidx) | |
vis_root = args.output_dir | |
os.makedirs(vis_root, exist_ok=True) | |
for i in range(args.num_vis): | |
image, data_dict = dataset[i] | |
# Save the visualization | |
output_path = os.path.join(vis_root, f"visualization_{i}.png") | |
visualize(image, data_dict, output_path) | |
print(f"Visualization saved to {output_path}") | |