skallewag commited on
Commit
62cc23b
·
verified ·
1 Parent(s): d843b3f

Upload 25 files

Browse files
README.md CHANGED
@@ -1,12 +1,57 @@
1
  ---
2
- title: SegMatch
3
- emoji: 📈
4
  colorFrom: gray
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.33.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: LaDeco
3
+ emoji: 👀
4
  colorFrom: gray
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.31.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: 'LaDeco: A tool to analyze visual landscape elements'
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ # LaDeco - Landscape Environment Semantic Analysis Model
16
+
17
+ LaDeco is a tool that analyzes landscape images, performs semantic segmentation to identify different elements in the scene (sky, vegetation, buildings, etc.), and enables region-based color matching between images.
18
+
19
+ ## Features
20
+
21
+ ### Semantic Segmentation
22
+ - Analyzes landscape images and segments them into different semantic regions
23
+ - Provides area ratio analysis for each landscape element
24
+
25
+ ### Region-Based Color Matching
26
+ - Matches colors between corresponding semantic regions of two images
27
+ - Shows visualization of which regions are being matched between images
28
+ - Offers multiple color matching algorithms:
29
+ - **adain**: Adaptive Instance Normalization - Matches mean and standard deviation of colors
30
+ - **mkl**: Monge-Kantorovich Linearization - Linear transformation of color statistics
31
+ - **reinhard**: Reinhard color transfer - Simple statistical approach that matches mean and standard deviation
32
+ - **mvgd**: Multi-Variate Gaussian Distribution - Uses color covariance matrices for more accurate matching
33
+ - **hm**: Histogram Matching - Matches the full color distribution histograms
34
+ - **hm-mvgd-hm**: Histogram + MVGD + Histogram compound method
35
+ - **hm-mkl-hm**: Histogram + MKL + Histogram compound method
36
+
37
+ ## Installation
38
+
39
+ 1. Clone this repository
40
+ 2. Create a virtual environment: `python3 -m venv .venv`
41
+ 3. Activate the virtual environment: `source .venv/bin/activate`
42
+ 4. Install requirements: `pip install -r requirements.txt`
43
+ 5. Run the application: `python app.py`
44
+
45
+ ## Usage
46
+
47
+ 1. Upload two landscape images - the first will be the color reference, the second will be color-matched to the first
48
+ 2. Choose a color matching method from the dropdown menu
49
+ 3. Click "Start Analysis" to process the images
50
+ 4. View the results in the Segmentation and Color Matching tabs
51
+ - Segmentation tab shows the semantic segmentation and area ratios for both images
52
+ - Color Matching tab shows the matched regions visualization and the color matching result
53
+
54
+ ## Reference
55
+
56
+ Li-Chih Ho (2023), LaDeco: A Tool to Analyze Visual Landscape Elements, Ecological Informatics, vol. 78.
57
+ https://www.sciencedirect.com/science/article/pii/S1574954123003187
app.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core import Ladeco
3
+ from matplotlib.figure import Figure
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib as mpl
6
+ import spaces
7
+ from PIL import Image
8
+ import numpy as np
9
+ from color_matching import RegionColorMatcher, create_comparison_figure
10
+ from face_comparison import FaceComparison
11
+ from cdl_smoothing import cdl_edge_smoothing, get_smoothing_stats, cdl_edge_smoothing_apply_to_source
12
+ import tempfile
13
+ import os
14
+ import cv2
15
+
16
+
17
+ plt.rcParams['figure.facecolor'] = '#0b0f19'
18
+ plt.rcParams['text.color'] = '#aab6cc'
19
+ ladeco = Ladeco()
20
+
21
+
22
+ @spaces.GPU
23
+ def infer_two_images(img1: str, img2: str, method: str, enable_face_matching: bool, enable_edge_smoothing: bool) -> tuple[Figure, Figure, Figure, Figure, Figure, Figure, str, str, str]:
24
+ """
25
+ Clean 4-step approach:
26
+ 1. Segment both images identically
27
+ 2. Determine segment correspondences
28
+ 3. Match each segment pair in isolation
29
+ 4. Composite all matched segments
30
+ """
31
+
32
+ cdl_display = "" # Initialize CDL display string
33
+
34
+ # STEP 1: SEGMENT BOTH IMAGES IDENTICALLY
35
+ # This step is always identical regardless of face matching
36
+ print("Step 1: Segmenting both images...")
37
+ out1 = ladeco.predict(img1)
38
+ out2 = ladeco.predict(img2)
39
+
40
+ # Extract visualization and stats (unchanged)
41
+ seg1 = out1.visualize(level=2)[0].image
42
+ colormap1 = out1.color_map(level=2)
43
+ area1 = out1.area()[0]
44
+
45
+ seg2 = out2.visualize(level=2)[0].image
46
+ colormap2 = out2.color_map(level=2)
47
+ area2 = out2.area()[0]
48
+
49
+ # Process areas for pie charts
50
+ colors1, l2_area1 = [], {}
51
+ for labelname, area_ratio in area1.items():
52
+ if labelname.startswith("l2") and area_ratio > 0:
53
+ colors1.append(colormap1[labelname])
54
+ labelname = labelname.replace("l2_", "").capitalize()
55
+ l2_area1[labelname] = area_ratio
56
+
57
+ colors2, l2_area2 = [], {}
58
+ for labelname, area_ratio in area2.items():
59
+ if labelname.startswith("l2") and area_ratio > 0:
60
+ colors2.append(colormap2[labelname])
61
+ labelname = labelname.replace("l2_", "").capitalize()
62
+ l2_area2[labelname] = area_ratio
63
+
64
+ pie1 = plot_pie(l2_area1, colors=colors1)
65
+ pie2 = plot_pie(l2_area2, colors=colors2)
66
+
67
+ # Set plot sizes
68
+ for fig in [seg1, seg2, pie1, pie2]:
69
+ fig.set_dpi(96)
70
+ fig.set_size_inches(256/96, 256/96)
71
+
72
+ # Extract semantic masks - IDENTICAL for both images regardless of face matching
73
+ masks1 = extract_semantic_masks(out1)
74
+ masks2 = extract_semantic_masks(out2)
75
+
76
+ print(f"Extracted {len(masks1)} masks from img1, {len(masks2)} masks from img2")
77
+
78
+ # STEP 2: DETERMINE SEGMENT CORRESPONDENCES
79
+ print("Step 2: Determining segment correspondences...")
80
+ face_log = ["Step 2: Determining segment correspondences"]
81
+
82
+ # Find common segments between both images
83
+ common_segments = set(masks1.keys()).intersection(set(masks2.keys()))
84
+ face_log.append(f"Found {len(common_segments)} common segments: {sorted(common_segments)}")
85
+
86
+ # Determine which segments to match based on face matching logic
87
+ segments_to_match = determine_segments_to_match(img1, img2, common_segments, enable_face_matching, face_log)
88
+
89
+ face_log.append(f"Final segments to match: {sorted(segments_to_match)}")
90
+
91
+ # STEP 3: MATCH EACH SEGMENT PAIR IN ISOLATION
92
+ print("Step 3: Matching each segment pair in isolation...")
93
+ face_log.append("\nStep 3: Color matching each segment independently")
94
+
95
+ matched_regions = {}
96
+ segment_masks = {} # Store masks for all segments being matched
97
+
98
+ for segment_name in segments_to_match:
99
+ if segment_name in masks1 and segment_name in masks2:
100
+ face_log.append(f" Processing {segment_name}...")
101
+
102
+ # Match this segment in complete isolation
103
+ matched_region, final_mask1, final_mask2 = match_single_segment(
104
+ img1, img2,
105
+ masks1[segment_name], masks2[segment_name],
106
+ segment_name, method, face_log
107
+ )
108
+
109
+ if matched_region is not None:
110
+ matched_regions[segment_name] = matched_region
111
+ segment_masks[segment_name] = final_mask2 # Use mask from target image for compositing
112
+ face_log.append(f" ✅ {segment_name} matched successfully")
113
+ else:
114
+ face_log.append(f" ❌ {segment_name} matching failed")
115
+ elif segment_name.startswith('l4_'):
116
+ # Handle fine-grained segments that need to be generated
117
+ face_log.append(f" Processing fine-grained {segment_name}...")
118
+
119
+ matched_region, final_mask1, final_mask2 = match_single_segment(
120
+ img1, img2, None, None, segment_name, method, face_log
121
+ )
122
+
123
+ if matched_region is not None:
124
+ matched_regions[segment_name] = matched_region
125
+ segment_masks[segment_name] = final_mask2 # Store the generated mask
126
+ face_log.append(f" ✅ {segment_name} matched successfully")
127
+ else:
128
+ face_log.append(f" ❌ {segment_name} matching failed")
129
+
130
+ # STEP 4: COMPOSITE ALL MATCHED SEGMENTS
131
+ print("Step 4: Compositing all matched segments...")
132
+ face_log.append(f"\nStep 4: Compositing {len(matched_regions)} matched segments")
133
+
134
+ final_image = composite_matched_segments(img2, matched_regions, segment_masks, face_log)
135
+
136
+ # STEP 5: OPTIONAL CDL-BASED EDGE SMOOTHING
137
+ if enable_edge_smoothing:
138
+ print("Step 5: Applying CDL-based edge smoothing...")
139
+ face_log.append("\nStep 5: CDL edge smoothing - applying CDL transform to image 2 based on composited result")
140
+
141
+ try:
142
+ # Save the composited result temporarily for CDL calculation
143
+ temp_dir = tempfile.gettempdir()
144
+ temp_composite_path = os.path.join(temp_dir, "temp_composite_for_cdl.png")
145
+ final_image.save(temp_composite_path, "PNG")
146
+
147
+ # Calculate CDL parameters to transform image 2 → composited result
148
+ cdl_stats = get_smoothing_stats(img2, temp_composite_path)
149
+
150
+ # Log the CDL values
151
+ slope = cdl_stats['cdl_slope']
152
+ offset = cdl_stats['cdl_offset']
153
+ power = cdl_stats['cdl_power']
154
+
155
+ # Format CDL values for display
156
+ cdl_display = f"""📊 CDL Parameters (Image 2 → Composited Result):
157
+
158
+ 🔧 Method: Simple Mean/Std Matching (basic statistical approach)
159
+
160
+ 🔸 Slope (Gain):
161
+ Red: {slope[0]:.6f}
162
+ Green: {slope[1]:.6f}
163
+ Blue: {slope[2]:.6f}
164
+
165
+ 🔸 Offset:
166
+ Red: {offset[0]:.6f}
167
+ Green: {offset[1]:.6f}
168
+ Blue: {offset[2]:.6f}
169
+
170
+ 🔸 Power (Gamma):
171
+ Red: {power[0]:.6f}
172
+ Green: {power[1]:.6f}
173
+ Blue: {power[2]:.6f}
174
+
175
+ These CDL values represent the color transformation needed to convert Image 2 into the composited result.
176
+
177
+ The CDL calculation uses the simplest possible approach: matches the mean and standard deviation
178
+ of each color channel between the original and composited images, with simple gamma calculation
179
+ based on brightness relationships.
180
+ """
181
+
182
+ face_log.append(f"📊 CDL Parameters (image 2 → composited result):")
183
+ face_log.append(f" Method: Simple mean/std matching")
184
+ face_log.append(f" Slope (R,G,B): [{slope[0]:.4f}, {slope[1]:.4f}, {slope[2]:.4f}]")
185
+ face_log.append(f" Offset (R,G,B): [{offset[0]:.4f}, {offset[1]:.4f}, {offset[2]:.4f}]")
186
+ face_log.append(f" Power (R,G,B): [{power[0]:.4f}, {power[1]:.4f}, {power[2]:.4f}]")
187
+
188
+ # Apply CDL transformation to image 2 to approximate the composited result
189
+ final_image = cdl_edge_smoothing_apply_to_source(img2, temp_composite_path, factor=1.0)
190
+
191
+ # Clean up temp file
192
+ if os.path.exists(temp_composite_path):
193
+ os.remove(temp_composite_path)
194
+
195
+ face_log.append("✅ CDL edge smoothing completed - transformed image 2 using calculated CDL parameters")
196
+
197
+ except Exception as e:
198
+ face_log.append(f"❌ CDL edge smoothing failed: {e}")
199
+ cdl_display = f"❌ CDL calculation failed: {e}"
200
+ else:
201
+ face_log.append("\nStep 5: CDL edge smoothing disabled")
202
+ cdl_display = "CDL edge smoothing is disabled. Enable it to see CDL parameters."
203
+
204
+ # Save result
205
+ temp_dir = tempfile.gettempdir()
206
+ filename = os.path.basename(img2).split('.')[0]
207
+ temp_filename = f"color_matched_{method}_{filename}.png"
208
+ temp_path = os.path.join(temp_dir, temp_filename)
209
+ final_image.save(temp_path, "PNG")
210
+
211
+ # Create visualizations
212
+ # For visualization, we need to collect the masks that were actually used
213
+ vis_masks1 = {}
214
+ vis_masks2 = {}
215
+
216
+ for segment_name in segments_to_match:
217
+ if segment_name in segment_masks:
218
+ if segment_name.startswith('l4_'):
219
+ # Fine-grained segments - we'll regenerate for visualization
220
+ part_name = segment_name.replace('l4_', '')
221
+ if part_name in ['face', 'hair']:
222
+ from human_parts_segmentation import HumanPartsSegmentation
223
+ segmenter = HumanPartsSegmentation()
224
+ masks_dict1 = segmenter.segment_parts(img1, [part_name])
225
+ masks_dict2 = segmenter.segment_parts(img2, [part_name])
226
+ if part_name in masks_dict1 and part_name in masks_dict2:
227
+ vis_masks1[segment_name] = masks_dict1[part_name]
228
+ vis_masks2[segment_name] = masks_dict2[part_name]
229
+ elif part_name == 'upper_clothes':
230
+ from clothes_segmentation import ClothesSegmentation
231
+ segmenter = ClothesSegmentation()
232
+ mask1 = segmenter.segment_clothes(img1, ["Upper-clothes"])
233
+ mask2 = segmenter.segment_clothes(img2, ["Upper-clothes"])
234
+ if mask1 is not None and mask2 is not None:
235
+ vis_masks1[segment_name] = mask1
236
+ vis_masks2[segment_name] = mask2
237
+ else:
238
+ # Regular segments - use original masks
239
+ if segment_name in masks1 and segment_name in masks2:
240
+ vis_masks1[segment_name] = masks1[segment_name]
241
+ vis_masks2[segment_name] = masks2[segment_name]
242
+
243
+ mask_vis = visualize_matching_masks(img1, img2, vis_masks1, vis_masks2)
244
+
245
+ comparison = create_comparison_figure(Image.open(img2), final_image, f"Color Matching Result ({method})")
246
+
247
+ face_log_text = "\n".join(face_log)
248
+
249
+ return seg1, pie1, seg2, pie2, comparison, mask_vis, temp_path, face_log_text, cdl_display
250
+
251
+
252
+ def determine_segments_to_match(img1: str, img2: str, common_segments: set, enable_face_matching: bool, log: list) -> set:
253
+ """
254
+ Determine which segments should be matched based on face matching logic.
255
+ Returns the set of segment names to process.
256
+ """
257
+ if not enable_face_matching:
258
+ log.append("Face matching disabled - matching all common segments")
259
+ return common_segments
260
+
261
+ log.append("Face matching enabled - checking faces...")
262
+
263
+ # Run face comparison
264
+ face_comparator = FaceComparison()
265
+ faces_match, face_log = face_comparator.run_face_comparison(img1, img2)
266
+ log.extend(face_log)
267
+
268
+ if not faces_match:
269
+ # Remove human/bio segments from matching
270
+ log.append("No face match - excluding human/bio segments")
271
+ non_human_segments = set()
272
+ for segment in common_segments:
273
+ if not any(term in segment.lower() for term in ['l3_human', 'l2_bio']):
274
+ non_human_segments.add(segment)
275
+ else:
276
+ log.append(f" Excluding human segment: {segment}")
277
+
278
+ log.append(f"Matching {len(non_human_segments)} non-human segments")
279
+ return non_human_segments
280
+
281
+ else:
282
+ # Faces match - include all segments + add fine-grained if possible
283
+ log.append("Faces match - including all segments + fine-grained")
284
+
285
+ segments_to_match = common_segments.copy()
286
+
287
+ # Add fine-grained human parts if bio regions exist
288
+ bio_segments = [s for s in common_segments if 'l2_bio' in s.lower()]
289
+ if bio_segments:
290
+ fine_grained_segments = add_fine_grained_segments(img1, img2, common_segments, log)
291
+ segments_to_match.update(fine_grained_segments)
292
+
293
+ return segments_to_match
294
+
295
+
296
+ def add_fine_grained_segments(img1: str, img2: str, common_segments: set, log: list) -> set:
297
+ """
298
+ Add fine-grained human parts segments when faces match.
299
+ Returns set of fine-grained segment names that were successfully added.
300
+ """
301
+ fine_grained_segments = set()
302
+
303
+ try:
304
+ from human_parts_segmentation import HumanPartsSegmentation
305
+ from clothes_segmentation import ClothesSegmentation
306
+
307
+ log.append(" Adding fine-grained human parts...")
308
+
309
+ # Get face and hair masks
310
+ human_segmenter = HumanPartsSegmentation()
311
+ face_hair_masks1 = human_segmenter.segment_parts(img1, ['face', 'hair'])
312
+ face_hair_masks2 = human_segmenter.segment_parts(img2, ['face', 'hair'])
313
+
314
+ # Get clothes masks
315
+ clothes_segmenter = ClothesSegmentation()
316
+ clothes_mask1 = clothes_segmenter.segment_clothes(img1, ["Upper-clothes"])
317
+ clothes_mask2 = clothes_segmenter.segment_clothes(img2, ["Upper-clothes"])
318
+
319
+ # Process face/hair
320
+ for part_name, mask1 in face_hair_masks1.items():
321
+ if (mask1 is not None and part_name in face_hair_masks2 and
322
+ face_hair_masks2[part_name] is not None):
323
+
324
+ if np.sum(mask1 > 0) > 0 and np.sum(face_hair_masks2[part_name] > 0) > 0:
325
+ fine_grained_segments.add(f'l4_{part_name}')
326
+ log.append(f" Added fine-grained: {part_name}")
327
+
328
+ # Process clothes
329
+ if (clothes_mask1 is not None and clothes_mask2 is not None and
330
+ np.sum(clothes_mask1 > 0) > 0 and np.sum(clothes_mask2 > 0) > 0):
331
+ fine_grained_segments.add('l4_upper_clothes')
332
+ log.append(f" Added fine-grained: upper_clothes")
333
+
334
+ except Exception as e:
335
+ log.append(f" Error adding fine-grained segments: {e}")
336
+
337
+ return fine_grained_segments
338
+
339
+
340
+ def match_single_segment(img1_path: str, img2_path: str, mask1: np.ndarray, mask2: np.ndarray,
341
+ segment_name: str, method: str, log: list) -> tuple[Image.Image, np.ndarray, np.ndarray]:
342
+ """
343
+ Match colors of a single segment in complete isolation from other segments.
344
+ Each segment is processed independently with no knowledge of other segments.
345
+ Returns: (matched_image, final_mask1, final_mask2)
346
+ """
347
+ try:
348
+ # Load images
349
+ img1 = Image.open(img1_path).convert("RGB")
350
+ img2 = Image.open(img2_path).convert("RGB")
351
+
352
+ # Convert to numpy
353
+ img1_np = np.array(img1)
354
+ img2_np = np.array(img2)
355
+
356
+ # Handle fine-grained segments
357
+ if segment_name.startswith('l4_'):
358
+ part_name = segment_name.replace('l4_', '')
359
+ if part_name in ['face', 'hair']:
360
+ from human_parts_segmentation import HumanPartsSegmentation
361
+ segmenter = HumanPartsSegmentation()
362
+ masks_dict1 = segmenter.segment_parts(img1_path, [part_name])
363
+ masks_dict2 = segmenter.segment_parts(img2_path, [part_name])
364
+
365
+ if part_name in masks_dict1 and part_name in masks_dict2:
366
+ mask1 = masks_dict1[part_name]
367
+ mask2 = masks_dict2[part_name]
368
+ else:
369
+ return None, None, None
370
+
371
+ elif part_name == 'upper_clothes':
372
+ from clothes_segmentation import ClothesSegmentation
373
+ segmenter = ClothesSegmentation()
374
+ mask1 = segmenter.segment_clothes(img1_path, ["Upper-clothes"])
375
+ mask2 = segmenter.segment_clothes(img2_path, ["Upper-clothes"])
376
+
377
+ if mask1 is None or mask2 is None:
378
+ return None, None, None
379
+
380
+ # Ensure masks are same size as images
381
+ if mask1.shape != img1_np.shape[:2]:
382
+ mask1 = cv2.resize(mask1.astype(np.float32), (img1_np.shape[1], img1_np.shape[0]),
383
+ interpolation=cv2.INTER_NEAREST)
384
+ if mask2.shape != img2_np.shape[:2]:
385
+ mask2 = cv2.resize(mask2.astype(np.float32), (img2_np.shape[1], img2_np.shape[0]),
386
+ interpolation=cv2.INTER_NEAREST)
387
+
388
+ # Convert to binary masks
389
+ mask1_binary = (mask1 > 0.5).astype(np.float32)
390
+ mask2_binary = (mask2 > 0.5).astype(np.float32)
391
+
392
+ # Check if masks have content
393
+ pixels1 = np.sum(mask1_binary > 0)
394
+ pixels2 = np.sum(mask2_binary > 0)
395
+
396
+ if pixels1 == 0 or pixels2 == 0:
397
+ log.append(f" No pixels in {segment_name}: img1={pixels1}, img2={pixels2}")
398
+ return None, None, None
399
+
400
+ log.append(f" {segment_name}: img1={pixels1} pixels, img2={pixels2} pixels")
401
+
402
+ # Create single-segment masks dictionary for color matcher
403
+ masks1_dict = {segment_name: mask1_binary}
404
+ masks2_dict = {segment_name: mask2_binary}
405
+
406
+ # Apply color matching to this segment only
407
+ color_matcher = RegionColorMatcher(factor=0.8, preserve_colors=True,
408
+ preserve_luminance=True, method=method)
409
+
410
+ matched_img = color_matcher.match_regions(img1_path, img2_path, masks1_dict, masks2_dict)
411
+
412
+ return matched_img, mask1_binary, mask2_binary
413
+
414
+ except Exception as e:
415
+ log.append(f" Error matching {segment_name}: {e}")
416
+ return None, None, None
417
+
418
+
419
+ def composite_matched_segments(base_img_path: str, matched_regions: dict, segment_masks: dict, log: list) -> Image.Image:
420
+ """
421
+ Composite all matched segments back together using simple alpha compositing.
422
+ Each matched segment is completely independent and overlaid on the base image.
423
+ """
424
+ # Start with base image
425
+ result = Image.open(base_img_path).convert("RGBA")
426
+ result_np = np.array(result)
427
+
428
+ log.append(f"Compositing {len(matched_regions)} segments onto base image")
429
+
430
+ for segment_name, matched_img in matched_regions.items():
431
+ if segment_name in segment_masks:
432
+ mask = segment_masks[segment_name]
433
+
434
+ # Ensure mask is right size
435
+ if mask.shape != result_np.shape[:2]:
436
+ mask = cv2.resize(mask.astype(np.float32),
437
+ (result_np.shape[1], result_np.shape[0]),
438
+ interpolation=cv2.INTER_NEAREST)
439
+
440
+ # Convert matched image to numpy
441
+ matched_np = np.array(matched_img.convert("RGB"))
442
+
443
+ # Ensure matched image is right size
444
+ if matched_np.shape[:2] != result_np.shape[:2]:
445
+ matched_pil = Image.fromarray(matched_np)
446
+ matched_pil = matched_pil.resize((result_np.shape[1], result_np.shape[0]), Image.LANCZOS)
447
+ matched_np = np.array(matched_pil)
448
+
449
+ # Apply mask with alpha blending
450
+ mask_binary = (mask > 0.5).astype(np.float32)
451
+ alpha = np.expand_dims(mask_binary, axis=2)
452
+
453
+ # Blend: result = result * (1 - alpha) + matched * alpha
454
+ result_np[:, :, :3] = (result_np[:, :, :3] * (1 - alpha) +
455
+ matched_np * alpha).astype(np.uint8)
456
+
457
+ pixels = np.sum(mask_binary > 0)
458
+ log.append(f" Composited {segment_name}: {pixels} pixels")
459
+
460
+ return Image.fromarray(result_np).convert("RGB")
461
+
462
+
463
+ def visualize_matching_masks(img1_path, img2_path, masks1, masks2):
464
+ """
465
+ Create a visualization of the masks being matched between two images.
466
+
467
+ Args:
468
+ img1_path: Path to first image
469
+ img2_path: Path to second image
470
+ masks1: Dictionary of masks for first image {label: binary_mask}
471
+ masks2: Dictionary of masks for second image {label: binary_mask}
472
+
473
+ Returns:
474
+ A matplotlib Figure showing the matched masks
475
+ """
476
+ # Load images
477
+ img1 = Image.open(img1_path).convert("RGB")
478
+ img2 = Image.open(img2_path).convert("RGB")
479
+
480
+ # Convert to numpy arrays
481
+ img1_np = np.array(img1)
482
+ img2_np = np.array(img2)
483
+
484
+ # Separate fine-grained human parts from regular masks
485
+ fine_grained_masks = {}
486
+ regular_masks = {}
487
+
488
+ for label, mask in masks1.items():
489
+ if label.startswith('l4_'): # Fine-grained human parts
490
+ fine_grained_masks[label] = mask
491
+ else:
492
+ regular_masks[label] = mask
493
+
494
+ # Find common labels in both regular and fine-grained masks
495
+ common_regular = set(regular_masks.keys()).intersection(set(masks2.keys()))
496
+
497
+ # Count fine-grained masks that are in both masks1 and masks2
498
+ common_fine_grained = set()
499
+ for label in fine_grained_masks.keys():
500
+ if label.startswith('l4_') and label in masks2:
501
+ part_name = label.replace('l4_', '')
502
+ common_fine_grained.add(part_name)
503
+
504
+ # Count total rows needed
505
+ n_regular_rows = len(common_regular)
506
+ n_fine_rows = len(common_fine_grained)
507
+ n_rows = n_regular_rows + n_fine_rows
508
+
509
+ if n_rows == 0:
510
+ # No common regions found
511
+ fig, ax = plt.subplots(1, 1, figsize=(10, 5))
512
+ ax.text(0.5, 0.5, "No matching regions found between images",
513
+ ha='center', va='center', fontsize=14, color='white')
514
+ ax.axis('off')
515
+ return fig
516
+
517
+ fig, axes = plt.subplots(n_rows, 2, figsize=(12, 3 * n_rows))
518
+
519
+ # If only one row, reshape axes
520
+ if n_rows == 1:
521
+ axes = np.array([axes])
522
+
523
+ row_idx = 0
524
+
525
+ # Visualize regular semantic regions
526
+ for label in sorted(common_regular):
527
+ # Get label display name
528
+ display_name = label.replace("l2_", "").capitalize()
529
+
530
+ # Get masks and resize them to match the image dimensions
531
+ mask1 = regular_masks[label]
532
+ mask2 = masks2[label]
533
+
534
+ # Create visualizations
535
+ masked_img1, masked_img2 = create_mask_overlay(img1_np, img2_np, mask1, mask2, [255, 0, 0]) # Red
536
+
537
+ # Plot the masked images
538
+ axes[row_idx, 0].imshow(masked_img1)
539
+ axes[row_idx, 0].set_title(f"Image 1: {display_name}")
540
+ axes[row_idx, 0].axis('off')
541
+
542
+ axes[row_idx, 1].imshow(masked_img2)
543
+ axes[row_idx, 1].set_title(f"Image 2: {display_name}")
544
+ axes[row_idx, 1].axis('off')
545
+
546
+ row_idx += 1
547
+
548
+ # Visualize fine-grained human parts
549
+ part_colors = {
550
+ 'face': [255, 0, 0], # Red (like other masks)
551
+ 'hair': [255, 0, 0], # Red (like other masks)
552
+ 'upper_clothes': [255, 0, 0] # Red (like other masks)
553
+ }
554
+
555
+ for part_name in sorted(common_fine_grained):
556
+ label = f'l4_{part_name}'
557
+
558
+ if label in fine_grained_masks and label in masks2:
559
+ mask1 = fine_grained_masks[label]
560
+ mask2 = masks2[label]
561
+
562
+ color = part_colors.get(part_name, [255, 0, 0]) # Default to red
563
+
564
+ # Create visualizations
565
+ masked_img1, masked_img2 = create_mask_overlay(img1_np, img2_np, mask1, mask2, color)
566
+
567
+ # Plot the masked images
568
+ display_name = part_name.replace('_', ' ').title()
569
+ axes[row_idx, 0].imshow(masked_img1)
570
+ axes[row_idx, 0].set_title(f"Image 1: {display_name} (Fine-grained)")
571
+ axes[row_idx, 0].axis('off')
572
+
573
+ axes[row_idx, 1].imshow(masked_img2)
574
+ axes[row_idx, 1].set_title(f"Image 2: {display_name} (Fine-grained)")
575
+ axes[row_idx, 1].axis('off')
576
+
577
+ row_idx += 1
578
+
579
+ plt.suptitle("Matched Regions (highlighted with different colors)", fontsize=16, color='white')
580
+ plt.tight_layout()
581
+
582
+ return fig
583
+
584
+
585
+ def create_mask_overlay(img1_np, img2_np, mask1, mask2, overlay_color):
586
+ """
587
+ Create mask overlays on images with the specified color.
588
+
589
+ Args:
590
+ img1_np: First image as numpy array
591
+ img2_np: Second image as numpy array
592
+ mask1: Mask for first image
593
+ mask2: Mask for second image
594
+ overlay_color: RGB color for overlay [R, G, B]
595
+
596
+ Returns:
597
+ Tuple of (masked_img1, masked_img2)
598
+ """
599
+ # Resize masks to match image dimensions if needed
600
+ if mask1.shape != img1_np.shape[:2]:
601
+ mask1_img = Image.fromarray((mask1 * 255).astype(np.uint8))
602
+ mask1_img = mask1_img.resize((img1_np.shape[1], img1_np.shape[0]), Image.NEAREST)
603
+ mask1 = np.array(mask1_img).astype(np.float32) / 255.0
604
+
605
+ if mask2.shape != img2_np.shape[:2]:
606
+ mask2_img = Image.fromarray((mask2 * 255).astype(np.uint8))
607
+ mask2_img = mask2_img.resize((img2_np.shape[1], img2_np.shape[0]), Image.NEAREST)
608
+ mask2 = np.array(mask2_img).astype(np.float32) / 255.0
609
+
610
+ # Create masked versions of the images
611
+ masked_img1 = img1_np.copy()
612
+ masked_img2 = img2_np.copy()
613
+
614
+ # Apply a semi-transparent colored overlay to show the masked region
615
+ overlay_color = np.array(overlay_color, dtype=np.uint8)
616
+
617
+ # Create alpha channel based on the mask (with transparency)
618
+ alpha1 = mask1 * 0.6 # Increased opacity for better visibility
619
+ alpha2 = mask2 * 0.6
620
+
621
+ # Apply the colored overlay to masked regions
622
+ for c in range(3):
623
+ masked_img1[:, :, c] = masked_img1[:, :, c] * (1 - alpha1) + overlay_color[c] * alpha1
624
+ masked_img2[:, :, c] = masked_img2[:, :, c] * (1 - alpha2) + overlay_color[c] * alpha2
625
+
626
+ return masked_img1, masked_img2
627
+
628
+
629
+ def extract_semantic_masks(output):
630
+ """
631
+ Extract binary masks for each semantic region from the LadecoOutput.
632
+
633
+ Args:
634
+ output: LadecoOutput from Ladeco.predict()
635
+
636
+ Returns:
637
+ Dictionary mapping label names to binary masks
638
+ """
639
+ masks = {}
640
+
641
+ # Get the segmentation mask
642
+ seg_mask = output.masks[0].cpu().numpy()
643
+
644
+ # Process each label in level 2 (as we're visualizing at level 2)
645
+ for label, indices in output.ladeco2ade.items():
646
+ if label.startswith("l2_"):
647
+ # Create a binary mask for this label
648
+ binary_mask = np.zeros_like(seg_mask, dtype=np.float32)
649
+
650
+ # Set 1 for pixels matching this label
651
+ for idx in indices:
652
+ binary_mask[seg_mask == idx] = 1.0
653
+
654
+ # Only include labels that have some pixels in the image
655
+ if np.any(binary_mask):
656
+ masks[label] = binary_mask
657
+
658
+ return masks
659
+
660
+
661
+ def plot_pie(data: dict[str, float], colors=None) -> Figure:
662
+ fig, ax = plt.subplots()
663
+
664
+ labels = list(data.keys())
665
+ sizes = list(data.values())
666
+
667
+ *_, autotexts = ax.pie(sizes, labels=labels, autopct="%1.1f%%", colors=colors)
668
+
669
+ for percent_text in autotexts:
670
+ percent_text.set_color("k")
671
+
672
+ ax.axis("equal")
673
+
674
+ return fig
675
+
676
+
677
+ def choose_example(imgpath: str, target_component) -> gr.Image:
678
+ img = Image.open(imgpath)
679
+ width, height = img.size
680
+ ratio = 512 / max(width, height)
681
+ img = img.resize((int(width * ratio), int(height * ratio)))
682
+ return gr.Image(value=img, label="Input Image (SVG format not supported)", type="filepath")
683
+
684
+
685
+ css = """
686
+ .reference {
687
+ text-align: center;
688
+ font-size: 1.2em;
689
+ color: #d1d5db;
690
+ margin-bottom: 20px;
691
+ }
692
+ .reference a {
693
+ color: #FB923C;
694
+ text-decoration: none;
695
+ }
696
+ .reference a:hover {
697
+ text-decoration: underline;
698
+ color: #FB923C;
699
+ }
700
+ .description {
701
+ text-align: center;
702
+ font-size: 1.1em;
703
+ color: #d1d5db;
704
+ margin-bottom: 25px;
705
+ }
706
+ .footer {
707
+ text-align: center;
708
+ margin-top: 30px;
709
+ padding-top: 20px;
710
+ border-top: 1px solid #ddd;
711
+ color: #d1d5db;
712
+ font-size: 14px;
713
+ }
714
+ .main-title {
715
+ font-size: 24px;
716
+ font-weight: bold;
717
+ text-align: center;
718
+ margin-bottom: 20px;
719
+ }
720
+ .selected-image {
721
+ height: 756px;
722
+ }
723
+ .example-image {
724
+ height: 220px;
725
+ padding: 25px;
726
+ }
727
+ """.strip()
728
+ theme = gr.themes.Base(
729
+ primary_hue="orange",
730
+ secondary_hue="cyan",
731
+ neutral_hue="gray",
732
+ ).set(
733
+ body_text_color='*neutral_100',
734
+ body_text_color_subdued='*neutral_600',
735
+ background_fill_primary='*neutral_950',
736
+ background_fill_secondary='*neutral_600',
737
+ border_color_accent='*secondary_800',
738
+ color_accent='*primary_50',
739
+ color_accent_soft='*secondary_800',
740
+ code_background_fill='*neutral_700',
741
+ block_background_fill_dark='*body_background_fill',
742
+ block_info_text_color='#6b7280',
743
+ block_label_text_color='*neutral_300',
744
+ block_label_text_weight='700',
745
+ block_title_text_color='*block_label_text_color',
746
+ block_title_text_weight='300',
747
+ panel_background_fill='*neutral_800',
748
+ table_text_color_dark='*secondary_800',
749
+ checkbox_background_color_selected='*primary_500',
750
+ checkbox_label_background_fill='*neutral_500',
751
+ checkbox_label_background_fill_hover='*neutral_700',
752
+ checkbox_label_text_color='*neutral_200',
753
+ input_background_fill='*neutral_700',
754
+ input_background_fill_focus='*neutral_600',
755
+ slider_color='*primary_500',
756
+ table_even_background_fill='*neutral_700',
757
+ table_odd_background_fill='*neutral_600',
758
+ table_row_focus='*neutral_800'
759
+ )
760
+ with gr.Blocks(css=css, theme=theme) as demo:
761
+ gr.HTML(
762
+ """
763
+ <div class="main-title">SegMatch – Zero Shot Segmentation-based color matching</div>
764
+ <div class="description">
765
+ Advanced region-based color matching using semantic segmentation and fine-grained human parts detection for precise, contextually-aware color transfer between images.
766
+ </div>
767
+ """.strip()
768
+ )
769
+
770
+ with gr.Row():
771
+ # First image inputs
772
+ with gr.Column():
773
+ img1 = gr.Image(
774
+ label="First Input Image - Color Reference (SVG not supported)",
775
+ type="filepath",
776
+ height="256px",
777
+ )
778
+ gr.Label("Example Images for First Input", show_label=False)
779
+ with gr.Row():
780
+ ex1_1 = gr.Image(
781
+ value="examples/beach.jpg",
782
+ show_label=False,
783
+ type="filepath",
784
+ elem_classes="example-image",
785
+ interactive=False,
786
+ show_download_button=False,
787
+ show_fullscreen_button=False,
788
+ show_share_button=False,
789
+ )
790
+ ex1_2 = gr.Image(
791
+ value="examples/field.jpg",
792
+ show_label=False,
793
+ type="filepath",
794
+ elem_classes="example-image",
795
+ interactive=False,
796
+ show_download_button=False,
797
+ show_fullscreen_button=False,
798
+ show_share_button=False,
799
+ )
800
+
801
+ # Second image inputs
802
+ with gr.Column():
803
+ img2 = gr.Image(
804
+ label="Second Input Image - To Be Color Matched (SVG not supported)",
805
+ type="filepath",
806
+ height="256px",
807
+ )
808
+ gr.Label("Example Images for Second Input", show_label=False)
809
+ with gr.Row():
810
+ ex2_1 = gr.Image(
811
+ value="examples/field.jpg",
812
+ show_label=False,
813
+ type="filepath",
814
+ elem_classes="example-image",
815
+ interactive=False,
816
+ show_download_button=False,
817
+ show_fullscreen_button=False,
818
+ show_share_button=False,
819
+ )
820
+ ex2_2 = gr.Image(
821
+ value="examples/sky.jpg",
822
+ show_label=False,
823
+ type="filepath",
824
+ elem_classes="example-image",
825
+ interactive=False,
826
+ show_download_button=False,
827
+ show_fullscreen_button=False,
828
+ show_share_button=False,
829
+ )
830
+
831
+ with gr.Row():
832
+ with gr.Column():
833
+ method = gr.Dropdown(
834
+ label="Color Matching Method",
835
+ choices=["adain", "mkl", "hm", "reinhard", "mvgd", "hm-mvgd-hm", "hm-mkl-hm", "coral"],
836
+ value="adain",
837
+ info="Choose the algorithm for color matching between regions"
838
+ )
839
+
840
+ with gr.Column():
841
+ enable_face_matching = gr.Checkbox(
842
+ label="Enable Face Matching for Human Regions",
843
+ value=True,
844
+ info="Only match human regions if faces are similar (requires DeepFace)"
845
+ )
846
+
847
+ with gr.Row():
848
+ with gr.Column():
849
+ enable_edge_smoothing = gr.Checkbox(
850
+ label="Enable CDL Edge Smoothing",
851
+ value=False,
852
+ info="Apply CDL transform to original image using calculated parameters (see log for values)"
853
+ )
854
+
855
+ start = gr.Button("Start Analysis", variant="primary")
856
+
857
+ # Download button positioned right after the start button
858
+ download_btn = gr.File(
859
+ label="📥 Download Color-Matched Image",
860
+ visible=True,
861
+ interactive=False
862
+ )
863
+
864
+ with gr.Tabs():
865
+ with gr.TabItem("Segmentation Results"):
866
+ with gr.Row():
867
+ # First image results
868
+ with gr.Column():
869
+ gr.Label("Results for First Image", show_label=True)
870
+ seg1 = gr.Plot(label="Semantic Segmentation")
871
+ pie1 = gr.Plot(label="Element Area Ratio")
872
+
873
+ # Second image results
874
+ with gr.Column():
875
+ gr.Label("Results for Second Image", show_label=True)
876
+ seg2 = gr.Plot(label="Semantic Segmentation")
877
+ pie2 = gr.Plot(label="Element Area Ratio")
878
+
879
+ with gr.TabItem("Color Matching"):
880
+ gr.Markdown("""
881
+ ### Region-Based Color Matching
882
+
883
+ This tab shows the result of matching the colors of the second image to the first image's colors,
884
+ but only within corresponding semantic regions. For example, sky areas in the second image are
885
+ matched to sky areas in the first image, while vegetation areas are matched separately.
886
+
887
+ #### Face Matching Feature:
888
+ When enabled, the system will detect faces within human/bio regions and only apply color matching
889
+ to human regions where similar faces are found in both images. This ensures that color transfer
890
+ only occurs between images of the same person.
891
+
892
+ #### CDL Edge Smoothing Feature:
893
+ When enabled, calculates Color Decision List (CDL) parameters to transform the original target image
894
+ towards the segment-matched result, then applies those CDL parameters to the original image. This creates
895
+ a "smoothed" version that maintains the original image's overall characteristics while incorporating the
896
+ color relationships found through segment matching.
897
+
898
+ The CDL calculation uses the simplest possible approach: matches the mean and standard deviation
899
+ of each color channel between the original and composited images, with simple gamma calculation
900
+ based on brightness relationships.
901
+
902
+ #### Available Methods:
903
+ - **adain**: Adaptive Instance Normalization - Matches mean and standard deviation of colors
904
+ - **mkl**: Monge-Kantorovich Linearization - Linear transformation of color statistics
905
+ - **reinhard**: Reinhard color transfer - Simple statistical approach that matches mean and standard deviation
906
+ - **mvgd**: Multi-Variate Gaussian Distribution - Uses color covariance matrices for more accurate matching
907
+ - **hm**: Histogram Matching - Matches the full color distribution histograms
908
+ - **hm-mvgd-hm**: Histogram + MVGD + Histogram compound method
909
+ - **hm-mkl-hm**: Histogram + MKL + Histogram compound method
910
+ - **coral**: CORAL (Color Transfer using Correlated Color Temperature) - Advanced covariance-based method for natural color transfer
911
+ """)
912
+
913
+ # CDL Parameters Display
914
+ cdl_display = gr.Textbox(
915
+ label="📊 CDL Parameters",
916
+ lines=15,
917
+ max_lines=20,
918
+ interactive=False,
919
+ info="Color Decision List parameters calculated when CDL edge smoothing is enabled"
920
+ )
921
+
922
+ face_log = gr.Textbox(
923
+ label="Face Matching Log",
924
+ lines=8,
925
+ max_lines=15,
926
+ interactive=False,
927
+ info="Shows details of face detection and matching process"
928
+ )
929
+
930
+ mask_vis = gr.Plot(label="Matched Regions Visualization")
931
+ comparison = gr.Plot(label="Region-Based Color Matching Result")
932
+
933
+ gr.HTML(
934
+ """
935
+ <div class="footer">
936
+ © 2024 SegMatch All Rights Reserved<br>
937
+ Developer: Stefan Allen
938
+ </div>
939
+ """.strip()
940
+ )
941
+
942
+ # Connect the inference function
943
+ start.click(
944
+ fn=infer_two_images,
945
+ inputs=[img1, img2, method, enable_face_matching, enable_edge_smoothing],
946
+ outputs=[seg1, pie1, seg2, pie2, comparison, mask_vis, download_btn, face_log, cdl_display]
947
+ )
948
+
949
+ # Example image selection handlers
950
+ ex1_1.select(fn=lambda x: choose_example(x, img1), inputs=ex1_1, outputs=img1)
951
+ ex1_2.select(fn=lambda x: choose_example(x, img1), inputs=ex1_2, outputs=img1)
952
+ ex2_1.select(fn=lambda x: choose_example(x, img2), inputs=ex2_1, outputs=img2)
953
+ ex2_2.select(fn=lambda x: choose_example(x, img2), inputs=ex2_2, outputs=img2)
954
+
955
+ if __name__ == "__main__":
956
+ demo.launch()
cdl_smoothing.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CDL (Color Decision List) based edge smoothing for SegMatch
4
+ """
5
+
6
+ import numpy as np
7
+ from typing import Tuple, Optional
8
+ from PIL import Image
9
+ import cv2
10
+
11
+
12
+ def calculate_cdl_params_face_only(source: np.ndarray, target: np.ndarray,
13
+ source_face_mask: np.ndarray, target_face_mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
14
+ """Calculate CDL parameters using only face pixels for focused accuracy.
15
+
16
+ Args:
17
+ source (np.ndarray): Source image as numpy array (0-1 range)
18
+ target (np.ndarray): Target image as numpy array (0-1 range)
19
+ source_face_mask (np.ndarray): Binary mask of face in source image
20
+ target_face_mask (np.ndarray): Binary mask of face in target image
21
+
22
+ Returns:
23
+ Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
24
+ """
25
+ epsilon = 1e-6
26
+
27
+ # Extract face pixels only
28
+ source_face_pixels = source[source_face_mask > 0.5]
29
+ target_face_pixels = target[target_face_mask > 0.5]
30
+
31
+ # Ensure we have enough face pixels
32
+ if len(source_face_pixels) < 100 or len(target_face_pixels) < 100:
33
+ # Fallback to simple calculation if not enough face pixels
34
+ return calculate_cdl_params_simple(source, target)
35
+
36
+ slopes = []
37
+ offsets = []
38
+ powers = []
39
+
40
+ for channel in range(3):
41
+ src_channel = source_face_pixels[:, channel]
42
+ tgt_channel = target_face_pixels[:, channel]
43
+
44
+ # Use robust percentiles for face pixels
45
+ percentiles = [10, 25, 50, 75, 90]
46
+ src_percentiles = np.percentile(src_channel, percentiles)
47
+ tgt_percentiles = np.percentile(tgt_channel, percentiles)
48
+
49
+ # Calculate slope from face pixel range
50
+ src_range = src_percentiles[4] - src_percentiles[0] # 90th - 10th
51
+ tgt_range = tgt_percentiles[4] - tgt_percentiles[0]
52
+ slope = tgt_range / (src_range + epsilon)
53
+
54
+ # Calculate offset using face median
55
+ src_median = src_percentiles[2]
56
+ tgt_median = tgt_percentiles[2]
57
+ offset = tgt_median - (src_median * slope)
58
+
59
+ # Calculate gamma from face brightness relationship
60
+ src_mean = np.mean(src_channel)
61
+ tgt_mean = np.mean(tgt_channel)
62
+
63
+ if src_mean > epsilon:
64
+ power = np.log(tgt_mean + epsilon) / np.log(src_mean + epsilon)
65
+ power = np.clip(power, 0.3, 3.0)
66
+ else:
67
+ power = 1.0
68
+
69
+ slopes.append(slope)
70
+ offsets.append(offset)
71
+ powers.append(power)
72
+
73
+ return np.array(slopes), np.array(offsets), np.array(powers)
74
+
75
+
76
+ def calculate_cdl_params_simple(source: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
77
+ """Simple CDL calculation as fallback method.
78
+
79
+ Args:
80
+ source (np.ndarray): Source image as numpy array (0-1 range)
81
+ target (np.ndarray): Target image as numpy array (0-1 range)
82
+
83
+ Returns:
84
+ Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
85
+ """
86
+ epsilon = 1e-6
87
+
88
+ # Calculate mean and standard deviation for each RGB channel
89
+ source_mean = np.mean(source, axis=(0, 1))
90
+ source_std = np.std(source, axis=(0, 1))
91
+ target_mean = np.mean(target, axis=(0, 1))
92
+ target_std = np.std(target, axis=(0, 1))
93
+
94
+ # Calculate slope (gain)
95
+ slope = target_std / (source_std + epsilon)
96
+
97
+ # Calculate offset
98
+ offset = target_mean - (source_mean * slope)
99
+
100
+ # Set power to neutral
101
+ power = np.ones(3)
102
+
103
+ return slope, offset, power
104
+
105
+
106
+ def calculate_cdl_params_histogram(source: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
107
+ """Calculate CDL parameters using histogram matching approach.
108
+
109
+ Args:
110
+ source (np.ndarray): Source image as numpy array (0-1 range)
111
+ target (np.ndarray): Target image as numpy array (0-1 range)
112
+
113
+ Returns:
114
+ Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
115
+ """
116
+ epsilon = 1e-6
117
+
118
+ # Convert to 0-255 range for histogram calculation
119
+ source_255 = (source * 255).astype(np.uint8)
120
+ target_255 = (target * 255).astype(np.uint8)
121
+
122
+ slopes = []
123
+ offsets = []
124
+ powers = []
125
+
126
+ for channel in range(3):
127
+ # Calculate histograms
128
+ hist_source = cv2.calcHist([source_255], [channel], None, [256], [0, 256])
129
+ hist_target = cv2.calcHist([target_255], [channel], None, [256], [0, 256])
130
+
131
+ # Calculate cumulative distributions
132
+ cdf_source = np.cumsum(hist_source) / np.sum(hist_source)
133
+ cdf_target = np.cumsum(hist_target) / np.sum(hist_target)
134
+
135
+ # Find percentile mappings
136
+ p25_src = np.percentile(source[:, :, channel], 25)
137
+ p75_src = np.percentile(source[:, :, channel], 75)
138
+ p25_tgt = np.percentile(target[:, :, channel], 25)
139
+ p75_tgt = np.percentile(target[:, :, channel], 75)
140
+
141
+ # Calculate slope from percentile mapping
142
+ slope = (p75_tgt - p25_tgt) / (p75_src - p25_src + epsilon)
143
+
144
+ # Calculate offset
145
+ median_src = np.percentile(source[:, :, channel], 50)
146
+ median_tgt = np.percentile(target[:, :, channel], 50)
147
+ offset = median_tgt - (median_src * slope)
148
+
149
+ # Estimate power/gamma from the histogram shape
150
+ mean_src = np.mean(source[:, :, channel])
151
+ mean_tgt = np.mean(target[:, :, channel])
152
+ if mean_src > epsilon:
153
+ power = np.log(mean_tgt + epsilon) / np.log(mean_src + epsilon)
154
+ power = np.clip(power, 0.1, 10.0) # Reasonable gamma range
155
+ else:
156
+ power = 1.0
157
+
158
+ slopes.append(slope)
159
+ offsets.append(offset)
160
+ powers.append(power)
161
+
162
+ return np.array(slopes), np.array(offsets), np.array(powers)
163
+
164
+
165
+ def calculate_cdl_params_mask_aware(source: np.ndarray, target: np.ndarray,
166
+ changed_mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
167
+ """Calculate CDL parameters focusing only on changed regions.
168
+
169
+ Args:
170
+ source (np.ndarray): Source image as numpy array (0-1 range)
171
+ target (np.ndarray): Target image as numpy array (0-1 range)
172
+ changed_mask (np.ndarray, optional): Binary mask of changed regions
173
+
174
+ Returns:
175
+ Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
176
+ """
177
+ if changed_mask is not None:
178
+ # Only use pixels where changes occurred
179
+ mask_bool = changed_mask > 0.5
180
+ if np.sum(mask_bool) > 100: # Ensure enough pixels
181
+ source_masked = source[mask_bool]
182
+ target_masked = target[mask_bool]
183
+
184
+ # Reshape back to have channel dimension
185
+ source_masked = source_masked.reshape(-1, 3)
186
+ target_masked = target_masked.reshape(-1, 3)
187
+
188
+ # Calculate statistics on masked regions
189
+ epsilon = 1e-6
190
+ source_mean = np.mean(source_masked, axis=0)
191
+ source_std = np.std(source_masked, axis=0)
192
+ target_mean = np.mean(target_masked, axis=0)
193
+ target_std = np.std(target_masked, axis=0)
194
+
195
+ slope = target_std / (source_std + epsilon)
196
+ offset = target_mean - (source_mean * slope)
197
+ power = np.ones(3)
198
+
199
+ return slope, offset, power
200
+
201
+ # Fallback to simple method if mask is not useful
202
+ return calculate_cdl_params_simple(source, target)
203
+
204
+
205
+ def calculate_cdl_params_lab(source: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
206
+ """Calculate CDL parameters in LAB color space for better perceptual matching.
207
+
208
+ Args:
209
+ source (np.ndarray): Source image as numpy array (0-1 range)
210
+ target (np.ndarray): Target image as numpy array (0-1 range)
211
+
212
+ Returns:
213
+ Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
214
+ """
215
+ # Convert to LAB color space
216
+ source_lab = cv2.cvtColor((source * 255).astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
217
+ target_lab = cv2.cvtColor((target * 255).astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
218
+
219
+ # Normalize LAB values
220
+ source_lab[:, :, 0] /= 100.0 # L: 0-100 -> 0-1
221
+ source_lab[:, :, 1] = (source_lab[:, :, 1] + 128) / 255.0 # A: -128-127 -> 0-1
222
+ source_lab[:, :, 2] = (source_lab[:, :, 2] + 128) / 255.0 # B: -128-127 -> 0-1
223
+
224
+ target_lab[:, :, 0] /= 100.0
225
+ target_lab[:, :, 1] = (target_lab[:, :, 1] + 128) / 255.0
226
+ target_lab[:, :, 2] = (target_lab[:, :, 2] + 128) / 255.0
227
+
228
+ # Calculate CDL in LAB space
229
+ epsilon = 1e-6
230
+ source_mean = np.mean(source_lab, axis=(0, 1))
231
+ source_std = np.std(source_lab, axis=(0, 1))
232
+ target_mean = np.mean(target_lab, axis=(0, 1))
233
+ target_std = np.std(target_lab, axis=(0, 1))
234
+
235
+ slope_lab = target_std / (source_std + epsilon)
236
+ offset_lab = target_mean - (source_mean * slope_lab)
237
+
238
+ # Convert back to RGB approximation
239
+ # This is a simplified conversion - for full accuracy we'd need to convert each pixel
240
+ slope = np.array([slope_lab[0], slope_lab[1], slope_lab[2]]) # Rough mapping
241
+ offset = np.array([offset_lab[0], offset_lab[1], offset_lab[2]])
242
+ power = np.ones(3)
243
+
244
+ return slope, offset, power
245
+
246
+
247
+ def calculate_cdl_params(source: np.ndarray, target: np.ndarray,
248
+ source_path: str = None, target_path: str = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
249
+ """Calculate CDL parameters using simple mean/std matching - the most basic approach.
250
+
251
+ Args:
252
+ source (np.ndarray): Source image as numpy array (0-1 range)
253
+ target (np.ndarray): Target image as numpy array (0-1 range)
254
+ source_path (str, optional): Ignored - kept for compatibility
255
+ target_path (str, optional): Ignored - kept for compatibility
256
+
257
+ Returns:
258
+ Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
259
+ """
260
+ epsilon = 1e-6
261
+
262
+ # Calculate simple mean and standard deviation for each RGB channel
263
+ source_mean = np.mean(source, axis=(0, 1))
264
+ source_std = np.std(source, axis=(0, 1))
265
+ target_mean = np.mean(target, axis=(0, 1))
266
+ target_std = np.std(target, axis=(0, 1))
267
+
268
+ # Calculate slope (gain) from std ratio
269
+ slope = target_std / (source_std + epsilon)
270
+
271
+ # Calculate offset from mean difference
272
+ offset = target_mean - (source_mean * slope)
273
+
274
+ # Calculate simple gamma from brightness relationship
275
+ power = []
276
+ for channel in range(3):
277
+ if source_mean[channel] > epsilon:
278
+ gamma = np.log(target_mean[channel] + epsilon) / np.log(source_mean[channel] + epsilon)
279
+ gamma = np.clip(gamma, 0.1, 10.0) # Keep within reasonable bounds
280
+ else:
281
+ gamma = 1.0
282
+ power.append(gamma)
283
+
284
+ power = np.array(power)
285
+
286
+ return slope, offset, power
287
+
288
+
289
+ def calculate_change_mask(original: np.ndarray, composited: np.ndarray, threshold: float = 0.05) -> np.ndarray:
290
+ """Calculate a mask of significantly changed regions between original and composited images.
291
+
292
+ Args:
293
+ original (np.ndarray): Original image (0-1 range)
294
+ composited (np.ndarray): Composited result (0-1 range)
295
+ threshold (float): Threshold for detecting significant changes
296
+
297
+ Returns:
298
+ np.ndarray: Binary mask of changed regions
299
+ """
300
+ # Calculate per-pixel difference
301
+ diff = np.sqrt(np.sum((composited - original) ** 2, axis=2))
302
+
303
+ # Create binary mask where changes exceed threshold
304
+ change_mask = (diff > threshold).astype(np.float32)
305
+
306
+ # Apply morphological operations to clean up the mask
307
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
308
+ change_mask = cv2.morphologyEx(change_mask, cv2.MORPH_CLOSE, kernel)
309
+
310
+ return change_mask
311
+
312
+
313
+ def calculate_channel_stats(array: np.ndarray) -> dict:
314
+ """Calculate per-channel statistics for an image array.
315
+
316
+ Args:
317
+ array: Image array of shape (H, W, 3)
318
+
319
+ Returns:
320
+ dict: Dictionary containing mean, std, min, max for each channel
321
+ """
322
+ stats = {
323
+ 'mean': np.mean(array, axis=(0, 1)),
324
+ 'std': np.std(array, axis=(0, 1)),
325
+ 'min': np.min(array, axis=(0, 1)),
326
+ 'max': np.max(array, axis=(0, 1))
327
+ }
328
+ return stats
329
+
330
+
331
+ def apply_cdl_transform(image: np.ndarray, slope: np.ndarray, offset: np.ndarray, power: np.ndarray,
332
+ factor: float = 0.3) -> np.ndarray:
333
+ """Apply CDL transformation to an image.
334
+
335
+ Args:
336
+ image (np.ndarray): Input image (0-1 range)
337
+ slope (np.ndarray): CDL slope parameters for each channel
338
+ offset (np.ndarray): CDL offset parameters for each channel
339
+ power (np.ndarray): CDL power parameters for each channel
340
+ factor (float): Blending factor (0.0 = no change, 1.0 = full transform)
341
+
342
+ Returns:
343
+ np.ndarray: Transformed image
344
+ """
345
+ # Apply CDL transform: out = ((in * slope) + offset) ** power
346
+ transformed = np.power(np.maximum(image * slope + offset, 0), power)
347
+
348
+ # Clamp to valid range
349
+ transformed = np.clip(transformed, 0.0, 1.0)
350
+
351
+ # Blend with original based on factor
352
+ result = (1 - factor) * image + factor * transformed
353
+
354
+ return result
355
+
356
+
357
+ def cdl_edge_smoothing(composited_image_path: str, original_image_path: str, factor: float = 0.3) -> Image.Image:
358
+ """Apply CDL-based edge smoothing between composited result and original image.
359
+
360
+ Args:
361
+ composited_image_path (str): Path to the composited result image
362
+ original_image_path (str): Path to the original target image
363
+ factor (float): Smoothing strength (0.0 = no smoothing, 1.0 = full smoothing)
364
+
365
+ Returns:
366
+ Image.Image: Smoothed result image
367
+ """
368
+ # Load images
369
+ composited_img = Image.open(composited_image_path).convert("RGB")
370
+ original_img = Image.open(original_image_path).convert("RGB")
371
+
372
+ # Ensure same dimensions
373
+ if composited_img.size != original_img.size:
374
+ composited_img = composited_img.resize(original_img.size, Image.LANCZOS)
375
+
376
+ # Convert to numpy arrays (0-1 range)
377
+ composited_np = np.array(composited_img).astype(np.float32) / 255.0
378
+ original_np = np.array(original_img).astype(np.float32) / 255.0
379
+
380
+ # Calculate CDL parameters to transform composited to match original
381
+ slope, offset, power = calculate_cdl_params(composited_np, original_np)
382
+
383
+ # Apply CDL transformation with blending
384
+ smoothed_np = apply_cdl_transform(composited_np, slope, offset, power, factor)
385
+
386
+ # Convert back to PIL Image
387
+ smoothed_img = Image.fromarray((smoothed_np * 255).astype(np.uint8))
388
+
389
+ return smoothed_img
390
+
391
+
392
+ def get_smoothing_stats(original_image_path: str, composited_image_path: str) -> dict:
393
+ """Get statistics about the CDL transformation for debugging.
394
+
395
+ Args:
396
+ original_image_path (str): Path to the original target image
397
+ composited_image_path (str): Path to the composited result image
398
+
399
+ Returns:
400
+ dict: Statistics about the transformation
401
+ """
402
+ # Load images
403
+ composited_img = Image.open(composited_image_path).convert("RGB")
404
+ original_img = Image.open(original_image_path).convert("RGB")
405
+
406
+ # Ensure same dimensions
407
+ if composited_img.size != original_img.size:
408
+ composited_img = composited_img.resize(original_img.size, Image.LANCZOS)
409
+
410
+ # Convert to numpy arrays (0-1 range)
411
+ composited_np = np.array(composited_img).astype(np.float32) / 255.0
412
+ original_np = np.array(original_img).astype(np.float32) / 255.0
413
+
414
+ # Calculate statistics
415
+ composited_stats = calculate_channel_stats(composited_np)
416
+ original_stats = calculate_channel_stats(original_np)
417
+
418
+ # Calculate CDL parameters using face-based method when possible
419
+ slope, offset, power = calculate_cdl_params(original_np, composited_np,
420
+ original_image_path, composited_image_path)
421
+
422
+ return {
423
+ 'composited_stats': composited_stats,
424
+ 'original_stats': original_stats,
425
+ 'cdl_slope': slope,
426
+ 'cdl_offset': offset,
427
+ 'cdl_power': power
428
+ }
429
+
430
+
431
+ def cdl_edge_smoothing_apply_to_source(source_image_path: str, target_image_path: str, factor: float = 1.0) -> Image.Image:
432
+ """Apply CDL transformation to source image using face-based parameters when possible.
433
+
434
+ This function:
435
+ 1. Calculates CDL parameters to transform source to match target (using face pixels when available)
436
+ 2. Applies those CDL parameters to the entire source image
437
+ 3. Returns the transformed source image
438
+
439
+ Args:
440
+ source_image_path (str): Path to the source image (to be transformed)
441
+ target_image_path (str): Path to the target image (reference for CDL calculation)
442
+ factor (float): Transform strength (0.0 = no change, 1.0 = full transform)
443
+
444
+ Returns:
445
+ Image.Image: Source image with CDL transformation applied
446
+ """
447
+ # Load images
448
+ source_img = Image.open(source_image_path).convert("RGB")
449
+ target_img = Image.open(target_image_path).convert("RGB")
450
+
451
+ # Ensure same dimensions
452
+ if source_img.size != target_img.size:
453
+ target_img = target_img.resize(source_img.size, Image.LANCZOS)
454
+
455
+ # Convert to numpy arrays (0-1 range)
456
+ source_np = np.array(source_img).astype(np.float32) / 255.0
457
+ target_np = np.array(target_img).astype(np.float32) / 255.0
458
+
459
+ # Calculate CDL parameters using face-based method when possible
460
+ slope, offset, power = calculate_cdl_params(source_np, target_np,
461
+ source_image_path, target_image_path)
462
+
463
+ # Apply CDL transformation to the entire source image
464
+ transformed_np = apply_cdl_transform(source_np, slope, offset, power, factor)
465
+
466
+ # Convert back to PIL Image
467
+ transformed_img = Image.fromarray((transformed_np * 255).astype(np.uint8))
468
+
469
+ return transformed_img
470
+
471
+
472
+ def extract_face_mask(image_path: str) -> Optional[np.ndarray]:
473
+ """Extract face mask from an image using human parts segmentation.
474
+
475
+ Args:
476
+ image_path (str): Path to the image
477
+
478
+ Returns:
479
+ np.ndarray or None: Binary face mask, or None if no face found
480
+ """
481
+ try:
482
+ from human_parts_segmentation import HumanPartsSegmentation
483
+
484
+ segmenter = HumanPartsSegmentation()
485
+ masks_dict = segmenter.segment_parts(image_path, ['face'])
486
+
487
+ if 'face' in masks_dict and masks_dict['face'] is not None:
488
+ face_mask = masks_dict['face']
489
+ # Ensure it's a proper binary mask
490
+ if np.sum(face_mask > 0.5) > 100: # At least 100 face pixels
491
+ return face_mask
492
+
493
+ return None
494
+
495
+ except Exception as e:
496
+ print(f"Face extraction failed: {e}")
497
+ return None
clothes_segmentation.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from typing import Union, Tuple
6
+ from PIL import Image, ImageFilter
7
+ import cv2
8
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
9
+ from huggingface_hub import hf_hub_download
10
+ import shutil
11
+
12
+ # Device configuration
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # Model configuration
16
+ AVAILABLE_MODELS = {
17
+ "segformer_b2_clothes": "1038lab/segformer_clothes"
18
+ }
19
+
20
+ # Model paths
21
+ current_dir = os.path.dirname(os.path.abspath(__file__))
22
+ models_dir = os.path.join(current_dir, "models")
23
+
24
+
25
+ def pil2tensor(image: Image.Image) -> torch.Tensor:
26
+ """Convert PIL Image to tensor."""
27
+ return torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
28
+
29
+
30
+ def tensor2pil(image: torch.Tensor) -> Image.Image:
31
+ """Convert tensor to PIL Image."""
32
+ return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8))
33
+
34
+
35
+ def image2mask(image: Image.Image) -> torch.Tensor:
36
+ """Convert image to mask tensor."""
37
+ if isinstance(image, Image.Image):
38
+ image = pil2tensor(image)
39
+ return image.squeeze()[..., 0]
40
+
41
+
42
+ def mask2image(mask: torch.Tensor) -> Image.Image:
43
+ """Convert mask tensor to PIL Image."""
44
+ if len(mask.shape) == 2:
45
+ mask = mask.unsqueeze(0)
46
+ return tensor2pil(mask)
47
+
48
+
49
+ class ClothesSegmentation:
50
+ """
51
+ Standalone clothing segmentation using Segformer model.
52
+ """
53
+
54
+ def __init__(self):
55
+ self.processor = None
56
+ self.model = None
57
+ self.cache_dir = os.path.join(models_dir, "RMBG", "segformer_clothes")
58
+
59
+ # Class mapping for segmentation - consistent with latest repo
60
+ self.class_map = {
61
+ "Background": 0, "Hat": 1, "Hair": 2, "Sunglasses": 3,
62
+ "Upper-clothes": 4, "Skirt": 5, "Pants": 6, "Dress": 7,
63
+ "Belt": 8, "Left-shoe": 9, "Right-shoe": 10, "Face": 11,
64
+ "Left-leg": 12, "Right-leg": 13, "Left-arm": 14, "Right-arm": 15,
65
+ "Bag": 16, "Scarf": 17
66
+ }
67
+
68
+ def check_model_cache(self):
69
+ """Check if model files exist in cache."""
70
+ if not os.path.exists(self.cache_dir):
71
+ return False, "Model directory not found"
72
+
73
+ required_files = [
74
+ 'config.json',
75
+ 'model.safetensors',
76
+ 'preprocessor_config.json'
77
+ ]
78
+
79
+ missing_files = [f for f in required_files if not os.path.exists(os.path.join(self.cache_dir, f))]
80
+ if missing_files:
81
+ return False, f"Required model files missing: {', '.join(missing_files)}"
82
+ return True, "Model cache verified"
83
+
84
+ def clear_model(self):
85
+ """Clear model from memory - improved version."""
86
+ if self.model is not None:
87
+ self.model.cpu()
88
+ del self.model
89
+ self.model = None
90
+ self.processor = None
91
+ if torch.cuda.is_available():
92
+ torch.cuda.empty_cache()
93
+
94
+ def download_model_files(self):
95
+ """Download model files from Hugging Face - improved version."""
96
+ model_id = AVAILABLE_MODELS["segformer_b2_clothes"]
97
+ model_files = {
98
+ 'config.json': 'config.json',
99
+ 'model.safetensors': 'model.safetensors',
100
+ 'preprocessor_config.json': 'preprocessor_config.json'
101
+ }
102
+
103
+ os.makedirs(self.cache_dir, exist_ok=True)
104
+ print(f"Downloading Clothes Segformer model files...")
105
+
106
+ try:
107
+ for save_name, repo_path in model_files.items():
108
+ print(f"Downloading {save_name}...")
109
+ downloaded_path = hf_hub_download(
110
+ repo_id=model_id,
111
+ filename=repo_path,
112
+ local_dir=self.cache_dir,
113
+ local_dir_use_symlinks=False
114
+ )
115
+
116
+ if os.path.dirname(downloaded_path) != self.cache_dir:
117
+ target_path = os.path.join(self.cache_dir, save_name)
118
+ shutil.move(downloaded_path, target_path)
119
+ return True, "Model files downloaded successfully"
120
+ except Exception as e:
121
+ return False, f"Error downloading model files: {str(e)}"
122
+
123
+ def load_model(self):
124
+ """Load the clothing segmentation model - improved version."""
125
+ try:
126
+ # Check and download model if needed
127
+ cache_status, message = self.check_model_cache()
128
+ if not cache_status:
129
+ print(f"Cache check: {message}")
130
+ download_status, download_message = self.download_model_files()
131
+ if not download_status:
132
+ print(f"❌ {download_message}")
133
+ return False
134
+
135
+ # Load model if needed
136
+ if self.processor is None:
137
+ print("Loading clothes segmentation model...")
138
+ self.processor = SegformerImageProcessor.from_pretrained(self.cache_dir)
139
+ self.model = AutoModelForSemanticSegmentation.from_pretrained(self.cache_dir)
140
+ self.model.eval()
141
+ for param in self.model.parameters():
142
+ param.requires_grad = False
143
+ self.model.to(device)
144
+ print("✅ Clothes segmentation model loaded successfully")
145
+
146
+ return True
147
+
148
+ except Exception as e:
149
+ print(f"❌ Error loading clothes model: {e}")
150
+ self.clear_model() # Cleanup on error
151
+ return False
152
+
153
+ def segment_clothes(self, image_path: str, target_classes: list = None, process_res: int = 512) -> np.ndarray:
154
+ """
155
+ Segment clothing from an image - improved version with process_res parameter.
156
+
157
+ Args:
158
+ image_path: Path to the image
159
+ target_classes: List of clothing classes to segment (default: ["Upper-clothes"])
160
+ process_res: Processing resolution (default: 512)
161
+
162
+ Returns:
163
+ Binary mask as numpy array
164
+ """
165
+ if target_classes is None:
166
+ target_classes = ["Upper-clothes"]
167
+
168
+ if not self.load_model():
169
+ print("❌ Cannot load clothes segmentation model")
170
+ return None
171
+
172
+ try:
173
+ # Load and preprocess image
174
+ image = cv2.imread(image_path)
175
+ if image is None:
176
+ print(f"❌ Could not load image: {image_path}")
177
+ return None
178
+
179
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
180
+ original_size = image_rgb.shape[:2]
181
+
182
+ # Preprocess image with custom resolution
183
+ pil_image = Image.fromarray(image_rgb)
184
+
185
+ # Resize for processing if needed
186
+ if process_res != 512:
187
+ pil_image = pil_image.resize((process_res, process_res), Image.Resampling.LANCZOS)
188
+
189
+ inputs = self.processor(images=pil_image, return_tensors="pt")
190
+ inputs = {k: v.to(device) for k, v in inputs.items()}
191
+
192
+ # Run inference
193
+ with torch.no_grad():
194
+ outputs = self.model(**inputs)
195
+ logits = outputs.logits.cpu()
196
+
197
+ # Resize logits to original image size
198
+ upsampled_logits = nn.functional.interpolate(
199
+ logits,
200
+ size=original_size,
201
+ mode="bilinear",
202
+ align_corners=False,
203
+ )
204
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
205
+
206
+ # Combine selected class masks
207
+ combined_mask = None
208
+ for class_name in target_classes:
209
+ if class_name in self.class_map:
210
+ mask = (pred_seg == self.class_map[class_name]).float()
211
+ if combined_mask is None:
212
+ combined_mask = mask
213
+ else:
214
+ combined_mask = torch.clamp(combined_mask + mask, 0, 1)
215
+ else:
216
+ print(f"⚠️ Unknown class: {class_name}")
217
+
218
+ if combined_mask is None:
219
+ print(f"❌ No valid classes found in: {target_classes}")
220
+ return None
221
+
222
+ # Convert to numpy
223
+ mask_np = combined_mask.numpy().astype(np.float32)
224
+
225
+ return mask_np
226
+
227
+ except Exception as e:
228
+ print(f"❌ Error in clothes segmentation: {e}")
229
+ return None
230
+ finally:
231
+ # Clean up model if not training (consistent with updated repo)
232
+ if self.model is not None and not self.model.training:
233
+ self.clear_model()
234
+
235
+ def segment_clothes_with_filters(self, image_path: str, target_classes: list = None,
236
+ mask_blur: int = 0, mask_offset: int = 0,
237
+ process_res: int = 512) -> np.ndarray:
238
+ """
239
+ Segment clothing with additional filtering options - new method from updated repo.
240
+
241
+ Args:
242
+ image_path: Path to the image
243
+ target_classes: List of clothing classes to segment
244
+ mask_blur: Blur amount for mask edges
245
+ mask_offset: Expand/Shrink mask boundary
246
+ process_res: Processing resolution
247
+
248
+ Returns:
249
+ Filtered binary mask as numpy array
250
+ """
251
+ # Get initial mask
252
+ mask = self.segment_clothes(image_path, target_classes, process_res)
253
+ if mask is None:
254
+ return None
255
+
256
+ try:
257
+ # Convert to PIL for filtering
258
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8))
259
+
260
+ # Apply blur if specified
261
+ if mask_blur > 0:
262
+ mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
263
+
264
+ # Apply offset if specified
265
+ if mask_offset != 0:
266
+ if mask_offset > 0:
267
+ mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1))
268
+ else:
269
+ mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1))
270
+
271
+ # Convert back to numpy
272
+ filtered_mask = np.array(mask_image).astype(np.float32) / 255.0
273
+ return filtered_mask
274
+
275
+ except Exception as e:
276
+ print(f"❌ Error applying filters: {e}")
277
+ return mask
278
+
279
+
280
+ # Standalone function for easy usage
281
+ def segment_upper_clothes(image_path: str) -> np.ndarray:
282
+ """
283
+ Convenience function to segment upper clothes from an image.
284
+
285
+ Args:
286
+ image_path: Path to the image
287
+
288
+ Returns:
289
+ Binary mask as numpy array or None if failed
290
+ """
291
+ segmenter = ClothesSegmentation()
292
+ return segmenter.segment_clothes(image_path, ["Upper-clothes"])
color_matching.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.figure as figure
6
+ from matplotlib.figure import Figure
7
+ import numpy.typing as npt
8
+ import os
9
+ import sys
10
+ import tempfile
11
+ import time
12
+
13
+ class RegionColorMatcher:
14
+ def __init__(self, factor=1.0, preserve_colors=True, preserve_luminance=True, method="adain"):
15
+ """
16
+ Initialize the RegionColorMatcher.
17
+
18
+ Args:
19
+ factor: Strength of the color matching (0.0 to 1.0)
20
+ preserve_colors: If True, convert to YUV and preserve color relationships
21
+ preserve_luminance: If True, preserve the luminance when in YUV mode
22
+ method: The color matching method to use (adain, mkl, hm, reinhard, mvgd, hm-mvgd-hm, hm-mkl-hm)
23
+ """
24
+ self.factor = factor
25
+ self.preserve_colors = preserve_colors
26
+ self.preserve_luminance = preserve_luminance
27
+ self.method = method
28
+
29
+ def match_regions(self, img1_path, img2_path, masks1, masks2):
30
+ """
31
+ Match colors between corresponding masked regions of two images.
32
+
33
+ Args:
34
+ img1_path: Path to first image
35
+ img2_path: Path to second image
36
+ masks1: Dictionary of masks for first image {label: binary_mask}
37
+ masks2: Dictionary of masks for second image {label: binary_mask}
38
+
39
+ Returns:
40
+ A PIL Image with the color-matched result
41
+ """
42
+ print(f"🎨 Color matching with method: {self.method}")
43
+ print(f"📊 Processing {len(masks1)} regions from img1 and {len(masks2)} regions from img2")
44
+
45
+ # Load images
46
+ img1 = Image.open(img1_path).convert("RGB")
47
+ img2 = Image.open(img2_path).convert("RGB")
48
+
49
+ # Convert to numpy arrays and normalize to [0,1]
50
+ img1_np = np.array(img1).astype(np.float32) / 255.0
51
+ img2_np = np.array(img2).astype(np.float32) / 255.0
52
+
53
+ # Create a copy of the second image as our base for color matching
54
+ # We want to make img2 look like img1's colors
55
+ result_np = np.copy(img2_np)
56
+
57
+ # Convert images to PyTorch tensors
58
+ img1_tensor = torch.from_numpy(img1_np)
59
+ img2_tensor = torch.from_numpy(img2_np)
60
+ result_tensor = torch.from_numpy(result_np)
61
+
62
+ # Track coverage to ensure all regions are processed
63
+ total_coverage = np.zeros(img2_np.shape[:2], dtype=np.float32)
64
+ processed_regions = 0
65
+
66
+ # Process each mask region
67
+ for label, mask1 in masks1.items():
68
+ if label not in masks2:
69
+ print(f"⚠️ Skipping {label} - not found in masks2")
70
+ continue
71
+
72
+ mask2 = masks2[label]
73
+
74
+ # Resize masks to match image dimensions if needed
75
+ if mask1.shape != img1_np.shape[:2]:
76
+ mask1 = self._resize_mask(mask1, img1_np.shape[:2])
77
+
78
+ if mask2.shape != img2_np.shape[:2]:
79
+ mask2 = self._resize_mask(mask2, img2_np.shape[:2])
80
+
81
+ # Check mask coverage
82
+ mask1_pixels = np.sum(mask1 > 0)
83
+ mask2_pixels = np.sum(mask2 > 0)
84
+ print(f"🔍 Processing {label}: {mask1_pixels} pixels (img1) → {mask2_pixels} pixels (img2)")
85
+
86
+ if mask1_pixels == 0 or mask2_pixels == 0:
87
+ print(f"⚠️ Skipping {label} - no pixels in mask")
88
+ continue
89
+
90
+ # Track coverage
91
+ total_coverage += (mask2 > 0).astype(np.float32)
92
+ processed_regions += 1
93
+
94
+ # Convert masks to torch tensors
95
+ mask1_tensor = torch.from_numpy(mask1.astype(np.float32))
96
+ mask2_tensor = torch.from_numpy(mask2.astype(np.float32))
97
+
98
+ # Apply color matching for this region based on selected method
99
+ if self.method == "adain":
100
+ result_tensor = self._apply_adain_to_region(
101
+ img1_tensor,
102
+ img2_tensor,
103
+ result_tensor,
104
+ mask1_tensor,
105
+ mask2_tensor
106
+ )
107
+ else:
108
+ result_tensor = self._apply_color_matcher_to_region(
109
+ img1_tensor,
110
+ img2_tensor,
111
+ result_tensor,
112
+ mask1_tensor,
113
+ mask2_tensor,
114
+ self.method
115
+ )
116
+
117
+ print(f"✅ Completed color matching for {label}")
118
+
119
+ # Debug coverage
120
+ total_pixels = img2_np.shape[0] * img2_np.shape[1]
121
+ covered_pixels = np.sum(total_coverage > 0)
122
+ overlap_pixels = np.sum(total_coverage > 1)
123
+
124
+ print(f"📊 Coverage summary:")
125
+ print(f" Total image pixels: {total_pixels}")
126
+ print(f" Covered pixels: {covered_pixels} ({100*covered_pixels/total_pixels:.1f}%)")
127
+ print(f" Overlapping pixels: {overlap_pixels} ({100*overlap_pixels/total_pixels:.1f}%)")
128
+ print(f" Processed regions: {processed_regions}")
129
+
130
+ # Convert back to numpy, scale to [0,255] and convert to uint8
131
+ result_np = (result_tensor.numpy() * 255.0).astype(np.uint8)
132
+
133
+ # Convert to PIL Image
134
+ result_img = Image.fromarray(result_np)
135
+
136
+ return result_img
137
+
138
+ def _resize_mask(self, mask, target_shape):
139
+ """
140
+ Resize a mask to match the target shape.
141
+
142
+ Args:
143
+ mask: Binary mask array
144
+ target_shape: Target shape (height, width)
145
+
146
+ Returns:
147
+ Resized mask array
148
+ """
149
+ # Convert to PIL Image for resizing
150
+ mask_img = Image.fromarray((mask * 255).astype(np.uint8))
151
+
152
+ # Resize to target shape
153
+ mask_img = mask_img.resize((target_shape[1], target_shape[0]), Image.NEAREST)
154
+
155
+ # Convert back to numpy array and normalize to [0,1]
156
+ resized_mask = np.array(mask_img).astype(np.float32) / 255.0
157
+
158
+ return resized_mask
159
+
160
+ def _apply_adain_to_region(self, source_img, target_img, result_img, source_mask, target_mask):
161
+ """
162
+ Apply AdaIN to match the statistics of the masked region in source to the target.
163
+
164
+ Args:
165
+ source_img: Source image tensor [H,W,3] (reference for color matching)
166
+ target_img: Target image tensor [H,W,3] (to be color matched)
167
+ result_img: Result image tensor to modify [H,W,3]
168
+ source_mask: Binary mask for source image [H,W]
169
+ target_mask: Binary mask for target image [H,W]
170
+
171
+ Returns:
172
+ Modified result tensor
173
+ """
174
+ # Ensure masks are binary
175
+ source_mask_binary = (source_mask > 0.5).float()
176
+ target_mask_binary = (target_mask > 0.5).float()
177
+
178
+ # If preserving colors, convert to YUV
179
+ if self.preserve_colors:
180
+ # RGB to YUV conversion matrix
181
+ rgb_to_yuv = torch.tensor([
182
+ [0.299, 0.587, 0.114],
183
+ [-0.14713, -0.28886, 0.436],
184
+ [0.615, -0.51499, -0.10001]
185
+ ])
186
+
187
+ # Convert to YUV
188
+ source_yuv = torch.matmul(source_img, rgb_to_yuv.T)
189
+ target_yuv = torch.matmul(target_img, rgb_to_yuv.T)
190
+ result_yuv = torch.matmul(result_img, rgb_to_yuv.T)
191
+
192
+ # Only normalize Y channel if preserving luminance is False
193
+ channels_to_process = [0] if not self.preserve_luminance else []
194
+
195
+ # Always process U and V channels (chroma)
196
+ channels_to_process.extend([1, 2])
197
+
198
+ # Process selected channels
199
+ for c in channels_to_process:
200
+ # Apply the color matching only to the masked region in the result
201
+ result_channel = result_yuv[:,:,c]
202
+ matched_channel = self._match_channel_statistics(
203
+ source_yuv[:,:,c],
204
+ target_yuv[:,:,c],
205
+ result_channel,
206
+ source_mask_binary,
207
+ target_mask_binary
208
+ )
209
+
210
+ # Only update the masked region in the result
211
+ mask_expanded = target_mask_binary.unsqueeze(-1).expand_as(result_yuv)[:,:,c]
212
+ result_yuv[:,:,c] = torch.where(
213
+ mask_expanded > 0.5,
214
+ matched_channel,
215
+ result_channel
216
+ )
217
+
218
+ # Convert back to RGB
219
+ yuv_to_rgb = torch.tensor([
220
+ [1.0, 0.0, 1.13983],
221
+ [1.0, -0.39465, -0.58060],
222
+ [1.0, 2.03211, 0.0]
223
+ ])
224
+
225
+ result_rgb = torch.matmul(result_yuv, yuv_to_rgb.T)
226
+
227
+ # Only update the masked region in the result
228
+ mask_expanded = target_mask_binary.unsqueeze(-1).expand_as(result_img)
229
+ result_img = torch.where(
230
+ mask_expanded > 0.5,
231
+ result_rgb,
232
+ result_img
233
+ )
234
+
235
+ else:
236
+ # Process each RGB channel separately
237
+ for c in range(3):
238
+ # Apply the color matching only to the masked region in the result
239
+ result_channel = result_img[:,:,c]
240
+ matched_channel = self._match_channel_statistics(
241
+ source_img[:,:,c],
242
+ target_img[:,:,c],
243
+ result_channel,
244
+ source_mask_binary,
245
+ target_mask_binary
246
+ )
247
+
248
+ # Only update the masked region in the result
249
+ mask_expanded = target_mask_binary.unsqueeze(-1).expand_as(result_img)[:,:,c]
250
+ result_img[:,:,c] = torch.where(
251
+ mask_expanded > 0.5,
252
+ matched_channel,
253
+ result_channel
254
+ )
255
+
256
+ # Ensure values are in valid range [0, 1]
257
+ return torch.clamp(result_img, 0.0, 1.0)
258
+
259
+ def _apply_color_matcher_to_region(self, source_img, target_img, result_img, source_mask, target_mask, method):
260
+ """
261
+ Apply color-matcher library methods to match the statistics of the masked region in source to the target.
262
+
263
+ Args:
264
+ source_img: Source image tensor [H,W,3] (reference for color matching)
265
+ target_img: Target image tensor [H,W,3] (to be color matched)
266
+ result_img: Result image tensor to modify [H,W,3]
267
+ source_mask: Binary mask for source image [H,W]
268
+ target_mask: Binary mask for target image [H,W]
269
+ method: The color matching method to use (mkl, hm, reinhard, mvgd, hm-mvgd-hm, hm-mkl-hm)
270
+
271
+ Returns:
272
+ Modified result tensor
273
+ """
274
+ # Ensure masks are binary
275
+ source_mask_binary = (source_mask > 0.5).float()
276
+ target_mask_binary = (target_mask > 0.5).float()
277
+
278
+ # Convert tensors to numpy arrays
279
+ source_np = source_img.detach().cpu().numpy()
280
+ target_np = target_img.detach().cpu().numpy()
281
+ source_mask_np = source_mask_binary.detach().cpu().numpy()
282
+ target_mask_np = target_mask_binary.detach().cpu().numpy()
283
+
284
+ try:
285
+ # Try to import the color_matcher library
286
+ try:
287
+ from color_matcher import ColorMatcher
288
+ from color_matcher.normalizer import Normalizer
289
+ except ImportError:
290
+ self._install_package("color-matcher")
291
+ from color_matcher import ColorMatcher
292
+ from color_matcher.normalizer import Normalizer
293
+
294
+ # Extract only the masked pixels from both images
295
+ source_coords = np.where(source_mask_np > 0.5)
296
+ target_coords = np.where(target_mask_np > 0.5)
297
+
298
+ if len(source_coords[0]) == 0 or len(target_coords[0]) == 0:
299
+ return result_img
300
+
301
+ # Extract pixel values from masked regions
302
+ source_pixels = source_np[source_coords]
303
+ target_pixels = target_np[target_coords]
304
+
305
+ # Initialize color matcher
306
+ cm = ColorMatcher()
307
+
308
+ if method == "mkl":
309
+ # For MKL, calculate mean and standard deviation from masked regions
310
+ source_mean = np.mean(source_pixels, axis=0)
311
+ source_std = np.std(source_pixels, axis=0)
312
+ target_mean = np.mean(target_pixels, axis=0)
313
+ target_std = np.std(target_pixels, axis=0)
314
+
315
+ # Apply the transformation
316
+ result_np = np.copy(target_np)
317
+ for c in range(3):
318
+ # Normalize the target channel and scale by source statistics
319
+ normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
320
+
321
+ # Only apply to masked region
322
+ result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
323
+
324
+ # Convert back to tensor
325
+ result_tensor = torch.from_numpy(result_np).to(result_img.device)
326
+
327
+ # Blend with original based on factor
328
+ result_img = torch.lerp(result_img, result_tensor, self.factor)
329
+
330
+ elif method == "reinhard":
331
+ # Similar to MKL but with a different normalization approach
332
+ source_mean = np.mean(source_pixels, axis=0)
333
+ source_std = np.std(source_pixels, axis=0)
334
+ target_mean = np.mean(target_pixels, axis=0)
335
+ target_std = np.std(target_pixels, axis=0)
336
+
337
+ # Apply the transformation
338
+ result_np = np.copy(target_np)
339
+ for c in range(3):
340
+ # Normalize the target channel and scale by source statistics
341
+ normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
342
+
343
+ # Only apply to masked region
344
+ result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
345
+
346
+ # Convert back to tensor
347
+ result_tensor = torch.from_numpy(result_np).to(result_img.device)
348
+
349
+ # Blend with original based on factor
350
+ result_img = torch.lerp(result_img, result_tensor, self.factor)
351
+
352
+ elif method == "mvgd":
353
+ # For MVGD, we need mean and covariance matrices
354
+ source_mean = np.mean(source_pixels, axis=0)
355
+ source_cov = np.cov(source_pixels, rowvar=False)
356
+ target_mean = np.mean(target_pixels, axis=0)
357
+ target_cov = np.cov(target_pixels, rowvar=False)
358
+
359
+ # Check if covariance matrices are valid
360
+ if np.isnan(source_cov).any() or np.isnan(target_cov).any():
361
+ # Fallback to simple statistics matching
362
+ source_std = np.std(source_pixels, axis=0)
363
+ target_std = np.std(target_pixels, axis=0)
364
+
365
+ result_np = np.copy(target_np)
366
+ for c in range(3):
367
+ normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
368
+ result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
369
+ else:
370
+ # Apply full MVGD transformation to masked pixels
371
+ # Reshape the masked pixels for matrix operations
372
+ target_flat = target_np.reshape(-1, 3)
373
+ result_np = np.copy(target_np)
374
+
375
+ try:
376
+ # Try to compute the full MVGD transformation
377
+ source_cov_sqrt = np.linalg.cholesky(source_cov)
378
+ target_cov_sqrt = np.linalg.cholesky(target_cov)
379
+ target_cov_sqrt_inv = np.linalg.inv(target_cov_sqrt)
380
+
381
+ # Compute the transformation matrix
382
+ temp = target_cov_sqrt_inv @ source_cov @ target_cov_sqrt_inv.T
383
+ temp_sqrt_inv = np.linalg.inv(np.linalg.cholesky(temp))
384
+ A = target_cov_sqrt @ temp_sqrt_inv @ target_cov_sqrt_inv
385
+
386
+ # Apply the transformation to all pixels
387
+ for i in range(target_np.shape[0]):
388
+ for j in range(target_np.shape[1]):
389
+ if target_mask_np[i, j] > 0.5:
390
+ # Only apply to masked pixels
391
+ pixel = target_np[i, j]
392
+ centered = pixel - target_mean
393
+ transformed = centered @ A.T + source_mean
394
+ result_np[i, j] = transformed
395
+ except np.linalg.LinAlgError:
396
+ # Fallback to simple statistics matching
397
+ source_std = np.std(source_pixels, axis=0)
398
+ target_std = np.std(target_pixels, axis=0)
399
+
400
+ for c in range(3):
401
+ normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
402
+ result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
403
+
404
+ # Convert back to tensor
405
+ result_tensor = torch.from_numpy(result_np).to(result_img.device)
406
+
407
+ # Blend with original based on factor
408
+ result_img = torch.lerp(result_img, result_tensor, self.factor)
409
+
410
+ elif method in ["hm", "hm-mvgd-hm", "hm-mkl-hm"]:
411
+ # For histogram-based methods, we'll create temporary cropped images with just the masked regions
412
+
413
+ # Get the bounding box of the masked regions
414
+ source_min_y, source_min_x = np.min(source_coords[0]), np.min(source_coords[1])
415
+ source_max_y, source_max_x = np.max(source_coords[0]), np.max(source_coords[1])
416
+ target_min_y, target_min_x = np.min(target_coords[0]), np.min(target_coords[1])
417
+ target_max_y, target_max_x = np.max(target_coords[0]), np.max(target_coords[1])
418
+
419
+ # Create cropped images with just the masked regions
420
+ source_crop = source_np[source_min_y:source_max_y+1, source_min_x:source_max_x+1].copy()
421
+ target_crop = target_np[target_min_y:target_max_y+1, target_min_x:target_max_x+1].copy()
422
+
423
+ # Create cropped masks
424
+ source_mask_crop = source_mask_np[source_min_y:source_max_y+1, source_min_x:source_max_x+1]
425
+ target_mask_crop = target_mask_np[target_min_y:target_max_y+1, target_min_x:target_max_x+1]
426
+
427
+ # Apply the mask to the cropped images
428
+ # For non-masked areas, use the average color
429
+ source_avg_color = np.mean(source_pixels, axis=0)
430
+ target_avg_color = np.mean(target_pixels, axis=0)
431
+
432
+ for c in range(3):
433
+ source_crop[:, :, c] = np.where(source_mask_crop > 0.5, source_crop[:, :, c], source_avg_color[c])
434
+ target_crop[:, :, c] = np.where(target_mask_crop > 0.5, target_crop[:, :, c], target_avg_color[c])
435
+
436
+ try:
437
+ # Use the color matcher directly on the masked regions
438
+ matched_crop = cm.transfer(src=target_crop, ref=source_crop, method=method)
439
+
440
+ # Apply the matched colors back to the original image, only in the masked region
441
+ result_np = np.copy(target_np)
442
+
443
+ # Create a mapping from crop coordinates to original image coordinates
444
+ for i in range(target_crop.shape[0]):
445
+ for j in range(target_crop.shape[1]):
446
+ orig_i = target_min_y + i
447
+ orig_j = target_min_x + j
448
+ if orig_i < target_np.shape[0] and orig_j < target_np.shape[1] and target_mask_np[orig_i, orig_j] > 0.5:
449
+ result_np[orig_i, orig_j] = matched_crop[i, j]
450
+
451
+ # Convert back to tensor
452
+ result_tensor = torch.from_numpy(result_np).to(result_img.device)
453
+
454
+ # Blend with original based on factor
455
+ result_img = torch.lerp(result_img, result_tensor, self.factor)
456
+
457
+ except Exception as e:
458
+ # Fallback to AdaIN if color matcher fails
459
+ print(f"Color matcher failed for {method}, using fallback: {str(e)}")
460
+ result_img = self._apply_adain_to_region(
461
+ source_img,
462
+ target_img,
463
+ result_img,
464
+ source_mask_binary,
465
+ target_mask_binary
466
+ )
467
+
468
+ elif method == "coral":
469
+ # For CORAL method, extract masked regions and apply CORAL color transfer
470
+ try:
471
+ # Create masked versions of the images
472
+ source_masked = source_np.copy()
473
+ target_masked = target_np.copy()
474
+
475
+ # Apply masks - set non-masked areas to average color
476
+ source_avg_color = np.mean(source_pixels, axis=0)
477
+ target_avg_color = np.mean(target_pixels, axis=0)
478
+
479
+ for c in range(3):
480
+ source_masked[:, :, c] = np.where(source_mask_np > 0.5, source_masked[:, :, c], source_avg_color[c])
481
+ target_masked[:, :, c] = np.where(target_mask_np > 0.5, target_masked[:, :, c], target_avg_color[c])
482
+
483
+ # Convert to torch tensors and rearrange to [C, H, W]
484
+ source_tensor = torch.from_numpy(source_masked).permute(2, 0, 1).float()
485
+ target_tensor = torch.from_numpy(target_masked).permute(2, 0, 1).float()
486
+
487
+ # Apply CORAL color transfer
488
+ matched_tensor = coral(target_tensor, source_tensor) # target gets matched to source
489
+
490
+ # Convert back to [H, W, C] format
491
+ matched_np = matched_tensor.permute(1, 2, 0).numpy()
492
+
493
+ # Apply the matched colors back to the original image, only in the masked region
494
+ result_np = np.copy(target_np)
495
+ for c in range(3):
496
+ result_np[:, :, c] = np.where(target_mask_np > 0.5, matched_np[:, :, c], target_np[:, :, c])
497
+
498
+ # Convert back to tensor
499
+ result_tensor = torch.from_numpy(result_np).to(result_img.device)
500
+
501
+ # Blend with original based on factor
502
+ result_img = torch.lerp(result_img, result_tensor, self.factor)
503
+
504
+ except Exception as e:
505
+ # Fallback to AdaIN if CORAL fails
506
+ print(f"CORAL failed for {method}, using fallback: {str(e)}")
507
+ result_img = self._apply_adain_to_region(
508
+ source_img,
509
+ target_img,
510
+ result_img,
511
+ source_mask_binary,
512
+ target_mask_binary
513
+ )
514
+ else:
515
+ # Default to AdaIN for unsupported methods
516
+ result_img = self._apply_adain_to_region(
517
+ source_img,
518
+ target_img,
519
+ result_img,
520
+ source_mask_binary,
521
+ target_mask_binary
522
+ )
523
+
524
+ except Exception as e:
525
+ # If all fails, fallback to AdaIN
526
+ print(f"Error in color matching: {str(e)}, using AdaIN as fallback")
527
+ result_img = self._apply_adain_to_region(
528
+ source_img,
529
+ target_img,
530
+ result_img,
531
+ source_mask_binary,
532
+ target_mask_binary
533
+ )
534
+
535
+ return torch.clamp(result_img, 0.0, 1.0)
536
+
537
+ def _match_channel_statistics(self, source_channel, target_channel, result_channel, source_mask, target_mask):
538
+ """
539
+ Match the statistics of a single channel.
540
+
541
+ Args:
542
+ source_channel: Source channel [H,W] (reference for color matching)
543
+ target_channel: Target channel [H,W] (to be color matched)
544
+ result_channel: Result channel to modify [H,W]
545
+ source_mask: Binary mask for source [H,W]
546
+ target_mask: Binary mask for target [H,W]
547
+
548
+ Returns:
549
+ Modified result channel
550
+ """
551
+ # Count non-zero elements in masks
552
+ source_count = torch.sum(source_mask)
553
+ target_count = torch.sum(target_mask)
554
+
555
+ if source_count > 0 and target_count > 0:
556
+ # Calculate statistics only from masked regions
557
+ source_masked = source_channel * source_mask
558
+ target_masked = target_channel * target_mask
559
+
560
+ # Calculate mean
561
+ source_mean = torch.sum(source_masked) / source_count
562
+ target_mean = torch.sum(target_masked) / target_count
563
+
564
+ # Calculate variance
565
+ source_var = torch.sum(((source_channel - source_mean) * source_mask) ** 2) / source_count
566
+ target_var = torch.sum(((target_channel - target_mean) * target_mask) ** 2) / target_count
567
+
568
+ # Calculate std (add small epsilon to avoid division by zero)
569
+ source_std = torch.sqrt(source_var + 1e-8)
570
+ target_std = torch.sqrt(target_var + 1e-8)
571
+
572
+ # Apply AdaIN to the masked region
573
+ normalized = ((target_channel - target_mean) / target_std) * source_std + source_mean
574
+
575
+ # Blend with original based on factor
576
+ result = torch.lerp(target_channel, normalized, self.factor)
577
+
578
+ return result
579
+
580
+ return result_channel
581
+
582
+ def _install_package(self, package_name):
583
+ """Install a package using pip."""
584
+ import subprocess
585
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
586
+
587
+
588
+ def create_comparison_figure(original_img, matched_img, title="Color Matching Comparison"):
589
+ """
590
+ Create a matplotlib figure with the original and color-matched images.
591
+
592
+ Args:
593
+ original_img: Original PIL Image
594
+ matched_img: Color-matched PIL Image
595
+ title: Title for the figure
596
+
597
+ Returns:
598
+ matplotlib Figure
599
+ """
600
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
601
+
602
+ ax1.imshow(original_img)
603
+ ax1.set_title("Original")
604
+ ax1.axis('off')
605
+
606
+ ax2.imshow(matched_img)
607
+ ax2.set_title("Color Matched")
608
+ ax2.axis('off')
609
+
610
+ plt.suptitle(title)
611
+ plt.tight_layout()
612
+
613
+ return fig
614
+
615
+ def coral(source, target):
616
+ """
617
+ CORAL (Color Transfer using Correlated Color Temperature) implementation.
618
+ Based on the original ColorMatchImage approach.
619
+
620
+ Args:
621
+ source: Source image tensor [C, H, W] (to be color matched)
622
+ target: Target image tensor [C, H, W] (reference for color matching)
623
+
624
+ Returns:
625
+ Color-matched source image tensor [C, H, W]
626
+ """
627
+ # Ensure tensors are float
628
+ source = source.float()
629
+ target = target.float()
630
+
631
+ # Reshape to [C, N] where N is number of pixels
632
+ C, H, W = source.shape
633
+ source_flat = source.view(C, -1) # [C, H*W]
634
+ target_flat = target.view(C, -1) # [C, H*W]
635
+
636
+ # Compute means
637
+ source_mean = torch.mean(source_flat, dim=1, keepdim=True) # [C, 1]
638
+ target_mean = torch.mean(target_flat, dim=1, keepdim=True) # [C, 1]
639
+
640
+ # Center the data
641
+ source_centered = source_flat - source_mean # [C, H*W]
642
+ target_centered = target_flat - target_mean # [C, H*W]
643
+
644
+ # Compute covariance matrices
645
+ N = source_centered.shape[1]
646
+ source_cov = torch.mm(source_centered, source_centered.t()) / (N - 1) # [C, C]
647
+ target_cov = torch.mm(target_centered, target_centered.t()) / (N - 1) # [C, C]
648
+
649
+ # Add small epsilon to diagonal for numerical stability
650
+ eps = 1e-5
651
+ source_cov += eps * torch.eye(C, device=source.device)
652
+ target_cov += eps * torch.eye(C, device=source.device)
653
+
654
+ try:
655
+ # Compute the transformation matrix using Cholesky decomposition
656
+ # This is more stable than eigendecomposition for positive definite matrices
657
+
658
+ # Cholesky decomposition: A = L * L^T
659
+ source_chol = torch.linalg.cholesky(source_cov) # Lower triangular
660
+ target_chol = torch.linalg.cholesky(target_cov) # Lower triangular
661
+
662
+ # Compute the transformation matrix
663
+ # We want to transform source covariance to target covariance
664
+ # Transform = target_chol * source_chol^(-1)
665
+ source_chol_inv = torch.linalg.inv(source_chol)
666
+ transform_matrix = torch.mm(target_chol, source_chol_inv)
667
+
668
+ # Apply transformation: result = transform_matrix * (source - source_mean) + target_mean
669
+ result_centered = torch.mm(transform_matrix, source_centered)
670
+ result_flat = result_centered + target_mean
671
+
672
+ # Reshape back to original shape
673
+ result = result_flat.view(C, H, W)
674
+
675
+ # Clamp to valid range
676
+ result = torch.clamp(result, 0.0, 1.0)
677
+
678
+ return result
679
+
680
+ except Exception as e:
681
+ # Fallback to simple mean/std matching if Cholesky fails
682
+ print(f"CORAL Cholesky failed, using simple statistics matching: {e}")
683
+
684
+ # Simple per-channel statistics matching
685
+ source_std = torch.std(source_centered, dim=1, keepdim=True)
686
+ target_std = torch.std(target_centered, dim=1, keepdim=True)
687
+
688
+ # Avoid division by zero
689
+ source_std = torch.clamp(source_std, min=eps)
690
+
691
+ # Apply simple transformation: (source - source_mean) / source_std * target_std + target_mean
692
+ result_flat = (source_centered / source_std) * target_std + target_mean
693
+ result = result_flat.view(C, H, W)
694
+
695
+ # Clamp to valid range
696
+ result = torch.clamp(result, 0.0, 1.0)
697
+
698
+ return result
core.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core part of LaDeco v2
2
+
3
+ Example usage:
4
+ >>> from core import Ladeco
5
+ >>> from PIL import Image
6
+ >>> from pathlib import Path
7
+ >>>
8
+ >>> # predict
9
+ >>> ldc = Ladeco()
10
+ >>> imgs = (thing for thing in Path("example").glob("*.jpg"))
11
+ >>> out = ldc.predict(imgs)
12
+ >>>
13
+ >>> # output - visualization
14
+ >>> segs = out.visualize(level=2)
15
+ >>> segs[0].image.show()
16
+ >>>
17
+ >>> # output - element area
18
+ >>> area = out.area()
19
+ >>> area[0]
20
+ {"fid": "example/.jpg", "l1_nature": 0.673, "l1_man_made": 0.241, ...}
21
+ """
22
+ from matplotlib.patches import Rectangle
23
+ from pathlib import Path
24
+ from PIL import Image
25
+ from transformers import AutoModelForUniversalSegmentation, AutoProcessor
26
+ import math
27
+ import matplotlib as mpl
28
+ import matplotlib.pyplot as plt
29
+ import numpy as np
30
+ import torch
31
+ from functools import lru_cache
32
+ from matplotlib.figure import Figure
33
+ import numpy.typing as npt
34
+ from typing import Iterable, NamedTuple, Generator
35
+ from tqdm import tqdm
36
+
37
+
38
+ class LadecoVisualization(NamedTuple):
39
+ filename: str
40
+ image: Figure
41
+
42
+
43
+ class Ladeco:
44
+
45
+ def __init__(self,
46
+ model_name: str = "shi-labs/oneformer_ade20k_swin_large",
47
+ area_threshold: float = 0.01,
48
+ device: str | None = None,
49
+ ):
50
+ if device is None:
51
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ else:
53
+ self.device = device
54
+
55
+ self.processor = AutoProcessor.from_pretrained(model_name)
56
+ self.model = AutoModelForUniversalSegmentation.from_pretrained(model_name).to(self.device)
57
+
58
+ self.area_threshold = area_threshold
59
+
60
+ self.ade20k_labels = {
61
+ name.strip(): int(idx)
62
+ for name, idx in self.model.config.label2id.items()
63
+ }
64
+ self.ladeco2ade20k: dict[str, tuple[int]] = _get_ladeco_labels(self.ade20k_labels)
65
+
66
+ def predict(
67
+ self, image_paths: str | Path | Iterable[str | Path], show_progress: bool = False
68
+ ) -> "LadecoOutput":
69
+ if isinstance(image_paths, (str, Path)):
70
+ imgpaths = [image_paths]
71
+ else:
72
+ imgpaths = list(image_paths)
73
+
74
+ images = (
75
+ Image.open(img_path).convert("RGB")
76
+ for img_path in imgpaths
77
+ )
78
+
79
+ # batch inference functionality of OneFormer is broken
80
+ masks: list[torch.Tensor] = []
81
+ for img in tqdm(images, total=len(imgpaths), desc="Segmenting", disable=not show_progress):
82
+ samples = self.processor(
83
+ images=img, task_inputs=["semantic"], return_tensors="pt"
84
+ ).to(self.device)
85
+
86
+ with torch.no_grad():
87
+ outputs = self.model(**samples)
88
+
89
+ masks.append(
90
+ self.processor.post_process_semantic_segmentation(outputs)[0]
91
+ )
92
+
93
+ return LadecoOutput(imgpaths, masks, self.ladeco2ade20k, self.area_threshold)
94
+
95
+
96
+ class LadecoOutput:
97
+
98
+ def __init__(
99
+ self,
100
+ filenames: list[str | Path],
101
+ masks: torch.Tensor,
102
+ ladeco2ade: dict[str, tuple[int]],
103
+ threshold: float,
104
+ ):
105
+ self.filenames = filenames
106
+ self.masks = masks
107
+ self.ladeco2ade: dict[str, tuple[int]] = ladeco2ade
108
+ self.ade2ladeco: dict[int, str] = {
109
+ idx: label
110
+ for label, indices in self.ladeco2ade.items()
111
+ for idx in indices
112
+ }
113
+ self.threshold = threshold
114
+
115
+ def visualize(self, level: int) -> list[LadecoVisualization]:
116
+ return list(self.ivisualize(level))
117
+
118
+ def ivisualize(self, level: int) -> Generator[LadecoVisualization, None, None]:
119
+ colormaps = self.color_map(level)
120
+ labelnames = [name for name in self.ladeco2ade if name.startswith(f"l{level}")]
121
+
122
+ for fname, mask in zip(self.filenames, self.masks):
123
+ size = mask.shape + (3,) # (H, W, RGB)
124
+ vis = torch.zeros(size, dtype=torch.uint8)
125
+ for name in labelnames:
126
+ for idx in self.ladeco2ade[name]:
127
+ color = torch.tensor(colormaps[name] * 255, dtype=torch.uint8)
128
+ vis[mask == idx] = color
129
+
130
+ with Image.open(fname) as img:
131
+ target_size = img.size
132
+ vis = Image.fromarray(vis.numpy(), mode="RGB").resize(target_size)
133
+
134
+ fig, ax = plt.subplots()
135
+ ax.imshow(vis)
136
+ ax.axis('off')
137
+
138
+ yield LadecoVisualization(filename=str(fname), image=fig)
139
+
140
+ def area(self) -> list[dict[str, float | str]]:
141
+ return list(self.iarea())
142
+
143
+ def iarea(self) -> Generator[dict[str, float | str], None, None]:
144
+ n_label_ADE20k = 150
145
+ for filename, mask in zip(self.filenames, self.masks):
146
+ ade_ratios = torch.tensor([(mask == i).count_nonzero() / mask.numel() for i in range(n_label_ADE20k)])
147
+ #breakpoint()
148
+ ldc_ratios: dict[str, float] = {
149
+ label: round(ade_ratios[list(ade_indices)].sum().item(), 4)
150
+ for label, ade_indices in self.ladeco2ade.items()
151
+ }
152
+ ldc_ratios: dict[str, float] = {
153
+ label: 0 if ratio < self.threshold else ratio
154
+ for label, ratio in ldc_ratios.items()
155
+ }
156
+ others = round(1 - ldc_ratios["l1_nature"] - ldc_ratios["l1_man_made"], 4)
157
+ nfi = round(ldc_ratios["l1_nature"]/ (ldc_ratios["l1_nature"] + ldc_ratios.get("l1_man_made", 0) + 1e-6), 4)
158
+
159
+ yield {
160
+ "fid": str(filename), **ldc_ratios, "others": others, "LC_NFI": nfi,
161
+ }
162
+
163
+ def color_map(self, level: int) -> dict[str, npt.NDArray[np.float64]]:
164
+ "returns {'label_name': (R, G, B), ...}, where (R, G, B) in range [0, 1]"
165
+ labels = [
166
+ name for name in self.ladeco2ade.keys() if name.startswith(f"l{level}")
167
+ ]
168
+ if len(labels) == 0:
169
+ raise RuntimeError(
170
+ f"LaDeco only has 4 levels in 1, 2, 3, 4. You assigned {level}."
171
+ )
172
+ colormap = mpl.colormaps["viridis"].resampled(len(labels)).colors[:, :-1]
173
+ # [:, :-1]: discard alpha channel
174
+ return {name: color for name, color in zip(labels, colormap)}
175
+
176
+ def color_legend(self, level: int) -> Figure:
177
+ colors = self.color_map(level)
178
+
179
+ match level:
180
+ case 1:
181
+ ncols = 1
182
+ case 2:
183
+ ncols = 1
184
+ case 3:
185
+ ncols = 2
186
+ case 4:
187
+ ncols = 5
188
+
189
+ cell_width = 212
190
+ cell_height = 22
191
+ swatch_width = 48
192
+ margin = 12
193
+
194
+ nrows = math.ceil(len(colors) / ncols)
195
+
196
+ width = cell_width * ncols + 2 * margin
197
+ height = cell_height * nrows + 2 * margin
198
+ dpi = 72
199
+
200
+ fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
201
+ fig.subplots_adjust(margin/width, margin/height,
202
+ (width-margin)/width, (height-margin*2)/height)
203
+ ax.set_xlim(0, cell_width * ncols)
204
+ ax.set_ylim(cell_height * (nrows-0.5), -cell_height/2.)
205
+ ax.yaxis.set_visible(False)
206
+ ax.xaxis.set_visible(False)
207
+ ax.set_axis_off()
208
+
209
+ for i, name in enumerate(colors):
210
+ row = i % nrows
211
+ col = i // nrows
212
+ y = row * cell_height
213
+
214
+ swatch_start_x = cell_width * col
215
+ text_pos_x = cell_width * col + swatch_width + 7
216
+
217
+ ax.text(text_pos_x, y, name, fontsize=14,
218
+ horizontalalignment='left',
219
+ verticalalignment='center')
220
+
221
+ ax.add_patch(
222
+ Rectangle(xy=(swatch_start_x, y-9), width=swatch_width,
223
+ height=18, facecolor=colors[name], edgecolor='0.7')
224
+ )
225
+
226
+ ax.set_title(f"LaDeco Color Legend - Level {level}")
227
+
228
+ return fig
229
+
230
+
231
+ def _get_ladeco_labels(ade20k: dict[str, int]) -> dict[str, tuple[int]]:
232
+ labels = {
233
+ # level 4 labels
234
+ # under l3_architecture
235
+ "l4_hovel": (ade20k["hovel, hut, hutch, shack, shanty"],),
236
+ "l4_building": (ade20k["building"], ade20k["house"]),
237
+ "l4_skyscraper": (ade20k["skyscraper"],),
238
+ "l4_tower": (ade20k["tower"],),
239
+ # under l3_archi_parts
240
+ "l4_step": (ade20k["step, stair"],),
241
+ "l4_canopy": (ade20k["awning, sunshade, sunblind"], ade20k["canopy"]),
242
+ "l4_arcade": (ade20k["arcade machine"],),
243
+ "l4_door": (ade20k["door"],),
244
+ "l4_window": (ade20k["window"],),
245
+ "l4_wall": (ade20k["wall"],),
246
+ # under l3_roadway
247
+ "l4_stairway": (ade20k["stairway, staircase"],),
248
+ "l4_sidewalk": (ade20k["sidewalk, pavement"],),
249
+ "l4_road": (ade20k["road, route"],),
250
+ # under l3_furniture
251
+ "l4_sculpture": (ade20k["sculpture"],),
252
+ "l4_flag": (ade20k["flag"],),
253
+ "l4_can": (ade20k["trash can"],),
254
+ "l4_chair": (ade20k["chair"],),
255
+ "l4_pot": (ade20k["pot"],),
256
+ "l4_booth": (ade20k["booth"],),
257
+ "l4_streetlight": (ade20k["street lamp"],),
258
+ "l4_bench": (ade20k["bench"],),
259
+ "l4_fence": (ade20k["fence"],),
260
+ "l4_table": (ade20k["table"],),
261
+ # under l3_vehicle
262
+ "l4_bike": (ade20k["bicycle"],),
263
+ "l4_motorbike": (ade20k["minibike, motorbike"],),
264
+ "l4_van": (ade20k["van"],),
265
+ "l4_truck": (ade20k["truck"],),
266
+ "l4_bus": (ade20k["bus"],),
267
+ "l4_car": (ade20k["car"],),
268
+ # under l3_sign
269
+ "l4_traffic_sign": (ade20k["traffic light"],),
270
+ "l4_poster": (ade20k["poster, posting, placard, notice, bill, card"],),
271
+ "l4_signboard": (ade20k["signboard, sign"],),
272
+ # under l3_vert_land
273
+ "l4_rock": (ade20k["rock, stone"],),
274
+ "l4_hill": (ade20k["hill"],),
275
+ "l4_mountain": (ade20k["mountain, mount"],),
276
+ # under l3_hori_land
277
+ "l4_ground": (ade20k["earth, ground"], ade20k["land, ground, soil"]),
278
+ "l4_field": (ade20k["field"],),
279
+ "l4_sand": (ade20k["sand"],),
280
+ "l4_dirt": (ade20k["dirt track"],),
281
+ "l4_path": (ade20k["path"],),
282
+ # under l3_flower
283
+ "l4_flower": (ade20k["flower"],),
284
+ # under l3_grass
285
+ "l4_grass": (ade20k["grass"],),
286
+ # under l3_shrub
287
+ "l4_flora": (ade20k["plant"],),
288
+ # under l3_arbor
289
+ "l4_tree": (ade20k["tree"],),
290
+ "l4_palm": (ade20k["palm, palm tree"],),
291
+ # under l3_hori_water
292
+ "l4_lake": (ade20k["lake"],),
293
+ "l4_pool": (ade20k["pool"],),
294
+ "l4_river": (ade20k["river"],),
295
+ "l4_sea": (ade20k["sea"],),
296
+ "l4_water": (ade20k["water"],),
297
+ # under l3_vert_water
298
+ "l4_fountain": (ade20k["fountain"],),
299
+ "l4_waterfall": (ade20k["falls"],),
300
+ # under l3_human
301
+ "l4_person": (ade20k["person"],),
302
+ # under l3_animal
303
+ "l4_animal": (ade20k["animal"],),
304
+ # under l3_sky
305
+ "l4_sky": (ade20k["sky"],),
306
+ }
307
+ labels = labels | {
308
+ # level 3 labels
309
+ # under l2_landform
310
+ "l3_hori_land": labels["l4_ground"] + labels["l4_field"] + labels["l4_sand"] + labels["l4_dirt"] + labels["l4_path"],
311
+ "l3_vert_land": labels["l4_mountain"] + labels["l4_hill"] + labels["l4_rock"],
312
+ # under l2_vegetation
313
+ "l3_woody_plant": labels["l4_tree"] + labels["l4_palm"] + labels["l4_flora"],
314
+ "l3_herb_plant": labels["l4_grass"],
315
+ "l3_flower": labels["l4_flower"],
316
+ # under l2_water
317
+ "l3_hori_water": labels["l4_water"] + labels["l4_sea"] + labels["l4_river"] + labels["l4_pool"] + labels["l4_lake"],
318
+ "l3_vert_water": labels["l4_fountain"] + labels["l4_waterfall"],
319
+ # under l2_bio
320
+ "l3_human": labels["l4_person"],
321
+ "l3_animal": labels["l4_animal"],
322
+ # under l2_sky
323
+ "l3_sky": labels["l4_sky"],
324
+ # under l2_archi
325
+ "l3_architecture": labels["l4_building"] + labels["l4_hovel"] + labels["l4_tower"] + labels["l4_skyscraper"],
326
+ "l3_archi_parts": labels["l4_wall"] + labels["l4_window"] + labels["l4_door"] + labels["l4_arcade"] + labels["l4_canopy"] + labels["l4_step"],
327
+ # under l2_street
328
+ "l3_roadway": labels["l4_road"] + labels["l4_sidewalk"] + labels["l4_stairway"],
329
+ "l3_furniture": labels["l4_table"] + labels["l4_chair"] + labels["l4_fence"] + labels["l4_bench"] + labels["l4_streetlight"] + labels["l4_booth"] + labels["l4_pot"] + labels["l4_can"] + labels["l4_flag"] + labels["l4_sculpture"],
330
+ "l3_vehicle": labels["l4_car"] + labels["l4_bus"] + labels["l4_truck"] + labels["l4_van"] + labels["l4_motorbike"] + labels["l4_bike"],
331
+ "l3_sign": labels["l4_signboard"] + labels["l4_poster"] + labels["l4_traffic_sign"],
332
+ }
333
+ labels = labels | {
334
+ # level 2 labels
335
+ # under l1_nature
336
+ "l2_landform": labels["l3_hori_land"] + labels["l3_vert_land"],
337
+ "l2_vegetation": labels["l3_woody_plant"] + labels["l3_herb_plant"] + labels["l3_flower"],
338
+ "l2_water": labels["l3_hori_water"] + labels["l3_vert_water"],
339
+ "l2_bio": labels["l3_human"] + labels["l3_animal"],
340
+ "l2_sky": labels["l3_sky"],
341
+ # under l1_man_made
342
+ "l2_archi": labels["l3_architecture"] + labels["l3_archi_parts"],
343
+ "l2_street": labels["l3_roadway"] + labels["l3_furniture"] + labels["l3_vehicle"] + labels["l3_sign"],
344
+ }
345
+ labels = labels | {
346
+ # level 1 labels
347
+ "l1_nature": labels["l2_landform"] + labels["l2_vegetation"] + labels["l2_water"] + labels["l2_bio"] + labels["l2_sky"],
348
+ "l1_man_made": labels["l2_archi"] + labels["l2_street"],
349
+ }
350
+ return labels
351
+
352
+
353
+ if __name__ == "__main__":
354
+ ldc = Ladeco()
355
+ image = Path("images") / "canyon_3011_00002354.jpg"
356
+ out = ldc.predict(image)
examples/beach.jpg ADDED
examples/field.jpg ADDED
examples/sky.jpg ADDED
face_comparison.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import tempfile
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ import json
9
+ from typing import Dict, List, Tuple, Optional
10
+ import logging
11
+
12
+ # Set up logging to suppress DeepFace warnings
13
+ logging.getLogger('deepface').setLevel(logging.ERROR)
14
+
15
+ try:
16
+ from deepface import DeepFace
17
+ DEEPFACE_AVAILABLE = True
18
+ except ImportError:
19
+ DEEPFACE_AVAILABLE = False
20
+ print("Warning: DeepFace not available. Face comparison will be disabled.")
21
+
22
+
23
+ def run_deepface_in_subprocess(img1_path: str, img2_path: str) -> dict:
24
+ """
25
+ Run DeepFace verification in a separate process to avoid TensorFlow conflicts.
26
+ """
27
+ script_content = f'''
28
+ import sys
29
+ import json
30
+ from deepface import DeepFace
31
+
32
+ try:
33
+ result = DeepFace.verify(img1_path="{img1_path}", img2_path="{img2_path}")
34
+ print(json.dumps(result))
35
+ except Exception as e:
36
+ print(json.dumps({{"error": str(e)}}))
37
+ '''
38
+
39
+ try:
40
+ # Write the script to a temporary file
41
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as script_file:
42
+ script_file.write(script_content)
43
+ script_path = script_file.name
44
+
45
+ # Run the script in a subprocess
46
+ result = subprocess.run([sys.executable, script_path],
47
+ capture_output=True, text=True, timeout=30)
48
+
49
+ # Clean up the script file
50
+ os.unlink(script_path)
51
+
52
+ if result.returncode == 0:
53
+ return json.loads(result.stdout.strip())
54
+ else:
55
+ return {"error": f"Subprocess failed: {result.stderr}"}
56
+
57
+ except Exception as e:
58
+ return {"error": str(e)}
59
+
60
+
61
+ class FaceComparison:
62
+ """
63
+ Handles face detection and comparison on full images.
64
+ Only responsible for determining if faces match - does not handle segmentation.
65
+ """
66
+
67
+ def __init__(self):
68
+ """
69
+ Initialize face comparison using DeepFace's default verification threshold.
70
+ """
71
+ self.available = DEEPFACE_AVAILABLE
72
+ self.face_match_result = None
73
+ self.comparison_log = []
74
+
75
+ def extract_faces(self, image_path: str) -> List[np.ndarray]:
76
+ """
77
+ Extract faces from the full image using DeepFace (exactly like the working script).
78
+
79
+ Args:
80
+ image_path: Path to the image
81
+
82
+ Returns:
83
+ List of face arrays
84
+ """
85
+ if not self.available:
86
+ return []
87
+
88
+ try:
89
+ faces = DeepFace.extract_faces(img_path=image_path, detector_backend='opencv')
90
+ if len(faces) == 0:
91
+ return []
92
+ return [f['face'] for f in faces]
93
+
94
+ except Exception as e:
95
+ print(f"Error extracting faces from {image_path}: {str(e)}")
96
+ return []
97
+
98
+ def compare_all_faces(self, image1_path: str, image2_path: str) -> Tuple[bool, List[str]]:
99
+ """
100
+ Compare all faces between two images (exactly like the working script).
101
+
102
+ Args:
103
+ image1_path: Path to first image
104
+ image2_path: Path to second image
105
+
106
+ Returns:
107
+ Tuple of (match_found, log_messages)
108
+ """
109
+ if not self.available:
110
+ return False, ["Face comparison not available - DeepFace not installed"]
111
+
112
+ log_messages = []
113
+
114
+ try:
115
+ faces1 = self.extract_faces(image1_path)
116
+ faces2 = self.extract_faces(image2_path)
117
+
118
+ match_found = False
119
+
120
+ log_messages.append(f"Found {len(faces1)} face(s) in Image 1 and {len(faces2)} face(s) in Image 2")
121
+
122
+ if len(faces1) == 0 or len(faces2) == 0:
123
+ log_messages.append("❌ No faces found in one or both images")
124
+ return False, log_messages
125
+
126
+ for idx1, face1 in enumerate(faces1):
127
+ for idx2, face2 in enumerate(faces2):
128
+ # Create temporary files instead of permanent ones (exactly like original)
129
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp1, \
130
+ tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp2:
131
+
132
+ # Convert faces to uint8 and save temporarily (exactly like original)
133
+ face1_uint8 = (face1 * 255).astype(np.uint8)
134
+ face2_uint8 = (face2 * 255).astype(np.uint8)
135
+
136
+ cv2.imwrite(temp1.name, cv2.cvtColor(face1_uint8, cv2.COLOR_RGB2BGR))
137
+ cv2.imwrite(temp2.name, cv2.cvtColor(face2_uint8, cv2.COLOR_RGB2BGR))
138
+
139
+ try:
140
+ # Try subprocess approach first to avoid TensorFlow conflicts
141
+ result = run_deepface_in_subprocess(temp1.name, temp2.name)
142
+
143
+ if "error" in result:
144
+ # If subprocess fails, try direct approach
145
+ result = DeepFace.verify(img1_path=temp1.name, img2_path=temp2.name)
146
+
147
+ similarity = 1 - result['distance']
148
+
149
+ log_messages.append(f"Comparing Face1-{idx1} to Face2-{idx2} | Similarity: {similarity:.3f}")
150
+
151
+ if result['verified']:
152
+ log_messages.append(f"✅ Match found between Face1-{idx1} and Face2-{idx2}")
153
+ match_found = True
154
+ else:
155
+ log_messages.append(f"❌ No match between Face1-{idx1} and Face2-{idx2}")
156
+
157
+ except Exception as e:
158
+ log_messages.append(f"❌ Error comparing Face1-{idx1} to Face2-{idx2}: {str(e)}")
159
+
160
+ # Clean up temporary files immediately
161
+ try:
162
+ os.unlink(temp1.name)
163
+ os.unlink(temp2.name)
164
+ except:
165
+ pass
166
+
167
+ if not match_found:
168
+ log_messages.append("❌ No matching faces found between the two images.")
169
+
170
+ return match_found, log_messages
171
+
172
+ except Exception as e:
173
+ log_messages.append(f"Error in face comparison: {str(e)}")
174
+ return False, log_messages
175
+
176
+ def run_face_comparison(self, img1_path: str, img2_path: str) -> Tuple[bool, List[str]]:
177
+ """
178
+ Run face comparison and store results for later use.
179
+
180
+ Args:
181
+ img1_path: Path to first image
182
+ img2_path: Path to second image
183
+
184
+ Returns:
185
+ Tuple of (faces_match, log_messages)
186
+ """
187
+ faces_match, log_messages = self.compare_all_faces(img1_path, img2_path)
188
+
189
+ # Store results for later filtering
190
+ self.face_match_result = faces_match
191
+ self.comparison_log = log_messages
192
+
193
+ return faces_match, log_messages
194
+
195
+ def filter_human_regions_by_face_match(self, masks: Dict[str, np.ndarray]) -> Tuple[Dict[str, np.ndarray], List[str]]:
196
+ """
197
+ Filter human regions based on previously computed face comparison results.
198
+ This only includes/excludes human regions - fine-grained segmentation happens elsewhere.
199
+
200
+ Args:
201
+ masks: Dictionary of semantic masks
202
+
203
+ Returns:
204
+ Tuple of (filtered_masks, log_messages)
205
+ """
206
+ if not self.available:
207
+ return masks, ["Face comparison not available - DeepFace not installed"]
208
+
209
+ if self.face_match_result is None:
210
+ return masks, ["No face comparison results available. Run face comparison first."]
211
+
212
+ filtered_masks = {}
213
+ log_messages = []
214
+
215
+ # Look for human-specific regions (l3_human, not l2_bio which includes animals)
216
+ human_labels = [label for label in masks.keys() if 'l3_human' in label.lower()]
217
+ bio_labels = [label for label in masks.keys() if 'l2_bio' in label.lower()]
218
+
219
+ log_messages.append(f"Found human labels: {human_labels}")
220
+ log_messages.append(f"Found bio labels: {bio_labels}")
221
+
222
+ # Include all non-human regions regardless of face matching
223
+ for label, mask in masks.items():
224
+ if not any(human_term in label.lower() for human_term in ['l3_human', 'l2_bio']):
225
+ filtered_masks[label] = mask
226
+ log_messages.append(f"✅ Including non-human region: {label}")
227
+ else:
228
+ log_messages.append(f"🔍 Found human/bio region: {label}")
229
+
230
+ # Handle human regions based on face matching results
231
+ if self.face_match_result:
232
+ log_messages.append("✅ Faces matched! Including human regions in color matching.")
233
+ # Include human regions since faces matched
234
+ for label in human_labels + bio_labels:
235
+ if label in masks:
236
+ filtered_masks[label] = masks[label]
237
+ log_messages.append(f"✅ Including human region (faces matched): {label}")
238
+ else:
239
+ log_messages.append("❌ No face match found. Excluding human regions from color matching.")
240
+ # Don't include human regions since faces didn't match
241
+ for label in human_labels + bio_labels:
242
+ log_messages.append(f"❌ Excluding human region (no face match): {label}")
243
+
244
+ log_messages.append(f"📊 Final filtered masks: {list(filtered_masks.keys())}")
245
+
246
+ return filtered_masks, log_messages
folder_paths.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Simple folder_paths module to replace ComfyUI's folder_paths
4
+ current_dir = os.path.dirname(os.path.abspath(__file__))
5
+ models_dir = os.path.join(current_dir, "models")
6
+
7
+ # Model folder mappings
8
+ model_folder_paths = {}
9
+
10
+ def add_model_folder_path(name, path):
11
+ """Add a model folder path."""
12
+ model_folder_paths[name] = path
13
+ os.makedirs(path, exist_ok=True)
14
+
15
+ def get_full_path(dirname, filename):
16
+ """Get the full path for a model file."""
17
+ if dirname in model_folder_paths:
18
+ return os.path.join(model_folder_paths[dirname], filename)
19
+ return os.path.join(models_dir, dirname, filename)
20
+
21
+ # Initialize default paths
22
+ os.makedirs(models_dir, exist_ok=True)
human_parts_segmentation.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image, ImageFilter
5
+ import cv2
6
+ import requests
7
+ from typing import Dict, List, Tuple, Optional
8
+ import onnxruntime as ort
9
+
10
+ # Human parts labels based on CCIHP dataset - consistent with latest repo
11
+ HUMAN_PARTS_LABELS = {
12
+ 0: ("background", "Background"),
13
+ 1: ("hat", "Hat: Hat, helmet, cap, hood, veil, headscarf, part covering the skull and hair of a hood/balaclava, crown…"),
14
+ 2: ("hair", "Hair"),
15
+ 3: ("glove", "Glove"),
16
+ 4: ("glasses", "Sunglasses/Glasses: Sunglasses, eyewear, protective glasses…"),
17
+ 5: ("upper_clothes", "UpperClothes: T-shirt, shirt, tank top, sweater under a coat, top of a dress…"),
18
+ 6: ("face_mask", "Face Mask: Protective mask, surgical mask, carnival mask, facial part of a balaclava, visor of a helmet…"),
19
+ 7: ("coat", "Coat: Coat, jacket worn without anything on it, vest with nothing on it, a sweater with nothing on it…"),
20
+ 8: ("socks", "Socks"),
21
+ 9: ("pants", "Pants: Pants, shorts, tights, leggings, swimsuit bottoms… (clothing with 2 legs)"),
22
+ 10: ("torso-skin", "Torso-skin"),
23
+ 11: ("scarf", "Scarf: Scarf, bow tie, tie…"),
24
+ 12: ("skirt", "Skirt: Skirt, kilt, bottom of a dress…"),
25
+ 13: ("face", "Face"),
26
+ 14: ("left-arm", "Left-arm (naked part)"),
27
+ 15: ("right-arm", "Right-arm (naked part)"),
28
+ 16: ("left-leg", "Left-leg (naked part)"),
29
+ 17: ("right-leg", "Right-leg (naked part)"),
30
+ 18: ("left-shoe", "Left-shoe"),
31
+ 19: ("right-shoe", "Right-shoe"),
32
+ 20: ("bag", "Bag: Backpack, shoulder bag, fanny pack… (bag carried on oneself"),
33
+ 21: ("", "Others: Jewelry, tags, bibs, belts, ribbons, pins, head decorations, headphones…"),
34
+ }
35
+
36
+ # Model configuration - updated paths consistent with new repos
37
+ current_dir = os.path.dirname(os.path.abspath(__file__))
38
+ models_dir = os.path.join(current_dir, "models")
39
+ models_dir_path = os.path.join(models_dir, "onnx", "human-parts")
40
+ model_url = "https://huggingface.co/Metal3d/deeplabv3p-resnet50-human/resolve/main/deeplabv3p-resnet50-human.onnx"
41
+ model_name = "deeplabv3p-resnet50-human.onnx"
42
+ model_path = os.path.join(models_dir_path, model_name)
43
+
44
+
45
+ def get_class_index(class_name: str) -> int:
46
+ """Return the index of the class name in the model."""
47
+ if class_name == "":
48
+ return -1
49
+
50
+ for key, value in HUMAN_PARTS_LABELS.items():
51
+ if value[0] == class_name:
52
+ return key
53
+ return -1
54
+
55
+
56
+ def download_model(model_url: str, model_path: str) -> bool:
57
+ """Download the human parts segmentation model if not present - improved version."""
58
+ if os.path.exists(model_path):
59
+ return True
60
+
61
+ try:
62
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
63
+ print(f"Downloading human parts model to {model_path}...")
64
+
65
+ response = requests.get(model_url, stream=True)
66
+ response.raise_for_status()
67
+
68
+ total_size = int(response.headers.get('content-length', 0))
69
+ downloaded = 0
70
+
71
+ with open(model_path, 'wb') as f:
72
+ for chunk in response.iter_content(chunk_size=8192):
73
+ f.write(chunk)
74
+ downloaded += len(chunk)
75
+ if total_size > 0:
76
+ percent = (downloaded / total_size) * 100
77
+ print(f"\rDownload progress: {percent:.1f}%", end='', flush=True)
78
+
79
+ print("\n✅ Model download completed")
80
+ return True
81
+
82
+ except Exception as e:
83
+ print(f"\n❌ Error downloading model: {e}")
84
+ return False
85
+
86
+
87
+ def get_human_parts_mask(image: torch.Tensor, model: ort.InferenceSession, rotation: float = 0, **kwargs) -> Tuple[torch.Tensor, int]:
88
+ """
89
+ Generate human parts mask using the ONNX model - improved version.
90
+
91
+ Args:
92
+ image: Input image tensor
93
+ model: ONNX inference session
94
+ rotation: Rotation angle (not used currently)
95
+ **kwargs: Part-specific enable flags
96
+
97
+ Returns:
98
+ Tuple of (mask_tensor, score)
99
+ """
100
+ image = image.squeeze(0)
101
+ image_np = image.numpy() * 255
102
+
103
+ pil_image = Image.fromarray(image_np.astype(np.uint8))
104
+ original_size = pil_image.size
105
+
106
+ # Resize to 512x512 as the model expects
107
+ pil_image = pil_image.resize((512, 512))
108
+ center = (256, 256)
109
+
110
+ if rotation != 0:
111
+ pil_image = pil_image.rotate(rotation, center=center)
112
+
113
+ # Normalize the image
114
+ image_np = np.array(pil_image).astype(np.float32) / 127.5 - 1
115
+ image_np = np.expand_dims(image_np, axis=0)
116
+
117
+ # Use the ONNX model to get the segmentation
118
+ input_name = model.get_inputs()[0].name
119
+ output_name = model.get_outputs()[0].name
120
+ result = model.run([output_name], {input_name: image_np})
121
+ result = np.array(result[0]).argmax(axis=3).squeeze(0)
122
+
123
+ # Debug: Check what classes the model actually detected
124
+ unique_classes = np.unique(result)
125
+
126
+ score = 0
127
+ mask = np.zeros_like(result)
128
+
129
+ # Combine masks for enabled classes
130
+ for class_name, enabled in kwargs.items():
131
+ class_index = get_class_index(class_name)
132
+ if enabled and class_index != -1:
133
+ detected = result == class_index
134
+ mask[detected] = 255
135
+ score += mask.sum()
136
+
137
+ # Resize back to original size
138
+ mask_image = Image.fromarray(mask.astype(np.uint8), mode="L")
139
+ if rotation != 0:
140
+ mask_image = mask_image.rotate(-rotation, center=center)
141
+
142
+ mask_image = mask_image.resize(original_size)
143
+
144
+ # Convert back to numpy - improved tensor handling
145
+ mask = np.array(mask_image).astype(np.float32) / 255.0 # Normalize to 0-1 range
146
+
147
+ # Add dimensions for torch tensor - consistent format
148
+ mask = np.expand_dims(mask, axis=0)
149
+ mask = np.expand_dims(mask, axis=0)
150
+
151
+ return torch.from_numpy(mask), score
152
+
153
+
154
+ def numpy_to_torch_tensor(image_np: np.ndarray) -> torch.Tensor:
155
+ """Convert numpy array to torch tensor in the format expected by the models."""
156
+ if len(image_np.shape) == 3:
157
+ return torch.from_numpy(image_np.astype(np.float32) / 255.0).unsqueeze(0)
158
+ return torch.from_numpy(image_np.astype(np.float32) / 255.0)
159
+
160
+
161
+ def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
162
+ """Convert torch tensor back to numpy array - improved version."""
163
+ if len(tensor.shape) == 4:
164
+ tensor = tensor.squeeze(0)
165
+
166
+ # Always handle as float32 tensor in 0-1 range then convert to binary
167
+ tensor_np = tensor.numpy()
168
+ if tensor_np.dtype == np.float32 and tensor_np.max() <= 1.0:
169
+ return (tensor_np > 0.5).astype(np.float32) # Binary threshold
170
+ else:
171
+ return tensor_np
172
+
173
+
174
+ class HumanPartsSegmentation:
175
+ """
176
+ Standalone human parts segmentation for face and hair using DeepLabV3+ ResNet50.
177
+ """
178
+
179
+ def __init__(self):
180
+ self.model = None
181
+
182
+ def check_model_cache(self):
183
+ """Check if model file exists in cache - consistent with updated repos."""
184
+ if not os.path.exists(model_path):
185
+ return False, "Model file not found"
186
+ return True, "Model cache verified"
187
+
188
+ def clear_model(self):
189
+ """Clear model from memory - improved version."""
190
+ if self.model is not None:
191
+ del self.model
192
+ self.model = None
193
+
194
+ def load_model(self):
195
+ """Load the human parts segmentation model - improved version."""
196
+ try:
197
+ # Check and download model if needed
198
+ cache_status, message = self.check_model_cache()
199
+ if not cache_status:
200
+ print(f"Cache check: {message}")
201
+ if not download_model(model_url, model_path):
202
+ return False
203
+
204
+ # Load model if needed
205
+ if self.model is None:
206
+ print("Loading human parts segmentation model...")
207
+ self.model = ort.InferenceSession(model_path)
208
+ print("✅ Human parts segmentation model loaded successfully")
209
+
210
+ return True
211
+
212
+ except Exception as e:
213
+ print(f"❌ Error loading human parts model: {e}")
214
+ self.clear_model() # Cleanup on error
215
+ return False
216
+
217
+ def segment_parts(self, image_path: str, parts: List[str], mask_blur: int = 0, mask_offset: int = 0) -> Dict[str, np.ndarray]:
218
+ """
219
+ Segment specific human parts from an image - improved version with filtering.
220
+
221
+ Args:
222
+ image_path: Path to the image file
223
+ parts: List of part names to segment (e.g., ['face', 'hair'])
224
+ mask_blur: Blur amount for mask edges
225
+ mask_offset: Expand/Shrink mask boundary
226
+
227
+ Returns:
228
+ Dictionary mapping part names to binary masks
229
+ """
230
+ if not self.load_model():
231
+ print("❌ Cannot load human parts segmentation model")
232
+ return {}
233
+
234
+ try:
235
+ # Load image
236
+ image = cv2.imread(image_path)
237
+ if image is None:
238
+ print(f"❌ Could not load image: {image_path}")
239
+ return {}
240
+
241
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
242
+
243
+ # Convert to tensor format expected by the model
244
+ image_tensor = numpy_to_torch_tensor(image_rgb)
245
+
246
+ # Prepare kwargs for each part
247
+ part_kwargs = {part: True for part in parts}
248
+
249
+ # Get segmentation mask
250
+ mask_tensor, score = get_human_parts_mask(image_tensor, self.model, **part_kwargs)
251
+
252
+ # Convert back to numpy
253
+ if len(mask_tensor.shape) == 4:
254
+ mask_tensor = mask_tensor.squeeze(0).squeeze(0)
255
+ elif len(mask_tensor.shape) == 3:
256
+ mask_tensor = mask_tensor.squeeze(0)
257
+
258
+ # Get the combined mask for all requested parts
259
+ combined_mask = mask_tensor.numpy()
260
+
261
+ # Generate individual masks for each part if multiple parts requested
262
+ result_masks = {}
263
+ if len(parts) == 1:
264
+ # Single part - return the combined mask
265
+ part_name = parts[0]
266
+ final_mask = self._apply_filters(combined_mask, mask_blur, mask_offset)
267
+ if np.sum(final_mask > 0) > 0:
268
+ result_masks[part_name] = final_mask
269
+ else:
270
+ result_masks[part_name] = final_mask # Return empty mask instead of None
271
+ else:
272
+ # Multiple parts - need to segment each individually
273
+ for part in parts:
274
+ single_part_kwargs = {part: True}
275
+ single_mask_tensor, _ = get_human_parts_mask(image_tensor, self.model, **single_part_kwargs)
276
+
277
+ if len(single_mask_tensor.shape) == 4:
278
+ single_mask_tensor = single_mask_tensor.squeeze(0).squeeze(0)
279
+ elif len(single_mask_tensor.shape) == 3:
280
+ single_mask_tensor = single_mask_tensor.squeeze(0)
281
+
282
+ single_mask = single_mask_tensor.numpy()
283
+ final_mask = self._apply_filters(single_mask, mask_blur, mask_offset)
284
+
285
+ result_masks[part] = final_mask # Always add mask, even if empty
286
+
287
+ return result_masks
288
+
289
+ except Exception as e:
290
+ print(f"❌ Error in human parts segmentation: {e}")
291
+ return {}
292
+ finally:
293
+ # Clean up model if not needed
294
+ self.clear_model()
295
+
296
+ def _apply_filters(self, mask: np.ndarray, mask_blur: int = 0, mask_offset: int = 0) -> np.ndarray:
297
+ """Apply filtering to mask - new method from updated repo."""
298
+ if mask_blur == 0 and mask_offset == 0:
299
+ return mask
300
+
301
+ try:
302
+ # Convert to PIL for filtering
303
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8))
304
+
305
+ # Apply blur if specified
306
+ if mask_blur > 0:
307
+ mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
308
+
309
+ # Apply offset if specified
310
+ if mask_offset != 0:
311
+ if mask_offset > 0:
312
+ mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1))
313
+ else:
314
+ mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1))
315
+
316
+ # Convert back to numpy
317
+ filtered_mask = np.array(mask_image).astype(np.float32) / 255.0
318
+ return filtered_mask
319
+
320
+ except Exception as e:
321
+ print(f"❌ Error applying filters: {e}")
322
+ return mask
models/RMBG/segformer_clothes/.cache/huggingface/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *
models/RMBG/segformer_clothes/.cache/huggingface/download/config.json.lock ADDED
File without changes
models/RMBG/segformer_clothes/.cache/huggingface/download/config.json.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 2634bcc40712620e414ffb0efd5f5e4ea732ec5d
2
+ 8352c4562bb0e1f72767dcb170ad6f3f56007836
3
+ 1748821507.461211
models/RMBG/segformer_clothes/.cache/huggingface/download/model.safetensors.lock ADDED
File without changes
models/RMBG/segformer_clothes/.cache/huggingface/download/model.safetensors.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 2634bcc40712620e414ffb0efd5f5e4ea732ec5d
2
+ f70ae566c5773fb335796ebaa8acc924ac25eb97222c2b2967d44d2fc11568e6
3
+ 1748821512.848557
models/RMBG/segformer_clothes/.cache/huggingface/download/preprocessor_config.json.lock ADDED
File without changes
models/RMBG/segformer_clothes/.cache/huggingface/download/preprocessor_config.json.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 2634bcc40712620e414ffb0efd5f5e4ea732ec5d
2
+ b2340cf4e53b37fda4f5b92d28f11c0f33c3d0fd
3
+ 1748821513.065366
models/RMBG/segformer_clothes/config.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nvidia/mit-b3",
3
+ "architectures": [
4
+ "SegformerForSemanticSegmentation"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "classifier_dropout_prob": 0.1,
8
+ "decoder_hidden_size": 768,
9
+ "depths": [
10
+ 3,
11
+ 4,
12
+ 18,
13
+ 3
14
+ ],
15
+ "downsampling_rates": [
16
+ 1,
17
+ 4,
18
+ 8,
19
+ 16
20
+ ],
21
+ "drop_path_rate": 0.1,
22
+ "hidden_act": "gelu",
23
+ "hidden_dropout_prob": 0.0,
24
+ "hidden_sizes": [
25
+ 64,
26
+ 128,
27
+ 320,
28
+ 512
29
+ ],
30
+ "id2label": {
31
+ "0": "Background",
32
+ "1": "Hat",
33
+ "10": "Right-shoe",
34
+ "11": "Face",
35
+ "12": "Left-leg",
36
+ "13": "Right-leg",
37
+ "14": "Left-arm",
38
+ "15": "Right-arm",
39
+ "16": "Bag",
40
+ "17": "Scarf",
41
+ "2": "Hair",
42
+ "3": "Sunglasses",
43
+ "4": "Upper-clothes",
44
+ "5": "Skirt",
45
+ "6": "Pants",
46
+ "7": "Dress",
47
+ "8": "Belt",
48
+ "9": "Left-shoe"
49
+ },
50
+ "image_size": 224,
51
+ "initializer_range": 0.02,
52
+ "label2id": {
53
+ "Background": "0",
54
+ "Bag": "16",
55
+ "Belt": "8",
56
+ "Dress": "7",
57
+ "Face": "11",
58
+ "Hair": "2",
59
+ "Hat": "1",
60
+ "Left-arm": "14",
61
+ "Left-leg": "12",
62
+ "Left-shoe": "9",
63
+ "Pants": "6",
64
+ "Right-arm": "15",
65
+ "Right-leg": "13",
66
+ "Right-shoe": "10",
67
+ "Scarf": "17",
68
+ "Skirt": "5",
69
+ "Sunglasses": "3",
70
+ "Upper-clothes": "4"
71
+ },
72
+ "layer_norm_eps": 1e-06,
73
+ "mlp_ratios": [
74
+ 4,
75
+ 4,
76
+ 4,
77
+ 4
78
+ ],
79
+ "model_type": "segformer",
80
+ "num_attention_heads": [
81
+ 1,
82
+ 2,
83
+ 5,
84
+ 8
85
+ ],
86
+ "num_channels": 3,
87
+ "num_encoder_blocks": 4,
88
+ "patch_sizes": [
89
+ 7,
90
+ 3,
91
+ 3,
92
+ 3
93
+ ],
94
+ "reshape_last_stage": true,
95
+ "semantic_loss_ignore_index": 255,
96
+ "sr_ratios": [
97
+ 8,
98
+ 4,
99
+ 2,
100
+ 1
101
+ ],
102
+ "strides": [
103
+ 4,
104
+ 2,
105
+ 2,
106
+ 2
107
+ ],
108
+ "torch_dtype": "float32",
109
+ "transformers_version": "4.38.1"
110
+ }
models/RMBG/segformer_clothes/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f70ae566c5773fb335796ebaa8acc924ac25eb97222c2b2967d44d2fc11568e6
3
+ size 189029000
models/RMBG/segformer_clothes/preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_reduce_labels": false,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.485,
8
+ 0.456,
9
+ 0.406
10
+ ],
11
+ "image_processor_type": "SegformerImageProcessor",
12
+ "image_std": [
13
+ 0.229,
14
+ 0.224,
15
+ 0.225
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 512,
21
+ "width": 512
22
+ }
23
+ }
models/onnx/human-parts/deeplabv3p-resnet50-human.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6e823a82da10ba24c29adfb544130684568c46bfac865e215bbace3b4035a71
3
+ size 47210581
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LaDeco requirements
2
+ torch==2.3.1
3
+ torchaudio
4
+ torchvision
5
+ tf-keras
6
+ transformers==4.42.4
7
+ diffusers
8
+ opencv-python
9
+ Pillow
10
+ numpy
11
+ matplotlib
12
+ scipy
13
+ scikit-learn
14
+
15
+ # For Gradio interface
16
+ gradio
17
+
18
+ # Face comparison
19
+ deepface
20
+
21
+ # Human parts segmentation
22
+ onnxruntime
23
+
24
+ # Clothing segmentation
25
+ huggingface-hub>=0.19.0
26
+ segment-anything>=1.0
27
+
28
+ # Color matching dependencies
29
+ color-matcher
30
+ spaces
spaces.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+
4
+ def GPU(func):
5
+ """
6
+ A decorator to indicate that a function should use GPU acceleration if available.
7
+ This is used specifically for Hugging Face Spaces.
8
+ """
9
+ @functools.wraps(func)
10
+ def wrapper(*args, **kwargs):
11
+ return func(*args, **kwargs)
12
+ return wrapper