Chintan-Shah commited on
Commit
9949d8d
·
verified ·
1 Parent(s): 5f98b9f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import FastSAM
2
+ from ultralytics.models.fastsam import FastSAMPrompt
3
+ import matplotlib.pyplot as plt
4
+
5
+ import os
6
+ import io
7
+ import numpy as np
8
+ import torch
9
+ import cv2
10
+ from PIL import Image
11
+
12
+ def fig2img(fig):
13
+ buf = io.BytesIO()
14
+ fig.savefig(buf)
15
+ buf.seek(0)
16
+ img = Image.open(buf)
17
+ return img
18
+
19
+
20
+ def plot(
21
+ annotations,
22
+ prompt_process,
23
+ bbox=None,
24
+ points=None,
25
+ point_label=None,
26
+ mask_random_color=True,
27
+ better_quality=True,
28
+ retina=False,
29
+ with_contours=True,
30
+ ):
31
+ """
32
+ Plots annotations, bounding boxes, and points on images and saves the output.
33
+
34
+ Args:
35
+ annotations (list): Annotations to be plotted.
36
+ output (str or Path): Output directory for saving the plots.
37
+ bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
38
+ points (list, optional): Points to be plotted. Defaults to None.
39
+ point_label (list, optional): Labels for the points. Defaults to None.
40
+ mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
41
+ better_quality (bool, optional): Whether to apply morphological transformations for better mask quality.
42
+ Defaults to True.
43
+ retina (bool, optional): Whether to use retina mask. Defaults to False.
44
+ with_contours (bool, optional): Whether to plot contours. Defaults to True.
45
+ """
46
+
47
+ # pbar = TQDM(annotations, total=len(annotations))
48
+ for ann in annotations:
49
+ result_name = os.path.basename(ann.path)
50
+ image = ann.orig_img[..., ::-1] # BGR to RGB
51
+ original_h, original_w = ann.orig_shape
52
+ # For macOS only
53
+ # plt.switch_backend('TkAgg')
54
+ fig = plt.figure(figsize=(original_w / 100, original_h / 100))
55
+ # Add subplot with no margin.
56
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
57
+ plt.margins(0, 0)
58
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
59
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
60
+ plt.imshow(image)
61
+
62
+ if ann.masks is not None:
63
+ masks = ann.masks.data
64
+ if better_quality:
65
+ if isinstance(masks[0], torch.Tensor):
66
+ masks = np.array(masks.cpu())
67
+ for i, mask in enumerate(masks):
68
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
69
+ masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
70
+
71
+ prompt_process.fast_show_mask(
72
+ masks,
73
+ plt.gca(),
74
+ random_color=mask_random_color,
75
+ bbox=bbox,
76
+ points=points,
77
+ pointlabel=point_label,
78
+ retinamask=retina,
79
+ target_height=original_h,
80
+ target_width=original_w,
81
+ )
82
+
83
+ if with_contours:
84
+ contour_all = []
85
+ temp = np.zeros((original_h, original_w, 1))
86
+ for i, mask in enumerate(masks):
87
+ mask = mask.astype(np.uint8)
88
+ if not retina:
89
+ mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
90
+ contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
91
+ contour_all.extend(iter(contours))
92
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
93
+ color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
94
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
95
+ plt.imshow(contour_mask)
96
+
97
+ # Save the figure
98
+ # save_path = Path(output) / result_name
99
+ # save_path.parent.mkdir(exist_ok=True, parents=True)
100
+ plt.axis("off")
101
+ # plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
102
+ plt.close()
103
+ # pbar.set_description(f"Saving {result_name} to {save_path}")
104
+
105
+ return fig2img(fig)
106
+
107
+ # Create a FastSAM model
108
+ model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt
109
+
110
+ def generateOutput(source):
111
+ everything_results = model(source, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
112
+ # Prepare a Prompt Process object
113
+ prompt_process = FastSAMPrompt(source, everything_results, device="cpu")
114
+ # Everything prompt
115
+ results = prompt_process.everything_prompt()
116
+
117
+ outputimage = plot(annotations=results, prompt_process=prompt_process)
118
+
119
+ return(outputimage)
120
+
121
+ title = "FastSAM Inference Trials"
122
+ description = "Shows the FastSAM related Inference Trials"
123
+ examples = [["Elephants.jpg"], ["Puppies.jpg"], ["photo2.JPG"], ["MultipleItems.jpg"]]
124
+ demo = gr.Interface(
125
+ generateOutput,
126
+ inputs = [
127
+ gr.Image(width=256, height=256, label="Input Image"),
128
+ ],
129
+ outputs = [
130
+ gr.Image(width=256, height=256, label="Output"),
131
+ ],
132
+ title = title,
133
+ description = description,
134
+ examples = examples,
135
+ cache_examples=False
136
+ )
137
+ demo.launch()