from ultralytics import FastSAM from ultralytics.models.fastsam import FastSAMPrompt import matplotlib.pyplot as plt import gradio as gr import os import io import numpy as np import torch import cv2 from PIL import Image def fig2img(fig): buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img def plot( annotations, prompt_process, bbox=None, points=None, point_label=None, mask_random_color=True, better_quality=True, retina=False, with_contours=True, ): """ Plots annotations, bounding boxes, and points on images and saves the output. Args: annotations (list): Annotations to be plotted. output (str or Path): Output directory for saving the plots. bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None. points (list, optional): Points to be plotted. Defaults to None. point_label (list, optional): Labels for the points. Defaults to None. mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True. better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True. retina (bool, optional): Whether to use retina mask. Defaults to False. with_contours (bool, optional): Whether to plot contours. Defaults to True. """ # pbar = TQDM(annotations, total=len(annotations)) for ann in annotations: result_name = os.path.basename(ann.path) image = ann.orig_img[..., ::-1] # BGR to RGB original_h, original_w = ann.orig_shape # For macOS only # plt.switch_backend('TkAgg') fig = plt.figure(figsize=(original_w / 100, original_h / 100)) # Add subplot with no margin. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.imshow(image) if ann.masks is not None: masks = ann.masks.data if better_quality: if isinstance(masks[0], torch.Tensor): masks = np.array(masks.cpu()) for i, mask in enumerate(masks): mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) prompt_process.fast_show_mask( masks, plt.gca(), random_color=mask_random_color, bbox=bbox, points=points, pointlabel=point_label, retinamask=retina, target_height=original_h, target_width=original_w, ) if with_contours: contour_all = [] temp = np.zeros((original_h, original_w, 1)) for i, mask in enumerate(masks): mask = mask.astype(np.uint8) if not retina: mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST) contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) contour_all.extend(iter(contours)) cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) color = np.array([0 / 255, 0 / 255, 1.0, 0.8]) contour_mask = temp / 255 * color.reshape(1, 1, -1) plt.imshow(contour_mask) # Save the figure # save_path = Path(output) / result_name # save_path.parent.mkdir(exist_ok=True, parents=True) plt.axis("off") # plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True) plt.close() # pbar.set_description(f"Saving {result_name} to {save_path}") return fig2img(fig) # Create a FastSAM model model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt def generateOutput(source): everything_results = model(source, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9) # Prepare a Prompt Process object prompt_process = FastSAMPrompt(source, everything_results, device="cpu") # Everything prompt results = prompt_process.everything_prompt() outputimage = plot(annotations=results, prompt_process=prompt_process) return(outputimage) title = "FastSAM Inference Trials" description = "Shows the FastSAM related Inference Trials" examples = [["941398-beautiful-farm-animals-wallpaper-2000x1402-for-meizu.jpg"], ["Puppies.jpg"], ["photo2.JPG"], ["MultipleItems.jpg"]] demo = gr.Interface( generateOutput, inputs = [ gr.Image(width=256, height=256, label="Input Image"), ], outputs = [ gr.Image(width=256, height=256, label="Output"), ], title = title, description = description, examples = examples, cache_examples=False ) demo.launch()