LPX55 commited on
Commit
8fa6d2e
·
verified ·
1 Parent(s): 76f2c71

Update sam2_mask.py

Browse files
Files changed (1) hide show
  1. sam2_mask.py +51 -20
sam2_mask.py CHANGED
@@ -39,7 +39,15 @@ def get_point(point_type, tracking_points, trackings_input_label, first_frame_pa
39
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
40
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
41
  return tracking_points, trackings_input_label, selected_point_map
 
 
 
42
 
 
 
 
 
 
43
  def show_mask(mask, ax, random_color=False, borders=True):
44
  if random_color:
45
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
@@ -102,8 +110,12 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
102
 
103
  @spaces.GPU()
104
  def sam_process(original_image, points, labels):
 
105
  print(f"Points: {points}")
106
  print(f"Labels: {labels}")
 
 
 
107
  if not points or not labels:
108
  print("No points or labels provided, returning None")
109
  return None
@@ -111,15 +123,26 @@ def sam_process(original_image, points, labels):
111
  image = np.array(original_image)
112
  predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
113
  predictor.set_image(image)
114
- input_point = np.array(points)
115
- input_label = np.array(labels)
116
- masks, scores, _ = predictor.predict(input_point, input_label, multimask_output=False)
 
 
 
 
 
 
 
117
  sorted_indices = np.argsort(scores)[::-1]
118
  masks = masks[sorted_indices]
119
- # Generate mask image
120
- mask = masks[0] * 255
121
- mask_image = Image.fromarray(mask.astype(np.uint8))
122
- return mask_image
 
 
 
 
123
 
124
  def create_sam2_tab():
125
  first_frame = gr.State() # Tracks original image
@@ -127,24 +150,32 @@ def create_sam2_tab():
127
  trackings_input_label = gr.State([])
128
 
129
  with gr.Column():
130
- gr.Markdown("# SAM2 Image Predictor")
131
- gr.Markdown("1. Upload your image\n2. Click points to mask\n3. Submit")
132
-
133
- points_map = gr.Image(label="Points Map", type="pil", interactive=True)
134
- input_image = gr.Image(type="pil", visible=False) # Original image
135
-
136
  with gr.Row():
137
- point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
138
- clear_button = gr.Button("Clear Points")
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- submit_button = gr.Button("Submit")
141
- output_image = gr.Image("Segmented Output")
 
142
 
143
  # Event handlers
144
  points_map.upload(
145
  lambda img: (img, img, [], []),
146
  inputs=points_map,
147
- outputs=[input_image, first_frame, tracking_points, trackings_input_label]
148
  )
149
 
150
  clear_button.click(
@@ -161,8 +192,8 @@ def create_sam2_tab():
161
 
162
  submit_button.click(
163
  sam_process,
164
- inputs=[input_image, tracking_points, trackings_input_label],
165
  outputs=output_image
166
  )
167
 
168
- return input_image, points_map, output_image
 
39
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
40
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
41
  return tracking_points, trackings_input_label, selected_point_map
42
+
43
+ # use bfloat16 for the entire notebook
44
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
45
 
46
+ if torch.cuda.get_device_properties(0).major >= 8:
47
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
48
+ torch.backends.cuda.matmul.allow_tf32 = True
49
+ torch.backends.cudnn.allow_tf32 = True
50
+
51
  def show_mask(mask, ax, random_color=False, borders=True):
52
  if random_color:
53
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
 
110
 
111
  @spaces.GPU()
112
  def sam_process(original_image, points, labels):
113
+
114
  print(f"Points: {points}")
115
  print(f"Labels: {labels}")
116
+ image = Image.open(original_image)
117
+ image = np.array(image.convert("RGB"))
118
+
119
  if not points or not labels:
120
  print("No points or labels provided, returning None")
121
  return None
 
123
  image = np.array(original_image)
124
  predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
125
  predictor.set_image(image)
126
+ input_point = np.array(points.value)
127
+ input_label = np.array(labels.value)
128
+
129
+ print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
130
+
131
+ masks, scores, logits = predictor.predict(
132
+ point_coords=input_point,
133
+ point_labels=input_label,
134
+ multimask_output=False,
135
+ )
136
  sorted_indices = np.argsort(scores)[::-1]
137
  masks = masks[sorted_indices]
138
+ scores = scores[sorted_indices]
139
+ logits = logits[sorted_indices]
140
+ print(masks.shape)
141
+
142
+ results, mask_results = show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
143
+ print(results)
144
+
145
+ return results[0], mask_results[0]
146
 
147
  def create_sam2_tab():
148
  first_frame = gr.State() # Tracks original image
 
150
  trackings_input_label = gr.State([])
151
 
152
  with gr.Column():
 
 
 
 
 
 
153
  with gr.Row():
154
+ with gr.Column():
155
+ sam_input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
156
+ points_map = gr.Image(
157
+ label="points map",
158
+ type="filepath",
159
+ interactive=True
160
+ )
161
+ with gr.Row():
162
+ point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
163
+ clear_button = gr.Button("Clear Points")
164
+ submit_button = gr.Button("Submit")
165
+
166
+ with gr.Column():
167
+ output_image = gr.Image("Segmented Output")
168
+ output_result_mask = gr.Image()
169
 
170
+
171
+
172
+
173
 
174
  # Event handlers
175
  points_map.upload(
176
  lambda img: (img, img, [], []),
177
  inputs=points_map,
178
+ outputs=[sam_input_image, first_frame, tracking_points, trackings_input_label]
179
  )
180
 
181
  clear_button.click(
 
192
 
193
  submit_button.click(
194
  sam_process,
195
+ inputs=[sam_input_image, tracking_points, trackings_input_label],
196
  outputs=output_image
197
  )
198
 
199
+ return sam_input_image, points_map, output_image