LPX55 commited on
Commit
c4299c8
·
1 Parent(s): ea0d88d

temp seg solution

Browse files
Files changed (1) hide show
  1. app.py +83 -79
app.py CHANGED
@@ -10,10 +10,12 @@ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
10
  from gradio_image_prompter import ImagePrompter
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
- from sam2.sam2_image_predictor import SAM2ImagePredictor
14
- from sam2_mask import create_sam2_tab
15
  import subprocess
 
16
 
 
17
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
18
 
19
  # class SAM2PredictorSingleton:
@@ -79,57 +81,57 @@ def load_default_pipeline():
79
  ).to("cuda")
80
  return gr.update(value="Default pipeline loaded!")
81
 
82
- @spaces.GPU()
83
- def predict_masks(prompts):
84
-
85
- DEVICE = torch.device("cuda")
86
- SAM_MODEL = "facebook/sam2.1-hiera-large"
87
- # if PREDICTOR is None:
88
- # PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
89
- # else:
90
- # PREDICTOR = PREDICTOR
91
- PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
92
-
93
- """Predict a single mask from the image based on selected points."""
94
- image = np.array(prompts["image"]) # Convert the image to a numpy array
95
- points = prompts["points"] # Get the points from prompts
96
-
97
- if not points:
98
- return image # Return the original image if no points are selected
99
-
100
- # Debugging: Print the structure of points
101
- print(f"Points structure: {points}")
102
-
103
- # Ensure points is a list of lists with at least two elements
104
- if isinstance(points, list) and all(isinstance(point, list) and len(point) >= 2 for point in points):
105
- input_points = [[point[0], point[1]] for point in points]
106
- else:
107
- return image # Return the original image if points structure is unexpected
108
-
109
- input_labels = [1] * len(input_points)
110
-
111
- with torch.inference_mode():
112
- PREDICTOR.set_image(image)
113
- masks, _, _ = PREDICTOR.predict(
114
- point_coords=input_points, point_labels=input_labels, multimask_output=False
115
- )
116
-
117
- # Prepare the overlay image
118
- red_mask = np.zeros_like(image)
119
- if masks and len(masks) > 0:
120
- red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
121
- red_mask = PILImage.fromarray(red_mask)
122
- original_image = PILImage.fromarray(image)
123
- blended_image = PILImage.blend(original_image, red_mask, alpha=0.5)
124
- return np.array(blended_image)
125
- else:
126
- return image
127
-
128
- def update_mask(prompts):
129
- """Update the mask based on the prompts."""
130
- image = prompts["image"]
131
- points = prompts["points"]
132
- return predict_masks(image, points)
133
 
134
 
135
  @spaces.GPU(duration=12)
@@ -558,33 +560,35 @@ with gr.Blocks(css=css, fill_height=True) as demo:
558
  use_as_input_button_outpaint = gr.Button("Use as Input Image", visible=False)
559
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
560
  preview_image = gr.Image(label="Preview")
561
- with gr.TabItem("SAM2 Masking"):
562
- input_image, points_map, output_result_mask = create_sam2_tab()
563
- with gr.TabItem("SAM2 Mask"):
564
- gr.Markdown("# Object Segmentation with SAM2")
565
- gr.Markdown(
566
- """
567
- This application utilizes **Segment Anything V2 (SAM2)** to allow you to upload an image and interactively generate a segmentation mask based on multiple points you select on the image.
568
- """
569
- )
570
- with gr.Row():
571
- with gr.Column():
572
- image_input = gr.State()
573
- # Input: ImagePrompter for uploaded image
574
- upload_image_input = ImagePrompter(show_label=False)
575
- with gr.Column():
576
- image_output = gr.Image(label="Segmented Image", type="pil", height=400)
577
- with gr.Row():
578
- # Button to trigger the prediction
579
- predict_button = gr.Button("Predict Mask")
580
 
581
- # Define the action triggered by the predict button
582
- predict_button.click(
583
- fn=predict_masks,
584
- inputs=[upload_image_input],
585
- outputs=[image_output],
586
- show_progress=True,
587
- )
 
 
588
  # Define the action triggered by the upload_image_input change
589
  # upload_image_input.change(
590
  # fn=update_mask,
 
10
  from gradio_image_prompter import ImagePrompter
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
+ # from sam2.sam2_image_predictor import SAM2ImagePredictor
14
+ # from sam2_mask import create_sam2_tab
15
  import subprocess
16
+ import os
17
 
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
20
 
21
  # class SAM2PredictorSingleton:
 
81
  ).to("cuda")
82
  return gr.update(value="Default pipeline loaded!")
83
 
84
+ # @spaces.GPU()
85
+ # def predict_masks(prompts):
86
+
87
+ # DEVICE = torch.device("cuda")
88
+ # SAM_MODEL = "facebook/sam2.1-hiera-large"
89
+ # # if PREDICTOR is None:
90
+ # # PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
91
+ # # else:
92
+ # # PREDICTOR = PREDICTOR
93
+ # PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
94
+
95
+ # """Predict a single mask from the image based on selected points."""
96
+ # image = np.array(prompts["image"]) # Convert the image to a numpy array
97
+ # points = prompts["points"] # Get the points from prompts
98
+
99
+ # if not points:
100
+ # return image # Return the original image if no points are selected
101
+
102
+ # # Debugging: Print the structure of points
103
+ # print(f"Points structure: {points}")
104
+
105
+ # # Ensure points is a list of lists with at least two elements
106
+ # if isinstance(points, list) and all(isinstance(point, list) and len(point) >= 2 for point in points):
107
+ # input_points = [[point[0], point[1]] for point in points]
108
+ # else:
109
+ # return image # Return the original image if points structure is unexpected
110
+
111
+ # input_labels = [1] * len(input_points)
112
+
113
+ # with torch.inference_mode():
114
+ # PREDICTOR.set_image(image)
115
+ # masks, _, _ = PREDICTOR.predict(
116
+ # point_coords=input_points, point_labels=input_labels, multimask_output=False
117
+ # )
118
+
119
+ # # Prepare the overlay image
120
+ # red_mask = np.zeros_like(image)
121
+ # if masks and len(masks) > 0:
122
+ # red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
123
+ # red_mask = PILImage.fromarray(red_mask)
124
+ # original_image = PILImage.fromarray(image)
125
+ # blended_image = PILImage.blend(original_image, red_mask, alpha=0.5)
126
+ # return np.array(blended_image)
127
+ # else:
128
+ # return image
129
+
130
+ # def update_mask(prompts):
131
+ # """Update the mask based on the prompts."""
132
+ # image = prompts["image"]
133
+ # points = prompts["points"]
134
+ # return predict_masks(image, points)
135
 
136
 
137
  @spaces.GPU(duration=12)
 
560
  use_as_input_button_outpaint = gr.Button("Use as Input Image", visible=False)
561
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
562
  preview_image = gr.Image(label="Preview")
563
+ # with gr.TabItem("SAM2 Masking"):
564
+ # input_image, points_map, output_result_mask = create_sam2_tab()
565
+ # with gr.TabItem("SAM2 Mask"):
566
+ # gr.Markdown("# Object Segmentation with SAM2")
567
+ # gr.Markdown(
568
+ # """
569
+ # This application utilizes **Segment Anything V2 (SAM2)** to allow you to upload an image and interactively generate a segmentation mask based on multiple points you select on the image.
570
+ # """
571
+ # )
572
+ # with gr.Row():
573
+ # with gr.Column():
574
+ # image_input = gr.State()
575
+ # # Input: ImagePrompter for uploaded image
576
+ # upload_image_input = ImagePrompter(show_label=False)
577
+ # with gr.Column():
578
+ # image_output = gr.Image(label="Segmented Image", type="pil", height=400)
579
+ # with gr.Row():
580
+ # # Button to trigger the prediction
581
+ # predict_button = gr.Button("Predict Mask")
582
 
583
+ # # Define the action triggered by the predict button
584
+ # predict_button.click(
585
+ # fn=predict_masks,
586
+ # inputs=[upload_image_input],
587
+ # outputs=[image_output],
588
+ # show_progress=True,
589
+ # )
590
+ with gr.Tab("SAM2.1 Segmented Mask"):
591
+ temp_space = gr.load("LPX55/SAM2-Image-Predictor-CPU", src="spaces", token=HF_TOKEN)
592
  # Define the action triggered by the upload_image_input change
593
  # upload_image_input.change(
594
  # fn=update_mask,