import gradio as gr import json import os from pycocotools import mask as coco_mask import argparse # Paths parser = argparse.ArgumentParser(description='Gradio App for Regional Captioning') parser.add_argument('--save_path', type=str, default="./captions.json", help='Path to the caption file') parser.add_argument('--caption_path', type=str, default="", help='Path to the caption file') parser.add_argument('--img_dir', type=str, default="./annotations", help='Path to the image directory') parser.add_argument('--json_dir', type=str, default="./annotations", help='Path to the json directory') args = parser.parse_args() caption_file = args.save_path from_dir = args.caption_path img_dir = args.img_dir json_dir = args.json_dir def format_json(caption_file): display = {} with open(from_dir, "r") as f: scene = json.load(f) for img_id, value in scene.items(): if img_id not in display.keys(): display[img_id] = {} for mask_id, mask_value in value.items(): display[img_id][mask_id] = mask_value["long_caption"] with open(caption_file, "w") as f: print("saving at", caption_file) json.dump(display, f, indent=4) return display if os.path.exists(caption_file): with open(caption_file, "r", encoding="utf-8") as f: captions_data = json.load(f) else: captions_data = format_json(caption_file) # Get list of image_ids from caption data image_ids = list(captions_data.keys()) # Decode segmentation and prepare masks def decode_segmentation(segmentation): return coco_mask.decode(segmentation) def prepare_masks(filtered_annotations): masks = [] for idx, annotation in enumerate(filtered_annotations): segmentation = annotation["segmentation"] mask_id = annotation["id"] decoded_mask = decode_segmentation(segmentation) # Decode binary mask masks.append((decoded_mask, f"{mask_id}")) # Add mask and its label return masks # Load image and annotations dynamically def load_image_and_masks(image_id): # Get image filename and annotation file image_filename = f"{image_id}.jpg" annotation_file = os.path.join(json_dir, f"{image_id}.json") image_path = os.path.join(img_dir, image_filename) # Load annotations with open(annotation_file, "r", encoding="utf-8") as f: annotations_data = json.load(f) annotations = annotations_data["annotations"] # Get relevant mask IDs relevant_mask_ids = set(map(int, captions_data.get(image_id, {}).keys())) # Mask IDs in captions file # Filter annotations to only include relevant masks filtered_annotations = [annotation for annotation in annotations if annotation["id"] in relevant_mask_ids] # Prepare masks masks = prepare_masks(filtered_annotations) return image_path, masks # Gradio event function to display captions def display_caption(evt: gr.SelectData, masks, image_id): # Extract the mask ID from the label # mask_id = int(masks[evt.index][1]) # Get the label corresponding to the selected mask mask_id = masks[evt.index][1] # Get the label corresponding to the selected mask caption_data = captions_data.get(image_id, {}).get(str(mask_id), "No caption found") return caption_data # def display_caption(evt: gr.SelectData, masks, image_id): # # Get the label, e.g., "mask 1" # label = masks[evt.index][1] # # Use the label directly as key in the captions_data # caption_data = captions_data.get(image_id, {}).get(label, "No caption found") # return caption_data # Gradio event function to update image def update_image(image_index): image_id = image_ids[image_index] image_path, masks = load_image_and_masks(image_id) # Return the correct tuple structure return (image_path, [(m[0], m[1]) for m in masks]), masks, image_id, image_index # Initialize first image and masks initial_image_id = image_ids[0] initial_image_path, initial_masks = load_image_and_masks(initial_image_id) # Gradio event function to reload JSON data def reload_data(image_index): global captions_data, image_ids # Reload the captions data captions_data = format_json(caption_file) # Update image IDs image_ids = list(captions_data.keys()) # Ensure the current index is within the updated range image_index = min(image_index, len(image_ids) - 1) # Reload the image and masks for the current index image_id = image_ids[image_index] image_path, masks = load_image_and_masks(image_id) return (image_path, [(m[0], m[1]) for m in masks]), masks, image_id, image_index, len(image_ids) - 1 # Add Reload button to the interface with gr.Blocks() as demo: gr.Markdown("## URECA Dataset Visualization") # Annotated Image component with gr.Row(): annotated_img = gr.AnnotatedImage( value=(initial_image_path, [(m[0], m[1]) for m in initial_masks]), label="Annotated Image", height=400, ) # Caption display with gr.Row(): scene_caption = gr.Textbox(label="Generated Caption", interactive=False, lines=5) # object_caption = gr.Textbox(label="Object Caption", interactive=False, lines=5) # Navigation controls with gr.Row(): prev_button = gr.Button("Prev Image") slider = gr.Slider(0, len(image_ids) - 1, step=1, value=0, label="Jump to Image") next_button = gr.Button("Next Image") with gr.Row(): reload_button = gr.Button("🔄 Reload Data") # State to store current masks and image_id current_masks = gr.State(initial_masks) current_image_id = gr.State(initial_image_id) # Event listener for mask selection annotated_img.select( fn=display_caption, inputs=[current_masks, current_image_id], outputs=[scene_caption] ) # Update image based on slider or button def prev_image(image_index): new_index = max(0, image_index - 1) return update_image(new_index) def next_image(image_index): new_index = min(len(image_ids) - 1, image_index + 1) return update_image(new_index) prev_button.click( fn=prev_image, inputs=slider, outputs=[annotated_img, current_masks, current_image_id, slider] ) next_button.click( fn=next_image, inputs=slider, outputs=[annotated_img, current_masks, current_image_id, slider] ) slider.release( fn=update_image, inputs=slider, outputs=[annotated_img, current_masks, current_image_id, slider] ) # Reload button functionality reload_button.click( fn=reload_data, inputs=slider, # Pass the current image index outputs=[annotated_img, current_masks, current_image_id, slider, slider] # Update slider range as well ) # Launch the Gradio app demo.launch()