Spaces:
Running
on
Zero
Running
on
Zero
Update sam2_mask.py
Browse files- sam2_mask.py +21 -16
sam2_mask.py
CHANGED
@@ -12,16 +12,16 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
12 |
|
13 |
def preprocess_image(image):
|
14 |
return image, gr.State([]), gr.State([]), image
|
15 |
-
|
16 |
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
17 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
18 |
-
tracking_points.
|
19 |
-
print(f"TRACKING POINTS: {tracking_points
|
20 |
if point_type == "include":
|
21 |
-
trackings_input_label.
|
22 |
elif point_type == "exclude":
|
23 |
-
trackings_input_label.
|
24 |
-
print(f"TRACKING INPUT LABELS: {trackings_input_label
|
25 |
# Open the image and get its dimensions
|
26 |
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
27 |
w, h = transparent_background.size
|
@@ -30,16 +30,16 @@ def get_point(point_type, tracking_points, trackings_input_label, first_frame_pa
|
|
30 |
radius = int(fraction * min(w, h))
|
31 |
# Create a transparent layer to draw on
|
32 |
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
33 |
-
for index, track in enumerate(tracking_points
|
34 |
-
if trackings_input_label
|
35 |
-
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
|
36 |
else:
|
37 |
-
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
|
38 |
# Convert the transparent layer back to an image
|
39 |
transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
|
40 |
selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
|
41 |
return tracking_points, trackings_input_label, selected_point_map
|
42 |
-
|
43 |
def show_mask(mask, ax, random_color=False, borders=True):
|
44 |
if random_color:
|
45 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
@@ -99,21 +99,21 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
|
|
99 |
mask_images.append(mask_filename)
|
100 |
plt.close() # Close the figure to free up memory
|
101 |
return combined_images, mask_images
|
102 |
-
|
103 |
@spaces.GPU()
|
104 |
def sam_process(original_image, points, labels):
|
105 |
print(f"Points: {points}")
|
106 |
print(f"Labels: {labels}")
|
|
|
|
|
|
|
107 |
# Convert image to numpy array for SAM2 processing
|
108 |
image = np.array(original_image)
|
109 |
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
|
110 |
predictor.set_image(image)
|
111 |
input_point = np.array(points)
|
112 |
input_label = np.array(labels)
|
113 |
-
|
114 |
-
print("No points or labels provided, returning None")
|
115 |
-
return None
|
116 |
-
masks, scores, _= predictor.predict(input_point, input_label, multimask_output=False)
|
117 |
sorted_indices = np.argsort(scores)[::-1]
|
118 |
masks = masks[sorted_indices]
|
119 |
# Generate mask image
|
@@ -129,12 +129,14 @@ def create_sam2_tab():
|
|
129 |
with gr.Column():
|
130 |
gr.Markdown("# SAM2 Image Predictor")
|
131 |
gr.Markdown("1. Upload your image\n2. Click points to mask\n3. Submit")
|
|
|
132 |
points_map = gr.Image(label="Points Map", type="pil", interactive=True)
|
133 |
input_image = gr.Image(type="pil", visible=False) # Original image
|
134 |
|
135 |
with gr.Row():
|
136 |
point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
|
137 |
clear_button = gr.Button("Clear Points")
|
|
|
138 |
submit_button = gr.Button("Submit")
|
139 |
output_image = gr.Image("Segmented Output")
|
140 |
|
@@ -144,16 +146,19 @@ def create_sam2_tab():
|
|
144 |
inputs=points_map,
|
145 |
outputs=[input_image, first_frame, tracking_points, trackings_input_label]
|
146 |
)
|
|
|
147 |
clear_button.click(
|
148 |
lambda img: ([], [], img),
|
149 |
inputs=first_frame,
|
150 |
outputs=[tracking_points, trackings_input_label, points_map]
|
151 |
)
|
|
|
152 |
points_map.select(
|
153 |
get_point,
|
154 |
inputs=[point_type, tracking_points, trackings_input_label, first_frame],
|
155 |
outputs=[tracking_points, trackings_input_label, points_map]
|
156 |
)
|
|
|
157 |
submit_button.click(
|
158 |
sam_process,
|
159 |
inputs=[input_image, tracking_points, trackings_input_label],
|
|
|
12 |
|
13 |
def preprocess_image(image):
|
14 |
return image, gr.State([]), gr.State([]), image
|
15 |
+
|
16 |
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
17 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
18 |
+
tracking_points.append(evt.index)
|
19 |
+
print(f"TRACKING POINTS: {tracking_points}")
|
20 |
if point_type == "include":
|
21 |
+
trackings_input_label.append(1)
|
22 |
elif point_type == "exclude":
|
23 |
+
trackings_input_label.append(0)
|
24 |
+
print(f"TRACKING INPUT LABELS: {trackings_input_label}")
|
25 |
# Open the image and get its dimensions
|
26 |
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
27 |
w, h = transparent_background.size
|
|
|
30 |
radius = int(fraction * min(w, h))
|
31 |
# Create a transparent layer to draw on
|
32 |
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
33 |
+
for index, track in enumerate(tracking_points):
|
34 |
+
if trackings_input_label[index] == 1:
|
35 |
+
cv2.circle(transparent_layer, tuple(track), radius, (0, 255, 0, 255), -1)
|
36 |
else:
|
37 |
+
cv2.circle(transparent_layer, tuple(track), radius, (255, 0, 0, 255), -1)
|
38 |
# Convert the transparent layer back to an image
|
39 |
transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
|
40 |
selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
|
41 |
return tracking_points, trackings_input_label, selected_point_map
|
42 |
+
|
43 |
def show_mask(mask, ax, random_color=False, borders=True):
|
44 |
if random_color:
|
45 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
|
99 |
mask_images.append(mask_filename)
|
100 |
plt.close() # Close the figure to free up memory
|
101 |
return combined_images, mask_images
|
102 |
+
|
103 |
@spaces.GPU()
|
104 |
def sam_process(original_image, points, labels):
|
105 |
print(f"Points: {points}")
|
106 |
print(f"Labels: {labels}")
|
107 |
+
if not points or not labels:
|
108 |
+
print("No points or labels provided, returning None")
|
109 |
+
return None
|
110 |
# Convert image to numpy array for SAM2 processing
|
111 |
image = np.array(original_image)
|
112 |
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
|
113 |
predictor.set_image(image)
|
114 |
input_point = np.array(points)
|
115 |
input_label = np.array(labels)
|
116 |
+
masks, scores, _ = predictor.predict(input_point, input_label, multimask_output=False)
|
|
|
|
|
|
|
117 |
sorted_indices = np.argsort(scores)[::-1]
|
118 |
masks = masks[sorted_indices]
|
119 |
# Generate mask image
|
|
|
129 |
with gr.Column():
|
130 |
gr.Markdown("# SAM2 Image Predictor")
|
131 |
gr.Markdown("1. Upload your image\n2. Click points to mask\n3. Submit")
|
132 |
+
|
133 |
points_map = gr.Image(label="Points Map", type="pil", interactive=True)
|
134 |
input_image = gr.Image(type="pil", visible=False) # Original image
|
135 |
|
136 |
with gr.Row():
|
137 |
point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
|
138 |
clear_button = gr.Button("Clear Points")
|
139 |
+
|
140 |
submit_button = gr.Button("Submit")
|
141 |
output_image = gr.Image("Segmented Output")
|
142 |
|
|
|
146 |
inputs=points_map,
|
147 |
outputs=[input_image, first_frame, tracking_points, trackings_input_label]
|
148 |
)
|
149 |
+
|
150 |
clear_button.click(
|
151 |
lambda img: ([], [], img),
|
152 |
inputs=first_frame,
|
153 |
outputs=[tracking_points, trackings_input_label, points_map]
|
154 |
)
|
155 |
+
|
156 |
points_map.select(
|
157 |
get_point,
|
158 |
inputs=[point_type, tracking_points, trackings_input_label, first_frame],
|
159 |
outputs=[tracking_points, trackings_input_label, points_map]
|
160 |
)
|
161 |
+
|
162 |
submit_button.click(
|
163 |
sam_process,
|
164 |
inputs=[input_image, tracking_points, trackings_input_label],
|