Spaces:
Sleeping
Sleeping
from ultralytics import FastSAM | |
from ultralytics.models.fastsam import FastSAMPrompt | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
import os | |
import io | |
import numpy as np | |
import torch | |
import cv2 | |
from PIL import Image | |
def fig2img(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def plot( | |
annotations, | |
prompt_process, | |
bbox=None, | |
points=None, | |
point_label=None, | |
mask_random_color=True, | |
better_quality=True, | |
retina=False, | |
with_contours=True, | |
): | |
""" | |
Plots annotations, bounding boxes, and points on images and saves the output. | |
Args: | |
annotations (list): Annotations to be plotted. | |
output (str or Path): Output directory for saving the plots. | |
bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None. | |
points (list, optional): Points to be plotted. Defaults to None. | |
point_label (list, optional): Labels for the points. Defaults to None. | |
mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True. | |
better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. | |
Defaults to True. | |
retina (bool, optional): Whether to use retina mask. Defaults to False. | |
with_contours (bool, optional): Whether to plot contours. Defaults to True. | |
""" | |
# pbar = TQDM(annotations, total=len(annotations)) | |
for ann in annotations: | |
result_name = os.path.basename(ann.path) | |
image = ann.orig_img[..., ::-1] # BGR to RGB | |
original_h, original_w = ann.orig_shape | |
# For macOS only | |
# plt.switch_backend('TkAgg') | |
fig = plt.figure(figsize=(original_w / 100, original_h / 100)) | |
# Add subplot with no margin. | |
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) | |
plt.margins(0, 0) | |
plt.gca().xaxis.set_major_locator(plt.NullLocator()) | |
plt.gca().yaxis.set_major_locator(plt.NullLocator()) | |
plt.imshow(image) | |
if ann.masks is not None: | |
masks = ann.masks.data | |
if better_quality: | |
if isinstance(masks[0], torch.Tensor): | |
masks = np.array(masks.cpu()) | |
for i, mask in enumerate(masks): | |
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) | |
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) | |
prompt_process.fast_show_mask( | |
masks, | |
plt.gca(), | |
random_color=mask_random_color, | |
bbox=bbox, | |
points=points, | |
pointlabel=point_label, | |
retinamask=retina, | |
target_height=original_h, | |
target_width=original_w, | |
) | |
if with_contours: | |
contour_all = [] | |
temp = np.zeros((original_h, original_w, 1)) | |
for i, mask in enumerate(masks): | |
mask = mask.astype(np.uint8) | |
if not retina: | |
mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST) | |
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
contour_all.extend(iter(contours)) | |
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) | |
color = np.array([0 / 255, 0 / 255, 1.0, 0.8]) | |
contour_mask = temp / 255 * color.reshape(1, 1, -1) | |
plt.imshow(contour_mask) | |
# Save the figure | |
# save_path = Path(output) / result_name | |
# save_path.parent.mkdir(exist_ok=True, parents=True) | |
plt.axis("off") | |
# plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True) | |
plt.close() | |
# pbar.set_description(f"Saving {result_name} to {save_path}") | |
return fig2img(fig) | |
# Create a FastSAM model | |
model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt | |
def generateOutput(source): | |
everything_results = model(source, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9) | |
# Prepare a Prompt Process object | |
prompt_process = FastSAMPrompt(source, everything_results, device="cpu") | |
# Everything prompt | |
results = prompt_process.everything_prompt() | |
outputimage = plot(annotations=results, prompt_process=prompt_process) | |
return(outputimage) | |
title = "FastSAM Inference Trials" | |
description = "Shows the FastSAM related Inference Trials" | |
examples = [["941398-beautiful-farm-animals-wallpaper-2000x1402-for-meizu.jpg"], ["Puppies.jpg"], ["photo2.JPG"], ["MultipleItems.jpg"]] | |
demo = gr.Interface( | |
generateOutput, | |
inputs = [ | |
gr.Image(width=256, height=256, label="Input Image"), | |
], | |
outputs = [ | |
gr.Image(width=256, height=256, label="Output"), | |
], | |
title = title, | |
description = description, | |
examples = examples, | |
cache_examples=False | |
) | |
demo.launch() |