reachomk commited on
Commit
353e8fc
·
verified ·
1 Parent(s): 87fc13a

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +280 -0
  2. gen2seg_mae_pipeline.py +132 -0
  3. gen2seg_sd_pipeline.py +454 -0
  4. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ import time
6
+ import os
7
+
8
+ # --- Import Custom Pipelines ---
9
+ # Ensure these files are in the same directory or accessible in PYTHONPATH
10
+ try:
11
+ from gen2seg_sd_pipeline import gen2segSDPipeline
12
+ from gen2seg_mae_pipeline import gen2segMAEInstancePipeline
13
+ except ImportError as e:
14
+ print(f"Error importing pipeline modules: {e}")
15
+ print("Please ensure gen2seg_sd_pipeline.py and gen2seg_mae_pipeline.py are in the same directory.")
16
+ # Optionally, raise an error or exit if pipelines are critical at startup
17
+ # raise ImportError("Could not import custom pipeline modules. Check file paths.") from e
18
+
19
+ from transformers import ViTMAEForPreTraining, AutoImageProcessor
20
+
21
+ # --- Configuration ---
22
+ MODEL_IDS = {
23
+ "SD": "reachomk/gen2seg-sd",
24
+ "MAE-H": "reachomk/gen2seg-mae-h"
25
+ }
26
+
27
+ # Check if a GPU is available and set the device accordingly
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ print(f"Using device: {DEVICE}")
30
+
31
+ # --- Global Variables for Caching Pipelines ---
32
+ sd_pipe_global = None
33
+ mae_pipe_global = None
34
+
35
+ # --- Model Loading Functions ---
36
+ def get_sd_pipeline():
37
+ """Loads and caches the gen2seg Stable Diffusion pipeline."""
38
+ global sd_pipe_global
39
+ if sd_pipe_global is None:
40
+ model_id_sd = MODEL_IDS["SD"]
41
+ print(f"Attempting to load SD pipeline from Hugging Face Hub: {model_id_sd}")
42
+ try:
43
+ sd_pipe_global = gen2segSDPipeline.from_pretrained(
44
+ model_id_sd,
45
+ use_safetensors=True,
46
+ # torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, # Optional: use float16 on GPU
47
+ ).to(DEVICE)
48
+ print(f"SD Pipeline loaded successfully from {model_id_sd} on {DEVICE}.")
49
+ except Exception as e:
50
+ print(f"Error loading SD pipeline from Hugging Face Hub ({model_id_sd}): {e}")
51
+ sd_pipe_global = None # Ensure it remains None on failure
52
+ # Do not raise gr.Error here; let the main function handle it.
53
+ return sd_pipe_global
54
+
55
+ def get_mae_pipeline():
56
+ """Loads and caches the gen2seg MAE-H pipeline."""
57
+ global mae_pipe_global
58
+ if mae_pipe_global is None:
59
+ model_id_mae = MODEL_IDS["MAE-H"]
60
+ print(f"Loading MAE-H pipeline with model {model_id_mae} on {DEVICE}...")
61
+ try:
62
+ model = ViTMAEForPreTraining.from_pretrained(model_id_mae)
63
+ model.to(DEVICE)
64
+ model.eval() # Set to evaluation mode
65
+
66
+ # Load the official MAE-H image processor
67
+ # Using "facebook/vit-mae-huge" as per the original app_mae.py
68
+ image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-huge")
69
+
70
+ mae_pipe_global = gen2segMAEInstancePipeline(model=model, image_processor=image_processor)
71
+ # The custom MAE pipeline's model is already on the DEVICE.
72
+ print(f"MAE-H Pipeline with model {model_id_mae} loaded successfully on {DEVICE}.")
73
+ except Exception as e:
74
+ print(f"Error loading MAE-H model or pipeline from Hugging Face Hub ({model_id_mae}): {e}")
75
+ mae_pipe_global = None # Ensure it remains None on failure
76
+ # Do not raise gr.Error here; let the main function handle it.
77
+ return mae_pipe_global
78
+
79
+ # --- Unified Prediction Function ---
80
+ def segment_image(input_image: Image.Image, model_choice: str) -> Image.Image:
81
+ """
82
+ Takes a PIL Image and model choice, performs segmentation, and returns the segmented image.
83
+ """
84
+ if input_image is None:
85
+ raise gr.Error("No image provided. Please upload an image.")
86
+
87
+ print(f"Model selected: {model_choice}")
88
+ # Ensure image is in RGB format
89
+ image_rgb = input_image.convert("RGB")
90
+ original_resolution = image_rgb.size # (width, height)
91
+ seg_array = None
92
+
93
+ try:
94
+ if model_choice == "SD":
95
+ pipe_sd = get_sd_pipeline()
96
+ if pipe_sd is None:
97
+ raise gr.Error("The SD segmentation pipeline could not be loaded. "
98
+ "Please check the Space logs for more details, or try again later.")
99
+
100
+ print(f"Running SD inference with image size: {image_rgb.size}")
101
+ start_time = time.time()
102
+ with torch.no_grad():
103
+ # The gen2segSDPipeline expects a single image or a list
104
+ # The pipeline's __call__ method handles preprocessing internally
105
+ seg_output = pipe_sd(image_rgb, match_input_resolution=False).prediction # Output is before resize
106
+
107
+ # seg_output is expected to be a numpy array (N,H,W,1) or (N,1,H,W) or tensor
108
+ # Based on gen2seg_sd_pipeline.py, if output_type="np" (default), it's [N,H,W,1]
109
+ # If output_type="pt", it's [N,1,H,W]
110
+ # The original app_sd.py converted tensor to numpy and squeezed.
111
+ if isinstance(seg_output, torch.Tensor):
112
+ seg_output = seg_output.cpu().numpy()
113
+
114
+ if seg_output.ndim == 4 and seg_output.shape[0] == 1: # Batch size 1
115
+ if seg_output.shape[1] == 1: # Grayscale, (1, 1, H, W)
116
+ seg_array = seg_output.squeeze(0).squeeze(0).astype(np.uint8)
117
+ elif seg_output.shape[-1] == 1: # Grayscale, (1, H, W, 1)
118
+ seg_array = seg_output.squeeze(0).squeeze(-1).astype(np.uint8)
119
+ elif seg_output.shape[1] == 3: # RGB, (1, 3, H, W) -> (H, W, 3)
120
+ seg_array = np.transpose(seg_output.squeeze(0), (1, 2, 0)).astype(np.uint8)
121
+ elif seg_output.shape[-1] == 3: # RGB, (1, H, W, 3)
122
+ seg_array = seg_output.squeeze(0).astype(np.uint8)
123
+ else: # Fallback for unexpected shapes
124
+ seg_array = seg_output.squeeze().astype(np.uint8)
125
+
126
+ elif seg_output.ndim == 3: # (H, W, C) or (C, H, W)
127
+ seg_array = seg_output.astype(np.uint8)
128
+ elif seg_output.ndim == 2: # (H,W)
129
+ seg_array = seg_output.astype(np.uint8)
130
+ else:
131
+ raise TypeError(f"Unexpected SD segmentation output type/shape: {type(seg_output)}, {seg_output.shape}")
132
+ end_time = time.time()
133
+ print(f"SD Inference completed in {end_time - start_time:.2f} seconds.")
134
+
135
+
136
+ elif model_choice == "MAE-H":
137
+ pipe_mae = get_mae_pipeline()
138
+ if pipe_mae is None:
139
+ raise gr.Error("The MAE-H segmentation pipeline could not be loaded. "
140
+ "Please check the Space logs for more details, or try again later.")
141
+
142
+ print(f"Running MAE-H inference with image size: {image_rgb.size}")
143
+ start_time = time.time()
144
+ with torch.no_grad():
145
+ # The gen2segMAEInstancePipeline expects a list of images
146
+ # output_type="np" returns a NumPy array
147
+ pipe_output = pipe_mae([image_rgb], output_type="np")
148
+ # Prediction is (batch_size, height, width, 3) for MAE
149
+ prediction_np = pipe_output.prediction[0] # Get the first (and only) image prediction
150
+
151
+ end_time = time.time()
152
+ print(f"MAE-H Inference completed in {end_time - start_time:.2f} seconds.")
153
+
154
+ if not isinstance(prediction_np, np.ndarray):
155
+ # This case should ideally not be reached if output_type="np"
156
+ prediction_np = prediction_np.cpu().numpy()
157
+
158
+ # Ensure it's in the expected (H, W, C) format and uint8
159
+ if prediction_np.ndim == 3 and prediction_np.shape[-1] == 3: # Expected (H, W, 3)
160
+ seg_array = prediction_np.astype(np.uint8)
161
+ else:
162
+ # Attempt to handle other shapes if necessary, or raise error
163
+ raise gr.Error(f"Unexpected MAE-H prediction shape: {prediction_np.shape}. Expected (H, W, 3).")
164
+
165
+ # The MAE pipeline already does gamma correction and scaling to 0-255.
166
+ # It also ensures 3 channels.
167
+
168
+ else:
169
+ raise gr.Error(f"Invalid model choice: {model_choice}. Please select a valid model.")
170
+
171
+ if seg_array is None:
172
+ raise gr.Error("Segmentation array was not generated. An unknown error occurred.")
173
+
174
+ print(f"Segmentation array generated with shape: {seg_array.shape}, dtype: {seg_array.dtype}")
175
+
176
+ # Convert numpy array to PIL Image
177
+ # Handle grayscale or RGB based on seg_array channels
178
+ if seg_array.ndim == 2: # Grayscale
179
+ segmented_image_pil = Image.fromarray(seg_array, mode='L')
180
+ elif seg_array.ndim == 3 and seg_array.shape[-1] == 3: # RGB
181
+ segmented_image_pil = Image.fromarray(seg_array, mode='RGB')
182
+ elif seg_array.ndim == 3 and seg_array.shape[-1] == 1: # Grayscale with channel dim
183
+ segmented_image_pil = Image.fromarray(seg_array.squeeze(-1), mode='L')
184
+ else:
185
+ raise gr.Error(f"Cannot convert seg_array with shape {seg_array.shape} to PIL Image.")
186
+
187
+ # Resize back to original image resolution using LANCZOS for high quality
188
+ segmented_image_pil = segmented_image_pil.resize(original_resolution, Image.Resampling.LANCZOS)
189
+
190
+ print(f"Segmented image processed. Output size: {segmented_image_pil.size}, mode: {segmented_image_pil.mode}")
191
+ return segmented_image_pil
192
+
193
+ except Exception as e:
194
+ print(f"Error during segmentation with {model_choice}: {e}")
195
+ # Re-raise as gr.Error for Gradio to display, if not already one
196
+ if not isinstance(e, gr.Error):
197
+ # It's often helpful to include the type of the original exception
198
+ error_type = type(e).__name__
199
+ raise gr.Error(f"An error occurred during segmentation: {error_type} - {str(e)}")
200
+ else:
201
+ raise e # Re-raise if it's already a gr.Error
202
+
203
+ # --- Gradio Interface ---
204
+ title = "gen2seg: Generative Models Enable Generalizable Instance Segmentation Demo (SD & MAE-H)"
205
+ description = f"""
206
+ <div style="text-align: center; font-family: 'Arial', sans-serif;">
207
+ <p>Upload an image and choose a model architecture to see the instance segmentation result generated by the respective model. </p>
208
+ <p>
209
+ Currently, inference is running on CPU.
210
+ Performance will be significantly better on GPU.
211
+ </p>
212
+ <ul>
213
+ <li><strong>SD</strong>: Based on Stable Diffusion 2.
214
+ <a href="https://huggingface.co/{MODEL_IDS['SD']}" target="_blank">Model Link</a>.
215
+ <em>Approx. CPU inference time: ~1-2 minutes per image.</em>
216
+ </li>
217
+ <li><strong>MAE-H</strong>: Based on Masked Autoencoder (Huge).
218
+ <a href="https://huggingface.co/{MODEL_IDS['MAE-H']}" target="_blank">Model Link</a>.
219
+ <em>Approx. CPU inference time: ~15-45 seconds per image.</em>
220
+ If you experience tokenizer artifacts or very dark images, you can use gamma correction to handle this.
221
+ </li>
222
+ </ul>
223
+ <p>
224
+ For faster inference, please check out our GitHub to run the models locally on a GPU:
225
+ <a href="https://github.com/UCDvision/gen2seg" target="_blank">https://github.com/UCDvision/gen2seg</a>
226
+ </p>
227
+ <p>If the demo experiences issues, please open an issue on our GitHub.</p>
228
+ <p> If you have not already, please see our webpage at <a href="https://reachomk.github.io/gen2seg" target="_blank">https://reachomk.github.io/gen2seg</a>
229
+
230
+ </div>
231
+ """
232
+
233
+ article = """
234
+ """
235
+
236
+ # Define Gradio inputs
237
+ input_image_component = gr.Image(type="pil", label="Input Image")
238
+ model_choice_component = gr.Dropdown(
239
+ choices=list(MODEL_IDS.keys()),
240
+ value="SD", # Default model
241
+ label="Choose Segmentation Model Architecture"
242
+ )
243
+
244
+ # Define Gradio output
245
+ output_image_component = gr.Image(type="pil", label="Segmented Image")
246
+
247
+ # Example images (ensure these paths are correct if you upload examples to your Space)
248
+ # For example, if you create an "examples" folder in your Space repo:
249
+ # example_paths = [
250
+ # os.path.join("examples", "example1.jpg"),
251
+ # os.path.join("examples", "example2.png")
252
+ # ]
253
+ # Filter out non-existent example files to prevent errors
254
+ # example_paths = [ex for ex in example_paths if os.path.exists(ex)]
255
+ example_paths = [] # Add paths to example images here if you have them
256
+
257
+ iface = gr.Interface(
258
+ fn=segment_image,
259
+ inputs=[input_image_component, model_choice_component],
260
+ outputs=output_image_component,
261
+ title=title,
262
+ description=description,
263
+ article=article,
264
+ examples=example_paths if example_paths else None, # Pass None if no examples
265
+ allow_flagging="never",
266
+ theme=gr.themes.Soft() # Using a soft theme for a slightly modern look
267
+ )
268
+
269
+ if __name__ == "__main__":
270
+ # Optional: Pre-load a default model on startup if desired.
271
+ # This can make the first inference faster but increases startup time.
272
+ # print("Attempting to pre-load default SD model on startup...")
273
+ # try:
274
+ # get_sd_pipeline() # Pre-load the default SD model
275
+ # print("Default SD model pre-loaded successfully or was already cached.")
276
+ # except Exception as e:
277
+ # print(f"Could not pre-load default SD model: {e}")
278
+
279
+ print("Launching Gradio interface...")
280
+ iface.launch()
gen2seg_mae_pipeline.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gen2seg official inference pipeline code for Stable Diffusion model
2
+ #
3
+ # Please see our project website at https://reachomk.github.io/gen2seg
4
+ #
5
+ # Additionally, if you use our code please cite our paper, along with the two works above.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Union, List, Optional
9
+
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image
13
+ from einops import rearrange
14
+
15
+ from diffusers import DiffusionPipeline
16
+ from diffusers.utils import BaseOutput, logging
17
+ from transformers import AutoImageProcessor
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+
22
+ @dataclass
23
+ class gen2segMAEInstanceOutput(BaseOutput):
24
+ """
25
+ Output class for the ViTMAE Instance Segmentation Pipeline.
26
+
27
+ Args:
28
+ prediction (`np.ndarray` or `torch.Tensor`):
29
+ Predicted instance segmentation maps. The output has shape
30
+ `(batch_size, 3, height, width)` with pixel values scaled to [0, 255].
31
+ """
32
+ prediction: Union[np.ndarray, torch.Tensor]
33
+
34
+
35
+ class gen2segMAEInstancePipeline(DiffusionPipeline):
36
+ r"""
37
+ Pipeline for Instance Segmentation using a fine-tuned ViTMAEForPreTraining model.
38
+
39
+ This pipeline takes one or more input images and returns an instance segmentation
40
+ prediction for each image. The model is assumed to have been fine-tuned using an instance
41
+ segmentation loss, and the reconstruction is performed by rearranging the model’s
42
+ patch logits into an image.
43
+
44
+ Args:
45
+ model (`ViTMAEForPreTraining`):
46
+ The fine-tuned ViTMAE model.
47
+ image_processor (`AutoImageProcessor`):
48
+ The image processor responsible for preprocessing input images.
49
+ """
50
+ def __init__(self, model, image_processor):
51
+ super().__init__()
52
+ self.register_modules(model=model, image_processor=image_processor)
53
+ self.model = model
54
+ self.image_processor = image_processor
55
+
56
+ def check_inputs(
57
+ self,
58
+ image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]]
59
+ ) -> List:
60
+ if not isinstance(image, list):
61
+ image = [image]
62
+ # Additional input validations can be added here if desired.
63
+ return image
64
+
65
+ @torch.no_grad()
66
+ def __call__(
67
+ self,
68
+ image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]],
69
+ output_type: str = "np",
70
+ **kwargs
71
+ ) -> gen2segMAEInstanceOutput:
72
+ r"""
73
+ The call method of the pipeline.
74
+
75
+ Args:
76
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, or a list of these):
77
+ The input image(s) for instance segmentation. For arrays/tensors, expected values are in [0, 1].
78
+ output_type (`str`, optional, defaults to `"np"`):
79
+ The format of the output prediction. Choose `"np"` for a NumPy array or `"pt"` for a PyTorch tensor.
80
+ **kwargs:
81
+ Additional keyword arguments passed to the image processor.
82
+
83
+ Returns:
84
+ [`gen2segMAEInstanceOutput`]:
85
+ An output object containing the predicted instance segmentation maps.
86
+ """
87
+ # 1. Check and prepare input images.
88
+ images = self.check_inputs(image)
89
+ inputs = self.image_processor(images=images, return_tensors="pt", **kwargs)
90
+ pixel_values = inputs["pixel_values"].to(self.device)
91
+
92
+ # 2. Forward pass through the model.
93
+ outputs = self.model(pixel_values=pixel_values)
94
+ logits = outputs.logits # Expected shape: (B, num_patches, patch_dim)
95
+
96
+ # 3. Retrieve patch size and image size from the model configuration.
97
+ patch_size = self.model.config.patch_size # e.g., 16
98
+ image_size = self.model.config.image_size # e.g., 224
99
+ grid_size = image_size // patch_size
100
+
101
+ # 4. Rearrange logits into the reconstructed image.
102
+ # The logits are reshaped from (B, num_patches, patch_dim) to (B, 3, H, W).
103
+ reconstructed = rearrange(
104
+ logits,
105
+ "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
106
+ h=grid_size,
107
+ p1=patch_size,
108
+ p2=patch_size,
109
+ c=3,
110
+ )
111
+
112
+ # 5. Post-process the reconstructed output.
113
+ # For each sample, shift and scale the prediction to [0, 255].
114
+ predictions = []
115
+ for i in range(reconstructed.shape[0]):
116
+ sample = reconstructed[i]
117
+ min_val = torch.abs(sample.min())
118
+ max_val = torch.abs(sample.max())
119
+ sample = (sample + min_val) / (max_val + min_val + 1e-5)
120
+ # sometimes the image is very dark so we perform gamma correction to "brighten" it
121
+ # in practice we can set this value to whatever we want or disable it entirely.
122
+ sample = sample**0.6
123
+ sample = sample * 255.0
124
+ predictions.append(sample)
125
+ prediction_tensor = torch.stack(predictions, dim=0).permute(0, 2, 3, 1)
126
+
127
+ # 6. Format the output.
128
+ if output_type == "np":
129
+ prediction = prediction_tensor.cpu().numpy()
130
+ else:
131
+ prediction = prediction_tensor
132
+ return gen2segMAEInstanceOutput(prediction=prediction)
gen2seg_sd_pipeline.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gen2seg official inference pipeline code for Stable Diffusion model
2
+ #
3
+ # This code was adapted from Marigold and Diffusion E2E Finetuning.
4
+ #
5
+ # Please see our project website at https://reachomk.github.io/gen2seg
6
+ #
7
+ # Additionally, if you use our code please cite our paper, along with the two works above.
8
+
9
+ from dataclasses import dataclass
10
+ from typing import List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ from PIL import Image
15
+ from tqdm.auto import tqdm
16
+ from transformers import CLIPTextModel, CLIPTokenizer
17
+
18
+ from diffusers.image_processor import PipelineImageInput
19
+ from diffusers.models import (
20
+ AutoencoderKL,
21
+ UNet2DConditionModel,
22
+ )
23
+ from diffusers.schedulers import (
24
+ DDIMScheduler,
25
+ )
26
+ from diffusers.utils import (
27
+ BaseOutput,
28
+ logging,
29
+ )
30
+ from diffusers import DiffusionPipeline
31
+ from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
32
+
33
+ # add
34
+ def zeros_tensor(
35
+ shape: Union[Tuple, List],
36
+ device: Optional["torch.device"] = None,
37
+ dtype: Optional["torch.dtype"] = None,
38
+ layout: Optional["torch.layout"] = None,
39
+ ):
40
+ """
41
+ A helper function to create tensors of zeros on the desired `device`.
42
+ Mirrors randn_tensor from diffusers.utils.torch_utils.
43
+ """
44
+ layout = layout or torch.strided
45
+ device = device or torch.device("cpu")
46
+ latents = torch.zeros(list(shape), dtype=dtype, layout=layout).to(device)
47
+ return latents
48
+
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+ @dataclass
53
+ class Gen2SegSDSegOutput(BaseOutput):
54
+ """
55
+ Output class for gen2seg Instance Segmentation prediction pipeline.
56
+
57
+ Args:
58
+ prediction (`np.ndarray`, `torch.Tensor`):
59
+ Predicted instance segmentation with values in the range [0, 255]. The shape is always $numimages \times 1 \times height
60
+ \times width$, regardless of whether the images were passed as a 4D array or a list.
61
+ latent (`None`, `torch.Tensor`):
62
+ Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
63
+ The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
64
+ """
65
+
66
+ prediction: Union[np.ndarray, torch.Tensor]
67
+ latent: Union[None, torch.Tensor]
68
+
69
+
70
+ class Gen2SegSDPipeline(DiffusionPipeline):
71
+ """
72
+ # add
73
+ Pipeline for Instance Segmentation prediction using our Stable Diffusion model.
74
+ Implementation is built upon Marigold: https://marigoldmonodepth.github.io and E2E FThttps://gonzalomartingarcia.github.io/diffusion-e2e-ft/
75
+
76
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
77
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
78
+
79
+ Args:
80
+ unet (`UNet2DConditionModel`):
81
+ Conditional U-Net to denoise the segmentation latent, synthesized from image latent.
82
+ vae (`AutoencoderKL`):
83
+ Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent
84
+ representations.
85
+ scheduler (`DDIMScheduler`):
86
+ A scheduler to be used in combination with `unet` to denoise the encoded image latent.
87
+ text_encoder (`CLIPTextModel`):
88
+ Text-encoder, for empty text embedding.
89
+ tokenizer (`CLIPTokenizer`):
90
+ CLIP tokenizer.
91
+ default_processing_resolution (`int`, *optional*):
92
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
93
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
94
+ default value is used. This is required to ensure reasonable results with various model flavors trained
95
+ with varying optimal processing resolution values.
96
+ """
97
+
98
+ model_cpu_offload_seq = "text_encoder->unet->vae"
99
+
100
+ def __init__(
101
+ self,
102
+ unet: UNet2DConditionModel,
103
+ vae: AutoencoderKL,
104
+ scheduler: Union[DDIMScheduler],
105
+ text_encoder: CLIPTextModel,
106
+ tokenizer: CLIPTokenizer,
107
+ default_processing_resolution: Optional[int] = 768, # add
108
+ ):
109
+ super().__init__()
110
+
111
+ self.register_modules(
112
+ unet=unet,
113
+ vae=vae,
114
+ scheduler=scheduler,
115
+ text_encoder=text_encoder,
116
+ tokenizer=tokenizer,
117
+ )
118
+ self.register_to_config(
119
+ default_processing_resolution=default_processing_resolution,
120
+ )
121
+
122
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
123
+ self.default_processing_resolution = default_processing_resolution
124
+ self.empty_text_embedding = None
125
+
126
+ self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
127
+
128
+ def check_inputs(
129
+ self,
130
+ image: PipelineImageInput,
131
+ processing_resolution: int,
132
+ resample_method_input: str,
133
+ resample_method_output: str,
134
+ batch_size: int,
135
+ output_type: str,
136
+ ) -> int:
137
+ if processing_resolution is None:
138
+ raise ValueError(
139
+ "`processing_resolution` is not specified and could not be resolved from the model config."
140
+ )
141
+ if processing_resolution < 0:
142
+ raise ValueError(
143
+ "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
144
+ "downsampled processing."
145
+ )
146
+ if processing_resolution % self.vae_scale_factor != 0:
147
+ raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
148
+ if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
149
+ raise ValueError(
150
+ "`resample_method_input` takes string values compatible with PIL library: "
151
+ "nearest, nearest-exact, bilinear, bicubic, area."
152
+ )
153
+ if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
154
+ raise ValueError(
155
+ "`resample_method_output` takes string values compatible with PIL library: "
156
+ "nearest, nearest-exact, bilinear, bicubic, area."
157
+ )
158
+ if batch_size < 1:
159
+ raise ValueError("`batch_size` must be positive.")
160
+ if output_type not in ["pt", "np"]:
161
+ raise ValueError("`output_type` must be one of `pt` or `np`.")
162
+
163
+ # image checks
164
+ num_images = 0
165
+ W, H = None, None
166
+ if not isinstance(image, list):
167
+ image = [image]
168
+ for i, img in enumerate(image):
169
+ if isinstance(img, np.ndarray) or torch.is_tensor(img):
170
+ if img.ndim not in (2, 3, 4):
171
+ raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
172
+ H_i, W_i = img.shape[-2:]
173
+ N_i = 1
174
+ if img.ndim == 4:
175
+ N_i = img.shape[0]
176
+ elif isinstance(img, Image.Image):
177
+ W_i, H_i = img.size
178
+ N_i = 1
179
+ else:
180
+ raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
181
+ if W is None:
182
+ W, H = W_i, H_i
183
+ elif (W, H) != (W_i, H_i):
184
+ raise ValueError(
185
+ f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
186
+ )
187
+ num_images += N_i
188
+
189
+ return num_images
190
+
191
+ def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
192
+ if not hasattr(self, "_progress_bar_config"):
193
+ self._progress_bar_config = {}
194
+ elif not isinstance(self._progress_bar_config, dict):
195
+ raise ValueError(
196
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
197
+ )
198
+
199
+ progress_bar_config = dict(**self._progress_bar_config)
200
+ progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
201
+ progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
202
+ if iterable is not None:
203
+ return tqdm(iterable, **progress_bar_config)
204
+ elif total is not None:
205
+ return tqdm(total=total, **progress_bar_config)
206
+ else:
207
+ raise ValueError("Either `total` or `iterable` has to be defined.")
208
+
209
+ @torch.no_grad()
210
+ def __call__(
211
+ self,
212
+ image: PipelineImageInput,
213
+ processing_resolution: Optional[int] = None,
214
+ match_input_resolution: bool = False,
215
+ resample_method_input: str = "bilinear",
216
+ resample_method_output: str = "bilinear",
217
+ batch_size: int = 1,
218
+ output_type: str = "np",
219
+ output_latent: bool = False,
220
+ return_dict: bool = True,
221
+ ):
222
+ """
223
+ Function invoked when calling the pipeline.
224
+
225
+ Args:
226
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
227
+ `List[torch.Tensor]`: An input image or images used as an input for the instance segmentation task. For
228
+ arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
229
+ by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
230
+ three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
231
+ same width and height.
232
+ processing_resolution (`int`, *optional*, defaults to `None`):
233
+ Effective processing resolution. When set to `0`, matches the larger input image dimension. This
234
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
235
+ value `None` resolves to the optimal value from the model config.
236
+ match_input_resolution (`bool`, *optional*, defaults to `True`):
237
+ When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
238
+ side of the output will equal to `processing_resolution`.
239
+ resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
240
+ Resampling method used to resize input images to `processing_resolution`. The accepted values are:
241
+ `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
242
+ resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
243
+ Resampling method used to resize output predictions to match the input resolution. The accepted values
244
+ are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
245
+ batch_size (`int`, *optional*, defaults to `1`):
246
+ Batch size; only matters passing a tensor of images.
247
+ output_type (`str`, *optional*, defaults to `"np"`):
248
+ Preferred format of the output's `prediction`. The accepted ßvalues are: `"np"` (numpy array) or `"pt"` (torch tensor).
249
+ output_latent (`bool`, *optional*, defaults to `False`):
250
+ When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
251
+ within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
252
+ `latents` argument.
253
+ return_dict (`bool`, *optional*, defaults to `True`):
254
+ Whether or not to return a [`Gen2SegSDSegOutput`] instead of a plain tuple.
255
+
256
+ # add
257
+ E2E FT models are deterministic single step models involving no ensembling, i.e. E=1.
258
+ """
259
+
260
+ # 0. Resolving variables.
261
+ device = self._execution_device
262
+ dtype = self.dtype
263
+
264
+ # Model-specific optimal default values leading to fast and reasonable results.
265
+ if processing_resolution is None:
266
+ processing_resolution = self.default_processing_resolution
267
+
268
+ #print(image[0].size)
269
+ #processing_resolution = 8 * round(max(image[0].size) / 8)
270
+
271
+ # 1. Check inputs.
272
+ num_images = self.check_inputs(
273
+ image,
274
+ processing_resolution,
275
+ resample_method_input,
276
+ resample_method_output,
277
+ batch_size,
278
+ output_type,
279
+ )
280
+
281
+ # 2. Prepare empty text conditioning.
282
+ # Model invocation: self.tokenizer, self.text_encoder.
283
+ prompt = ""
284
+ text_inputs = self.tokenizer(
285
+ prompt,
286
+ padding="do_not_pad",
287
+ max_length=self.tokenizer.model_max_length,
288
+ truncation=True,
289
+ return_tensors="pt",
290
+ )
291
+ text_input_ids = text_inputs.input_ids.to(device)
292
+ self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
293
+
294
+ # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
295
+ # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
296
+ # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
297
+ # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
298
+ # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
299
+ # operation and leads to the most reasonable results. Using the native image resolution or any other processing
300
+ # resolution can lead to loss of either fine details or global context in the output predictions.
301
+ image, padding, original_resolution = self.image_processor.preprocess(
302
+ image, processing_resolution, resample_method_input, device, dtype
303
+ ) # [N,3,PPH,PPW]
304
+ # image =(image+torch.abs(image.min()))
305
+ # image = image/(torch.abs(image.max())+torch.abs(image.min()))
306
+ # # prediction = prediction**0.5
307
+ # #prediction = torch.clip(prediction, min=-1, max=1)+1
308
+ # image = (image) * 2
309
+ # image = image - 1
310
+ # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
311
+ # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
312
+ # Latents of each such predictions across all input images and all ensemble members are represented in the
313
+ # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
314
+ # into latent space and replicated `E` times. Encoding into latent space happens in batches of size `batch_size`.
315
+ # Model invocation: self.vae.encoder.
316
+ image_latent, pred_latent = self.prepare_latents(
317
+ image, batch_size
318
+ ) # [N*E,4,h,w], [N*E,4,h,w]
319
+
320
+ del image
321
+
322
+ batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat(
323
+ batch_size, 1, 1
324
+ ) # [B,1024,2]
325
+
326
+ # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`.
327
+ # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and
328
+ # outputs noise for the predicted modality's latent space.
329
+ # Model invocation: self.unet.
330
+ pred_latents = []
331
+
332
+ for i in range(0, num_images, batch_size):
333
+ batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w]
334
+ batch_pred_latent = batch_image_latent[i : i + batch_size] # [B,4,h,w]
335
+ effective_batch_size = batch_image_latent.shape[0]
336
+ text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024]
337
+
338
+ # add
339
+ # Single step inference for E2E FT models
340
+ self.scheduler.set_timesteps(1, device=device)
341
+ for t in self.scheduler.timesteps:
342
+ batch_latent = batch_image_latent # torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w]
343
+ noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w]
344
+ batch_pred_latent = self.scheduler.step(
345
+ noise, t, batch_image_latent
346
+ ).pred_original_sample # [B,4,h,w], # add
347
+ # directly take pred_original_sample rather than prev_sample
348
+
349
+ pred_latents.append(batch_pred_latent)
350
+
351
+ pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
352
+
353
+ del (
354
+ pred_latents,
355
+ image_latent,
356
+ batch_empty_text_embedding,
357
+ batch_image_latent,
358
+ # batch_pred_latent,
359
+ text,
360
+ batch_latent,
361
+ noise,
362
+ )
363
+
364
+ # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`,
365
+ # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`.
366
+ # Model invocation: self.vae.decoder.
367
+ prediction = torch.cat(
368
+ [
369
+ self.decode_prediction(pred_latent[i : i + batch_size])
370
+ for i in range(0, pred_latent.shape[0], batch_size)
371
+ ],
372
+ dim=0,
373
+ ) # [N*E,1,PPH,PPW]
374
+
375
+ if not output_latent:
376
+ pred_latent = None
377
+
378
+ # 7. Remove padding. The output shape is (PH, PW).
379
+ prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,1,PH,PW]
380
+
381
+ # 9. If `match_input_resolution` is set, the output prediction are upsampled to match the
382
+ # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled.
383
+ # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by
384
+ # setting the `resample_method_output` parameter (e.g., to `"nearest"`).
385
+ if match_input_resolution:
386
+ prediction = self.image_processor.resize_antialias(
387
+ prediction, original_resolution, resample_method_output, is_aa=False
388
+ ) # [N,1,H,W]
389
+
390
+ # 10. Prepare the final outputs.
391
+ if output_type == "np":
392
+ prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,1]
393
+
394
+ # 11. Offload all models
395
+ self.maybe_free_model_hooks()
396
+
397
+ if not return_dict:
398
+ return (prediction, pred_latent)
399
+
400
+ return Gen2SegSDSegOutput(
401
+ prediction=prediction,
402
+ latent=pred_latent,
403
+ )
404
+
405
+ def prepare_latents(
406
+ self,
407
+ image: torch.Tensor,
408
+ batch_size: int,
409
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
410
+ def retrieve_latents(encoder_output):
411
+ if hasattr(encoder_output, "latent_dist"):
412
+ return encoder_output.latent_dist.mode()
413
+ elif hasattr(encoder_output, "latents"):
414
+ return encoder_output.latents
415
+ else:
416
+ raise AttributeError("Could not access latents of provided encoder_output")
417
+
418
+ image_latent = torch.cat(
419
+ [
420
+ retrieve_latents(self.vae.encode(image[i : i + batch_size]))
421
+ for i in range(0, image.shape[0], batch_size)
422
+ ],
423
+ dim=0,
424
+ ) # [N,4,h,w]
425
+ image_latent = image_latent * self.vae.config.scaling_factor # [N*E,4,h,w]
426
+
427
+ # add
428
+ # provide zeros as noised latent
429
+ pred_latent = zeros_tensor(
430
+ image_latent.shape,
431
+ device=image_latent.device,
432
+ dtype=image_latent.dtype,
433
+ ) # [N*E,4,h,w]
434
+
435
+ return image_latent, pred_latent
436
+
437
+ def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
438
+ if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
439
+ raise ValueError(
440
+ f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
441
+ )
442
+
443
+ prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
444
+ #print(prediction.max())
445
+ #print(prediction.min())
446
+
447
+ prediction =(prediction+torch.abs(prediction.min()))
448
+ prediction = prediction/(torch.abs(prediction.max())+torch.abs(prediction.min()))
449
+ #prediction = prediction**0.5
450
+ #prediction = torch.clip(prediction, min=-1, max=1)+1
451
+ prediction = (prediction) * 255.0
452
+ #print(prediction.max())
453
+ #print(prediction.min())
454
+ return prediction # [B,1,H,W]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ numpy
6
+ diffusers
7
+ transformers
8
+ einops
9
+ tqdm
10
+ safetensors