import gradio as gr import numpy as np from PIL import Image import sam_utils import matplotlib.pyplot as plt from io import BytesIO from sam2.sam2_image_predictor import SAM2ImagePredictor # Dummy placeholders for SAM2 functions (replace with real logic) def segment_reference(image, click): # click = [x, y] # Replace this with your SAM2 model's inference logic # Return a binary mask (numpy array with shape [H, W], values 0 or 1) print(f"Segmenting reference at point: {click}") width, height = image.size click = np.array(click) input_label = np.array([1 for _ in range(len(click))]) sam2_img.set_image(image) masks, _, _ = sam2_img.predict( point_coords=click, point_labels=input_label, multimask_output=False, ) return masks def segment_target(target_image, ref_image, ref_mask): target_image = np.array(target_image) ref_image = np.array(ref_image) state = sam_utils.load_masks(sam2_vid, [target_image], ref_image, ref_mask) out = sam_utils.propagate_masks(sam2_vid, state)[-1]['segmentation'] return out # Just for placeholder demo def visualize_segmentation(image, masks, target_image, target_mask): # Visualize the segmentation result fig, ax = plt.subplots(1, 2, figsize=(12, 6)) ax[0].imshow(image.convert("L"), cmap='gray') for i, mask in enumerate(masks): sam_utils.show_mask(mask, ax[0], obj_id=i, alpha=0.75) ax[0].axis('off') ax[0].set_title("Reference Image with Expert Segmentation") ax[1].imshow(target_image.convert("L"), cmap='gray') for i, mask in enumerate(target_mask): sam_utils.show_mask(mask, ax[1], obj_id=i, alpha=0.75) ax[1].axis('off') ax[1].set_title("Target Image with Inferred Segmentation") # save it to buffer plt.tight_layout() buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) vis = Image.open(buf).copy() plt.close(fig) buf.close() return vis # Store click coords globally (can be improved with state) click_coords = [] def record_click(img, evt: gr.SelectData): global click_coords click_coords.append([evt.index[0], evt.index[1]]) return f"Clicked at: {click_coords}" def generate(reference_image, target_image): if not click_coords: return None, "Click on the reference image first!" ref_mask = segment_reference(reference_image, click_coords) tgt_mask = segment_target(target_image, reference_image, ref_mask) vis = visualize_segmentation(reference_image, ref_mask, target_image, tgt_mask) return vis, "Done!" with gr.Blocks() as demo: gr.Markdown("### SST Demo: Label-Efficient Trait Segmentation") with gr.Row(): reference_img = gr.Image(type="pil", label="Reference Image") target_img = gr.Image(type="pil", label="Target Image") click_info = gr.Textbox(label="Click Info") generate_btn = gr.Button("Generate") output_mask = gr.Image(type="pil", label="Generated Mask") reference_img.select(fn=record_click, inputs=[reference_img], outputs=[click_info]) generate_btn.click(fn=generate, inputs=[reference_img, target_img], outputs=[output_mask, click_info]) global sam2_img sam2_img = sam_utils.load_SAM2(ckpt_path="checkpoints/sam2_hiera_small.pt", model_cfg_path="checkpoints/sam2_hiera_s.yaml") sam2_img = SAM2ImagePredictor(sam2_img) global sam2_vid sam2_vid = sam_utils.build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_small.pt", model_cfg="checkpoints/sam2_hiera_s.yaml") demo.launch()