from ultralytics import FastSAM from ultralytics.models.fastsam import FastSAMPrompt import matplotlib.pyplot as plt import os import io import numpy as np import torch import cv2 from PIL import Image def fig2img(fig): buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img def plot( annotations, prompt_process, bbox=None, points=None, point_label=None, mask_random_color=True, better_quality=True, retina=False, with_contours=True, ): """ Plots annotations, bounding boxes, and points on images and saves the output. Args: annotations (list): Annotations to be plotted. output (str or Path): Output directory for saving the plots. bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None. points (list, optional): Points to be plotted. Defaults to None. point_label (list, optional): Labels for the points. Defaults to None. mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True. better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True. retina (bool, optional): Whether to use retina mask. Defaults to False. with_contours (bool, optional): Whether to plot contours. Defaults to True. """ # pbar = TQDM(annotations, total=len(annotations)) for ann in annotations: result_name = os.path.basename(ann.path) image = ann.orig_img[..., ::-1] # BGR to RGB original_h, original_w = ann.orig_shape # For macOS only # plt.switch_backend('TkAgg') fig = plt.figure(figsize=(original_w / 100, original_h / 100)) # Add subplot with no margin. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.imshow(image) if ann.masks is not None: masks = ann.masks.data if better_quality: if isinstance(masks[0], torch.Tensor): masks = np.array(masks.cpu()) for i, mask in enumerate(masks): mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) prompt_process.fast_show_mask( masks, plt.gca(), random_color=mask_random_color, bbox=bbox, points=points, pointlabel=point_label, retinamask=retina, target_height=original_h, target_width=original_w, ) if with_contours: contour_all = [] temp = np.zeros((original_h, original_w, 1)) for i, mask in enumerate(masks): mask = mask.astype(np.uint8) if not retina: mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST) contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) contour_all.extend(iter(contours)) cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) color = np.array([0 / 255, 0 / 255, 1.0, 0.8]) contour_mask = temp / 255 * color.reshape(1, 1, -1) plt.imshow(contour_mask) # Save the figure # save_path = Path(output) / result_name # save_path.parent.mkdir(exist_ok=True, parents=True) plt.axis("off") # plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True) plt.close() # pbar.set_description(f"Saving {result_name} to {save_path}") return fig2img(fig) # Create a FastSAM model model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt def generateOutput(source): everything_results = model(source, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9) # Prepare a Prompt Process object prompt_process = FastSAMPrompt(source, everything_results, device="cpu") # Everything prompt results = prompt_process.everything_prompt() outputimage = plot(annotations=results, prompt_process=prompt_process) return(outputimage) title = "FastSAM Inference Trials" description = "Shows the FastSAM related Inference Trials" examples = [["Elephants.jpg"], ["Puppies.jpg"], ["photo2.JPG"], ["MultipleItems.jpg"]] demo = gr.Interface( generateOutput, inputs = [ gr.Image(width=256, height=256, label="Input Image"), ], outputs = [ gr.Image(width=256, height=256, label="Output"), ], title = title, description = description, examples = examples, cache_examples=False ) demo.launch()