Spaces:
Sleeping
Sleeping
File size: 5,134 Bytes
9949d8d 20ddd4d 9949d8d 3a9758c 9949d8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
import matplotlib.pyplot as plt
import gradio as gr
import os
import io
import numpy as np
import torch
import cv2
from PIL import Image
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def plot(
annotations,
prompt_process,
bbox=None,
points=None,
point_label=None,
mask_random_color=True,
better_quality=True,
retina=False,
with_contours=True,
):
"""
Plots annotations, bounding boxes, and points on images and saves the output.
Args:
annotations (list): Annotations to be plotted.
output (str or Path): Output directory for saving the plots.
bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
points (list, optional): Points to be plotted. Defaults to None.
point_label (list, optional): Labels for the points. Defaults to None.
mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
better_quality (bool, optional): Whether to apply morphological transformations for better mask quality.
Defaults to True.
retina (bool, optional): Whether to use retina mask. Defaults to False.
with_contours (bool, optional): Whether to plot contours. Defaults to True.
"""
# pbar = TQDM(annotations, total=len(annotations))
for ann in annotations:
result_name = os.path.basename(ann.path)
image = ann.orig_img[..., ::-1] # BGR to RGB
original_h, original_w = ann.orig_shape
# For macOS only
# plt.switch_backend('TkAgg')
fig = plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image)
if ann.masks is not None:
masks = ann.masks.data
if better_quality:
if isinstance(masks[0], torch.Tensor):
masks = np.array(masks.cpu())
for i, mask in enumerate(masks):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
prompt_process.fast_show_mask(
masks,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if with_contours:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(masks):
mask = mask.astype(np.uint8)
if not retina:
mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contour_all.extend(iter(contours))
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
# Save the figure
# save_path = Path(output) / result_name
# save_path.parent.mkdir(exist_ok=True, parents=True)
plt.axis("off")
# plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
plt.close()
# pbar.set_description(f"Saving {result_name} to {save_path}")
return fig2img(fig)
# Create a FastSAM model
model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt
def generateOutput(source):
everything_results = model(source, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
# Prepare a Prompt Process object
prompt_process = FastSAMPrompt(source, everything_results, device="cpu")
# Everything prompt
results = prompt_process.everything_prompt()
outputimage = plot(annotations=results, prompt_process=prompt_process)
return(outputimage)
title = "FastSAM Inference Trials"
description = "Shows the FastSAM related Inference Trials"
examples = [["941398-beautiful-farm-animals-wallpaper-2000x1402-for-meizu.jpg"], ["Puppies.jpg"], ["photo2.JPG"], ["MultipleItems.jpg"]]
demo = gr.Interface(
generateOutput,
inputs = [
gr.Image(width=256, height=256, label="Input Image"),
],
outputs = [
gr.Image(width=256, height=256, label="Output"),
],
title = title,
description = description,
examples = examples,
cache_examples=False
)
demo.launch() |