LPX55 commited on
Commit
76f2c71
·
verified ·
1 Parent(s): c4251ab

Update sam2_mask.py

Browse files
Files changed (1) hide show
  1. sam2_mask.py +21 -16
sam2_mask.py CHANGED
@@ -12,16 +12,16 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
 
13
  def preprocess_image(image):
14
  return image, gr.State([]), gr.State([]), image
15
-
16
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
17
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
18
- tracking_points.value.append(evt.index)
19
- print(f"TRACKING POINTS: {tracking_points.value}")
20
  if point_type == "include":
21
- trackings_input_label.value.append(1)
22
  elif point_type == "exclude":
23
- trackings_input_label.value.append(0)
24
- print(f"TRACKING INPUT LABELS: {trackings_input_label.value}")
25
  # Open the image and get its dimensions
26
  transparent_background = Image.open(first_frame_path).convert('RGBA')
27
  w, h = transparent_background.size
@@ -30,16 +30,16 @@ def get_point(point_type, tracking_points, trackings_input_label, first_frame_pa
30
  radius = int(fraction * min(w, h))
31
  # Create a transparent layer to draw on
32
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
33
- for index, track in enumerate(tracking_points.value):
34
- if trackings_input_label.value[index] == 1:
35
- cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
36
  else:
37
- cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
38
  # Convert the transparent layer back to an image
39
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
40
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
41
  return tracking_points, trackings_input_label, selected_point_map
42
-
43
  def show_mask(mask, ax, random_color=False, borders=True):
44
  if random_color:
45
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
@@ -99,21 +99,21 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
99
  mask_images.append(mask_filename)
100
  plt.close() # Close the figure to free up memory
101
  return combined_images, mask_images
102
-
103
  @spaces.GPU()
104
  def sam_process(original_image, points, labels):
105
  print(f"Points: {points}")
106
  print(f"Labels: {labels}")
 
 
 
107
  # Convert image to numpy array for SAM2 processing
108
  image = np.array(original_image)
109
  predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
110
  predictor.set_image(image)
111
  input_point = np.array(points)
112
  input_label = np.array(labels)
113
- if not input_point.size or not input_label.size:
114
- print("No points or labels provided, returning None")
115
- return None
116
- masks, scores, _= predictor.predict(input_point, input_label, multimask_output=False)
117
  sorted_indices = np.argsort(scores)[::-1]
118
  masks = masks[sorted_indices]
119
  # Generate mask image
@@ -129,12 +129,14 @@ def create_sam2_tab():
129
  with gr.Column():
130
  gr.Markdown("# SAM2 Image Predictor")
131
  gr.Markdown("1. Upload your image\n2. Click points to mask\n3. Submit")
 
132
  points_map = gr.Image(label="Points Map", type="pil", interactive=True)
133
  input_image = gr.Image(type="pil", visible=False) # Original image
134
 
135
  with gr.Row():
136
  point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
137
  clear_button = gr.Button("Clear Points")
 
138
  submit_button = gr.Button("Submit")
139
  output_image = gr.Image("Segmented Output")
140
 
@@ -144,16 +146,19 @@ def create_sam2_tab():
144
  inputs=points_map,
145
  outputs=[input_image, first_frame, tracking_points, trackings_input_label]
146
  )
 
147
  clear_button.click(
148
  lambda img: ([], [], img),
149
  inputs=first_frame,
150
  outputs=[tracking_points, trackings_input_label, points_map]
151
  )
 
152
  points_map.select(
153
  get_point,
154
  inputs=[point_type, tracking_points, trackings_input_label, first_frame],
155
  outputs=[tracking_points, trackings_input_label, points_map]
156
  )
 
157
  submit_button.click(
158
  sam_process,
159
  inputs=[input_image, tracking_points, trackings_input_label],
 
12
 
13
  def preprocess_image(image):
14
  return image, gr.State([]), gr.State([]), image
15
+
16
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
17
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
18
+ tracking_points.append(evt.index)
19
+ print(f"TRACKING POINTS: {tracking_points}")
20
  if point_type == "include":
21
+ trackings_input_label.append(1)
22
  elif point_type == "exclude":
23
+ trackings_input_label.append(0)
24
+ print(f"TRACKING INPUT LABELS: {trackings_input_label}")
25
  # Open the image and get its dimensions
26
  transparent_background = Image.open(first_frame_path).convert('RGBA')
27
  w, h = transparent_background.size
 
30
  radius = int(fraction * min(w, h))
31
  # Create a transparent layer to draw on
32
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
33
+ for index, track in enumerate(tracking_points):
34
+ if trackings_input_label[index] == 1:
35
+ cv2.circle(transparent_layer, tuple(track), radius, (0, 255, 0, 255), -1)
36
  else:
37
+ cv2.circle(transparent_layer, tuple(track), radius, (255, 0, 0, 255), -1)
38
  # Convert the transparent layer back to an image
39
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
40
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
41
  return tracking_points, trackings_input_label, selected_point_map
42
+
43
  def show_mask(mask, ax, random_color=False, borders=True):
44
  if random_color:
45
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
 
99
  mask_images.append(mask_filename)
100
  plt.close() # Close the figure to free up memory
101
  return combined_images, mask_images
102
+
103
  @spaces.GPU()
104
  def sam_process(original_image, points, labels):
105
  print(f"Points: {points}")
106
  print(f"Labels: {labels}")
107
+ if not points or not labels:
108
+ print("No points or labels provided, returning None")
109
+ return None
110
  # Convert image to numpy array for SAM2 processing
111
  image = np.array(original_image)
112
  predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
113
  predictor.set_image(image)
114
  input_point = np.array(points)
115
  input_label = np.array(labels)
116
+ masks, scores, _ = predictor.predict(input_point, input_label, multimask_output=False)
 
 
 
117
  sorted_indices = np.argsort(scores)[::-1]
118
  masks = masks[sorted_indices]
119
  # Generate mask image
 
129
  with gr.Column():
130
  gr.Markdown("# SAM2 Image Predictor")
131
  gr.Markdown("1. Upload your image\n2. Click points to mask\n3. Submit")
132
+
133
  points_map = gr.Image(label="Points Map", type="pil", interactive=True)
134
  input_image = gr.Image(type="pil", visible=False) # Original image
135
 
136
  with gr.Row():
137
  point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
138
  clear_button = gr.Button("Clear Points")
139
+
140
  submit_button = gr.Button("Submit")
141
  output_image = gr.Image("Segmented Output")
142
 
 
146
  inputs=points_map,
147
  outputs=[input_image, first_frame, tracking_points, trackings_input_label]
148
  )
149
+
150
  clear_button.click(
151
  lambda img: ([], [], img),
152
  inputs=first_frame,
153
  outputs=[tracking_points, trackings_input_label, points_map]
154
  )
155
+
156
  points_map.select(
157
  get_point,
158
  inputs=[point_type, tracking_points, trackings_input_label, first_frame],
159
  outputs=[tracking_points, trackings_input_label, points_map]
160
  )
161
+
162
  submit_button.click(
163
  sam_process,
164
  inputs=[input_image, tracking_points, trackings_input_label],