File size: 6,922 Bytes
863e121 a860625 863e121 3025cf4 863e121 3025cf4 863e121 3025cf4 863e121 a860625 863e121 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import gradio as gr
import json
import os
from pycocotools import mask as coco_mask
import argparse
# Paths
parser = argparse.ArgumentParser(description='Gradio App for Regional Captioning')
parser.add_argument('--save_path', type=str, default="./captions.json", help='Path to the caption file')
parser.add_argument('--caption_path', type=str, default="", help='Path to the caption file')
parser.add_argument('--img_dir', type=str, default="./annotations", help='Path to the image directory')
parser.add_argument('--json_dir', type=str, default="./annotations", help='Path to the json directory')
args = parser.parse_args()
caption_file = args.save_path
from_dir = args.caption_path
img_dir = args.img_dir
json_dir = args.json_dir
def format_json(caption_file):
display = {}
with open(from_dir, "r") as f:
scene = json.load(f)
for img_id, value in scene.items():
if img_id not in display.keys():
display[img_id] = {}
for mask_id, mask_value in value.items():
display[img_id][mask_id] = mask_value["long_caption"]
with open(caption_file, "w") as f:
print("saving at", caption_file)
json.dump(display, f, indent=4)
return display
if os.path.exists(caption_file):
with open(caption_file, "r", encoding="utf-8") as f:
captions_data = json.load(f)
else:
captions_data = format_json(caption_file)
# Get list of image_ids from caption data
image_ids = list(captions_data.keys())
# Decode segmentation and prepare masks
def decode_segmentation(segmentation):
return coco_mask.decode(segmentation)
def prepare_masks(filtered_annotations):
masks = []
for idx, annotation in enumerate(filtered_annotations):
segmentation = annotation["segmentation"]
mask_id = annotation["id"]
decoded_mask = decode_segmentation(segmentation) # Decode binary mask
masks.append((decoded_mask, f"{mask_id}")) # Add mask and its label
return masks
# Load image and annotations dynamically
def load_image_and_masks(image_id):
# Get image filename and annotation file
image_filename = f"{image_id}.jpg"
annotation_file = os.path.join(json_dir, f"{image_id}.json")
image_path = os.path.join(img_dir, image_filename)
# Load annotations
with open(annotation_file, "r", encoding="utf-8") as f:
annotations_data = json.load(f)
annotations = annotations_data["annotations"]
# Get relevant mask IDs
relevant_mask_ids = set(map(int, captions_data.get(image_id, {}).keys())) # Mask IDs in captions file
# Filter annotations to only include relevant masks
filtered_annotations = [annotation for annotation in annotations if annotation["id"] in relevant_mask_ids]
# Prepare masks
masks = prepare_masks(filtered_annotations)
return image_path, masks
# Gradio event function to display captions
def display_caption(evt: gr.SelectData, masks, image_id):
# Extract the mask ID from the label
# mask_id = int(masks[evt.index][1]) # Get the label corresponding to the selected mask
mask_id = masks[evt.index][1] # Get the label corresponding to the selected mask
caption_data = captions_data.get(image_id, {}).get(str(mask_id), "No caption found")
return caption_data
# def display_caption(evt: gr.SelectData, masks, image_id):
# # Get the label, e.g., "mask 1"
# label = masks[evt.index][1]
# # Use the label directly as key in the captions_data
# caption_data = captions_data.get(image_id, {}).get(label, "No caption found")
# return caption_data
# Gradio event function to update image
def update_image(image_index):
image_id = image_ids[image_index]
image_path, masks = load_image_and_masks(image_id)
# Return the correct tuple structure
return (image_path, [(m[0], m[1]) for m in masks]), masks, image_id, image_index
# Initialize first image and masks
initial_image_id = image_ids[0]
initial_image_path, initial_masks = load_image_and_masks(initial_image_id)
# Gradio event function to reload JSON data
def reload_data(image_index):
global captions_data, image_ids
# Reload the captions data
captions_data = format_json(caption_file)
# Update image IDs
image_ids = list(captions_data.keys())
# Ensure the current index is within the updated range
image_index = min(image_index, len(image_ids) - 1)
# Reload the image and masks for the current index
image_id = image_ids[image_index]
image_path, masks = load_image_and_masks(image_id)
return (image_path, [(m[0], m[1]) for m in masks]), masks, image_id, image_index, len(image_ids) - 1
# Add Reload button to the interface
with gr.Blocks() as demo:
gr.Markdown("## URECA Dataset Visualization")
# Annotated Image component
with gr.Row():
annotated_img = gr.AnnotatedImage(
value=(initial_image_path, [(m[0], m[1]) for m in initial_masks]),
label="Annotated Image",
height=400,
)
# Caption display
with gr.Row():
scene_caption = gr.Textbox(label="Generated Caption", interactive=False, lines=5)
# object_caption = gr.Textbox(label="Object Caption", interactive=False, lines=5)
# Navigation controls
with gr.Row():
prev_button = gr.Button("Prev Image")
slider = gr.Slider(0, len(image_ids) - 1, step=1, value=0, label="Jump to Image")
next_button = gr.Button("Next Image")
with gr.Row():
reload_button = gr.Button("🔄 Reload Data")
# State to store current masks and image_id
current_masks = gr.State(initial_masks)
current_image_id = gr.State(initial_image_id)
# Event listener for mask selection
annotated_img.select(
fn=display_caption,
inputs=[current_masks, current_image_id],
outputs=[scene_caption]
)
# Update image based on slider or button
def prev_image(image_index):
new_index = max(0, image_index - 1)
return update_image(new_index)
def next_image(image_index):
new_index = min(len(image_ids) - 1, image_index + 1)
return update_image(new_index)
prev_button.click(
fn=prev_image,
inputs=slider,
outputs=[annotated_img, current_masks, current_image_id, slider]
)
next_button.click(
fn=next_image,
inputs=slider,
outputs=[annotated_img, current_masks, current_image_id, slider]
)
slider.release(
fn=update_image,
inputs=slider,
outputs=[annotated_img, current_masks, current_image_id, slider]
)
# Reload button functionality
reload_button.click(
fn=reload_data,
inputs=slider, # Pass the current image index
outputs=[annotated_img, current_masks, current_image_id, slider, slider] # Update slider range as well
)
# Launch the Gradio app
demo.launch() |