Mountchicken commited on
Commit
bf00d99
·
verified ·
1 Parent(s): 37431ac

Upload 3 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  groundingdino/_C.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  groundingdino/_C.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
37
+ tools/Tahoma.ttf filter=lfs diff=lfs merge=lfs -text
tools/Tahoma.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:359413e76969fc8a03e0acf91b355a98bb13c42472614e54bff5c8e4f4817fbb
3
+ size 681120
tools/inference_tools.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ import groundingdino.datasets.transforms as T
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms.functional as F
8
+ from groundingdino.util.inference import load_model, predict
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ from qwen_vl_utils import process_vision_info, smart_resize
11
+
12
+
13
+ class ColorGenerator:
14
+ """A class for generating consistent colors for visualization.
15
+
16
+ This class provides methods to generate colors either consistently for all elements
17
+ or based on text content for better visual distinction.
18
+
19
+ Args:
20
+ color_type (str): Type of color generation strategy. Can be either "same" for consistent color
21
+ or "text" for text-based color generation.
22
+ """
23
+
24
+ def __init__(self, color_type) -> None:
25
+ self.color_type = color_type
26
+
27
+ if color_type == "same":
28
+ self.color = tuple((np.random.randint(0, 127, size=3) + 128).tolist())
29
+ elif color_type == "text":
30
+ np.random.seed(3396)
31
+ self.num_colors = 300
32
+ self.colors = np.random.randint(0, 127, size=(self.num_colors, 3)) + 128
33
+ else:
34
+ raise ValueError
35
+
36
+ def get_color(self, text):
37
+ """Get a color based on the text content or return a consistent color.
38
+
39
+ Args:
40
+ text (str): The text to generate color for.
41
+
42
+ Returns:
43
+ tuple: RGB color values as a tuple.
44
+
45
+ Raises:
46
+ ValueError: If color_type is not supported.
47
+ """
48
+ if self.color_type == "same":
49
+ return self.color
50
+
51
+ if self.color_type == "text":
52
+ text_hash = hash(text)
53
+ index = text_hash % self.num_colors
54
+ color = tuple(self.colors[index])
55
+ return color
56
+
57
+ raise ValueError
58
+
59
+
60
+ def visualize(
61
+ image_pil: Image,
62
+ boxes,
63
+ scores,
64
+ labels=None,
65
+ filter_score=-1,
66
+ topN=900,
67
+ font_size=15,
68
+ draw_width: int = 6,
69
+ draw_index: bool = True,
70
+ ) -> Image:
71
+ """Visualize bounding boxes and labels on an image.
72
+
73
+ This function draws bounding boxes and their corresponding labels on the input image.
74
+ It supports filtering by score, limiting the number of boxes, and customizing the
75
+ visualization appearance.
76
+
77
+ Args:
78
+ image_pil (PIL.Image): The input image to draw on.
79
+ boxes (List[List[float]]): List of bounding boxes in [x1, y1, x2, y2] format.
80
+ scores (List[float]): Confidence scores for each bounding box.
81
+ labels (List[str], optional): Labels for each bounding box. Defaults to None.
82
+ filter_score (float, optional): Minimum score threshold for visualization. Defaults to -1.
83
+ topN (int, optional): Maximum number of boxes to visualize. Defaults to 900.
84
+ font_size (int, optional): Font size for labels. Defaults to 15.
85
+ draw_width (int, optional): Width of bounding box lines. Defaults to 6.
86
+ draw_index (bool, optional): Whether to draw index numbers for unlabeled boxes. Defaults to True.
87
+
88
+ Returns:
89
+ PIL.Image: The image with visualized bounding boxes and labels.
90
+ """
91
+ # Get the bounding boxes and labels from the target dictionary
92
+ font_path = "tools/Tahoma.ttf"
93
+ font = ImageFont.truetype(font_path, font_size)
94
+ # Create a PIL ImageDraw object to draw on the input image
95
+ draw = ImageDraw.Draw(image_pil)
96
+ boxes = boxes[:topN]
97
+ scores = scores[:topN]
98
+ # Draw boxes and masks for each box and label in the target dictionary
99
+ box_idx = 1
100
+ color_generaor = ColorGenerator("text")
101
+ if labels is None:
102
+ labels = [""] * len(boxes)
103
+ for box, score, label in zip(boxes, scores, labels):
104
+ if score < filter_score:
105
+ continue
106
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
107
+ # Extract the box coordinates
108
+ x0, y0, x1, y1 = box
109
+ # rescale the box coordinates to the input image size
110
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
111
+
112
+ if draw_index and label is "":
113
+ text = str(box_idx) + f" {label}"
114
+ else:
115
+ text = str(label)
116
+ max_words_per_line = 10
117
+ words = text.split()
118
+ lines = []
119
+ line = ""
120
+ for word in words:
121
+ if len(line.split()) < max_words_per_line:
122
+ line += word + " "
123
+ else:
124
+ lines.append(line)
125
+ line = word + " "
126
+ lines.append(line)
127
+ text = "\n".join(lines)
128
+
129
+ draw.rectangle(
130
+ [x0, y0, x1, y1], outline=color_generaor.get_color(text), width=draw_width
131
+ )
132
+
133
+ bbox = draw.textbbox((x0, y0), text, font)
134
+ box_h = bbox[3] - bbox[1]
135
+ box_w = bbox[2] - bbox[0]
136
+
137
+ y0_text = y0 - box_h - (draw_width * 2)
138
+ y1_text = y0 + draw_width
139
+ box_idx += 1
140
+ if y0_text < 0:
141
+ y0_text = 0
142
+ y1_text = y0 + 2 * draw_width + box_h
143
+ draw.rectangle(
144
+ [x0, y0_text, bbox[2] + draw_width * 2, y1_text],
145
+ fill=color_generaor.get_color(text),
146
+ )
147
+ draw.text(
148
+ (x0 + draw_width, y0_text),
149
+ str(text),
150
+ fill="black",
151
+ font=font,
152
+ )
153
+ return image_pil
154
+
155
+
156
+ def compute_iou(box1, box2):
157
+ """Compute Intersection over Union (IoU) between two bounding boxes.
158
+
159
+ Args:
160
+ box1 (List[float]): First bounding box in [x1, y1, x2, y2] format.
161
+ box2 (List[float]): Second bounding box in [x1, y1, x2, y2] format.
162
+
163
+ Returns:
164
+ float: IoU score between 0 and 1.
165
+ """
166
+ x1 = max(box1[0], box2[0])
167
+ y1 = max(box1[1], box2[1])
168
+ x2 = min(box1[2], box2[2])
169
+ y2 = min(box1[3], box2[3])
170
+
171
+ inter_area = max(0, x2 - x1) * max(0, y2 - y1)
172
+ if inter_area == 0:
173
+ return 0.0
174
+
175
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
176
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
177
+
178
+ union_area = box1_area + box2_area - inter_area
179
+ return inter_area / union_area
180
+
181
+
182
+ def return_maximum_overlap(gt_box, candidate_boxes, min_iou=0.5):
183
+ """Find the best matching box from candidate boxes based on IoU.
184
+
185
+ Args:
186
+ gt_box (List[float]): Ground truth bounding box in [x1, y1, x2, y2] format.
187
+ candidate_boxes (List[List[float]]): List of candidate bounding boxes.
188
+ min_iou (float, optional): Minimum IoU threshold for matching. Defaults to 0.5.
189
+
190
+ Returns:
191
+ int or None: Index of the best matching box if IoU > min_iou, None otherwise.
192
+ """
193
+ max_iou = 0.0
194
+ best_box = None
195
+ for i, box in enumerate(candidate_boxes):
196
+ iou = compute_iou(gt_box, box)
197
+ if iou >= min_iou and iou > max_iou:
198
+ max_iou = iou
199
+ best_box = i
200
+ return best_box
201
+
202
+
203
+ def find_best_matched_index(group1, group2):
204
+ """Find the best matching indices between two groups of bounding boxes.
205
+
206
+ Args:
207
+ group1 (List[List[float]]): First group of bounding boxes.
208
+ group2 (List[List[float]]): Second group of bounding boxes.
209
+
210
+ Returns:
211
+ List[int]: List of indices (1-based) indicating the best matches from group2 for each box in group1.
212
+ """
213
+ labels = []
214
+ for box in group1:
215
+ best_box = return_maximum_overlap(box, group2)
216
+ labels.append(best_box + 1)
217
+ return labels
218
+
219
+
220
+ def gdino_load_image(image: Union[str, Image.Image]) -> torch.Tensor:
221
+ """Load and transform image for Grounding DINO model.
222
+
223
+ Args:
224
+ image (Union[str, Image.Image]): Input image path or PIL Image.
225
+
226
+ Returns:
227
+ torch.Tensor: Transformed image tensor ready for model input.
228
+ """
229
+ transform = T.Compose(
230
+ [
231
+ T.RandomResize([800], max_size=1333),
232
+ T.ToTensor(),
233
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
234
+ ]
235
+ )
236
+ if isinstance(image, str):
237
+ image_source = Image.open(image).convert("RGB")
238
+ else:
239
+ image_source = image
240
+ image = np.asarray(image_source)
241
+ image_transformed, _ = transform(image_source, None)
242
+ return image_transformed
243
+
244
+
245
+ def inference_gdino(
246
+ image: Image.Image,
247
+ prompts: List[str],
248
+ gdino_model: Any,
249
+ TEXT_TRESHOLD: float = 0.25,
250
+ BOX_TRESHOLD: float = 0.25,
251
+ ) -> torch.Tensor:
252
+ """Process an image with Grounding DINO model to detect objects.
253
+
254
+ Args:
255
+ image (Image.Image): Input PIL image.
256
+ prompts (List[str]): List of text prompts for object detection.
257
+ gdino_model (Any): The Grounding DINO model instance.
258
+ TEXT_TRESHOLD (float, optional): Text confidence threshold. Defaults to 0.25.
259
+ BOX_TRESHOLD (float, optional): Box confidence threshold. Defaults to 0.35.
260
+
261
+ Returns:
262
+ List[List[float]]: List of detected bounding boxes in [x1, y1, x2, y2] format.
263
+ """
264
+ text_labels = ".".join(prompts)
265
+ image_transformed = gdino_load_image(image)
266
+ boxes, _, _ = predict(
267
+ model=gdino_model,
268
+ image=image_transformed,
269
+ caption=text_labels,
270
+ box_threshold=BOX_TRESHOLD,
271
+ text_threshold=TEXT_TRESHOLD,
272
+ )
273
+ # the output boxes is in the format of (x,y,w,h), in [0,1]
274
+ boxes = boxes * torch.tensor([image.width, image.height, image.width, image.height])
275
+ # convert to the format of (x1,y1,x2,y2)
276
+ boxes = torch.cat(
277
+ (boxes[:, :2] - boxes[:, 2:4] / 2, boxes[:, :2] + boxes[:, 2:4] / 2), dim=1
278
+ )
279
+ return boxes.tolist()
280
+
281
+
282
+ def convert_boxes_from_absolute_to_qwen25_format(gt_boxes, ori_width, ori_height):
283
+ """Convert bounding boxes from absolute coordinates to Qwen-25 format.
284
+
285
+ This function resizes bounding boxes according to Qwen-25's requirements while
286
+ maintaining aspect ratio and pixel constraints.
287
+
288
+ Args:
289
+ gt_boxes (List[List[float]]): List of bounding boxes in absolute coordinates.
290
+ ori_width (int): Original image width.
291
+ ori_height (int): Original image height.
292
+
293
+ Returns:
294
+ List[List[int]]: Resized bounding boxes in Qwen-25 format.
295
+ """
296
+ resized_height, resized_width = smart_resize(
297
+ ori_height,
298
+ ori_width,
299
+ 28,
300
+ min_pixels=16 * 28 * 28,
301
+ max_pixels=1280 * 28 * 28,
302
+ )
303
+ resized_gt_boxes = []
304
+ for box in gt_boxes:
305
+ # resize the box
306
+ x0, y0, x1, y1 = box
307
+ x0 = int(x0 / ori_width * resized_width)
308
+ x1 = int(x1 / ori_width * resized_width)
309
+ y0 = int(y0 / ori_height * resized_height)
310
+ y1 = int(y1 / ori_height * resized_height)
311
+
312
+ x0 = max(0, min(x0, resized_width - 1))
313
+ y0 = max(0, min(y0, resized_height - 1))
314
+ x1 = max(0, min(x1, resized_width - 1))
315
+ y1 = max(0, min(y1, resized_height - 1))
316
+ resized_gt_boxes.append([x0, y0, x1, y1])
317
+ return resized_gt_boxes
318
+
319
+
320
+ def parse_json(json_output):
321
+ """Parse JSON string containing coordinate arrays.
322
+
323
+ Args:
324
+ json_output (str): JSON string containing coordinate arrays.
325
+
326
+ Returns:
327
+ List[List[float]]: List of parsed coordinate arrays.
328
+ """
329
+ pattern = r"\[([0-9\.]+(?:, ?[0-9\.]+)*)\]"
330
+
331
+ matches = re.findall(pattern, json_output)
332
+ coordinates = [
333
+ [float(num) if "." in num else int(num) for num in match.split(",")]
334
+ for match in matches
335
+ ]
336
+
337
+ return coordinates
338
+
339
+
340
+ def postprocess_and_vis_inference_out(
341
+ target_image,
342
+ answer,
343
+ proposed_box,
344
+ gdino_boxes,
345
+ font_size,
346
+ draw_width,
347
+ input_height,
348
+ input_width,
349
+ ):
350
+ """Post-process inference results and create visualization.
351
+
352
+ This function processes the model output, matches boxes with Grounding DINO results,
353
+ and creates visualization images.
354
+
355
+ Args:
356
+ target_image (PIL.Image): Target image for visualization.
357
+ answer (str): Model output containing box coordinates.
358
+ proposed_box (List[List[float]] or None): Proposed bounding boxes.
359
+ gdino_boxes (List[List[float]]): Grounding DINO detected boxes.
360
+ font_size (int): Font size for visualization.
361
+ draw_width (int): Line width for visualization.
362
+ input_height (int): Original input image height.
363
+ input_width (int): Original input image width.
364
+
365
+ Returns:
366
+ Tuple[PIL.Image, PIL.Image]: Two visualization images - one for reference boxes
367
+ and one for Grounding DINO boxes.
368
+ """
369
+ if proposed_box is None:
370
+ return target_image, target_image
371
+
372
+ w, h = target_image.size
373
+ json_output = parse_json(answer)
374
+ final_boxes = []
375
+ input_height = input_height.item()
376
+ input_width = input_width.item()
377
+ for box in json_output:
378
+ x0, y0, x1, y1 = box
379
+ x0 = x0 / input_width * w
380
+ y0 = y0 / input_height * h
381
+ x1 = x1 / input_width * w
382
+ y1 = y1 / input_height * h
383
+
384
+ final_boxes.append([x0, y0, x1, y1])
385
+
386
+ ref_labels = find_best_matched_index(
387
+ final_boxes, gdino_boxes
388
+ ) # find the best matched index
389
+
390
+ print("ref_labels", ref_labels)
391
+ ref_vis_result = visualize(
392
+ target_image.copy(),
393
+ final_boxes,
394
+ np.ones(len(final_boxes)),
395
+ labels=ref_labels,
396
+ font_size=font_size,
397
+ draw_width=draw_width,
398
+ )
399
+ dinox_vis_result = visualize(
400
+ target_image.copy(),
401
+ gdino_boxes,
402
+ np.ones(len(gdino_boxes)),
403
+ font_size=font_size,
404
+ draw_width=draw_width,
405
+ )
406
+ return ref_vis_result, dinox_vis_result
tools/visualize_humanref_cot.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ import random
5
+ from base64 import b64decode
6
+ from io import BytesIO
7
+
8
+ import matplotlib.patches as patches
9
+ import matplotlib.pyplot as plt
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+
13
+
14
+ def parse_args():
15
+ """Parse command line arguments for the visualization script.
16
+
17
+ Returns:
18
+ argparse.Namespace: Parsed command line arguments containing:
19
+ - img_tsv (str): Path to image TSV file
20
+ - ann_tsv (str): Path to annotation TSV file
21
+ - ann_lineidx (str): Path to annotation lineidx file
22
+ - idx (int): Index of the sample to visualize
23
+ - output (str): Output path for visualization image
24
+ """
25
+ parser = argparse.ArgumentParser(
26
+ description="Visualize human reference data with reasoning process"
27
+ )
28
+ parser.add_argument(
29
+ "--img_tsv",
30
+ type=str,
31
+ default="IDEA-Research/HumanRef-CoT-45k/humanref_cot.images.tsv",
32
+ help="Path to image TSV file",
33
+ )
34
+ parser.add_argument(
35
+ "--ann_tsv",
36
+ type=str,
37
+ default="IDEA-Research/HumanRef-CoT-45k/humanref_cot.annotations.tsv",
38
+ help="Path to annotation TSV file",
39
+ )
40
+ parser.add_argument(
41
+ "--ann_lineidx",
42
+ type=str,
43
+ default="IDEA-Research/HumanRef-CoT-45k/humanref_cot.annotations.tsv.lineidx",
44
+ help="Path to annotation lineidx file",
45
+ )
46
+ parser.add_argument(
47
+ "--num_vis", type=int, default=50, help="number of data to visualize"
48
+ )
49
+ parser.add_argument(
50
+ "--output_dir",
51
+ type=str,
52
+ default="vis/",
53
+ help="Output path for visualization",
54
+ )
55
+ return parser.parse_args()
56
+
57
+
58
+ class TSVDataset(Dataset):
59
+ """Dataset class for loading images and annotations from TSV files.
60
+
61
+ This dataset class handles loading of images and annotations from TSV format files,
62
+ where images are stored as base64 encoded strings and annotations are stored as JSON.
63
+
64
+ Args:
65
+ img_tsv_file (str): Path to the TSV file containing images
66
+ ann_tsv_file (str): Path to the TSV file containing annotations
67
+ ann_lineidx_file (str): Path to the line index file for annotations
68
+
69
+ Attributes:
70
+ data (list): List of line indices for annotations
71
+ img_handle (file): File handle for image TSV file
72
+ ann_handle (file): File handle for annotation TSV file
73
+ img_tsv_file (str): Path to image TSV file
74
+ ann_tsv_file (str): Path to annotation TSV file
75
+ """
76
+
77
+ def __init__(self, img_tsv_file: str, ann_tsv_file: str, ann_lineidx_file: str):
78
+ super(TSVDataset, self).__init__()
79
+ self.data = []
80
+ f = open(ann_lineidx_file)
81
+ for line in f:
82
+ self.data.append(int(line.strip()))
83
+ # shuffle(self.data)
84
+ random.shuffle(self.data)
85
+
86
+ self.img_handle = None
87
+ self.ann_handle = None
88
+ self.img_tsv_file = img_tsv_file
89
+ self.ann_tsv_file = ann_tsv_file
90
+
91
+ def __len__(self):
92
+ """Get the total number of samples in the dataset.
93
+
94
+ Returns:
95
+ int: Number of samples in the dataset
96
+ """
97
+ return len(self.data)
98
+
99
+ def __getitem__(self, idx):
100
+ """Get a sample from the dataset.
101
+
102
+ Args:
103
+ idx (int): Index of the sample to retrieve
104
+
105
+ Returns:
106
+ tuple: (image, data_dict) where:
107
+ - image (PIL.Image): RGB image
108
+ - data_dict (dict): Dictionary containing:
109
+ - gt_boxes (list): List of bounding boxes [x0, y0, x1, y1]
110
+ - region_map (dict): Mapping from referring expressions to box indices
111
+ - think (str): Reasoning process text
112
+ """
113
+ ann_line_idx = self.data[idx]
114
+
115
+ if self.ann_handle is None:
116
+ self.ann_handle = open(self.ann_tsv_file)
117
+ self.ann_handle.seek(ann_line_idx)
118
+
119
+ img_line_idx, ann = self.ann_handle.readline().strip().split("\t")
120
+ img_line_idx = int(img_line_idx)
121
+ if self.img_handle is None:
122
+ self.img_handle = open(self.img_tsv_file)
123
+ self.img_handle.seek(img_line_idx)
124
+ img = self.img_handle.readline().strip().split("\t")[1]
125
+ if img.startswith("b'"):
126
+ img = img[1:-1]
127
+ img = BytesIO(b64decode(img))
128
+ image = Image.open(img).convert("RGB")
129
+ data_dict = json.loads(ann)
130
+
131
+ return image, data_dict
132
+
133
+
134
+ def visualize(image, data_dict, output_path="visualization.png"):
135
+ """Visualize an image with bounding boxes and reasoning process.
136
+
137
+ This function creates a visualization with two panels:
138
+ - Left panel: Original image with ground truth boxes (red) and answer boxes (green)
139
+ - Right panel: Reasoning process text
140
+
141
+ Args:
142
+ image (PIL.Image): Input image to visualize
143
+ data_dict (dict): Dictionary containing:
144
+ - gt_boxes (list): List of bounding boxes [x0, y0, w, h]
145
+ - region_map (dict): Mapping from referring expressions to box indices
146
+ - think (str): Reasoning process text
147
+ output_path (str, optional): Path to save the visualization. Defaults to "visualization.png".
148
+ """
149
+ # Create figure with two subplots side by side
150
+ plt.rcParams["figure.dpi"] = 300
151
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
152
+
153
+ # Display image on the left subplot
154
+ ax1.imshow(image)
155
+
156
+ # Plot all ground truth boxes in red with indices
157
+ gt_boxes = data_dict.get("gt_boxes", [])
158
+ for idx, box in enumerate(gt_boxes):
159
+ x0, y0, width, height = box
160
+
161
+ # Create rectangle patch
162
+ rect = patches.Rectangle(
163
+ (x0, y0), width, height, linewidth=2, edgecolor="red", facecolor="none"
164
+ )
165
+ ax1.add_patch(rect)
166
+
167
+ # Add index number
168
+ ax1.text(
169
+ x0,
170
+ y0 - 5,
171
+ str(idx),
172
+ color="red",
173
+ fontsize=12,
174
+ bbox=dict(facecolor="white", alpha=0.7),
175
+ )
176
+
177
+ # Plot answer boxes from region_map in green
178
+ region_map = data_dict.get("region_map", {})
179
+ for referring_exp, answer_indices in region_map.items():
180
+ # Display referring expression at the top of the image
181
+ ax1.text(
182
+ 10,
183
+ 30,
184
+ referring_exp,
185
+ color="blue",
186
+ fontsize=12,
187
+ bbox=dict(facecolor="white", alpha=0.7),
188
+ )
189
+
190
+ # Plot answer boxes in green
191
+ for idx in answer_indices:
192
+ if idx < len(gt_boxes):
193
+ box = gt_boxes[idx]
194
+ x0, y0, width, height = box
195
+ # Create rectangle patch for answer box
196
+ rect = patches.Rectangle(
197
+ (x0, y0),
198
+ width,
199
+ height,
200
+ linewidth=3,
201
+ edgecolor="green",
202
+ facecolor="none",
203
+ )
204
+ ax1.add_patch(rect)
205
+
206
+ # Remove axis ticks from image
207
+ ax1.set_xticks([])
208
+ ax1.set_yticks([])
209
+ ax1.set_title("Image with Bounding Boxes")
210
+
211
+ # Display reasoning text on the right subplot
212
+ ax2.text(0.05, 0.95, data_dict.get("think", ""), wrap=True, fontsize=12, va="top")
213
+ ax2.set_xticks([])
214
+ ax2.set_yticks([])
215
+ ax2.set_title("Reasoning Process")
216
+
217
+ # Adjust layout and display
218
+ plt.tight_layout()
219
+ plt.savefig(output_path, dpi=300)
220
+
221
+
222
+ if __name__ == "__main__":
223
+ import argparse
224
+
225
+ # Parse arguments
226
+ args = parse_args()
227
+
228
+ # Initialize dataset
229
+ dataset = TSVDataset(args.img_tsv, args.ann_tsv, args.ann_lineidx)
230
+
231
+ vis_root = args.output_dir
232
+ os.makedirs(vis_root, exist_ok=True)
233
+ for i in range(args.num_vis):
234
+ image, data_dict = dataset[i]
235
+ # Save the visualization
236
+ output_path = os.path.join(vis_root, f"visualization_{i}.png")
237
+ visualize(image, data_dict, output_path)
238
+ print(f"Visualization saved to {output_path}")