File size: 6,922 Bytes
863e121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a860625
863e121
 
 
3025cf4
863e121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3025cf4
 
 
 
863e121
 
3025cf4
 
 
 
 
 
 
863e121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a860625
863e121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
import json
import os
from pycocotools import mask as coco_mask
import argparse
# Paths

parser = argparse.ArgumentParser(description='Gradio App for Regional Captioning')
parser.add_argument('--save_path', type=str, default="./captions.json", help='Path to the caption file')
parser.add_argument('--caption_path', type=str, default="", help='Path to the caption file')
parser.add_argument('--img_dir', type=str, default="./annotations", help='Path to the image directory')
parser.add_argument('--json_dir', type=str, default="./annotations", help='Path to the json directory')
args = parser.parse_args()

caption_file = args.save_path
from_dir = args.caption_path
img_dir = args.img_dir
json_dir = args.json_dir

def format_json(caption_file):
    display = {}
    with open(from_dir, "r") as f:
        scene = json.load(f)
    
    for img_id, value in scene.items():
        if img_id not in display.keys():
            display[img_id] = {}
        for mask_id, mask_value in value.items():
            display[img_id][mask_id] = mask_value["long_caption"]
            
    with open(caption_file, "w") as f:
        print("saving at", caption_file)
        json.dump(display, f, indent=4)
    
    return display

if os.path.exists(caption_file):
    with open(caption_file, "r", encoding="utf-8") as f:
        captions_data = json.load(f)
else:
    captions_data = format_json(caption_file)

# Get list of image_ids from caption data
image_ids = list(captions_data.keys())

# Decode segmentation and prepare masks
def decode_segmentation(segmentation):
    return coco_mask.decode(segmentation)

def prepare_masks(filtered_annotations):
    masks = []
    for idx, annotation in enumerate(filtered_annotations):
        segmentation = annotation["segmentation"]
        mask_id = annotation["id"]
        decoded_mask = decode_segmentation(segmentation)  # Decode binary mask
        masks.append((decoded_mask, f"{mask_id}"))  # Add mask and its label
    return masks

# Load image and annotations dynamically
def load_image_and_masks(image_id):
    # Get image filename and annotation file
    image_filename = f"{image_id}.jpg"
    annotation_file = os.path.join(json_dir, f"{image_id}.json")
    image_path = os.path.join(img_dir, image_filename)
    
    # Load annotations
    with open(annotation_file, "r", encoding="utf-8") as f:
        annotations_data = json.load(f)
    annotations = annotations_data["annotations"]
    
    # Get relevant mask IDs
    relevant_mask_ids = set(map(int, captions_data.get(image_id, {}).keys()))  # Mask IDs in captions file
    
    # Filter annotations to only include relevant masks
    filtered_annotations = [annotation for annotation in annotations if annotation["id"] in relevant_mask_ids]
    
    # Prepare masks
    masks = prepare_masks(filtered_annotations)
    
    return image_path, masks

# Gradio event function to display captions
def display_caption(evt: gr.SelectData, masks, image_id):
    # Extract the mask ID from the label
    # mask_id = int(masks[evt.index][1])  # Get the label corresponding to the selected mask
    mask_id = masks[evt.index][1]  # Get the label corresponding to the selected mask
    caption_data = captions_data.get(image_id, {}).get(str(mask_id), "No caption found")
    return caption_data

# def display_caption(evt: gr.SelectData, masks, image_id):
#     # Get the label, e.g., "mask 1"
#     label = masks[evt.index][1]
#     # Use the label directly as key in the captions_data
#     caption_data = captions_data.get(image_id, {}).get(label, "No caption found")
#     return caption_data

# Gradio event function to update image
def update_image(image_index):
    image_id = image_ids[image_index]
    image_path, masks = load_image_and_masks(image_id)
    # Return the correct tuple structure
    return (image_path, [(m[0], m[1]) for m in masks]), masks, image_id, image_index

# Initialize first image and masks
initial_image_id = image_ids[0]
initial_image_path, initial_masks = load_image_and_masks(initial_image_id)

# Gradio event function to reload JSON data
def reload_data(image_index):
    global captions_data, image_ids
    # Reload the captions data
    captions_data = format_json(caption_file)
    
    # Update image IDs
    image_ids = list(captions_data.keys())
    # Ensure the current index is within the updated range
    image_index = min(image_index, len(image_ids) - 1)
    # Reload the image and masks for the current index
    image_id = image_ids[image_index]
    image_path, masks = load_image_and_masks(image_id)
    return (image_path, [(m[0], m[1]) for m in masks]), masks, image_id, image_index, len(image_ids) - 1

# Add Reload button to the interface
with gr.Blocks() as demo:
    gr.Markdown("## URECA Dataset Visualization")

    # Annotated Image component
    with gr.Row():
        annotated_img = gr.AnnotatedImage(
            value=(initial_image_path, [(m[0], m[1]) for m in initial_masks]),
            label="Annotated Image",
            height=400,
        )

    # Caption display
    with gr.Row():
        scene_caption = gr.Textbox(label="Generated Caption", interactive=False, lines=5)
        # object_caption = gr.Textbox(label="Object Caption", interactive=False, lines=5)

    # Navigation controls
    with gr.Row():
        prev_button = gr.Button("Prev Image")
        slider = gr.Slider(0, len(image_ids) - 1, step=1, value=0, label="Jump to Image")
        next_button = gr.Button("Next Image")
    with gr.Row():
        reload_button = gr.Button("🔄 Reload Data")

    # State to store current masks and image_id
    current_masks = gr.State(initial_masks)
    current_image_id = gr.State(initial_image_id)

    # Event listener for mask selection
    annotated_img.select(
        fn=display_caption,
        inputs=[current_masks, current_image_id],
        outputs=[scene_caption]
    )

    # Update image based on slider or button
    def prev_image(image_index):
        new_index = max(0, image_index - 1)
        return update_image(new_index)

    def next_image(image_index):
        new_index = min(len(image_ids) - 1, image_index + 1)
        return update_image(new_index)

    prev_button.click(
        fn=prev_image,
        inputs=slider,
        outputs=[annotated_img, current_masks, current_image_id, slider]
    )

    next_button.click(
        fn=next_image,
        inputs=slider,
        outputs=[annotated_img, current_masks, current_image_id, slider]
    )

    slider.release(
        fn=update_image,
        inputs=slider,
        outputs=[annotated_img, current_masks, current_image_id, slider]
    )

    # Reload button functionality
    reload_button.click(
        fn=reload_data,
        inputs=slider,  # Pass the current image index
        outputs=[annotated_img, current_masks, current_image_id, slider, slider]  # Update slider range as well
    )

# Launch the Gradio app
demo.launch()