Upload 25 files
Browse files- README.md +49 -4
- app.py +956 -0
- cdl_smoothing.py +497 -0
- clothes_segmentation.py +292 -0
- color_matching.py +698 -0
- core.py +356 -0
- examples/beach.jpg +0 -0
- examples/field.jpg +0 -0
- examples/sky.jpg +0 -0
- face_comparison.py +246 -0
- folder_paths.py +22 -0
- human_parts_segmentation.py +322 -0
- models/RMBG/segformer_clothes/.cache/huggingface/.gitignore +1 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/config.json.lock +0 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/config.json.metadata +3 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/model.safetensors.lock +0 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/model.safetensors.metadata +3 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/preprocessor_config.json.lock +0 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/preprocessor_config.json.metadata +3 -0
- models/RMBG/segformer_clothes/config.json +110 -0
- models/RMBG/segformer_clothes/model.safetensors +3 -0
- models/RMBG/segformer_clothes/preprocessor_config.json +23 -0
- models/onnx/human-parts/deeplabv3p-resnet50-human.onnx +3 -0
- requirements.txt +30 -0
- spaces.py +12 -0
README.md
CHANGED
@@ -1,12 +1,57 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: LaDeco
|
3 |
+
emoji: 👀
|
4 |
colorFrom: gray
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.31.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
short_description: 'LaDeco: A tool to analyze visual landscape elements'
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
14 |
+
|
15 |
+
# LaDeco - Landscape Environment Semantic Analysis Model
|
16 |
+
|
17 |
+
LaDeco is a tool that analyzes landscape images, performs semantic segmentation to identify different elements in the scene (sky, vegetation, buildings, etc.), and enables region-based color matching between images.
|
18 |
+
|
19 |
+
## Features
|
20 |
+
|
21 |
+
### Semantic Segmentation
|
22 |
+
- Analyzes landscape images and segments them into different semantic regions
|
23 |
+
- Provides area ratio analysis for each landscape element
|
24 |
+
|
25 |
+
### Region-Based Color Matching
|
26 |
+
- Matches colors between corresponding semantic regions of two images
|
27 |
+
- Shows visualization of which regions are being matched between images
|
28 |
+
- Offers multiple color matching algorithms:
|
29 |
+
- **adain**: Adaptive Instance Normalization - Matches mean and standard deviation of colors
|
30 |
+
- **mkl**: Monge-Kantorovich Linearization - Linear transformation of color statistics
|
31 |
+
- **reinhard**: Reinhard color transfer - Simple statistical approach that matches mean and standard deviation
|
32 |
+
- **mvgd**: Multi-Variate Gaussian Distribution - Uses color covariance matrices for more accurate matching
|
33 |
+
- **hm**: Histogram Matching - Matches the full color distribution histograms
|
34 |
+
- **hm-mvgd-hm**: Histogram + MVGD + Histogram compound method
|
35 |
+
- **hm-mkl-hm**: Histogram + MKL + Histogram compound method
|
36 |
+
|
37 |
+
## Installation
|
38 |
+
|
39 |
+
1. Clone this repository
|
40 |
+
2. Create a virtual environment: `python3 -m venv .venv`
|
41 |
+
3. Activate the virtual environment: `source .venv/bin/activate`
|
42 |
+
4. Install requirements: `pip install -r requirements.txt`
|
43 |
+
5. Run the application: `python app.py`
|
44 |
+
|
45 |
+
## Usage
|
46 |
+
|
47 |
+
1. Upload two landscape images - the first will be the color reference, the second will be color-matched to the first
|
48 |
+
2. Choose a color matching method from the dropdown menu
|
49 |
+
3. Click "Start Analysis" to process the images
|
50 |
+
4. View the results in the Segmentation and Color Matching tabs
|
51 |
+
- Segmentation tab shows the semantic segmentation and area ratios for both images
|
52 |
+
- Color Matching tab shows the matched regions visualization and the color matching result
|
53 |
+
|
54 |
+
## Reference
|
55 |
+
|
56 |
+
Li-Chih Ho (2023), LaDeco: A Tool to Analyze Visual Landscape Elements, Ecological Informatics, vol. 78.
|
57 |
+
https://www.sciencedirect.com/science/article/pii/S1574954123003187
|
app.py
ADDED
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from core import Ladeco
|
3 |
+
from matplotlib.figure import Figure
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import matplotlib as mpl
|
6 |
+
import spaces
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
from color_matching import RegionColorMatcher, create_comparison_figure
|
10 |
+
from face_comparison import FaceComparison
|
11 |
+
from cdl_smoothing import cdl_edge_smoothing, get_smoothing_stats, cdl_edge_smoothing_apply_to_source
|
12 |
+
import tempfile
|
13 |
+
import os
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
|
17 |
+
plt.rcParams['figure.facecolor'] = '#0b0f19'
|
18 |
+
plt.rcParams['text.color'] = '#aab6cc'
|
19 |
+
ladeco = Ladeco()
|
20 |
+
|
21 |
+
|
22 |
+
@spaces.GPU
|
23 |
+
def infer_two_images(img1: str, img2: str, method: str, enable_face_matching: bool, enable_edge_smoothing: bool) -> tuple[Figure, Figure, Figure, Figure, Figure, Figure, str, str, str]:
|
24 |
+
"""
|
25 |
+
Clean 4-step approach:
|
26 |
+
1. Segment both images identically
|
27 |
+
2. Determine segment correspondences
|
28 |
+
3. Match each segment pair in isolation
|
29 |
+
4. Composite all matched segments
|
30 |
+
"""
|
31 |
+
|
32 |
+
cdl_display = "" # Initialize CDL display string
|
33 |
+
|
34 |
+
# STEP 1: SEGMENT BOTH IMAGES IDENTICALLY
|
35 |
+
# This step is always identical regardless of face matching
|
36 |
+
print("Step 1: Segmenting both images...")
|
37 |
+
out1 = ladeco.predict(img1)
|
38 |
+
out2 = ladeco.predict(img2)
|
39 |
+
|
40 |
+
# Extract visualization and stats (unchanged)
|
41 |
+
seg1 = out1.visualize(level=2)[0].image
|
42 |
+
colormap1 = out1.color_map(level=2)
|
43 |
+
area1 = out1.area()[0]
|
44 |
+
|
45 |
+
seg2 = out2.visualize(level=2)[0].image
|
46 |
+
colormap2 = out2.color_map(level=2)
|
47 |
+
area2 = out2.area()[0]
|
48 |
+
|
49 |
+
# Process areas for pie charts
|
50 |
+
colors1, l2_area1 = [], {}
|
51 |
+
for labelname, area_ratio in area1.items():
|
52 |
+
if labelname.startswith("l2") and area_ratio > 0:
|
53 |
+
colors1.append(colormap1[labelname])
|
54 |
+
labelname = labelname.replace("l2_", "").capitalize()
|
55 |
+
l2_area1[labelname] = area_ratio
|
56 |
+
|
57 |
+
colors2, l2_area2 = [], {}
|
58 |
+
for labelname, area_ratio in area2.items():
|
59 |
+
if labelname.startswith("l2") and area_ratio > 0:
|
60 |
+
colors2.append(colormap2[labelname])
|
61 |
+
labelname = labelname.replace("l2_", "").capitalize()
|
62 |
+
l2_area2[labelname] = area_ratio
|
63 |
+
|
64 |
+
pie1 = plot_pie(l2_area1, colors=colors1)
|
65 |
+
pie2 = plot_pie(l2_area2, colors=colors2)
|
66 |
+
|
67 |
+
# Set plot sizes
|
68 |
+
for fig in [seg1, seg2, pie1, pie2]:
|
69 |
+
fig.set_dpi(96)
|
70 |
+
fig.set_size_inches(256/96, 256/96)
|
71 |
+
|
72 |
+
# Extract semantic masks - IDENTICAL for both images regardless of face matching
|
73 |
+
masks1 = extract_semantic_masks(out1)
|
74 |
+
masks2 = extract_semantic_masks(out2)
|
75 |
+
|
76 |
+
print(f"Extracted {len(masks1)} masks from img1, {len(masks2)} masks from img2")
|
77 |
+
|
78 |
+
# STEP 2: DETERMINE SEGMENT CORRESPONDENCES
|
79 |
+
print("Step 2: Determining segment correspondences...")
|
80 |
+
face_log = ["Step 2: Determining segment correspondences"]
|
81 |
+
|
82 |
+
# Find common segments between both images
|
83 |
+
common_segments = set(masks1.keys()).intersection(set(masks2.keys()))
|
84 |
+
face_log.append(f"Found {len(common_segments)} common segments: {sorted(common_segments)}")
|
85 |
+
|
86 |
+
# Determine which segments to match based on face matching logic
|
87 |
+
segments_to_match = determine_segments_to_match(img1, img2, common_segments, enable_face_matching, face_log)
|
88 |
+
|
89 |
+
face_log.append(f"Final segments to match: {sorted(segments_to_match)}")
|
90 |
+
|
91 |
+
# STEP 3: MATCH EACH SEGMENT PAIR IN ISOLATION
|
92 |
+
print("Step 3: Matching each segment pair in isolation...")
|
93 |
+
face_log.append("\nStep 3: Color matching each segment independently")
|
94 |
+
|
95 |
+
matched_regions = {}
|
96 |
+
segment_masks = {} # Store masks for all segments being matched
|
97 |
+
|
98 |
+
for segment_name in segments_to_match:
|
99 |
+
if segment_name in masks1 and segment_name in masks2:
|
100 |
+
face_log.append(f" Processing {segment_name}...")
|
101 |
+
|
102 |
+
# Match this segment in complete isolation
|
103 |
+
matched_region, final_mask1, final_mask2 = match_single_segment(
|
104 |
+
img1, img2,
|
105 |
+
masks1[segment_name], masks2[segment_name],
|
106 |
+
segment_name, method, face_log
|
107 |
+
)
|
108 |
+
|
109 |
+
if matched_region is not None:
|
110 |
+
matched_regions[segment_name] = matched_region
|
111 |
+
segment_masks[segment_name] = final_mask2 # Use mask from target image for compositing
|
112 |
+
face_log.append(f" ✅ {segment_name} matched successfully")
|
113 |
+
else:
|
114 |
+
face_log.append(f" ❌ {segment_name} matching failed")
|
115 |
+
elif segment_name.startswith('l4_'):
|
116 |
+
# Handle fine-grained segments that need to be generated
|
117 |
+
face_log.append(f" Processing fine-grained {segment_name}...")
|
118 |
+
|
119 |
+
matched_region, final_mask1, final_mask2 = match_single_segment(
|
120 |
+
img1, img2, None, None, segment_name, method, face_log
|
121 |
+
)
|
122 |
+
|
123 |
+
if matched_region is not None:
|
124 |
+
matched_regions[segment_name] = matched_region
|
125 |
+
segment_masks[segment_name] = final_mask2 # Store the generated mask
|
126 |
+
face_log.append(f" ✅ {segment_name} matched successfully")
|
127 |
+
else:
|
128 |
+
face_log.append(f" ❌ {segment_name} matching failed")
|
129 |
+
|
130 |
+
# STEP 4: COMPOSITE ALL MATCHED SEGMENTS
|
131 |
+
print("Step 4: Compositing all matched segments...")
|
132 |
+
face_log.append(f"\nStep 4: Compositing {len(matched_regions)} matched segments")
|
133 |
+
|
134 |
+
final_image = composite_matched_segments(img2, matched_regions, segment_masks, face_log)
|
135 |
+
|
136 |
+
# STEP 5: OPTIONAL CDL-BASED EDGE SMOOTHING
|
137 |
+
if enable_edge_smoothing:
|
138 |
+
print("Step 5: Applying CDL-based edge smoothing...")
|
139 |
+
face_log.append("\nStep 5: CDL edge smoothing - applying CDL transform to image 2 based on composited result")
|
140 |
+
|
141 |
+
try:
|
142 |
+
# Save the composited result temporarily for CDL calculation
|
143 |
+
temp_dir = tempfile.gettempdir()
|
144 |
+
temp_composite_path = os.path.join(temp_dir, "temp_composite_for_cdl.png")
|
145 |
+
final_image.save(temp_composite_path, "PNG")
|
146 |
+
|
147 |
+
# Calculate CDL parameters to transform image 2 → composited result
|
148 |
+
cdl_stats = get_smoothing_stats(img2, temp_composite_path)
|
149 |
+
|
150 |
+
# Log the CDL values
|
151 |
+
slope = cdl_stats['cdl_slope']
|
152 |
+
offset = cdl_stats['cdl_offset']
|
153 |
+
power = cdl_stats['cdl_power']
|
154 |
+
|
155 |
+
# Format CDL values for display
|
156 |
+
cdl_display = f"""📊 CDL Parameters (Image 2 → Composited Result):
|
157 |
+
|
158 |
+
🔧 Method: Simple Mean/Std Matching (basic statistical approach)
|
159 |
+
|
160 |
+
🔸 Slope (Gain):
|
161 |
+
Red: {slope[0]:.6f}
|
162 |
+
Green: {slope[1]:.6f}
|
163 |
+
Blue: {slope[2]:.6f}
|
164 |
+
|
165 |
+
🔸 Offset:
|
166 |
+
Red: {offset[0]:.6f}
|
167 |
+
Green: {offset[1]:.6f}
|
168 |
+
Blue: {offset[2]:.6f}
|
169 |
+
|
170 |
+
🔸 Power (Gamma):
|
171 |
+
Red: {power[0]:.6f}
|
172 |
+
Green: {power[1]:.6f}
|
173 |
+
Blue: {power[2]:.6f}
|
174 |
+
|
175 |
+
These CDL values represent the color transformation needed to convert Image 2 into the composited result.
|
176 |
+
|
177 |
+
The CDL calculation uses the simplest possible approach: matches the mean and standard deviation
|
178 |
+
of each color channel between the original and composited images, with simple gamma calculation
|
179 |
+
based on brightness relationships.
|
180 |
+
"""
|
181 |
+
|
182 |
+
face_log.append(f"📊 CDL Parameters (image 2 → composited result):")
|
183 |
+
face_log.append(f" Method: Simple mean/std matching")
|
184 |
+
face_log.append(f" Slope (R,G,B): [{slope[0]:.4f}, {slope[1]:.4f}, {slope[2]:.4f}]")
|
185 |
+
face_log.append(f" Offset (R,G,B): [{offset[0]:.4f}, {offset[1]:.4f}, {offset[2]:.4f}]")
|
186 |
+
face_log.append(f" Power (R,G,B): [{power[0]:.4f}, {power[1]:.4f}, {power[2]:.4f}]")
|
187 |
+
|
188 |
+
# Apply CDL transformation to image 2 to approximate the composited result
|
189 |
+
final_image = cdl_edge_smoothing_apply_to_source(img2, temp_composite_path, factor=1.0)
|
190 |
+
|
191 |
+
# Clean up temp file
|
192 |
+
if os.path.exists(temp_composite_path):
|
193 |
+
os.remove(temp_composite_path)
|
194 |
+
|
195 |
+
face_log.append("✅ CDL edge smoothing completed - transformed image 2 using calculated CDL parameters")
|
196 |
+
|
197 |
+
except Exception as e:
|
198 |
+
face_log.append(f"❌ CDL edge smoothing failed: {e}")
|
199 |
+
cdl_display = f"❌ CDL calculation failed: {e}"
|
200 |
+
else:
|
201 |
+
face_log.append("\nStep 5: CDL edge smoothing disabled")
|
202 |
+
cdl_display = "CDL edge smoothing is disabled. Enable it to see CDL parameters."
|
203 |
+
|
204 |
+
# Save result
|
205 |
+
temp_dir = tempfile.gettempdir()
|
206 |
+
filename = os.path.basename(img2).split('.')[0]
|
207 |
+
temp_filename = f"color_matched_{method}_{filename}.png"
|
208 |
+
temp_path = os.path.join(temp_dir, temp_filename)
|
209 |
+
final_image.save(temp_path, "PNG")
|
210 |
+
|
211 |
+
# Create visualizations
|
212 |
+
# For visualization, we need to collect the masks that were actually used
|
213 |
+
vis_masks1 = {}
|
214 |
+
vis_masks2 = {}
|
215 |
+
|
216 |
+
for segment_name in segments_to_match:
|
217 |
+
if segment_name in segment_masks:
|
218 |
+
if segment_name.startswith('l4_'):
|
219 |
+
# Fine-grained segments - we'll regenerate for visualization
|
220 |
+
part_name = segment_name.replace('l4_', '')
|
221 |
+
if part_name in ['face', 'hair']:
|
222 |
+
from human_parts_segmentation import HumanPartsSegmentation
|
223 |
+
segmenter = HumanPartsSegmentation()
|
224 |
+
masks_dict1 = segmenter.segment_parts(img1, [part_name])
|
225 |
+
masks_dict2 = segmenter.segment_parts(img2, [part_name])
|
226 |
+
if part_name in masks_dict1 and part_name in masks_dict2:
|
227 |
+
vis_masks1[segment_name] = masks_dict1[part_name]
|
228 |
+
vis_masks2[segment_name] = masks_dict2[part_name]
|
229 |
+
elif part_name == 'upper_clothes':
|
230 |
+
from clothes_segmentation import ClothesSegmentation
|
231 |
+
segmenter = ClothesSegmentation()
|
232 |
+
mask1 = segmenter.segment_clothes(img1, ["Upper-clothes"])
|
233 |
+
mask2 = segmenter.segment_clothes(img2, ["Upper-clothes"])
|
234 |
+
if mask1 is not None and mask2 is not None:
|
235 |
+
vis_masks1[segment_name] = mask1
|
236 |
+
vis_masks2[segment_name] = mask2
|
237 |
+
else:
|
238 |
+
# Regular segments - use original masks
|
239 |
+
if segment_name in masks1 and segment_name in masks2:
|
240 |
+
vis_masks1[segment_name] = masks1[segment_name]
|
241 |
+
vis_masks2[segment_name] = masks2[segment_name]
|
242 |
+
|
243 |
+
mask_vis = visualize_matching_masks(img1, img2, vis_masks1, vis_masks2)
|
244 |
+
|
245 |
+
comparison = create_comparison_figure(Image.open(img2), final_image, f"Color Matching Result ({method})")
|
246 |
+
|
247 |
+
face_log_text = "\n".join(face_log)
|
248 |
+
|
249 |
+
return seg1, pie1, seg2, pie2, comparison, mask_vis, temp_path, face_log_text, cdl_display
|
250 |
+
|
251 |
+
|
252 |
+
def determine_segments_to_match(img1: str, img2: str, common_segments: set, enable_face_matching: bool, log: list) -> set:
|
253 |
+
"""
|
254 |
+
Determine which segments should be matched based on face matching logic.
|
255 |
+
Returns the set of segment names to process.
|
256 |
+
"""
|
257 |
+
if not enable_face_matching:
|
258 |
+
log.append("Face matching disabled - matching all common segments")
|
259 |
+
return common_segments
|
260 |
+
|
261 |
+
log.append("Face matching enabled - checking faces...")
|
262 |
+
|
263 |
+
# Run face comparison
|
264 |
+
face_comparator = FaceComparison()
|
265 |
+
faces_match, face_log = face_comparator.run_face_comparison(img1, img2)
|
266 |
+
log.extend(face_log)
|
267 |
+
|
268 |
+
if not faces_match:
|
269 |
+
# Remove human/bio segments from matching
|
270 |
+
log.append("No face match - excluding human/bio segments")
|
271 |
+
non_human_segments = set()
|
272 |
+
for segment in common_segments:
|
273 |
+
if not any(term in segment.lower() for term in ['l3_human', 'l2_bio']):
|
274 |
+
non_human_segments.add(segment)
|
275 |
+
else:
|
276 |
+
log.append(f" Excluding human segment: {segment}")
|
277 |
+
|
278 |
+
log.append(f"Matching {len(non_human_segments)} non-human segments")
|
279 |
+
return non_human_segments
|
280 |
+
|
281 |
+
else:
|
282 |
+
# Faces match - include all segments + add fine-grained if possible
|
283 |
+
log.append("Faces match - including all segments + fine-grained")
|
284 |
+
|
285 |
+
segments_to_match = common_segments.copy()
|
286 |
+
|
287 |
+
# Add fine-grained human parts if bio regions exist
|
288 |
+
bio_segments = [s for s in common_segments if 'l2_bio' in s.lower()]
|
289 |
+
if bio_segments:
|
290 |
+
fine_grained_segments = add_fine_grained_segments(img1, img2, common_segments, log)
|
291 |
+
segments_to_match.update(fine_grained_segments)
|
292 |
+
|
293 |
+
return segments_to_match
|
294 |
+
|
295 |
+
|
296 |
+
def add_fine_grained_segments(img1: str, img2: str, common_segments: set, log: list) -> set:
|
297 |
+
"""
|
298 |
+
Add fine-grained human parts segments when faces match.
|
299 |
+
Returns set of fine-grained segment names that were successfully added.
|
300 |
+
"""
|
301 |
+
fine_grained_segments = set()
|
302 |
+
|
303 |
+
try:
|
304 |
+
from human_parts_segmentation import HumanPartsSegmentation
|
305 |
+
from clothes_segmentation import ClothesSegmentation
|
306 |
+
|
307 |
+
log.append(" Adding fine-grained human parts...")
|
308 |
+
|
309 |
+
# Get face and hair masks
|
310 |
+
human_segmenter = HumanPartsSegmentation()
|
311 |
+
face_hair_masks1 = human_segmenter.segment_parts(img1, ['face', 'hair'])
|
312 |
+
face_hair_masks2 = human_segmenter.segment_parts(img2, ['face', 'hair'])
|
313 |
+
|
314 |
+
# Get clothes masks
|
315 |
+
clothes_segmenter = ClothesSegmentation()
|
316 |
+
clothes_mask1 = clothes_segmenter.segment_clothes(img1, ["Upper-clothes"])
|
317 |
+
clothes_mask2 = clothes_segmenter.segment_clothes(img2, ["Upper-clothes"])
|
318 |
+
|
319 |
+
# Process face/hair
|
320 |
+
for part_name, mask1 in face_hair_masks1.items():
|
321 |
+
if (mask1 is not None and part_name in face_hair_masks2 and
|
322 |
+
face_hair_masks2[part_name] is not None):
|
323 |
+
|
324 |
+
if np.sum(mask1 > 0) > 0 and np.sum(face_hair_masks2[part_name] > 0) > 0:
|
325 |
+
fine_grained_segments.add(f'l4_{part_name}')
|
326 |
+
log.append(f" Added fine-grained: {part_name}")
|
327 |
+
|
328 |
+
# Process clothes
|
329 |
+
if (clothes_mask1 is not None and clothes_mask2 is not None and
|
330 |
+
np.sum(clothes_mask1 > 0) > 0 and np.sum(clothes_mask2 > 0) > 0):
|
331 |
+
fine_grained_segments.add('l4_upper_clothes')
|
332 |
+
log.append(f" Added fine-grained: upper_clothes")
|
333 |
+
|
334 |
+
except Exception as e:
|
335 |
+
log.append(f" Error adding fine-grained segments: {e}")
|
336 |
+
|
337 |
+
return fine_grained_segments
|
338 |
+
|
339 |
+
|
340 |
+
def match_single_segment(img1_path: str, img2_path: str, mask1: np.ndarray, mask2: np.ndarray,
|
341 |
+
segment_name: str, method: str, log: list) -> tuple[Image.Image, np.ndarray, np.ndarray]:
|
342 |
+
"""
|
343 |
+
Match colors of a single segment in complete isolation from other segments.
|
344 |
+
Each segment is processed independently with no knowledge of other segments.
|
345 |
+
Returns: (matched_image, final_mask1, final_mask2)
|
346 |
+
"""
|
347 |
+
try:
|
348 |
+
# Load images
|
349 |
+
img1 = Image.open(img1_path).convert("RGB")
|
350 |
+
img2 = Image.open(img2_path).convert("RGB")
|
351 |
+
|
352 |
+
# Convert to numpy
|
353 |
+
img1_np = np.array(img1)
|
354 |
+
img2_np = np.array(img2)
|
355 |
+
|
356 |
+
# Handle fine-grained segments
|
357 |
+
if segment_name.startswith('l4_'):
|
358 |
+
part_name = segment_name.replace('l4_', '')
|
359 |
+
if part_name in ['face', 'hair']:
|
360 |
+
from human_parts_segmentation import HumanPartsSegmentation
|
361 |
+
segmenter = HumanPartsSegmentation()
|
362 |
+
masks_dict1 = segmenter.segment_parts(img1_path, [part_name])
|
363 |
+
masks_dict2 = segmenter.segment_parts(img2_path, [part_name])
|
364 |
+
|
365 |
+
if part_name in masks_dict1 and part_name in masks_dict2:
|
366 |
+
mask1 = masks_dict1[part_name]
|
367 |
+
mask2 = masks_dict2[part_name]
|
368 |
+
else:
|
369 |
+
return None, None, None
|
370 |
+
|
371 |
+
elif part_name == 'upper_clothes':
|
372 |
+
from clothes_segmentation import ClothesSegmentation
|
373 |
+
segmenter = ClothesSegmentation()
|
374 |
+
mask1 = segmenter.segment_clothes(img1_path, ["Upper-clothes"])
|
375 |
+
mask2 = segmenter.segment_clothes(img2_path, ["Upper-clothes"])
|
376 |
+
|
377 |
+
if mask1 is None or mask2 is None:
|
378 |
+
return None, None, None
|
379 |
+
|
380 |
+
# Ensure masks are same size as images
|
381 |
+
if mask1.shape != img1_np.shape[:2]:
|
382 |
+
mask1 = cv2.resize(mask1.astype(np.float32), (img1_np.shape[1], img1_np.shape[0]),
|
383 |
+
interpolation=cv2.INTER_NEAREST)
|
384 |
+
if mask2.shape != img2_np.shape[:2]:
|
385 |
+
mask2 = cv2.resize(mask2.astype(np.float32), (img2_np.shape[1], img2_np.shape[0]),
|
386 |
+
interpolation=cv2.INTER_NEAREST)
|
387 |
+
|
388 |
+
# Convert to binary masks
|
389 |
+
mask1_binary = (mask1 > 0.5).astype(np.float32)
|
390 |
+
mask2_binary = (mask2 > 0.5).astype(np.float32)
|
391 |
+
|
392 |
+
# Check if masks have content
|
393 |
+
pixels1 = np.sum(mask1_binary > 0)
|
394 |
+
pixels2 = np.sum(mask2_binary > 0)
|
395 |
+
|
396 |
+
if pixels1 == 0 or pixels2 == 0:
|
397 |
+
log.append(f" No pixels in {segment_name}: img1={pixels1}, img2={pixels2}")
|
398 |
+
return None, None, None
|
399 |
+
|
400 |
+
log.append(f" {segment_name}: img1={pixels1} pixels, img2={pixels2} pixels")
|
401 |
+
|
402 |
+
# Create single-segment masks dictionary for color matcher
|
403 |
+
masks1_dict = {segment_name: mask1_binary}
|
404 |
+
masks2_dict = {segment_name: mask2_binary}
|
405 |
+
|
406 |
+
# Apply color matching to this segment only
|
407 |
+
color_matcher = RegionColorMatcher(factor=0.8, preserve_colors=True,
|
408 |
+
preserve_luminance=True, method=method)
|
409 |
+
|
410 |
+
matched_img = color_matcher.match_regions(img1_path, img2_path, masks1_dict, masks2_dict)
|
411 |
+
|
412 |
+
return matched_img, mask1_binary, mask2_binary
|
413 |
+
|
414 |
+
except Exception as e:
|
415 |
+
log.append(f" Error matching {segment_name}: {e}")
|
416 |
+
return None, None, None
|
417 |
+
|
418 |
+
|
419 |
+
def composite_matched_segments(base_img_path: str, matched_regions: dict, segment_masks: dict, log: list) -> Image.Image:
|
420 |
+
"""
|
421 |
+
Composite all matched segments back together using simple alpha compositing.
|
422 |
+
Each matched segment is completely independent and overlaid on the base image.
|
423 |
+
"""
|
424 |
+
# Start with base image
|
425 |
+
result = Image.open(base_img_path).convert("RGBA")
|
426 |
+
result_np = np.array(result)
|
427 |
+
|
428 |
+
log.append(f"Compositing {len(matched_regions)} segments onto base image")
|
429 |
+
|
430 |
+
for segment_name, matched_img in matched_regions.items():
|
431 |
+
if segment_name in segment_masks:
|
432 |
+
mask = segment_masks[segment_name]
|
433 |
+
|
434 |
+
# Ensure mask is right size
|
435 |
+
if mask.shape != result_np.shape[:2]:
|
436 |
+
mask = cv2.resize(mask.astype(np.float32),
|
437 |
+
(result_np.shape[1], result_np.shape[0]),
|
438 |
+
interpolation=cv2.INTER_NEAREST)
|
439 |
+
|
440 |
+
# Convert matched image to numpy
|
441 |
+
matched_np = np.array(matched_img.convert("RGB"))
|
442 |
+
|
443 |
+
# Ensure matched image is right size
|
444 |
+
if matched_np.shape[:2] != result_np.shape[:2]:
|
445 |
+
matched_pil = Image.fromarray(matched_np)
|
446 |
+
matched_pil = matched_pil.resize((result_np.shape[1], result_np.shape[0]), Image.LANCZOS)
|
447 |
+
matched_np = np.array(matched_pil)
|
448 |
+
|
449 |
+
# Apply mask with alpha blending
|
450 |
+
mask_binary = (mask > 0.5).astype(np.float32)
|
451 |
+
alpha = np.expand_dims(mask_binary, axis=2)
|
452 |
+
|
453 |
+
# Blend: result = result * (1 - alpha) + matched * alpha
|
454 |
+
result_np[:, :, :3] = (result_np[:, :, :3] * (1 - alpha) +
|
455 |
+
matched_np * alpha).astype(np.uint8)
|
456 |
+
|
457 |
+
pixels = np.sum(mask_binary > 0)
|
458 |
+
log.append(f" Composited {segment_name}: {pixels} pixels")
|
459 |
+
|
460 |
+
return Image.fromarray(result_np).convert("RGB")
|
461 |
+
|
462 |
+
|
463 |
+
def visualize_matching_masks(img1_path, img2_path, masks1, masks2):
|
464 |
+
"""
|
465 |
+
Create a visualization of the masks being matched between two images.
|
466 |
+
|
467 |
+
Args:
|
468 |
+
img1_path: Path to first image
|
469 |
+
img2_path: Path to second image
|
470 |
+
masks1: Dictionary of masks for first image {label: binary_mask}
|
471 |
+
masks2: Dictionary of masks for second image {label: binary_mask}
|
472 |
+
|
473 |
+
Returns:
|
474 |
+
A matplotlib Figure showing the matched masks
|
475 |
+
"""
|
476 |
+
# Load images
|
477 |
+
img1 = Image.open(img1_path).convert("RGB")
|
478 |
+
img2 = Image.open(img2_path).convert("RGB")
|
479 |
+
|
480 |
+
# Convert to numpy arrays
|
481 |
+
img1_np = np.array(img1)
|
482 |
+
img2_np = np.array(img2)
|
483 |
+
|
484 |
+
# Separate fine-grained human parts from regular masks
|
485 |
+
fine_grained_masks = {}
|
486 |
+
regular_masks = {}
|
487 |
+
|
488 |
+
for label, mask in masks1.items():
|
489 |
+
if label.startswith('l4_'): # Fine-grained human parts
|
490 |
+
fine_grained_masks[label] = mask
|
491 |
+
else:
|
492 |
+
regular_masks[label] = mask
|
493 |
+
|
494 |
+
# Find common labels in both regular and fine-grained masks
|
495 |
+
common_regular = set(regular_masks.keys()).intersection(set(masks2.keys()))
|
496 |
+
|
497 |
+
# Count fine-grained masks that are in both masks1 and masks2
|
498 |
+
common_fine_grained = set()
|
499 |
+
for label in fine_grained_masks.keys():
|
500 |
+
if label.startswith('l4_') and label in masks2:
|
501 |
+
part_name = label.replace('l4_', '')
|
502 |
+
common_fine_grained.add(part_name)
|
503 |
+
|
504 |
+
# Count total rows needed
|
505 |
+
n_regular_rows = len(common_regular)
|
506 |
+
n_fine_rows = len(common_fine_grained)
|
507 |
+
n_rows = n_regular_rows + n_fine_rows
|
508 |
+
|
509 |
+
if n_rows == 0:
|
510 |
+
# No common regions found
|
511 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
|
512 |
+
ax.text(0.5, 0.5, "No matching regions found between images",
|
513 |
+
ha='center', va='center', fontsize=14, color='white')
|
514 |
+
ax.axis('off')
|
515 |
+
return fig
|
516 |
+
|
517 |
+
fig, axes = plt.subplots(n_rows, 2, figsize=(12, 3 * n_rows))
|
518 |
+
|
519 |
+
# If only one row, reshape axes
|
520 |
+
if n_rows == 1:
|
521 |
+
axes = np.array([axes])
|
522 |
+
|
523 |
+
row_idx = 0
|
524 |
+
|
525 |
+
# Visualize regular semantic regions
|
526 |
+
for label in sorted(common_regular):
|
527 |
+
# Get label display name
|
528 |
+
display_name = label.replace("l2_", "").capitalize()
|
529 |
+
|
530 |
+
# Get masks and resize them to match the image dimensions
|
531 |
+
mask1 = regular_masks[label]
|
532 |
+
mask2 = masks2[label]
|
533 |
+
|
534 |
+
# Create visualizations
|
535 |
+
masked_img1, masked_img2 = create_mask_overlay(img1_np, img2_np, mask1, mask2, [255, 0, 0]) # Red
|
536 |
+
|
537 |
+
# Plot the masked images
|
538 |
+
axes[row_idx, 0].imshow(masked_img1)
|
539 |
+
axes[row_idx, 0].set_title(f"Image 1: {display_name}")
|
540 |
+
axes[row_idx, 0].axis('off')
|
541 |
+
|
542 |
+
axes[row_idx, 1].imshow(masked_img2)
|
543 |
+
axes[row_idx, 1].set_title(f"Image 2: {display_name}")
|
544 |
+
axes[row_idx, 1].axis('off')
|
545 |
+
|
546 |
+
row_idx += 1
|
547 |
+
|
548 |
+
# Visualize fine-grained human parts
|
549 |
+
part_colors = {
|
550 |
+
'face': [255, 0, 0], # Red (like other masks)
|
551 |
+
'hair': [255, 0, 0], # Red (like other masks)
|
552 |
+
'upper_clothes': [255, 0, 0] # Red (like other masks)
|
553 |
+
}
|
554 |
+
|
555 |
+
for part_name in sorted(common_fine_grained):
|
556 |
+
label = f'l4_{part_name}'
|
557 |
+
|
558 |
+
if label in fine_grained_masks and label in masks2:
|
559 |
+
mask1 = fine_grained_masks[label]
|
560 |
+
mask2 = masks2[label]
|
561 |
+
|
562 |
+
color = part_colors.get(part_name, [255, 0, 0]) # Default to red
|
563 |
+
|
564 |
+
# Create visualizations
|
565 |
+
masked_img1, masked_img2 = create_mask_overlay(img1_np, img2_np, mask1, mask2, color)
|
566 |
+
|
567 |
+
# Plot the masked images
|
568 |
+
display_name = part_name.replace('_', ' ').title()
|
569 |
+
axes[row_idx, 0].imshow(masked_img1)
|
570 |
+
axes[row_idx, 0].set_title(f"Image 1: {display_name} (Fine-grained)")
|
571 |
+
axes[row_idx, 0].axis('off')
|
572 |
+
|
573 |
+
axes[row_idx, 1].imshow(masked_img2)
|
574 |
+
axes[row_idx, 1].set_title(f"Image 2: {display_name} (Fine-grained)")
|
575 |
+
axes[row_idx, 1].axis('off')
|
576 |
+
|
577 |
+
row_idx += 1
|
578 |
+
|
579 |
+
plt.suptitle("Matched Regions (highlighted with different colors)", fontsize=16, color='white')
|
580 |
+
plt.tight_layout()
|
581 |
+
|
582 |
+
return fig
|
583 |
+
|
584 |
+
|
585 |
+
def create_mask_overlay(img1_np, img2_np, mask1, mask2, overlay_color):
|
586 |
+
"""
|
587 |
+
Create mask overlays on images with the specified color.
|
588 |
+
|
589 |
+
Args:
|
590 |
+
img1_np: First image as numpy array
|
591 |
+
img2_np: Second image as numpy array
|
592 |
+
mask1: Mask for first image
|
593 |
+
mask2: Mask for second image
|
594 |
+
overlay_color: RGB color for overlay [R, G, B]
|
595 |
+
|
596 |
+
Returns:
|
597 |
+
Tuple of (masked_img1, masked_img2)
|
598 |
+
"""
|
599 |
+
# Resize masks to match image dimensions if needed
|
600 |
+
if mask1.shape != img1_np.shape[:2]:
|
601 |
+
mask1_img = Image.fromarray((mask1 * 255).astype(np.uint8))
|
602 |
+
mask1_img = mask1_img.resize((img1_np.shape[1], img1_np.shape[0]), Image.NEAREST)
|
603 |
+
mask1 = np.array(mask1_img).astype(np.float32) / 255.0
|
604 |
+
|
605 |
+
if mask2.shape != img2_np.shape[:2]:
|
606 |
+
mask2_img = Image.fromarray((mask2 * 255).astype(np.uint8))
|
607 |
+
mask2_img = mask2_img.resize((img2_np.shape[1], img2_np.shape[0]), Image.NEAREST)
|
608 |
+
mask2 = np.array(mask2_img).astype(np.float32) / 255.0
|
609 |
+
|
610 |
+
# Create masked versions of the images
|
611 |
+
masked_img1 = img1_np.copy()
|
612 |
+
masked_img2 = img2_np.copy()
|
613 |
+
|
614 |
+
# Apply a semi-transparent colored overlay to show the masked region
|
615 |
+
overlay_color = np.array(overlay_color, dtype=np.uint8)
|
616 |
+
|
617 |
+
# Create alpha channel based on the mask (with transparency)
|
618 |
+
alpha1 = mask1 * 0.6 # Increased opacity for better visibility
|
619 |
+
alpha2 = mask2 * 0.6
|
620 |
+
|
621 |
+
# Apply the colored overlay to masked regions
|
622 |
+
for c in range(3):
|
623 |
+
masked_img1[:, :, c] = masked_img1[:, :, c] * (1 - alpha1) + overlay_color[c] * alpha1
|
624 |
+
masked_img2[:, :, c] = masked_img2[:, :, c] * (1 - alpha2) + overlay_color[c] * alpha2
|
625 |
+
|
626 |
+
return masked_img1, masked_img2
|
627 |
+
|
628 |
+
|
629 |
+
def extract_semantic_masks(output):
|
630 |
+
"""
|
631 |
+
Extract binary masks for each semantic region from the LadecoOutput.
|
632 |
+
|
633 |
+
Args:
|
634 |
+
output: LadecoOutput from Ladeco.predict()
|
635 |
+
|
636 |
+
Returns:
|
637 |
+
Dictionary mapping label names to binary masks
|
638 |
+
"""
|
639 |
+
masks = {}
|
640 |
+
|
641 |
+
# Get the segmentation mask
|
642 |
+
seg_mask = output.masks[0].cpu().numpy()
|
643 |
+
|
644 |
+
# Process each label in level 2 (as we're visualizing at level 2)
|
645 |
+
for label, indices in output.ladeco2ade.items():
|
646 |
+
if label.startswith("l2_"):
|
647 |
+
# Create a binary mask for this label
|
648 |
+
binary_mask = np.zeros_like(seg_mask, dtype=np.float32)
|
649 |
+
|
650 |
+
# Set 1 for pixels matching this label
|
651 |
+
for idx in indices:
|
652 |
+
binary_mask[seg_mask == idx] = 1.0
|
653 |
+
|
654 |
+
# Only include labels that have some pixels in the image
|
655 |
+
if np.any(binary_mask):
|
656 |
+
masks[label] = binary_mask
|
657 |
+
|
658 |
+
return masks
|
659 |
+
|
660 |
+
|
661 |
+
def plot_pie(data: dict[str, float], colors=None) -> Figure:
|
662 |
+
fig, ax = plt.subplots()
|
663 |
+
|
664 |
+
labels = list(data.keys())
|
665 |
+
sizes = list(data.values())
|
666 |
+
|
667 |
+
*_, autotexts = ax.pie(sizes, labels=labels, autopct="%1.1f%%", colors=colors)
|
668 |
+
|
669 |
+
for percent_text in autotexts:
|
670 |
+
percent_text.set_color("k")
|
671 |
+
|
672 |
+
ax.axis("equal")
|
673 |
+
|
674 |
+
return fig
|
675 |
+
|
676 |
+
|
677 |
+
def choose_example(imgpath: str, target_component) -> gr.Image:
|
678 |
+
img = Image.open(imgpath)
|
679 |
+
width, height = img.size
|
680 |
+
ratio = 512 / max(width, height)
|
681 |
+
img = img.resize((int(width * ratio), int(height * ratio)))
|
682 |
+
return gr.Image(value=img, label="Input Image (SVG format not supported)", type="filepath")
|
683 |
+
|
684 |
+
|
685 |
+
css = """
|
686 |
+
.reference {
|
687 |
+
text-align: center;
|
688 |
+
font-size: 1.2em;
|
689 |
+
color: #d1d5db;
|
690 |
+
margin-bottom: 20px;
|
691 |
+
}
|
692 |
+
.reference a {
|
693 |
+
color: #FB923C;
|
694 |
+
text-decoration: none;
|
695 |
+
}
|
696 |
+
.reference a:hover {
|
697 |
+
text-decoration: underline;
|
698 |
+
color: #FB923C;
|
699 |
+
}
|
700 |
+
.description {
|
701 |
+
text-align: center;
|
702 |
+
font-size: 1.1em;
|
703 |
+
color: #d1d5db;
|
704 |
+
margin-bottom: 25px;
|
705 |
+
}
|
706 |
+
.footer {
|
707 |
+
text-align: center;
|
708 |
+
margin-top: 30px;
|
709 |
+
padding-top: 20px;
|
710 |
+
border-top: 1px solid #ddd;
|
711 |
+
color: #d1d5db;
|
712 |
+
font-size: 14px;
|
713 |
+
}
|
714 |
+
.main-title {
|
715 |
+
font-size: 24px;
|
716 |
+
font-weight: bold;
|
717 |
+
text-align: center;
|
718 |
+
margin-bottom: 20px;
|
719 |
+
}
|
720 |
+
.selected-image {
|
721 |
+
height: 756px;
|
722 |
+
}
|
723 |
+
.example-image {
|
724 |
+
height: 220px;
|
725 |
+
padding: 25px;
|
726 |
+
}
|
727 |
+
""".strip()
|
728 |
+
theme = gr.themes.Base(
|
729 |
+
primary_hue="orange",
|
730 |
+
secondary_hue="cyan",
|
731 |
+
neutral_hue="gray",
|
732 |
+
).set(
|
733 |
+
body_text_color='*neutral_100',
|
734 |
+
body_text_color_subdued='*neutral_600',
|
735 |
+
background_fill_primary='*neutral_950',
|
736 |
+
background_fill_secondary='*neutral_600',
|
737 |
+
border_color_accent='*secondary_800',
|
738 |
+
color_accent='*primary_50',
|
739 |
+
color_accent_soft='*secondary_800',
|
740 |
+
code_background_fill='*neutral_700',
|
741 |
+
block_background_fill_dark='*body_background_fill',
|
742 |
+
block_info_text_color='#6b7280',
|
743 |
+
block_label_text_color='*neutral_300',
|
744 |
+
block_label_text_weight='700',
|
745 |
+
block_title_text_color='*block_label_text_color',
|
746 |
+
block_title_text_weight='300',
|
747 |
+
panel_background_fill='*neutral_800',
|
748 |
+
table_text_color_dark='*secondary_800',
|
749 |
+
checkbox_background_color_selected='*primary_500',
|
750 |
+
checkbox_label_background_fill='*neutral_500',
|
751 |
+
checkbox_label_background_fill_hover='*neutral_700',
|
752 |
+
checkbox_label_text_color='*neutral_200',
|
753 |
+
input_background_fill='*neutral_700',
|
754 |
+
input_background_fill_focus='*neutral_600',
|
755 |
+
slider_color='*primary_500',
|
756 |
+
table_even_background_fill='*neutral_700',
|
757 |
+
table_odd_background_fill='*neutral_600',
|
758 |
+
table_row_focus='*neutral_800'
|
759 |
+
)
|
760 |
+
with gr.Blocks(css=css, theme=theme) as demo:
|
761 |
+
gr.HTML(
|
762 |
+
"""
|
763 |
+
<div class="main-title">SegMatch – Zero Shot Segmentation-based color matching</div>
|
764 |
+
<div class="description">
|
765 |
+
Advanced region-based color matching using semantic segmentation and fine-grained human parts detection for precise, contextually-aware color transfer between images.
|
766 |
+
</div>
|
767 |
+
""".strip()
|
768 |
+
)
|
769 |
+
|
770 |
+
with gr.Row():
|
771 |
+
# First image inputs
|
772 |
+
with gr.Column():
|
773 |
+
img1 = gr.Image(
|
774 |
+
label="First Input Image - Color Reference (SVG not supported)",
|
775 |
+
type="filepath",
|
776 |
+
height="256px",
|
777 |
+
)
|
778 |
+
gr.Label("Example Images for First Input", show_label=False)
|
779 |
+
with gr.Row():
|
780 |
+
ex1_1 = gr.Image(
|
781 |
+
value="examples/beach.jpg",
|
782 |
+
show_label=False,
|
783 |
+
type="filepath",
|
784 |
+
elem_classes="example-image",
|
785 |
+
interactive=False,
|
786 |
+
show_download_button=False,
|
787 |
+
show_fullscreen_button=False,
|
788 |
+
show_share_button=False,
|
789 |
+
)
|
790 |
+
ex1_2 = gr.Image(
|
791 |
+
value="examples/field.jpg",
|
792 |
+
show_label=False,
|
793 |
+
type="filepath",
|
794 |
+
elem_classes="example-image",
|
795 |
+
interactive=False,
|
796 |
+
show_download_button=False,
|
797 |
+
show_fullscreen_button=False,
|
798 |
+
show_share_button=False,
|
799 |
+
)
|
800 |
+
|
801 |
+
# Second image inputs
|
802 |
+
with gr.Column():
|
803 |
+
img2 = gr.Image(
|
804 |
+
label="Second Input Image - To Be Color Matched (SVG not supported)",
|
805 |
+
type="filepath",
|
806 |
+
height="256px",
|
807 |
+
)
|
808 |
+
gr.Label("Example Images for Second Input", show_label=False)
|
809 |
+
with gr.Row():
|
810 |
+
ex2_1 = gr.Image(
|
811 |
+
value="examples/field.jpg",
|
812 |
+
show_label=False,
|
813 |
+
type="filepath",
|
814 |
+
elem_classes="example-image",
|
815 |
+
interactive=False,
|
816 |
+
show_download_button=False,
|
817 |
+
show_fullscreen_button=False,
|
818 |
+
show_share_button=False,
|
819 |
+
)
|
820 |
+
ex2_2 = gr.Image(
|
821 |
+
value="examples/sky.jpg",
|
822 |
+
show_label=False,
|
823 |
+
type="filepath",
|
824 |
+
elem_classes="example-image",
|
825 |
+
interactive=False,
|
826 |
+
show_download_button=False,
|
827 |
+
show_fullscreen_button=False,
|
828 |
+
show_share_button=False,
|
829 |
+
)
|
830 |
+
|
831 |
+
with gr.Row():
|
832 |
+
with gr.Column():
|
833 |
+
method = gr.Dropdown(
|
834 |
+
label="Color Matching Method",
|
835 |
+
choices=["adain", "mkl", "hm", "reinhard", "mvgd", "hm-mvgd-hm", "hm-mkl-hm", "coral"],
|
836 |
+
value="adain",
|
837 |
+
info="Choose the algorithm for color matching between regions"
|
838 |
+
)
|
839 |
+
|
840 |
+
with gr.Column():
|
841 |
+
enable_face_matching = gr.Checkbox(
|
842 |
+
label="Enable Face Matching for Human Regions",
|
843 |
+
value=True,
|
844 |
+
info="Only match human regions if faces are similar (requires DeepFace)"
|
845 |
+
)
|
846 |
+
|
847 |
+
with gr.Row():
|
848 |
+
with gr.Column():
|
849 |
+
enable_edge_smoothing = gr.Checkbox(
|
850 |
+
label="Enable CDL Edge Smoothing",
|
851 |
+
value=False,
|
852 |
+
info="Apply CDL transform to original image using calculated parameters (see log for values)"
|
853 |
+
)
|
854 |
+
|
855 |
+
start = gr.Button("Start Analysis", variant="primary")
|
856 |
+
|
857 |
+
# Download button positioned right after the start button
|
858 |
+
download_btn = gr.File(
|
859 |
+
label="📥 Download Color-Matched Image",
|
860 |
+
visible=True,
|
861 |
+
interactive=False
|
862 |
+
)
|
863 |
+
|
864 |
+
with gr.Tabs():
|
865 |
+
with gr.TabItem("Segmentation Results"):
|
866 |
+
with gr.Row():
|
867 |
+
# First image results
|
868 |
+
with gr.Column():
|
869 |
+
gr.Label("Results for First Image", show_label=True)
|
870 |
+
seg1 = gr.Plot(label="Semantic Segmentation")
|
871 |
+
pie1 = gr.Plot(label="Element Area Ratio")
|
872 |
+
|
873 |
+
# Second image results
|
874 |
+
with gr.Column():
|
875 |
+
gr.Label("Results for Second Image", show_label=True)
|
876 |
+
seg2 = gr.Plot(label="Semantic Segmentation")
|
877 |
+
pie2 = gr.Plot(label="Element Area Ratio")
|
878 |
+
|
879 |
+
with gr.TabItem("Color Matching"):
|
880 |
+
gr.Markdown("""
|
881 |
+
### Region-Based Color Matching
|
882 |
+
|
883 |
+
This tab shows the result of matching the colors of the second image to the first image's colors,
|
884 |
+
but only within corresponding semantic regions. For example, sky areas in the second image are
|
885 |
+
matched to sky areas in the first image, while vegetation areas are matched separately.
|
886 |
+
|
887 |
+
#### Face Matching Feature:
|
888 |
+
When enabled, the system will detect faces within human/bio regions and only apply color matching
|
889 |
+
to human regions where similar faces are found in both images. This ensures that color transfer
|
890 |
+
only occurs between images of the same person.
|
891 |
+
|
892 |
+
#### CDL Edge Smoothing Feature:
|
893 |
+
When enabled, calculates Color Decision List (CDL) parameters to transform the original target image
|
894 |
+
towards the segment-matched result, then applies those CDL parameters to the original image. This creates
|
895 |
+
a "smoothed" version that maintains the original image's overall characteristics while incorporating the
|
896 |
+
color relationships found through segment matching.
|
897 |
+
|
898 |
+
The CDL calculation uses the simplest possible approach: matches the mean and standard deviation
|
899 |
+
of each color channel between the original and composited images, with simple gamma calculation
|
900 |
+
based on brightness relationships.
|
901 |
+
|
902 |
+
#### Available Methods:
|
903 |
+
- **adain**: Adaptive Instance Normalization - Matches mean and standard deviation of colors
|
904 |
+
- **mkl**: Monge-Kantorovich Linearization - Linear transformation of color statistics
|
905 |
+
- **reinhard**: Reinhard color transfer - Simple statistical approach that matches mean and standard deviation
|
906 |
+
- **mvgd**: Multi-Variate Gaussian Distribution - Uses color covariance matrices for more accurate matching
|
907 |
+
- **hm**: Histogram Matching - Matches the full color distribution histograms
|
908 |
+
- **hm-mvgd-hm**: Histogram + MVGD + Histogram compound method
|
909 |
+
- **hm-mkl-hm**: Histogram + MKL + Histogram compound method
|
910 |
+
- **coral**: CORAL (Color Transfer using Correlated Color Temperature) - Advanced covariance-based method for natural color transfer
|
911 |
+
""")
|
912 |
+
|
913 |
+
# CDL Parameters Display
|
914 |
+
cdl_display = gr.Textbox(
|
915 |
+
label="📊 CDL Parameters",
|
916 |
+
lines=15,
|
917 |
+
max_lines=20,
|
918 |
+
interactive=False,
|
919 |
+
info="Color Decision List parameters calculated when CDL edge smoothing is enabled"
|
920 |
+
)
|
921 |
+
|
922 |
+
face_log = gr.Textbox(
|
923 |
+
label="Face Matching Log",
|
924 |
+
lines=8,
|
925 |
+
max_lines=15,
|
926 |
+
interactive=False,
|
927 |
+
info="Shows details of face detection and matching process"
|
928 |
+
)
|
929 |
+
|
930 |
+
mask_vis = gr.Plot(label="Matched Regions Visualization")
|
931 |
+
comparison = gr.Plot(label="Region-Based Color Matching Result")
|
932 |
+
|
933 |
+
gr.HTML(
|
934 |
+
"""
|
935 |
+
<div class="footer">
|
936 |
+
© 2024 SegMatch All Rights Reserved<br>
|
937 |
+
Developer: Stefan Allen
|
938 |
+
</div>
|
939 |
+
""".strip()
|
940 |
+
)
|
941 |
+
|
942 |
+
# Connect the inference function
|
943 |
+
start.click(
|
944 |
+
fn=infer_two_images,
|
945 |
+
inputs=[img1, img2, method, enable_face_matching, enable_edge_smoothing],
|
946 |
+
outputs=[seg1, pie1, seg2, pie2, comparison, mask_vis, download_btn, face_log, cdl_display]
|
947 |
+
)
|
948 |
+
|
949 |
+
# Example image selection handlers
|
950 |
+
ex1_1.select(fn=lambda x: choose_example(x, img1), inputs=ex1_1, outputs=img1)
|
951 |
+
ex1_2.select(fn=lambda x: choose_example(x, img1), inputs=ex1_2, outputs=img1)
|
952 |
+
ex2_1.select(fn=lambda x: choose_example(x, img2), inputs=ex2_1, outputs=img2)
|
953 |
+
ex2_2.select(fn=lambda x: choose_example(x, img2), inputs=ex2_2, outputs=img2)
|
954 |
+
|
955 |
+
if __name__ == "__main__":
|
956 |
+
demo.launch()
|
cdl_smoothing.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
CDL (Color Decision List) based edge smoothing for SegMatch
|
4 |
+
"""
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from typing import Tuple, Optional
|
8 |
+
from PIL import Image
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
|
12 |
+
def calculate_cdl_params_face_only(source: np.ndarray, target: np.ndarray,
|
13 |
+
source_face_mask: np.ndarray, target_face_mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
14 |
+
"""Calculate CDL parameters using only face pixels for focused accuracy.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
18 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
19 |
+
source_face_mask (np.ndarray): Binary mask of face in source image
|
20 |
+
target_face_mask (np.ndarray): Binary mask of face in target image
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
24 |
+
"""
|
25 |
+
epsilon = 1e-6
|
26 |
+
|
27 |
+
# Extract face pixels only
|
28 |
+
source_face_pixels = source[source_face_mask > 0.5]
|
29 |
+
target_face_pixels = target[target_face_mask > 0.5]
|
30 |
+
|
31 |
+
# Ensure we have enough face pixels
|
32 |
+
if len(source_face_pixels) < 100 or len(target_face_pixels) < 100:
|
33 |
+
# Fallback to simple calculation if not enough face pixels
|
34 |
+
return calculate_cdl_params_simple(source, target)
|
35 |
+
|
36 |
+
slopes = []
|
37 |
+
offsets = []
|
38 |
+
powers = []
|
39 |
+
|
40 |
+
for channel in range(3):
|
41 |
+
src_channel = source_face_pixels[:, channel]
|
42 |
+
tgt_channel = target_face_pixels[:, channel]
|
43 |
+
|
44 |
+
# Use robust percentiles for face pixels
|
45 |
+
percentiles = [10, 25, 50, 75, 90]
|
46 |
+
src_percentiles = np.percentile(src_channel, percentiles)
|
47 |
+
tgt_percentiles = np.percentile(tgt_channel, percentiles)
|
48 |
+
|
49 |
+
# Calculate slope from face pixel range
|
50 |
+
src_range = src_percentiles[4] - src_percentiles[0] # 90th - 10th
|
51 |
+
tgt_range = tgt_percentiles[4] - tgt_percentiles[0]
|
52 |
+
slope = tgt_range / (src_range + epsilon)
|
53 |
+
|
54 |
+
# Calculate offset using face median
|
55 |
+
src_median = src_percentiles[2]
|
56 |
+
tgt_median = tgt_percentiles[2]
|
57 |
+
offset = tgt_median - (src_median * slope)
|
58 |
+
|
59 |
+
# Calculate gamma from face brightness relationship
|
60 |
+
src_mean = np.mean(src_channel)
|
61 |
+
tgt_mean = np.mean(tgt_channel)
|
62 |
+
|
63 |
+
if src_mean > epsilon:
|
64 |
+
power = np.log(tgt_mean + epsilon) / np.log(src_mean + epsilon)
|
65 |
+
power = np.clip(power, 0.3, 3.0)
|
66 |
+
else:
|
67 |
+
power = 1.0
|
68 |
+
|
69 |
+
slopes.append(slope)
|
70 |
+
offsets.append(offset)
|
71 |
+
powers.append(power)
|
72 |
+
|
73 |
+
return np.array(slopes), np.array(offsets), np.array(powers)
|
74 |
+
|
75 |
+
|
76 |
+
def calculate_cdl_params_simple(source: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
77 |
+
"""Simple CDL calculation as fallback method.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
81 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
85 |
+
"""
|
86 |
+
epsilon = 1e-6
|
87 |
+
|
88 |
+
# Calculate mean and standard deviation for each RGB channel
|
89 |
+
source_mean = np.mean(source, axis=(0, 1))
|
90 |
+
source_std = np.std(source, axis=(0, 1))
|
91 |
+
target_mean = np.mean(target, axis=(0, 1))
|
92 |
+
target_std = np.std(target, axis=(0, 1))
|
93 |
+
|
94 |
+
# Calculate slope (gain)
|
95 |
+
slope = target_std / (source_std + epsilon)
|
96 |
+
|
97 |
+
# Calculate offset
|
98 |
+
offset = target_mean - (source_mean * slope)
|
99 |
+
|
100 |
+
# Set power to neutral
|
101 |
+
power = np.ones(3)
|
102 |
+
|
103 |
+
return slope, offset, power
|
104 |
+
|
105 |
+
|
106 |
+
def calculate_cdl_params_histogram(source: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
107 |
+
"""Calculate CDL parameters using histogram matching approach.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
111 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
115 |
+
"""
|
116 |
+
epsilon = 1e-6
|
117 |
+
|
118 |
+
# Convert to 0-255 range for histogram calculation
|
119 |
+
source_255 = (source * 255).astype(np.uint8)
|
120 |
+
target_255 = (target * 255).astype(np.uint8)
|
121 |
+
|
122 |
+
slopes = []
|
123 |
+
offsets = []
|
124 |
+
powers = []
|
125 |
+
|
126 |
+
for channel in range(3):
|
127 |
+
# Calculate histograms
|
128 |
+
hist_source = cv2.calcHist([source_255], [channel], None, [256], [0, 256])
|
129 |
+
hist_target = cv2.calcHist([target_255], [channel], None, [256], [0, 256])
|
130 |
+
|
131 |
+
# Calculate cumulative distributions
|
132 |
+
cdf_source = np.cumsum(hist_source) / np.sum(hist_source)
|
133 |
+
cdf_target = np.cumsum(hist_target) / np.sum(hist_target)
|
134 |
+
|
135 |
+
# Find percentile mappings
|
136 |
+
p25_src = np.percentile(source[:, :, channel], 25)
|
137 |
+
p75_src = np.percentile(source[:, :, channel], 75)
|
138 |
+
p25_tgt = np.percentile(target[:, :, channel], 25)
|
139 |
+
p75_tgt = np.percentile(target[:, :, channel], 75)
|
140 |
+
|
141 |
+
# Calculate slope from percentile mapping
|
142 |
+
slope = (p75_tgt - p25_tgt) / (p75_src - p25_src + epsilon)
|
143 |
+
|
144 |
+
# Calculate offset
|
145 |
+
median_src = np.percentile(source[:, :, channel], 50)
|
146 |
+
median_tgt = np.percentile(target[:, :, channel], 50)
|
147 |
+
offset = median_tgt - (median_src * slope)
|
148 |
+
|
149 |
+
# Estimate power/gamma from the histogram shape
|
150 |
+
mean_src = np.mean(source[:, :, channel])
|
151 |
+
mean_tgt = np.mean(target[:, :, channel])
|
152 |
+
if mean_src > epsilon:
|
153 |
+
power = np.log(mean_tgt + epsilon) / np.log(mean_src + epsilon)
|
154 |
+
power = np.clip(power, 0.1, 10.0) # Reasonable gamma range
|
155 |
+
else:
|
156 |
+
power = 1.0
|
157 |
+
|
158 |
+
slopes.append(slope)
|
159 |
+
offsets.append(offset)
|
160 |
+
powers.append(power)
|
161 |
+
|
162 |
+
return np.array(slopes), np.array(offsets), np.array(powers)
|
163 |
+
|
164 |
+
|
165 |
+
def calculate_cdl_params_mask_aware(source: np.ndarray, target: np.ndarray,
|
166 |
+
changed_mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
167 |
+
"""Calculate CDL parameters focusing only on changed regions.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
171 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
172 |
+
changed_mask (np.ndarray, optional): Binary mask of changed regions
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
176 |
+
"""
|
177 |
+
if changed_mask is not None:
|
178 |
+
# Only use pixels where changes occurred
|
179 |
+
mask_bool = changed_mask > 0.5
|
180 |
+
if np.sum(mask_bool) > 100: # Ensure enough pixels
|
181 |
+
source_masked = source[mask_bool]
|
182 |
+
target_masked = target[mask_bool]
|
183 |
+
|
184 |
+
# Reshape back to have channel dimension
|
185 |
+
source_masked = source_masked.reshape(-1, 3)
|
186 |
+
target_masked = target_masked.reshape(-1, 3)
|
187 |
+
|
188 |
+
# Calculate statistics on masked regions
|
189 |
+
epsilon = 1e-6
|
190 |
+
source_mean = np.mean(source_masked, axis=0)
|
191 |
+
source_std = np.std(source_masked, axis=0)
|
192 |
+
target_mean = np.mean(target_masked, axis=0)
|
193 |
+
target_std = np.std(target_masked, axis=0)
|
194 |
+
|
195 |
+
slope = target_std / (source_std + epsilon)
|
196 |
+
offset = target_mean - (source_mean * slope)
|
197 |
+
power = np.ones(3)
|
198 |
+
|
199 |
+
return slope, offset, power
|
200 |
+
|
201 |
+
# Fallback to simple method if mask is not useful
|
202 |
+
return calculate_cdl_params_simple(source, target)
|
203 |
+
|
204 |
+
|
205 |
+
def calculate_cdl_params_lab(source: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
206 |
+
"""Calculate CDL parameters in LAB color space for better perceptual matching.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
210 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
214 |
+
"""
|
215 |
+
# Convert to LAB color space
|
216 |
+
source_lab = cv2.cvtColor((source * 255).astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
|
217 |
+
target_lab = cv2.cvtColor((target * 255).astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
|
218 |
+
|
219 |
+
# Normalize LAB values
|
220 |
+
source_lab[:, :, 0] /= 100.0 # L: 0-100 -> 0-1
|
221 |
+
source_lab[:, :, 1] = (source_lab[:, :, 1] + 128) / 255.0 # A: -128-127 -> 0-1
|
222 |
+
source_lab[:, :, 2] = (source_lab[:, :, 2] + 128) / 255.0 # B: -128-127 -> 0-1
|
223 |
+
|
224 |
+
target_lab[:, :, 0] /= 100.0
|
225 |
+
target_lab[:, :, 1] = (target_lab[:, :, 1] + 128) / 255.0
|
226 |
+
target_lab[:, :, 2] = (target_lab[:, :, 2] + 128) / 255.0
|
227 |
+
|
228 |
+
# Calculate CDL in LAB space
|
229 |
+
epsilon = 1e-6
|
230 |
+
source_mean = np.mean(source_lab, axis=(0, 1))
|
231 |
+
source_std = np.std(source_lab, axis=(0, 1))
|
232 |
+
target_mean = np.mean(target_lab, axis=(0, 1))
|
233 |
+
target_std = np.std(target_lab, axis=(0, 1))
|
234 |
+
|
235 |
+
slope_lab = target_std / (source_std + epsilon)
|
236 |
+
offset_lab = target_mean - (source_mean * slope_lab)
|
237 |
+
|
238 |
+
# Convert back to RGB approximation
|
239 |
+
# This is a simplified conversion - for full accuracy we'd need to convert each pixel
|
240 |
+
slope = np.array([slope_lab[0], slope_lab[1], slope_lab[2]]) # Rough mapping
|
241 |
+
offset = np.array([offset_lab[0], offset_lab[1], offset_lab[2]])
|
242 |
+
power = np.ones(3)
|
243 |
+
|
244 |
+
return slope, offset, power
|
245 |
+
|
246 |
+
|
247 |
+
def calculate_cdl_params(source: np.ndarray, target: np.ndarray,
|
248 |
+
source_path: str = None, target_path: str = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
249 |
+
"""Calculate CDL parameters using simple mean/std matching - the most basic approach.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
253 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
254 |
+
source_path (str, optional): Ignored - kept for compatibility
|
255 |
+
target_path (str, optional): Ignored - kept for compatibility
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
259 |
+
"""
|
260 |
+
epsilon = 1e-6
|
261 |
+
|
262 |
+
# Calculate simple mean and standard deviation for each RGB channel
|
263 |
+
source_mean = np.mean(source, axis=(0, 1))
|
264 |
+
source_std = np.std(source, axis=(0, 1))
|
265 |
+
target_mean = np.mean(target, axis=(0, 1))
|
266 |
+
target_std = np.std(target, axis=(0, 1))
|
267 |
+
|
268 |
+
# Calculate slope (gain) from std ratio
|
269 |
+
slope = target_std / (source_std + epsilon)
|
270 |
+
|
271 |
+
# Calculate offset from mean difference
|
272 |
+
offset = target_mean - (source_mean * slope)
|
273 |
+
|
274 |
+
# Calculate simple gamma from brightness relationship
|
275 |
+
power = []
|
276 |
+
for channel in range(3):
|
277 |
+
if source_mean[channel] > epsilon:
|
278 |
+
gamma = np.log(target_mean[channel] + epsilon) / np.log(source_mean[channel] + epsilon)
|
279 |
+
gamma = np.clip(gamma, 0.1, 10.0) # Keep within reasonable bounds
|
280 |
+
else:
|
281 |
+
gamma = 1.0
|
282 |
+
power.append(gamma)
|
283 |
+
|
284 |
+
power = np.array(power)
|
285 |
+
|
286 |
+
return slope, offset, power
|
287 |
+
|
288 |
+
|
289 |
+
def calculate_change_mask(original: np.ndarray, composited: np.ndarray, threshold: float = 0.05) -> np.ndarray:
|
290 |
+
"""Calculate a mask of significantly changed regions between original and composited images.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
original (np.ndarray): Original image (0-1 range)
|
294 |
+
composited (np.ndarray): Composited result (0-1 range)
|
295 |
+
threshold (float): Threshold for detecting significant changes
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
np.ndarray: Binary mask of changed regions
|
299 |
+
"""
|
300 |
+
# Calculate per-pixel difference
|
301 |
+
diff = np.sqrt(np.sum((composited - original) ** 2, axis=2))
|
302 |
+
|
303 |
+
# Create binary mask where changes exceed threshold
|
304 |
+
change_mask = (diff > threshold).astype(np.float32)
|
305 |
+
|
306 |
+
# Apply morphological operations to clean up the mask
|
307 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
308 |
+
change_mask = cv2.morphologyEx(change_mask, cv2.MORPH_CLOSE, kernel)
|
309 |
+
|
310 |
+
return change_mask
|
311 |
+
|
312 |
+
|
313 |
+
def calculate_channel_stats(array: np.ndarray) -> dict:
|
314 |
+
"""Calculate per-channel statistics for an image array.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
array: Image array of shape (H, W, 3)
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
dict: Dictionary containing mean, std, min, max for each channel
|
321 |
+
"""
|
322 |
+
stats = {
|
323 |
+
'mean': np.mean(array, axis=(0, 1)),
|
324 |
+
'std': np.std(array, axis=(0, 1)),
|
325 |
+
'min': np.min(array, axis=(0, 1)),
|
326 |
+
'max': np.max(array, axis=(0, 1))
|
327 |
+
}
|
328 |
+
return stats
|
329 |
+
|
330 |
+
|
331 |
+
def apply_cdl_transform(image: np.ndarray, slope: np.ndarray, offset: np.ndarray, power: np.ndarray,
|
332 |
+
factor: float = 0.3) -> np.ndarray:
|
333 |
+
"""Apply CDL transformation to an image.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
image (np.ndarray): Input image (0-1 range)
|
337 |
+
slope (np.ndarray): CDL slope parameters for each channel
|
338 |
+
offset (np.ndarray): CDL offset parameters for each channel
|
339 |
+
power (np.ndarray): CDL power parameters for each channel
|
340 |
+
factor (float): Blending factor (0.0 = no change, 1.0 = full transform)
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
np.ndarray: Transformed image
|
344 |
+
"""
|
345 |
+
# Apply CDL transform: out = ((in * slope) + offset) ** power
|
346 |
+
transformed = np.power(np.maximum(image * slope + offset, 0), power)
|
347 |
+
|
348 |
+
# Clamp to valid range
|
349 |
+
transformed = np.clip(transformed, 0.0, 1.0)
|
350 |
+
|
351 |
+
# Blend with original based on factor
|
352 |
+
result = (1 - factor) * image + factor * transformed
|
353 |
+
|
354 |
+
return result
|
355 |
+
|
356 |
+
|
357 |
+
def cdl_edge_smoothing(composited_image_path: str, original_image_path: str, factor: float = 0.3) -> Image.Image:
|
358 |
+
"""Apply CDL-based edge smoothing between composited result and original image.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
composited_image_path (str): Path to the composited result image
|
362 |
+
original_image_path (str): Path to the original target image
|
363 |
+
factor (float): Smoothing strength (0.0 = no smoothing, 1.0 = full smoothing)
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
Image.Image: Smoothed result image
|
367 |
+
"""
|
368 |
+
# Load images
|
369 |
+
composited_img = Image.open(composited_image_path).convert("RGB")
|
370 |
+
original_img = Image.open(original_image_path).convert("RGB")
|
371 |
+
|
372 |
+
# Ensure same dimensions
|
373 |
+
if composited_img.size != original_img.size:
|
374 |
+
composited_img = composited_img.resize(original_img.size, Image.LANCZOS)
|
375 |
+
|
376 |
+
# Convert to numpy arrays (0-1 range)
|
377 |
+
composited_np = np.array(composited_img).astype(np.float32) / 255.0
|
378 |
+
original_np = np.array(original_img).astype(np.float32) / 255.0
|
379 |
+
|
380 |
+
# Calculate CDL parameters to transform composited to match original
|
381 |
+
slope, offset, power = calculate_cdl_params(composited_np, original_np)
|
382 |
+
|
383 |
+
# Apply CDL transformation with blending
|
384 |
+
smoothed_np = apply_cdl_transform(composited_np, slope, offset, power, factor)
|
385 |
+
|
386 |
+
# Convert back to PIL Image
|
387 |
+
smoothed_img = Image.fromarray((smoothed_np * 255).astype(np.uint8))
|
388 |
+
|
389 |
+
return smoothed_img
|
390 |
+
|
391 |
+
|
392 |
+
def get_smoothing_stats(original_image_path: str, composited_image_path: str) -> dict:
|
393 |
+
"""Get statistics about the CDL transformation for debugging.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
original_image_path (str): Path to the original target image
|
397 |
+
composited_image_path (str): Path to the composited result image
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
dict: Statistics about the transformation
|
401 |
+
"""
|
402 |
+
# Load images
|
403 |
+
composited_img = Image.open(composited_image_path).convert("RGB")
|
404 |
+
original_img = Image.open(original_image_path).convert("RGB")
|
405 |
+
|
406 |
+
# Ensure same dimensions
|
407 |
+
if composited_img.size != original_img.size:
|
408 |
+
composited_img = composited_img.resize(original_img.size, Image.LANCZOS)
|
409 |
+
|
410 |
+
# Convert to numpy arrays (0-1 range)
|
411 |
+
composited_np = np.array(composited_img).astype(np.float32) / 255.0
|
412 |
+
original_np = np.array(original_img).astype(np.float32) / 255.0
|
413 |
+
|
414 |
+
# Calculate statistics
|
415 |
+
composited_stats = calculate_channel_stats(composited_np)
|
416 |
+
original_stats = calculate_channel_stats(original_np)
|
417 |
+
|
418 |
+
# Calculate CDL parameters using face-based method when possible
|
419 |
+
slope, offset, power = calculate_cdl_params(original_np, composited_np,
|
420 |
+
original_image_path, composited_image_path)
|
421 |
+
|
422 |
+
return {
|
423 |
+
'composited_stats': composited_stats,
|
424 |
+
'original_stats': original_stats,
|
425 |
+
'cdl_slope': slope,
|
426 |
+
'cdl_offset': offset,
|
427 |
+
'cdl_power': power
|
428 |
+
}
|
429 |
+
|
430 |
+
|
431 |
+
def cdl_edge_smoothing_apply_to_source(source_image_path: str, target_image_path: str, factor: float = 1.0) -> Image.Image:
|
432 |
+
"""Apply CDL transformation to source image using face-based parameters when possible.
|
433 |
+
|
434 |
+
This function:
|
435 |
+
1. Calculates CDL parameters to transform source to match target (using face pixels when available)
|
436 |
+
2. Applies those CDL parameters to the entire source image
|
437 |
+
3. Returns the transformed source image
|
438 |
+
|
439 |
+
Args:
|
440 |
+
source_image_path (str): Path to the source image (to be transformed)
|
441 |
+
target_image_path (str): Path to the target image (reference for CDL calculation)
|
442 |
+
factor (float): Transform strength (0.0 = no change, 1.0 = full transform)
|
443 |
+
|
444 |
+
Returns:
|
445 |
+
Image.Image: Source image with CDL transformation applied
|
446 |
+
"""
|
447 |
+
# Load images
|
448 |
+
source_img = Image.open(source_image_path).convert("RGB")
|
449 |
+
target_img = Image.open(target_image_path).convert("RGB")
|
450 |
+
|
451 |
+
# Ensure same dimensions
|
452 |
+
if source_img.size != target_img.size:
|
453 |
+
target_img = target_img.resize(source_img.size, Image.LANCZOS)
|
454 |
+
|
455 |
+
# Convert to numpy arrays (0-1 range)
|
456 |
+
source_np = np.array(source_img).astype(np.float32) / 255.0
|
457 |
+
target_np = np.array(target_img).astype(np.float32) / 255.0
|
458 |
+
|
459 |
+
# Calculate CDL parameters using face-based method when possible
|
460 |
+
slope, offset, power = calculate_cdl_params(source_np, target_np,
|
461 |
+
source_image_path, target_image_path)
|
462 |
+
|
463 |
+
# Apply CDL transformation to the entire source image
|
464 |
+
transformed_np = apply_cdl_transform(source_np, slope, offset, power, factor)
|
465 |
+
|
466 |
+
# Convert back to PIL Image
|
467 |
+
transformed_img = Image.fromarray((transformed_np * 255).astype(np.uint8))
|
468 |
+
|
469 |
+
return transformed_img
|
470 |
+
|
471 |
+
|
472 |
+
def extract_face_mask(image_path: str) -> Optional[np.ndarray]:
|
473 |
+
"""Extract face mask from an image using human parts segmentation.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
image_path (str): Path to the image
|
477 |
+
|
478 |
+
Returns:
|
479 |
+
np.ndarray or None: Binary face mask, or None if no face found
|
480 |
+
"""
|
481 |
+
try:
|
482 |
+
from human_parts_segmentation import HumanPartsSegmentation
|
483 |
+
|
484 |
+
segmenter = HumanPartsSegmentation()
|
485 |
+
masks_dict = segmenter.segment_parts(image_path, ['face'])
|
486 |
+
|
487 |
+
if 'face' in masks_dict and masks_dict['face'] is not None:
|
488 |
+
face_mask = masks_dict['face']
|
489 |
+
# Ensure it's a proper binary mask
|
490 |
+
if np.sum(face_mask > 0.5) > 100: # At least 100 face pixels
|
491 |
+
return face_mask
|
492 |
+
|
493 |
+
return None
|
494 |
+
|
495 |
+
except Exception as e:
|
496 |
+
print(f"Face extraction failed: {e}")
|
497 |
+
return None
|
clothes_segmentation.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from typing import Union, Tuple
|
6 |
+
from PIL import Image, ImageFilter
|
7 |
+
import cv2
|
8 |
+
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
# Device configuration
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
|
15 |
+
# Model configuration
|
16 |
+
AVAILABLE_MODELS = {
|
17 |
+
"segformer_b2_clothes": "1038lab/segformer_clothes"
|
18 |
+
}
|
19 |
+
|
20 |
+
# Model paths
|
21 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
22 |
+
models_dir = os.path.join(current_dir, "models")
|
23 |
+
|
24 |
+
|
25 |
+
def pil2tensor(image: Image.Image) -> torch.Tensor:
|
26 |
+
"""Convert PIL Image to tensor."""
|
27 |
+
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
|
28 |
+
|
29 |
+
|
30 |
+
def tensor2pil(image: torch.Tensor) -> Image.Image:
|
31 |
+
"""Convert tensor to PIL Image."""
|
32 |
+
return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8))
|
33 |
+
|
34 |
+
|
35 |
+
def image2mask(image: Image.Image) -> torch.Tensor:
|
36 |
+
"""Convert image to mask tensor."""
|
37 |
+
if isinstance(image, Image.Image):
|
38 |
+
image = pil2tensor(image)
|
39 |
+
return image.squeeze()[..., 0]
|
40 |
+
|
41 |
+
|
42 |
+
def mask2image(mask: torch.Tensor) -> Image.Image:
|
43 |
+
"""Convert mask tensor to PIL Image."""
|
44 |
+
if len(mask.shape) == 2:
|
45 |
+
mask = mask.unsqueeze(0)
|
46 |
+
return tensor2pil(mask)
|
47 |
+
|
48 |
+
|
49 |
+
class ClothesSegmentation:
|
50 |
+
"""
|
51 |
+
Standalone clothing segmentation using Segformer model.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self):
|
55 |
+
self.processor = None
|
56 |
+
self.model = None
|
57 |
+
self.cache_dir = os.path.join(models_dir, "RMBG", "segformer_clothes")
|
58 |
+
|
59 |
+
# Class mapping for segmentation - consistent with latest repo
|
60 |
+
self.class_map = {
|
61 |
+
"Background": 0, "Hat": 1, "Hair": 2, "Sunglasses": 3,
|
62 |
+
"Upper-clothes": 4, "Skirt": 5, "Pants": 6, "Dress": 7,
|
63 |
+
"Belt": 8, "Left-shoe": 9, "Right-shoe": 10, "Face": 11,
|
64 |
+
"Left-leg": 12, "Right-leg": 13, "Left-arm": 14, "Right-arm": 15,
|
65 |
+
"Bag": 16, "Scarf": 17
|
66 |
+
}
|
67 |
+
|
68 |
+
def check_model_cache(self):
|
69 |
+
"""Check if model files exist in cache."""
|
70 |
+
if not os.path.exists(self.cache_dir):
|
71 |
+
return False, "Model directory not found"
|
72 |
+
|
73 |
+
required_files = [
|
74 |
+
'config.json',
|
75 |
+
'model.safetensors',
|
76 |
+
'preprocessor_config.json'
|
77 |
+
]
|
78 |
+
|
79 |
+
missing_files = [f for f in required_files if not os.path.exists(os.path.join(self.cache_dir, f))]
|
80 |
+
if missing_files:
|
81 |
+
return False, f"Required model files missing: {', '.join(missing_files)}"
|
82 |
+
return True, "Model cache verified"
|
83 |
+
|
84 |
+
def clear_model(self):
|
85 |
+
"""Clear model from memory - improved version."""
|
86 |
+
if self.model is not None:
|
87 |
+
self.model.cpu()
|
88 |
+
del self.model
|
89 |
+
self.model = None
|
90 |
+
self.processor = None
|
91 |
+
if torch.cuda.is_available():
|
92 |
+
torch.cuda.empty_cache()
|
93 |
+
|
94 |
+
def download_model_files(self):
|
95 |
+
"""Download model files from Hugging Face - improved version."""
|
96 |
+
model_id = AVAILABLE_MODELS["segformer_b2_clothes"]
|
97 |
+
model_files = {
|
98 |
+
'config.json': 'config.json',
|
99 |
+
'model.safetensors': 'model.safetensors',
|
100 |
+
'preprocessor_config.json': 'preprocessor_config.json'
|
101 |
+
}
|
102 |
+
|
103 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
104 |
+
print(f"Downloading Clothes Segformer model files...")
|
105 |
+
|
106 |
+
try:
|
107 |
+
for save_name, repo_path in model_files.items():
|
108 |
+
print(f"Downloading {save_name}...")
|
109 |
+
downloaded_path = hf_hub_download(
|
110 |
+
repo_id=model_id,
|
111 |
+
filename=repo_path,
|
112 |
+
local_dir=self.cache_dir,
|
113 |
+
local_dir_use_symlinks=False
|
114 |
+
)
|
115 |
+
|
116 |
+
if os.path.dirname(downloaded_path) != self.cache_dir:
|
117 |
+
target_path = os.path.join(self.cache_dir, save_name)
|
118 |
+
shutil.move(downloaded_path, target_path)
|
119 |
+
return True, "Model files downloaded successfully"
|
120 |
+
except Exception as e:
|
121 |
+
return False, f"Error downloading model files: {str(e)}"
|
122 |
+
|
123 |
+
def load_model(self):
|
124 |
+
"""Load the clothing segmentation model - improved version."""
|
125 |
+
try:
|
126 |
+
# Check and download model if needed
|
127 |
+
cache_status, message = self.check_model_cache()
|
128 |
+
if not cache_status:
|
129 |
+
print(f"Cache check: {message}")
|
130 |
+
download_status, download_message = self.download_model_files()
|
131 |
+
if not download_status:
|
132 |
+
print(f"❌ {download_message}")
|
133 |
+
return False
|
134 |
+
|
135 |
+
# Load model if needed
|
136 |
+
if self.processor is None:
|
137 |
+
print("Loading clothes segmentation model...")
|
138 |
+
self.processor = SegformerImageProcessor.from_pretrained(self.cache_dir)
|
139 |
+
self.model = AutoModelForSemanticSegmentation.from_pretrained(self.cache_dir)
|
140 |
+
self.model.eval()
|
141 |
+
for param in self.model.parameters():
|
142 |
+
param.requires_grad = False
|
143 |
+
self.model.to(device)
|
144 |
+
print("✅ Clothes segmentation model loaded successfully")
|
145 |
+
|
146 |
+
return True
|
147 |
+
|
148 |
+
except Exception as e:
|
149 |
+
print(f"❌ Error loading clothes model: {e}")
|
150 |
+
self.clear_model() # Cleanup on error
|
151 |
+
return False
|
152 |
+
|
153 |
+
def segment_clothes(self, image_path: str, target_classes: list = None, process_res: int = 512) -> np.ndarray:
|
154 |
+
"""
|
155 |
+
Segment clothing from an image - improved version with process_res parameter.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
image_path: Path to the image
|
159 |
+
target_classes: List of clothing classes to segment (default: ["Upper-clothes"])
|
160 |
+
process_res: Processing resolution (default: 512)
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
Binary mask as numpy array
|
164 |
+
"""
|
165 |
+
if target_classes is None:
|
166 |
+
target_classes = ["Upper-clothes"]
|
167 |
+
|
168 |
+
if not self.load_model():
|
169 |
+
print("❌ Cannot load clothes segmentation model")
|
170 |
+
return None
|
171 |
+
|
172 |
+
try:
|
173 |
+
# Load and preprocess image
|
174 |
+
image = cv2.imread(image_path)
|
175 |
+
if image is None:
|
176 |
+
print(f"❌ Could not load image: {image_path}")
|
177 |
+
return None
|
178 |
+
|
179 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
180 |
+
original_size = image_rgb.shape[:2]
|
181 |
+
|
182 |
+
# Preprocess image with custom resolution
|
183 |
+
pil_image = Image.fromarray(image_rgb)
|
184 |
+
|
185 |
+
# Resize for processing if needed
|
186 |
+
if process_res != 512:
|
187 |
+
pil_image = pil_image.resize((process_res, process_res), Image.Resampling.LANCZOS)
|
188 |
+
|
189 |
+
inputs = self.processor(images=pil_image, return_tensors="pt")
|
190 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
191 |
+
|
192 |
+
# Run inference
|
193 |
+
with torch.no_grad():
|
194 |
+
outputs = self.model(**inputs)
|
195 |
+
logits = outputs.logits.cpu()
|
196 |
+
|
197 |
+
# Resize logits to original image size
|
198 |
+
upsampled_logits = nn.functional.interpolate(
|
199 |
+
logits,
|
200 |
+
size=original_size,
|
201 |
+
mode="bilinear",
|
202 |
+
align_corners=False,
|
203 |
+
)
|
204 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
205 |
+
|
206 |
+
# Combine selected class masks
|
207 |
+
combined_mask = None
|
208 |
+
for class_name in target_classes:
|
209 |
+
if class_name in self.class_map:
|
210 |
+
mask = (pred_seg == self.class_map[class_name]).float()
|
211 |
+
if combined_mask is None:
|
212 |
+
combined_mask = mask
|
213 |
+
else:
|
214 |
+
combined_mask = torch.clamp(combined_mask + mask, 0, 1)
|
215 |
+
else:
|
216 |
+
print(f"⚠️ Unknown class: {class_name}")
|
217 |
+
|
218 |
+
if combined_mask is None:
|
219 |
+
print(f"❌ No valid classes found in: {target_classes}")
|
220 |
+
return None
|
221 |
+
|
222 |
+
# Convert to numpy
|
223 |
+
mask_np = combined_mask.numpy().astype(np.float32)
|
224 |
+
|
225 |
+
return mask_np
|
226 |
+
|
227 |
+
except Exception as e:
|
228 |
+
print(f"❌ Error in clothes segmentation: {e}")
|
229 |
+
return None
|
230 |
+
finally:
|
231 |
+
# Clean up model if not training (consistent with updated repo)
|
232 |
+
if self.model is not None and not self.model.training:
|
233 |
+
self.clear_model()
|
234 |
+
|
235 |
+
def segment_clothes_with_filters(self, image_path: str, target_classes: list = None,
|
236 |
+
mask_blur: int = 0, mask_offset: int = 0,
|
237 |
+
process_res: int = 512) -> np.ndarray:
|
238 |
+
"""
|
239 |
+
Segment clothing with additional filtering options - new method from updated repo.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
image_path: Path to the image
|
243 |
+
target_classes: List of clothing classes to segment
|
244 |
+
mask_blur: Blur amount for mask edges
|
245 |
+
mask_offset: Expand/Shrink mask boundary
|
246 |
+
process_res: Processing resolution
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
Filtered binary mask as numpy array
|
250 |
+
"""
|
251 |
+
# Get initial mask
|
252 |
+
mask = self.segment_clothes(image_path, target_classes, process_res)
|
253 |
+
if mask is None:
|
254 |
+
return None
|
255 |
+
|
256 |
+
try:
|
257 |
+
# Convert to PIL for filtering
|
258 |
+
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
|
259 |
+
|
260 |
+
# Apply blur if specified
|
261 |
+
if mask_blur > 0:
|
262 |
+
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
263 |
+
|
264 |
+
# Apply offset if specified
|
265 |
+
if mask_offset != 0:
|
266 |
+
if mask_offset > 0:
|
267 |
+
mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1))
|
268 |
+
else:
|
269 |
+
mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1))
|
270 |
+
|
271 |
+
# Convert back to numpy
|
272 |
+
filtered_mask = np.array(mask_image).astype(np.float32) / 255.0
|
273 |
+
return filtered_mask
|
274 |
+
|
275 |
+
except Exception as e:
|
276 |
+
print(f"❌ Error applying filters: {e}")
|
277 |
+
return mask
|
278 |
+
|
279 |
+
|
280 |
+
# Standalone function for easy usage
|
281 |
+
def segment_upper_clothes(image_path: str) -> np.ndarray:
|
282 |
+
"""
|
283 |
+
Convenience function to segment upper clothes from an image.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
image_path: Path to the image
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
Binary mask as numpy array or None if failed
|
290 |
+
"""
|
291 |
+
segmenter = ClothesSegmentation()
|
292 |
+
return segmenter.segment_clothes(image_path, ["Upper-clothes"])
|
color_matching.py
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import matplotlib.figure as figure
|
6 |
+
from matplotlib.figure import Figure
|
7 |
+
import numpy.typing as npt
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import tempfile
|
11 |
+
import time
|
12 |
+
|
13 |
+
class RegionColorMatcher:
|
14 |
+
def __init__(self, factor=1.0, preserve_colors=True, preserve_luminance=True, method="adain"):
|
15 |
+
"""
|
16 |
+
Initialize the RegionColorMatcher.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
factor: Strength of the color matching (0.0 to 1.0)
|
20 |
+
preserve_colors: If True, convert to YUV and preserve color relationships
|
21 |
+
preserve_luminance: If True, preserve the luminance when in YUV mode
|
22 |
+
method: The color matching method to use (adain, mkl, hm, reinhard, mvgd, hm-mvgd-hm, hm-mkl-hm)
|
23 |
+
"""
|
24 |
+
self.factor = factor
|
25 |
+
self.preserve_colors = preserve_colors
|
26 |
+
self.preserve_luminance = preserve_luminance
|
27 |
+
self.method = method
|
28 |
+
|
29 |
+
def match_regions(self, img1_path, img2_path, masks1, masks2):
|
30 |
+
"""
|
31 |
+
Match colors between corresponding masked regions of two images.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
img1_path: Path to first image
|
35 |
+
img2_path: Path to second image
|
36 |
+
masks1: Dictionary of masks for first image {label: binary_mask}
|
37 |
+
masks2: Dictionary of masks for second image {label: binary_mask}
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
A PIL Image with the color-matched result
|
41 |
+
"""
|
42 |
+
print(f"🎨 Color matching with method: {self.method}")
|
43 |
+
print(f"📊 Processing {len(masks1)} regions from img1 and {len(masks2)} regions from img2")
|
44 |
+
|
45 |
+
# Load images
|
46 |
+
img1 = Image.open(img1_path).convert("RGB")
|
47 |
+
img2 = Image.open(img2_path).convert("RGB")
|
48 |
+
|
49 |
+
# Convert to numpy arrays and normalize to [0,1]
|
50 |
+
img1_np = np.array(img1).astype(np.float32) / 255.0
|
51 |
+
img2_np = np.array(img2).astype(np.float32) / 255.0
|
52 |
+
|
53 |
+
# Create a copy of the second image as our base for color matching
|
54 |
+
# We want to make img2 look like img1's colors
|
55 |
+
result_np = np.copy(img2_np)
|
56 |
+
|
57 |
+
# Convert images to PyTorch tensors
|
58 |
+
img1_tensor = torch.from_numpy(img1_np)
|
59 |
+
img2_tensor = torch.from_numpy(img2_np)
|
60 |
+
result_tensor = torch.from_numpy(result_np)
|
61 |
+
|
62 |
+
# Track coverage to ensure all regions are processed
|
63 |
+
total_coverage = np.zeros(img2_np.shape[:2], dtype=np.float32)
|
64 |
+
processed_regions = 0
|
65 |
+
|
66 |
+
# Process each mask region
|
67 |
+
for label, mask1 in masks1.items():
|
68 |
+
if label not in masks2:
|
69 |
+
print(f"⚠️ Skipping {label} - not found in masks2")
|
70 |
+
continue
|
71 |
+
|
72 |
+
mask2 = masks2[label]
|
73 |
+
|
74 |
+
# Resize masks to match image dimensions if needed
|
75 |
+
if mask1.shape != img1_np.shape[:2]:
|
76 |
+
mask1 = self._resize_mask(mask1, img1_np.shape[:2])
|
77 |
+
|
78 |
+
if mask2.shape != img2_np.shape[:2]:
|
79 |
+
mask2 = self._resize_mask(mask2, img2_np.shape[:2])
|
80 |
+
|
81 |
+
# Check mask coverage
|
82 |
+
mask1_pixels = np.sum(mask1 > 0)
|
83 |
+
mask2_pixels = np.sum(mask2 > 0)
|
84 |
+
print(f"🔍 Processing {label}: {mask1_pixels} pixels (img1) → {mask2_pixels} pixels (img2)")
|
85 |
+
|
86 |
+
if mask1_pixels == 0 or mask2_pixels == 0:
|
87 |
+
print(f"⚠️ Skipping {label} - no pixels in mask")
|
88 |
+
continue
|
89 |
+
|
90 |
+
# Track coverage
|
91 |
+
total_coverage += (mask2 > 0).astype(np.float32)
|
92 |
+
processed_regions += 1
|
93 |
+
|
94 |
+
# Convert masks to torch tensors
|
95 |
+
mask1_tensor = torch.from_numpy(mask1.astype(np.float32))
|
96 |
+
mask2_tensor = torch.from_numpy(mask2.astype(np.float32))
|
97 |
+
|
98 |
+
# Apply color matching for this region based on selected method
|
99 |
+
if self.method == "adain":
|
100 |
+
result_tensor = self._apply_adain_to_region(
|
101 |
+
img1_tensor,
|
102 |
+
img2_tensor,
|
103 |
+
result_tensor,
|
104 |
+
mask1_tensor,
|
105 |
+
mask2_tensor
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
result_tensor = self._apply_color_matcher_to_region(
|
109 |
+
img1_tensor,
|
110 |
+
img2_tensor,
|
111 |
+
result_tensor,
|
112 |
+
mask1_tensor,
|
113 |
+
mask2_tensor,
|
114 |
+
self.method
|
115 |
+
)
|
116 |
+
|
117 |
+
print(f"✅ Completed color matching for {label}")
|
118 |
+
|
119 |
+
# Debug coverage
|
120 |
+
total_pixels = img2_np.shape[0] * img2_np.shape[1]
|
121 |
+
covered_pixels = np.sum(total_coverage > 0)
|
122 |
+
overlap_pixels = np.sum(total_coverage > 1)
|
123 |
+
|
124 |
+
print(f"📊 Coverage summary:")
|
125 |
+
print(f" Total image pixels: {total_pixels}")
|
126 |
+
print(f" Covered pixels: {covered_pixels} ({100*covered_pixels/total_pixels:.1f}%)")
|
127 |
+
print(f" Overlapping pixels: {overlap_pixels} ({100*overlap_pixels/total_pixels:.1f}%)")
|
128 |
+
print(f" Processed regions: {processed_regions}")
|
129 |
+
|
130 |
+
# Convert back to numpy, scale to [0,255] and convert to uint8
|
131 |
+
result_np = (result_tensor.numpy() * 255.0).astype(np.uint8)
|
132 |
+
|
133 |
+
# Convert to PIL Image
|
134 |
+
result_img = Image.fromarray(result_np)
|
135 |
+
|
136 |
+
return result_img
|
137 |
+
|
138 |
+
def _resize_mask(self, mask, target_shape):
|
139 |
+
"""
|
140 |
+
Resize a mask to match the target shape.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
mask: Binary mask array
|
144 |
+
target_shape: Target shape (height, width)
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Resized mask array
|
148 |
+
"""
|
149 |
+
# Convert to PIL Image for resizing
|
150 |
+
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
|
151 |
+
|
152 |
+
# Resize to target shape
|
153 |
+
mask_img = mask_img.resize((target_shape[1], target_shape[0]), Image.NEAREST)
|
154 |
+
|
155 |
+
# Convert back to numpy array and normalize to [0,1]
|
156 |
+
resized_mask = np.array(mask_img).astype(np.float32) / 255.0
|
157 |
+
|
158 |
+
return resized_mask
|
159 |
+
|
160 |
+
def _apply_adain_to_region(self, source_img, target_img, result_img, source_mask, target_mask):
|
161 |
+
"""
|
162 |
+
Apply AdaIN to match the statistics of the masked region in source to the target.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
source_img: Source image tensor [H,W,3] (reference for color matching)
|
166 |
+
target_img: Target image tensor [H,W,3] (to be color matched)
|
167 |
+
result_img: Result image tensor to modify [H,W,3]
|
168 |
+
source_mask: Binary mask for source image [H,W]
|
169 |
+
target_mask: Binary mask for target image [H,W]
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
Modified result tensor
|
173 |
+
"""
|
174 |
+
# Ensure masks are binary
|
175 |
+
source_mask_binary = (source_mask > 0.5).float()
|
176 |
+
target_mask_binary = (target_mask > 0.5).float()
|
177 |
+
|
178 |
+
# If preserving colors, convert to YUV
|
179 |
+
if self.preserve_colors:
|
180 |
+
# RGB to YUV conversion matrix
|
181 |
+
rgb_to_yuv = torch.tensor([
|
182 |
+
[0.299, 0.587, 0.114],
|
183 |
+
[-0.14713, -0.28886, 0.436],
|
184 |
+
[0.615, -0.51499, -0.10001]
|
185 |
+
])
|
186 |
+
|
187 |
+
# Convert to YUV
|
188 |
+
source_yuv = torch.matmul(source_img, rgb_to_yuv.T)
|
189 |
+
target_yuv = torch.matmul(target_img, rgb_to_yuv.T)
|
190 |
+
result_yuv = torch.matmul(result_img, rgb_to_yuv.T)
|
191 |
+
|
192 |
+
# Only normalize Y channel if preserving luminance is False
|
193 |
+
channels_to_process = [0] if not self.preserve_luminance else []
|
194 |
+
|
195 |
+
# Always process U and V channels (chroma)
|
196 |
+
channels_to_process.extend([1, 2])
|
197 |
+
|
198 |
+
# Process selected channels
|
199 |
+
for c in channels_to_process:
|
200 |
+
# Apply the color matching only to the masked region in the result
|
201 |
+
result_channel = result_yuv[:,:,c]
|
202 |
+
matched_channel = self._match_channel_statistics(
|
203 |
+
source_yuv[:,:,c],
|
204 |
+
target_yuv[:,:,c],
|
205 |
+
result_channel,
|
206 |
+
source_mask_binary,
|
207 |
+
target_mask_binary
|
208 |
+
)
|
209 |
+
|
210 |
+
# Only update the masked region in the result
|
211 |
+
mask_expanded = target_mask_binary.unsqueeze(-1).expand_as(result_yuv)[:,:,c]
|
212 |
+
result_yuv[:,:,c] = torch.where(
|
213 |
+
mask_expanded > 0.5,
|
214 |
+
matched_channel,
|
215 |
+
result_channel
|
216 |
+
)
|
217 |
+
|
218 |
+
# Convert back to RGB
|
219 |
+
yuv_to_rgb = torch.tensor([
|
220 |
+
[1.0, 0.0, 1.13983],
|
221 |
+
[1.0, -0.39465, -0.58060],
|
222 |
+
[1.0, 2.03211, 0.0]
|
223 |
+
])
|
224 |
+
|
225 |
+
result_rgb = torch.matmul(result_yuv, yuv_to_rgb.T)
|
226 |
+
|
227 |
+
# Only update the masked region in the result
|
228 |
+
mask_expanded = target_mask_binary.unsqueeze(-1).expand_as(result_img)
|
229 |
+
result_img = torch.where(
|
230 |
+
mask_expanded > 0.5,
|
231 |
+
result_rgb,
|
232 |
+
result_img
|
233 |
+
)
|
234 |
+
|
235 |
+
else:
|
236 |
+
# Process each RGB channel separately
|
237 |
+
for c in range(3):
|
238 |
+
# Apply the color matching only to the masked region in the result
|
239 |
+
result_channel = result_img[:,:,c]
|
240 |
+
matched_channel = self._match_channel_statistics(
|
241 |
+
source_img[:,:,c],
|
242 |
+
target_img[:,:,c],
|
243 |
+
result_channel,
|
244 |
+
source_mask_binary,
|
245 |
+
target_mask_binary
|
246 |
+
)
|
247 |
+
|
248 |
+
# Only update the masked region in the result
|
249 |
+
mask_expanded = target_mask_binary.unsqueeze(-1).expand_as(result_img)[:,:,c]
|
250 |
+
result_img[:,:,c] = torch.where(
|
251 |
+
mask_expanded > 0.5,
|
252 |
+
matched_channel,
|
253 |
+
result_channel
|
254 |
+
)
|
255 |
+
|
256 |
+
# Ensure values are in valid range [0, 1]
|
257 |
+
return torch.clamp(result_img, 0.0, 1.0)
|
258 |
+
|
259 |
+
def _apply_color_matcher_to_region(self, source_img, target_img, result_img, source_mask, target_mask, method):
|
260 |
+
"""
|
261 |
+
Apply color-matcher library methods to match the statistics of the masked region in source to the target.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
source_img: Source image tensor [H,W,3] (reference for color matching)
|
265 |
+
target_img: Target image tensor [H,W,3] (to be color matched)
|
266 |
+
result_img: Result image tensor to modify [H,W,3]
|
267 |
+
source_mask: Binary mask for source image [H,W]
|
268 |
+
target_mask: Binary mask for target image [H,W]
|
269 |
+
method: The color matching method to use (mkl, hm, reinhard, mvgd, hm-mvgd-hm, hm-mkl-hm)
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
Modified result tensor
|
273 |
+
"""
|
274 |
+
# Ensure masks are binary
|
275 |
+
source_mask_binary = (source_mask > 0.5).float()
|
276 |
+
target_mask_binary = (target_mask > 0.5).float()
|
277 |
+
|
278 |
+
# Convert tensors to numpy arrays
|
279 |
+
source_np = source_img.detach().cpu().numpy()
|
280 |
+
target_np = target_img.detach().cpu().numpy()
|
281 |
+
source_mask_np = source_mask_binary.detach().cpu().numpy()
|
282 |
+
target_mask_np = target_mask_binary.detach().cpu().numpy()
|
283 |
+
|
284 |
+
try:
|
285 |
+
# Try to import the color_matcher library
|
286 |
+
try:
|
287 |
+
from color_matcher import ColorMatcher
|
288 |
+
from color_matcher.normalizer import Normalizer
|
289 |
+
except ImportError:
|
290 |
+
self._install_package("color-matcher")
|
291 |
+
from color_matcher import ColorMatcher
|
292 |
+
from color_matcher.normalizer import Normalizer
|
293 |
+
|
294 |
+
# Extract only the masked pixels from both images
|
295 |
+
source_coords = np.where(source_mask_np > 0.5)
|
296 |
+
target_coords = np.where(target_mask_np > 0.5)
|
297 |
+
|
298 |
+
if len(source_coords[0]) == 0 or len(target_coords[0]) == 0:
|
299 |
+
return result_img
|
300 |
+
|
301 |
+
# Extract pixel values from masked regions
|
302 |
+
source_pixels = source_np[source_coords]
|
303 |
+
target_pixels = target_np[target_coords]
|
304 |
+
|
305 |
+
# Initialize color matcher
|
306 |
+
cm = ColorMatcher()
|
307 |
+
|
308 |
+
if method == "mkl":
|
309 |
+
# For MKL, calculate mean and standard deviation from masked regions
|
310 |
+
source_mean = np.mean(source_pixels, axis=0)
|
311 |
+
source_std = np.std(source_pixels, axis=0)
|
312 |
+
target_mean = np.mean(target_pixels, axis=0)
|
313 |
+
target_std = np.std(target_pixels, axis=0)
|
314 |
+
|
315 |
+
# Apply the transformation
|
316 |
+
result_np = np.copy(target_np)
|
317 |
+
for c in range(3):
|
318 |
+
# Normalize the target channel and scale by source statistics
|
319 |
+
normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
|
320 |
+
|
321 |
+
# Only apply to masked region
|
322 |
+
result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
|
323 |
+
|
324 |
+
# Convert back to tensor
|
325 |
+
result_tensor = torch.from_numpy(result_np).to(result_img.device)
|
326 |
+
|
327 |
+
# Blend with original based on factor
|
328 |
+
result_img = torch.lerp(result_img, result_tensor, self.factor)
|
329 |
+
|
330 |
+
elif method == "reinhard":
|
331 |
+
# Similar to MKL but with a different normalization approach
|
332 |
+
source_mean = np.mean(source_pixels, axis=0)
|
333 |
+
source_std = np.std(source_pixels, axis=0)
|
334 |
+
target_mean = np.mean(target_pixels, axis=0)
|
335 |
+
target_std = np.std(target_pixels, axis=0)
|
336 |
+
|
337 |
+
# Apply the transformation
|
338 |
+
result_np = np.copy(target_np)
|
339 |
+
for c in range(3):
|
340 |
+
# Normalize the target channel and scale by source statistics
|
341 |
+
normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
|
342 |
+
|
343 |
+
# Only apply to masked region
|
344 |
+
result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
|
345 |
+
|
346 |
+
# Convert back to tensor
|
347 |
+
result_tensor = torch.from_numpy(result_np).to(result_img.device)
|
348 |
+
|
349 |
+
# Blend with original based on factor
|
350 |
+
result_img = torch.lerp(result_img, result_tensor, self.factor)
|
351 |
+
|
352 |
+
elif method == "mvgd":
|
353 |
+
# For MVGD, we need mean and covariance matrices
|
354 |
+
source_mean = np.mean(source_pixels, axis=0)
|
355 |
+
source_cov = np.cov(source_pixels, rowvar=False)
|
356 |
+
target_mean = np.mean(target_pixels, axis=0)
|
357 |
+
target_cov = np.cov(target_pixels, rowvar=False)
|
358 |
+
|
359 |
+
# Check if covariance matrices are valid
|
360 |
+
if np.isnan(source_cov).any() or np.isnan(target_cov).any():
|
361 |
+
# Fallback to simple statistics matching
|
362 |
+
source_std = np.std(source_pixels, axis=0)
|
363 |
+
target_std = np.std(target_pixels, axis=0)
|
364 |
+
|
365 |
+
result_np = np.copy(target_np)
|
366 |
+
for c in range(3):
|
367 |
+
normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
|
368 |
+
result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
|
369 |
+
else:
|
370 |
+
# Apply full MVGD transformation to masked pixels
|
371 |
+
# Reshape the masked pixels for matrix operations
|
372 |
+
target_flat = target_np.reshape(-1, 3)
|
373 |
+
result_np = np.copy(target_np)
|
374 |
+
|
375 |
+
try:
|
376 |
+
# Try to compute the full MVGD transformation
|
377 |
+
source_cov_sqrt = np.linalg.cholesky(source_cov)
|
378 |
+
target_cov_sqrt = np.linalg.cholesky(target_cov)
|
379 |
+
target_cov_sqrt_inv = np.linalg.inv(target_cov_sqrt)
|
380 |
+
|
381 |
+
# Compute the transformation matrix
|
382 |
+
temp = target_cov_sqrt_inv @ source_cov @ target_cov_sqrt_inv.T
|
383 |
+
temp_sqrt_inv = np.linalg.inv(np.linalg.cholesky(temp))
|
384 |
+
A = target_cov_sqrt @ temp_sqrt_inv @ target_cov_sqrt_inv
|
385 |
+
|
386 |
+
# Apply the transformation to all pixels
|
387 |
+
for i in range(target_np.shape[0]):
|
388 |
+
for j in range(target_np.shape[1]):
|
389 |
+
if target_mask_np[i, j] > 0.5:
|
390 |
+
# Only apply to masked pixels
|
391 |
+
pixel = target_np[i, j]
|
392 |
+
centered = pixel - target_mean
|
393 |
+
transformed = centered @ A.T + source_mean
|
394 |
+
result_np[i, j] = transformed
|
395 |
+
except np.linalg.LinAlgError:
|
396 |
+
# Fallback to simple statistics matching
|
397 |
+
source_std = np.std(source_pixels, axis=0)
|
398 |
+
target_std = np.std(target_pixels, axis=0)
|
399 |
+
|
400 |
+
for c in range(3):
|
401 |
+
normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
|
402 |
+
result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
|
403 |
+
|
404 |
+
# Convert back to tensor
|
405 |
+
result_tensor = torch.from_numpy(result_np).to(result_img.device)
|
406 |
+
|
407 |
+
# Blend with original based on factor
|
408 |
+
result_img = torch.lerp(result_img, result_tensor, self.factor)
|
409 |
+
|
410 |
+
elif method in ["hm", "hm-mvgd-hm", "hm-mkl-hm"]:
|
411 |
+
# For histogram-based methods, we'll create temporary cropped images with just the masked regions
|
412 |
+
|
413 |
+
# Get the bounding box of the masked regions
|
414 |
+
source_min_y, source_min_x = np.min(source_coords[0]), np.min(source_coords[1])
|
415 |
+
source_max_y, source_max_x = np.max(source_coords[0]), np.max(source_coords[1])
|
416 |
+
target_min_y, target_min_x = np.min(target_coords[0]), np.min(target_coords[1])
|
417 |
+
target_max_y, target_max_x = np.max(target_coords[0]), np.max(target_coords[1])
|
418 |
+
|
419 |
+
# Create cropped images with just the masked regions
|
420 |
+
source_crop = source_np[source_min_y:source_max_y+1, source_min_x:source_max_x+1].copy()
|
421 |
+
target_crop = target_np[target_min_y:target_max_y+1, target_min_x:target_max_x+1].copy()
|
422 |
+
|
423 |
+
# Create cropped masks
|
424 |
+
source_mask_crop = source_mask_np[source_min_y:source_max_y+1, source_min_x:source_max_x+1]
|
425 |
+
target_mask_crop = target_mask_np[target_min_y:target_max_y+1, target_min_x:target_max_x+1]
|
426 |
+
|
427 |
+
# Apply the mask to the cropped images
|
428 |
+
# For non-masked areas, use the average color
|
429 |
+
source_avg_color = np.mean(source_pixels, axis=0)
|
430 |
+
target_avg_color = np.mean(target_pixels, axis=0)
|
431 |
+
|
432 |
+
for c in range(3):
|
433 |
+
source_crop[:, :, c] = np.where(source_mask_crop > 0.5, source_crop[:, :, c], source_avg_color[c])
|
434 |
+
target_crop[:, :, c] = np.where(target_mask_crop > 0.5, target_crop[:, :, c], target_avg_color[c])
|
435 |
+
|
436 |
+
try:
|
437 |
+
# Use the color matcher directly on the masked regions
|
438 |
+
matched_crop = cm.transfer(src=target_crop, ref=source_crop, method=method)
|
439 |
+
|
440 |
+
# Apply the matched colors back to the original image, only in the masked region
|
441 |
+
result_np = np.copy(target_np)
|
442 |
+
|
443 |
+
# Create a mapping from crop coordinates to original image coordinates
|
444 |
+
for i in range(target_crop.shape[0]):
|
445 |
+
for j in range(target_crop.shape[1]):
|
446 |
+
orig_i = target_min_y + i
|
447 |
+
orig_j = target_min_x + j
|
448 |
+
if orig_i < target_np.shape[0] and orig_j < target_np.shape[1] and target_mask_np[orig_i, orig_j] > 0.5:
|
449 |
+
result_np[orig_i, orig_j] = matched_crop[i, j]
|
450 |
+
|
451 |
+
# Convert back to tensor
|
452 |
+
result_tensor = torch.from_numpy(result_np).to(result_img.device)
|
453 |
+
|
454 |
+
# Blend with original based on factor
|
455 |
+
result_img = torch.lerp(result_img, result_tensor, self.factor)
|
456 |
+
|
457 |
+
except Exception as e:
|
458 |
+
# Fallback to AdaIN if color matcher fails
|
459 |
+
print(f"Color matcher failed for {method}, using fallback: {str(e)}")
|
460 |
+
result_img = self._apply_adain_to_region(
|
461 |
+
source_img,
|
462 |
+
target_img,
|
463 |
+
result_img,
|
464 |
+
source_mask_binary,
|
465 |
+
target_mask_binary
|
466 |
+
)
|
467 |
+
|
468 |
+
elif method == "coral":
|
469 |
+
# For CORAL method, extract masked regions and apply CORAL color transfer
|
470 |
+
try:
|
471 |
+
# Create masked versions of the images
|
472 |
+
source_masked = source_np.copy()
|
473 |
+
target_masked = target_np.copy()
|
474 |
+
|
475 |
+
# Apply masks - set non-masked areas to average color
|
476 |
+
source_avg_color = np.mean(source_pixels, axis=0)
|
477 |
+
target_avg_color = np.mean(target_pixels, axis=0)
|
478 |
+
|
479 |
+
for c in range(3):
|
480 |
+
source_masked[:, :, c] = np.where(source_mask_np > 0.5, source_masked[:, :, c], source_avg_color[c])
|
481 |
+
target_masked[:, :, c] = np.where(target_mask_np > 0.5, target_masked[:, :, c], target_avg_color[c])
|
482 |
+
|
483 |
+
# Convert to torch tensors and rearrange to [C, H, W]
|
484 |
+
source_tensor = torch.from_numpy(source_masked).permute(2, 0, 1).float()
|
485 |
+
target_tensor = torch.from_numpy(target_masked).permute(2, 0, 1).float()
|
486 |
+
|
487 |
+
# Apply CORAL color transfer
|
488 |
+
matched_tensor = coral(target_tensor, source_tensor) # target gets matched to source
|
489 |
+
|
490 |
+
# Convert back to [H, W, C] format
|
491 |
+
matched_np = matched_tensor.permute(1, 2, 0).numpy()
|
492 |
+
|
493 |
+
# Apply the matched colors back to the original image, only in the masked region
|
494 |
+
result_np = np.copy(target_np)
|
495 |
+
for c in range(3):
|
496 |
+
result_np[:, :, c] = np.where(target_mask_np > 0.5, matched_np[:, :, c], target_np[:, :, c])
|
497 |
+
|
498 |
+
# Convert back to tensor
|
499 |
+
result_tensor = torch.from_numpy(result_np).to(result_img.device)
|
500 |
+
|
501 |
+
# Blend with original based on factor
|
502 |
+
result_img = torch.lerp(result_img, result_tensor, self.factor)
|
503 |
+
|
504 |
+
except Exception as e:
|
505 |
+
# Fallback to AdaIN if CORAL fails
|
506 |
+
print(f"CORAL failed for {method}, using fallback: {str(e)}")
|
507 |
+
result_img = self._apply_adain_to_region(
|
508 |
+
source_img,
|
509 |
+
target_img,
|
510 |
+
result_img,
|
511 |
+
source_mask_binary,
|
512 |
+
target_mask_binary
|
513 |
+
)
|
514 |
+
else:
|
515 |
+
# Default to AdaIN for unsupported methods
|
516 |
+
result_img = self._apply_adain_to_region(
|
517 |
+
source_img,
|
518 |
+
target_img,
|
519 |
+
result_img,
|
520 |
+
source_mask_binary,
|
521 |
+
target_mask_binary
|
522 |
+
)
|
523 |
+
|
524 |
+
except Exception as e:
|
525 |
+
# If all fails, fallback to AdaIN
|
526 |
+
print(f"Error in color matching: {str(e)}, using AdaIN as fallback")
|
527 |
+
result_img = self._apply_adain_to_region(
|
528 |
+
source_img,
|
529 |
+
target_img,
|
530 |
+
result_img,
|
531 |
+
source_mask_binary,
|
532 |
+
target_mask_binary
|
533 |
+
)
|
534 |
+
|
535 |
+
return torch.clamp(result_img, 0.0, 1.0)
|
536 |
+
|
537 |
+
def _match_channel_statistics(self, source_channel, target_channel, result_channel, source_mask, target_mask):
|
538 |
+
"""
|
539 |
+
Match the statistics of a single channel.
|
540 |
+
|
541 |
+
Args:
|
542 |
+
source_channel: Source channel [H,W] (reference for color matching)
|
543 |
+
target_channel: Target channel [H,W] (to be color matched)
|
544 |
+
result_channel: Result channel to modify [H,W]
|
545 |
+
source_mask: Binary mask for source [H,W]
|
546 |
+
target_mask: Binary mask for target [H,W]
|
547 |
+
|
548 |
+
Returns:
|
549 |
+
Modified result channel
|
550 |
+
"""
|
551 |
+
# Count non-zero elements in masks
|
552 |
+
source_count = torch.sum(source_mask)
|
553 |
+
target_count = torch.sum(target_mask)
|
554 |
+
|
555 |
+
if source_count > 0 and target_count > 0:
|
556 |
+
# Calculate statistics only from masked regions
|
557 |
+
source_masked = source_channel * source_mask
|
558 |
+
target_masked = target_channel * target_mask
|
559 |
+
|
560 |
+
# Calculate mean
|
561 |
+
source_mean = torch.sum(source_masked) / source_count
|
562 |
+
target_mean = torch.sum(target_masked) / target_count
|
563 |
+
|
564 |
+
# Calculate variance
|
565 |
+
source_var = torch.sum(((source_channel - source_mean) * source_mask) ** 2) / source_count
|
566 |
+
target_var = torch.sum(((target_channel - target_mean) * target_mask) ** 2) / target_count
|
567 |
+
|
568 |
+
# Calculate std (add small epsilon to avoid division by zero)
|
569 |
+
source_std = torch.sqrt(source_var + 1e-8)
|
570 |
+
target_std = torch.sqrt(target_var + 1e-8)
|
571 |
+
|
572 |
+
# Apply AdaIN to the masked region
|
573 |
+
normalized = ((target_channel - target_mean) / target_std) * source_std + source_mean
|
574 |
+
|
575 |
+
# Blend with original based on factor
|
576 |
+
result = torch.lerp(target_channel, normalized, self.factor)
|
577 |
+
|
578 |
+
return result
|
579 |
+
|
580 |
+
return result_channel
|
581 |
+
|
582 |
+
def _install_package(self, package_name):
|
583 |
+
"""Install a package using pip."""
|
584 |
+
import subprocess
|
585 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
|
586 |
+
|
587 |
+
|
588 |
+
def create_comparison_figure(original_img, matched_img, title="Color Matching Comparison"):
|
589 |
+
"""
|
590 |
+
Create a matplotlib figure with the original and color-matched images.
|
591 |
+
|
592 |
+
Args:
|
593 |
+
original_img: Original PIL Image
|
594 |
+
matched_img: Color-matched PIL Image
|
595 |
+
title: Title for the figure
|
596 |
+
|
597 |
+
Returns:
|
598 |
+
matplotlib Figure
|
599 |
+
"""
|
600 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
|
601 |
+
|
602 |
+
ax1.imshow(original_img)
|
603 |
+
ax1.set_title("Original")
|
604 |
+
ax1.axis('off')
|
605 |
+
|
606 |
+
ax2.imshow(matched_img)
|
607 |
+
ax2.set_title("Color Matched")
|
608 |
+
ax2.axis('off')
|
609 |
+
|
610 |
+
plt.suptitle(title)
|
611 |
+
plt.tight_layout()
|
612 |
+
|
613 |
+
return fig
|
614 |
+
|
615 |
+
def coral(source, target):
|
616 |
+
"""
|
617 |
+
CORAL (Color Transfer using Correlated Color Temperature) implementation.
|
618 |
+
Based on the original ColorMatchImage approach.
|
619 |
+
|
620 |
+
Args:
|
621 |
+
source: Source image tensor [C, H, W] (to be color matched)
|
622 |
+
target: Target image tensor [C, H, W] (reference for color matching)
|
623 |
+
|
624 |
+
Returns:
|
625 |
+
Color-matched source image tensor [C, H, W]
|
626 |
+
"""
|
627 |
+
# Ensure tensors are float
|
628 |
+
source = source.float()
|
629 |
+
target = target.float()
|
630 |
+
|
631 |
+
# Reshape to [C, N] where N is number of pixels
|
632 |
+
C, H, W = source.shape
|
633 |
+
source_flat = source.view(C, -1) # [C, H*W]
|
634 |
+
target_flat = target.view(C, -1) # [C, H*W]
|
635 |
+
|
636 |
+
# Compute means
|
637 |
+
source_mean = torch.mean(source_flat, dim=1, keepdim=True) # [C, 1]
|
638 |
+
target_mean = torch.mean(target_flat, dim=1, keepdim=True) # [C, 1]
|
639 |
+
|
640 |
+
# Center the data
|
641 |
+
source_centered = source_flat - source_mean # [C, H*W]
|
642 |
+
target_centered = target_flat - target_mean # [C, H*W]
|
643 |
+
|
644 |
+
# Compute covariance matrices
|
645 |
+
N = source_centered.shape[1]
|
646 |
+
source_cov = torch.mm(source_centered, source_centered.t()) / (N - 1) # [C, C]
|
647 |
+
target_cov = torch.mm(target_centered, target_centered.t()) / (N - 1) # [C, C]
|
648 |
+
|
649 |
+
# Add small epsilon to diagonal for numerical stability
|
650 |
+
eps = 1e-5
|
651 |
+
source_cov += eps * torch.eye(C, device=source.device)
|
652 |
+
target_cov += eps * torch.eye(C, device=source.device)
|
653 |
+
|
654 |
+
try:
|
655 |
+
# Compute the transformation matrix using Cholesky decomposition
|
656 |
+
# This is more stable than eigendecomposition for positive definite matrices
|
657 |
+
|
658 |
+
# Cholesky decomposition: A = L * L^T
|
659 |
+
source_chol = torch.linalg.cholesky(source_cov) # Lower triangular
|
660 |
+
target_chol = torch.linalg.cholesky(target_cov) # Lower triangular
|
661 |
+
|
662 |
+
# Compute the transformation matrix
|
663 |
+
# We want to transform source covariance to target covariance
|
664 |
+
# Transform = target_chol * source_chol^(-1)
|
665 |
+
source_chol_inv = torch.linalg.inv(source_chol)
|
666 |
+
transform_matrix = torch.mm(target_chol, source_chol_inv)
|
667 |
+
|
668 |
+
# Apply transformation: result = transform_matrix * (source - source_mean) + target_mean
|
669 |
+
result_centered = torch.mm(transform_matrix, source_centered)
|
670 |
+
result_flat = result_centered + target_mean
|
671 |
+
|
672 |
+
# Reshape back to original shape
|
673 |
+
result = result_flat.view(C, H, W)
|
674 |
+
|
675 |
+
# Clamp to valid range
|
676 |
+
result = torch.clamp(result, 0.0, 1.0)
|
677 |
+
|
678 |
+
return result
|
679 |
+
|
680 |
+
except Exception as e:
|
681 |
+
# Fallback to simple mean/std matching if Cholesky fails
|
682 |
+
print(f"CORAL Cholesky failed, using simple statistics matching: {e}")
|
683 |
+
|
684 |
+
# Simple per-channel statistics matching
|
685 |
+
source_std = torch.std(source_centered, dim=1, keepdim=True)
|
686 |
+
target_std = torch.std(target_centered, dim=1, keepdim=True)
|
687 |
+
|
688 |
+
# Avoid division by zero
|
689 |
+
source_std = torch.clamp(source_std, min=eps)
|
690 |
+
|
691 |
+
# Apply simple transformation: (source - source_mean) / source_std * target_std + target_mean
|
692 |
+
result_flat = (source_centered / source_std) * target_std + target_mean
|
693 |
+
result = result_flat.view(C, H, W)
|
694 |
+
|
695 |
+
# Clamp to valid range
|
696 |
+
result = torch.clamp(result, 0.0, 1.0)
|
697 |
+
|
698 |
+
return result
|
core.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Core part of LaDeco v2
|
2 |
+
|
3 |
+
Example usage:
|
4 |
+
>>> from core import Ladeco
|
5 |
+
>>> from PIL import Image
|
6 |
+
>>> from pathlib import Path
|
7 |
+
>>>
|
8 |
+
>>> # predict
|
9 |
+
>>> ldc = Ladeco()
|
10 |
+
>>> imgs = (thing for thing in Path("example").glob("*.jpg"))
|
11 |
+
>>> out = ldc.predict(imgs)
|
12 |
+
>>>
|
13 |
+
>>> # output - visualization
|
14 |
+
>>> segs = out.visualize(level=2)
|
15 |
+
>>> segs[0].image.show()
|
16 |
+
>>>
|
17 |
+
>>> # output - element area
|
18 |
+
>>> area = out.area()
|
19 |
+
>>> area[0]
|
20 |
+
{"fid": "example/.jpg", "l1_nature": 0.673, "l1_man_made": 0.241, ...}
|
21 |
+
"""
|
22 |
+
from matplotlib.patches import Rectangle
|
23 |
+
from pathlib import Path
|
24 |
+
from PIL import Image
|
25 |
+
from transformers import AutoModelForUniversalSegmentation, AutoProcessor
|
26 |
+
import math
|
27 |
+
import matplotlib as mpl
|
28 |
+
import matplotlib.pyplot as plt
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
from functools import lru_cache
|
32 |
+
from matplotlib.figure import Figure
|
33 |
+
import numpy.typing as npt
|
34 |
+
from typing import Iterable, NamedTuple, Generator
|
35 |
+
from tqdm import tqdm
|
36 |
+
|
37 |
+
|
38 |
+
class LadecoVisualization(NamedTuple):
|
39 |
+
filename: str
|
40 |
+
image: Figure
|
41 |
+
|
42 |
+
|
43 |
+
class Ladeco:
|
44 |
+
|
45 |
+
def __init__(self,
|
46 |
+
model_name: str = "shi-labs/oneformer_ade20k_swin_large",
|
47 |
+
area_threshold: float = 0.01,
|
48 |
+
device: str | None = None,
|
49 |
+
):
|
50 |
+
if device is None:
|
51 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
52 |
+
else:
|
53 |
+
self.device = device
|
54 |
+
|
55 |
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
56 |
+
self.model = AutoModelForUniversalSegmentation.from_pretrained(model_name).to(self.device)
|
57 |
+
|
58 |
+
self.area_threshold = area_threshold
|
59 |
+
|
60 |
+
self.ade20k_labels = {
|
61 |
+
name.strip(): int(idx)
|
62 |
+
for name, idx in self.model.config.label2id.items()
|
63 |
+
}
|
64 |
+
self.ladeco2ade20k: dict[str, tuple[int]] = _get_ladeco_labels(self.ade20k_labels)
|
65 |
+
|
66 |
+
def predict(
|
67 |
+
self, image_paths: str | Path | Iterable[str | Path], show_progress: bool = False
|
68 |
+
) -> "LadecoOutput":
|
69 |
+
if isinstance(image_paths, (str, Path)):
|
70 |
+
imgpaths = [image_paths]
|
71 |
+
else:
|
72 |
+
imgpaths = list(image_paths)
|
73 |
+
|
74 |
+
images = (
|
75 |
+
Image.open(img_path).convert("RGB")
|
76 |
+
for img_path in imgpaths
|
77 |
+
)
|
78 |
+
|
79 |
+
# batch inference functionality of OneFormer is broken
|
80 |
+
masks: list[torch.Tensor] = []
|
81 |
+
for img in tqdm(images, total=len(imgpaths), desc="Segmenting", disable=not show_progress):
|
82 |
+
samples = self.processor(
|
83 |
+
images=img, task_inputs=["semantic"], return_tensors="pt"
|
84 |
+
).to(self.device)
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
outputs = self.model(**samples)
|
88 |
+
|
89 |
+
masks.append(
|
90 |
+
self.processor.post_process_semantic_segmentation(outputs)[0]
|
91 |
+
)
|
92 |
+
|
93 |
+
return LadecoOutput(imgpaths, masks, self.ladeco2ade20k, self.area_threshold)
|
94 |
+
|
95 |
+
|
96 |
+
class LadecoOutput:
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
filenames: list[str | Path],
|
101 |
+
masks: torch.Tensor,
|
102 |
+
ladeco2ade: dict[str, tuple[int]],
|
103 |
+
threshold: float,
|
104 |
+
):
|
105 |
+
self.filenames = filenames
|
106 |
+
self.masks = masks
|
107 |
+
self.ladeco2ade: dict[str, tuple[int]] = ladeco2ade
|
108 |
+
self.ade2ladeco: dict[int, str] = {
|
109 |
+
idx: label
|
110 |
+
for label, indices in self.ladeco2ade.items()
|
111 |
+
for idx in indices
|
112 |
+
}
|
113 |
+
self.threshold = threshold
|
114 |
+
|
115 |
+
def visualize(self, level: int) -> list[LadecoVisualization]:
|
116 |
+
return list(self.ivisualize(level))
|
117 |
+
|
118 |
+
def ivisualize(self, level: int) -> Generator[LadecoVisualization, None, None]:
|
119 |
+
colormaps = self.color_map(level)
|
120 |
+
labelnames = [name for name in self.ladeco2ade if name.startswith(f"l{level}")]
|
121 |
+
|
122 |
+
for fname, mask in zip(self.filenames, self.masks):
|
123 |
+
size = mask.shape + (3,) # (H, W, RGB)
|
124 |
+
vis = torch.zeros(size, dtype=torch.uint8)
|
125 |
+
for name in labelnames:
|
126 |
+
for idx in self.ladeco2ade[name]:
|
127 |
+
color = torch.tensor(colormaps[name] * 255, dtype=torch.uint8)
|
128 |
+
vis[mask == idx] = color
|
129 |
+
|
130 |
+
with Image.open(fname) as img:
|
131 |
+
target_size = img.size
|
132 |
+
vis = Image.fromarray(vis.numpy(), mode="RGB").resize(target_size)
|
133 |
+
|
134 |
+
fig, ax = plt.subplots()
|
135 |
+
ax.imshow(vis)
|
136 |
+
ax.axis('off')
|
137 |
+
|
138 |
+
yield LadecoVisualization(filename=str(fname), image=fig)
|
139 |
+
|
140 |
+
def area(self) -> list[dict[str, float | str]]:
|
141 |
+
return list(self.iarea())
|
142 |
+
|
143 |
+
def iarea(self) -> Generator[dict[str, float | str], None, None]:
|
144 |
+
n_label_ADE20k = 150
|
145 |
+
for filename, mask in zip(self.filenames, self.masks):
|
146 |
+
ade_ratios = torch.tensor([(mask == i).count_nonzero() / mask.numel() for i in range(n_label_ADE20k)])
|
147 |
+
#breakpoint()
|
148 |
+
ldc_ratios: dict[str, float] = {
|
149 |
+
label: round(ade_ratios[list(ade_indices)].sum().item(), 4)
|
150 |
+
for label, ade_indices in self.ladeco2ade.items()
|
151 |
+
}
|
152 |
+
ldc_ratios: dict[str, float] = {
|
153 |
+
label: 0 if ratio < self.threshold else ratio
|
154 |
+
for label, ratio in ldc_ratios.items()
|
155 |
+
}
|
156 |
+
others = round(1 - ldc_ratios["l1_nature"] - ldc_ratios["l1_man_made"], 4)
|
157 |
+
nfi = round(ldc_ratios["l1_nature"]/ (ldc_ratios["l1_nature"] + ldc_ratios.get("l1_man_made", 0) + 1e-6), 4)
|
158 |
+
|
159 |
+
yield {
|
160 |
+
"fid": str(filename), **ldc_ratios, "others": others, "LC_NFI": nfi,
|
161 |
+
}
|
162 |
+
|
163 |
+
def color_map(self, level: int) -> dict[str, npt.NDArray[np.float64]]:
|
164 |
+
"returns {'label_name': (R, G, B), ...}, where (R, G, B) in range [0, 1]"
|
165 |
+
labels = [
|
166 |
+
name for name in self.ladeco2ade.keys() if name.startswith(f"l{level}")
|
167 |
+
]
|
168 |
+
if len(labels) == 0:
|
169 |
+
raise RuntimeError(
|
170 |
+
f"LaDeco only has 4 levels in 1, 2, 3, 4. You assigned {level}."
|
171 |
+
)
|
172 |
+
colormap = mpl.colormaps["viridis"].resampled(len(labels)).colors[:, :-1]
|
173 |
+
# [:, :-1]: discard alpha channel
|
174 |
+
return {name: color for name, color in zip(labels, colormap)}
|
175 |
+
|
176 |
+
def color_legend(self, level: int) -> Figure:
|
177 |
+
colors = self.color_map(level)
|
178 |
+
|
179 |
+
match level:
|
180 |
+
case 1:
|
181 |
+
ncols = 1
|
182 |
+
case 2:
|
183 |
+
ncols = 1
|
184 |
+
case 3:
|
185 |
+
ncols = 2
|
186 |
+
case 4:
|
187 |
+
ncols = 5
|
188 |
+
|
189 |
+
cell_width = 212
|
190 |
+
cell_height = 22
|
191 |
+
swatch_width = 48
|
192 |
+
margin = 12
|
193 |
+
|
194 |
+
nrows = math.ceil(len(colors) / ncols)
|
195 |
+
|
196 |
+
width = cell_width * ncols + 2 * margin
|
197 |
+
height = cell_height * nrows + 2 * margin
|
198 |
+
dpi = 72
|
199 |
+
|
200 |
+
fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
|
201 |
+
fig.subplots_adjust(margin/width, margin/height,
|
202 |
+
(width-margin)/width, (height-margin*2)/height)
|
203 |
+
ax.set_xlim(0, cell_width * ncols)
|
204 |
+
ax.set_ylim(cell_height * (nrows-0.5), -cell_height/2.)
|
205 |
+
ax.yaxis.set_visible(False)
|
206 |
+
ax.xaxis.set_visible(False)
|
207 |
+
ax.set_axis_off()
|
208 |
+
|
209 |
+
for i, name in enumerate(colors):
|
210 |
+
row = i % nrows
|
211 |
+
col = i // nrows
|
212 |
+
y = row * cell_height
|
213 |
+
|
214 |
+
swatch_start_x = cell_width * col
|
215 |
+
text_pos_x = cell_width * col + swatch_width + 7
|
216 |
+
|
217 |
+
ax.text(text_pos_x, y, name, fontsize=14,
|
218 |
+
horizontalalignment='left',
|
219 |
+
verticalalignment='center')
|
220 |
+
|
221 |
+
ax.add_patch(
|
222 |
+
Rectangle(xy=(swatch_start_x, y-9), width=swatch_width,
|
223 |
+
height=18, facecolor=colors[name], edgecolor='0.7')
|
224 |
+
)
|
225 |
+
|
226 |
+
ax.set_title(f"LaDeco Color Legend - Level {level}")
|
227 |
+
|
228 |
+
return fig
|
229 |
+
|
230 |
+
|
231 |
+
def _get_ladeco_labels(ade20k: dict[str, int]) -> dict[str, tuple[int]]:
|
232 |
+
labels = {
|
233 |
+
# level 4 labels
|
234 |
+
# under l3_architecture
|
235 |
+
"l4_hovel": (ade20k["hovel, hut, hutch, shack, shanty"],),
|
236 |
+
"l4_building": (ade20k["building"], ade20k["house"]),
|
237 |
+
"l4_skyscraper": (ade20k["skyscraper"],),
|
238 |
+
"l4_tower": (ade20k["tower"],),
|
239 |
+
# under l3_archi_parts
|
240 |
+
"l4_step": (ade20k["step, stair"],),
|
241 |
+
"l4_canopy": (ade20k["awning, sunshade, sunblind"], ade20k["canopy"]),
|
242 |
+
"l4_arcade": (ade20k["arcade machine"],),
|
243 |
+
"l4_door": (ade20k["door"],),
|
244 |
+
"l4_window": (ade20k["window"],),
|
245 |
+
"l4_wall": (ade20k["wall"],),
|
246 |
+
# under l3_roadway
|
247 |
+
"l4_stairway": (ade20k["stairway, staircase"],),
|
248 |
+
"l4_sidewalk": (ade20k["sidewalk, pavement"],),
|
249 |
+
"l4_road": (ade20k["road, route"],),
|
250 |
+
# under l3_furniture
|
251 |
+
"l4_sculpture": (ade20k["sculpture"],),
|
252 |
+
"l4_flag": (ade20k["flag"],),
|
253 |
+
"l4_can": (ade20k["trash can"],),
|
254 |
+
"l4_chair": (ade20k["chair"],),
|
255 |
+
"l4_pot": (ade20k["pot"],),
|
256 |
+
"l4_booth": (ade20k["booth"],),
|
257 |
+
"l4_streetlight": (ade20k["street lamp"],),
|
258 |
+
"l4_bench": (ade20k["bench"],),
|
259 |
+
"l4_fence": (ade20k["fence"],),
|
260 |
+
"l4_table": (ade20k["table"],),
|
261 |
+
# under l3_vehicle
|
262 |
+
"l4_bike": (ade20k["bicycle"],),
|
263 |
+
"l4_motorbike": (ade20k["minibike, motorbike"],),
|
264 |
+
"l4_van": (ade20k["van"],),
|
265 |
+
"l4_truck": (ade20k["truck"],),
|
266 |
+
"l4_bus": (ade20k["bus"],),
|
267 |
+
"l4_car": (ade20k["car"],),
|
268 |
+
# under l3_sign
|
269 |
+
"l4_traffic_sign": (ade20k["traffic light"],),
|
270 |
+
"l4_poster": (ade20k["poster, posting, placard, notice, bill, card"],),
|
271 |
+
"l4_signboard": (ade20k["signboard, sign"],),
|
272 |
+
# under l3_vert_land
|
273 |
+
"l4_rock": (ade20k["rock, stone"],),
|
274 |
+
"l4_hill": (ade20k["hill"],),
|
275 |
+
"l4_mountain": (ade20k["mountain, mount"],),
|
276 |
+
# under l3_hori_land
|
277 |
+
"l4_ground": (ade20k["earth, ground"], ade20k["land, ground, soil"]),
|
278 |
+
"l4_field": (ade20k["field"],),
|
279 |
+
"l4_sand": (ade20k["sand"],),
|
280 |
+
"l4_dirt": (ade20k["dirt track"],),
|
281 |
+
"l4_path": (ade20k["path"],),
|
282 |
+
# under l3_flower
|
283 |
+
"l4_flower": (ade20k["flower"],),
|
284 |
+
# under l3_grass
|
285 |
+
"l4_grass": (ade20k["grass"],),
|
286 |
+
# under l3_shrub
|
287 |
+
"l4_flora": (ade20k["plant"],),
|
288 |
+
# under l3_arbor
|
289 |
+
"l4_tree": (ade20k["tree"],),
|
290 |
+
"l4_palm": (ade20k["palm, palm tree"],),
|
291 |
+
# under l3_hori_water
|
292 |
+
"l4_lake": (ade20k["lake"],),
|
293 |
+
"l4_pool": (ade20k["pool"],),
|
294 |
+
"l4_river": (ade20k["river"],),
|
295 |
+
"l4_sea": (ade20k["sea"],),
|
296 |
+
"l4_water": (ade20k["water"],),
|
297 |
+
# under l3_vert_water
|
298 |
+
"l4_fountain": (ade20k["fountain"],),
|
299 |
+
"l4_waterfall": (ade20k["falls"],),
|
300 |
+
# under l3_human
|
301 |
+
"l4_person": (ade20k["person"],),
|
302 |
+
# under l3_animal
|
303 |
+
"l4_animal": (ade20k["animal"],),
|
304 |
+
# under l3_sky
|
305 |
+
"l4_sky": (ade20k["sky"],),
|
306 |
+
}
|
307 |
+
labels = labels | {
|
308 |
+
# level 3 labels
|
309 |
+
# under l2_landform
|
310 |
+
"l3_hori_land": labels["l4_ground"] + labels["l4_field"] + labels["l4_sand"] + labels["l4_dirt"] + labels["l4_path"],
|
311 |
+
"l3_vert_land": labels["l4_mountain"] + labels["l4_hill"] + labels["l4_rock"],
|
312 |
+
# under l2_vegetation
|
313 |
+
"l3_woody_plant": labels["l4_tree"] + labels["l4_palm"] + labels["l4_flora"],
|
314 |
+
"l3_herb_plant": labels["l4_grass"],
|
315 |
+
"l3_flower": labels["l4_flower"],
|
316 |
+
# under l2_water
|
317 |
+
"l3_hori_water": labels["l4_water"] + labels["l4_sea"] + labels["l4_river"] + labels["l4_pool"] + labels["l4_lake"],
|
318 |
+
"l3_vert_water": labels["l4_fountain"] + labels["l4_waterfall"],
|
319 |
+
# under l2_bio
|
320 |
+
"l3_human": labels["l4_person"],
|
321 |
+
"l3_animal": labels["l4_animal"],
|
322 |
+
# under l2_sky
|
323 |
+
"l3_sky": labels["l4_sky"],
|
324 |
+
# under l2_archi
|
325 |
+
"l3_architecture": labels["l4_building"] + labels["l4_hovel"] + labels["l4_tower"] + labels["l4_skyscraper"],
|
326 |
+
"l3_archi_parts": labels["l4_wall"] + labels["l4_window"] + labels["l4_door"] + labels["l4_arcade"] + labels["l4_canopy"] + labels["l4_step"],
|
327 |
+
# under l2_street
|
328 |
+
"l3_roadway": labels["l4_road"] + labels["l4_sidewalk"] + labels["l4_stairway"],
|
329 |
+
"l3_furniture": labels["l4_table"] + labels["l4_chair"] + labels["l4_fence"] + labels["l4_bench"] + labels["l4_streetlight"] + labels["l4_booth"] + labels["l4_pot"] + labels["l4_can"] + labels["l4_flag"] + labels["l4_sculpture"],
|
330 |
+
"l3_vehicle": labels["l4_car"] + labels["l4_bus"] + labels["l4_truck"] + labels["l4_van"] + labels["l4_motorbike"] + labels["l4_bike"],
|
331 |
+
"l3_sign": labels["l4_signboard"] + labels["l4_poster"] + labels["l4_traffic_sign"],
|
332 |
+
}
|
333 |
+
labels = labels | {
|
334 |
+
# level 2 labels
|
335 |
+
# under l1_nature
|
336 |
+
"l2_landform": labels["l3_hori_land"] + labels["l3_vert_land"],
|
337 |
+
"l2_vegetation": labels["l3_woody_plant"] + labels["l3_herb_plant"] + labels["l3_flower"],
|
338 |
+
"l2_water": labels["l3_hori_water"] + labels["l3_vert_water"],
|
339 |
+
"l2_bio": labels["l3_human"] + labels["l3_animal"],
|
340 |
+
"l2_sky": labels["l3_sky"],
|
341 |
+
# under l1_man_made
|
342 |
+
"l2_archi": labels["l3_architecture"] + labels["l3_archi_parts"],
|
343 |
+
"l2_street": labels["l3_roadway"] + labels["l3_furniture"] + labels["l3_vehicle"] + labels["l3_sign"],
|
344 |
+
}
|
345 |
+
labels = labels | {
|
346 |
+
# level 1 labels
|
347 |
+
"l1_nature": labels["l2_landform"] + labels["l2_vegetation"] + labels["l2_water"] + labels["l2_bio"] + labels["l2_sky"],
|
348 |
+
"l1_man_made": labels["l2_archi"] + labels["l2_street"],
|
349 |
+
}
|
350 |
+
return labels
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == "__main__":
|
354 |
+
ldc = Ladeco()
|
355 |
+
image = Path("images") / "canyon_3011_00002354.jpg"
|
356 |
+
out = ldc.predict(image)
|
examples/beach.jpg
ADDED
![]() |
examples/field.jpg
ADDED
![]() |
examples/sky.jpg
ADDED
![]() |
face_comparison.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
import subprocess
|
7 |
+
import sys
|
8 |
+
import json
|
9 |
+
from typing import Dict, List, Tuple, Optional
|
10 |
+
import logging
|
11 |
+
|
12 |
+
# Set up logging to suppress DeepFace warnings
|
13 |
+
logging.getLogger('deepface').setLevel(logging.ERROR)
|
14 |
+
|
15 |
+
try:
|
16 |
+
from deepface import DeepFace
|
17 |
+
DEEPFACE_AVAILABLE = True
|
18 |
+
except ImportError:
|
19 |
+
DEEPFACE_AVAILABLE = False
|
20 |
+
print("Warning: DeepFace not available. Face comparison will be disabled.")
|
21 |
+
|
22 |
+
|
23 |
+
def run_deepface_in_subprocess(img1_path: str, img2_path: str) -> dict:
|
24 |
+
"""
|
25 |
+
Run DeepFace verification in a separate process to avoid TensorFlow conflicts.
|
26 |
+
"""
|
27 |
+
script_content = f'''
|
28 |
+
import sys
|
29 |
+
import json
|
30 |
+
from deepface import DeepFace
|
31 |
+
|
32 |
+
try:
|
33 |
+
result = DeepFace.verify(img1_path="{img1_path}", img2_path="{img2_path}")
|
34 |
+
print(json.dumps(result))
|
35 |
+
except Exception as e:
|
36 |
+
print(json.dumps({{"error": str(e)}}))
|
37 |
+
'''
|
38 |
+
|
39 |
+
try:
|
40 |
+
# Write the script to a temporary file
|
41 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as script_file:
|
42 |
+
script_file.write(script_content)
|
43 |
+
script_path = script_file.name
|
44 |
+
|
45 |
+
# Run the script in a subprocess
|
46 |
+
result = subprocess.run([sys.executable, script_path],
|
47 |
+
capture_output=True, text=True, timeout=30)
|
48 |
+
|
49 |
+
# Clean up the script file
|
50 |
+
os.unlink(script_path)
|
51 |
+
|
52 |
+
if result.returncode == 0:
|
53 |
+
return json.loads(result.stdout.strip())
|
54 |
+
else:
|
55 |
+
return {"error": f"Subprocess failed: {result.stderr}"}
|
56 |
+
|
57 |
+
except Exception as e:
|
58 |
+
return {"error": str(e)}
|
59 |
+
|
60 |
+
|
61 |
+
class FaceComparison:
|
62 |
+
"""
|
63 |
+
Handles face detection and comparison on full images.
|
64 |
+
Only responsible for determining if faces match - does not handle segmentation.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self):
|
68 |
+
"""
|
69 |
+
Initialize face comparison using DeepFace's default verification threshold.
|
70 |
+
"""
|
71 |
+
self.available = DEEPFACE_AVAILABLE
|
72 |
+
self.face_match_result = None
|
73 |
+
self.comparison_log = []
|
74 |
+
|
75 |
+
def extract_faces(self, image_path: str) -> List[np.ndarray]:
|
76 |
+
"""
|
77 |
+
Extract faces from the full image using DeepFace (exactly like the working script).
|
78 |
+
|
79 |
+
Args:
|
80 |
+
image_path: Path to the image
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
List of face arrays
|
84 |
+
"""
|
85 |
+
if not self.available:
|
86 |
+
return []
|
87 |
+
|
88 |
+
try:
|
89 |
+
faces = DeepFace.extract_faces(img_path=image_path, detector_backend='opencv')
|
90 |
+
if len(faces) == 0:
|
91 |
+
return []
|
92 |
+
return [f['face'] for f in faces]
|
93 |
+
|
94 |
+
except Exception as e:
|
95 |
+
print(f"Error extracting faces from {image_path}: {str(e)}")
|
96 |
+
return []
|
97 |
+
|
98 |
+
def compare_all_faces(self, image1_path: str, image2_path: str) -> Tuple[bool, List[str]]:
|
99 |
+
"""
|
100 |
+
Compare all faces between two images (exactly like the working script).
|
101 |
+
|
102 |
+
Args:
|
103 |
+
image1_path: Path to first image
|
104 |
+
image2_path: Path to second image
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Tuple of (match_found, log_messages)
|
108 |
+
"""
|
109 |
+
if not self.available:
|
110 |
+
return False, ["Face comparison not available - DeepFace not installed"]
|
111 |
+
|
112 |
+
log_messages = []
|
113 |
+
|
114 |
+
try:
|
115 |
+
faces1 = self.extract_faces(image1_path)
|
116 |
+
faces2 = self.extract_faces(image2_path)
|
117 |
+
|
118 |
+
match_found = False
|
119 |
+
|
120 |
+
log_messages.append(f"Found {len(faces1)} face(s) in Image 1 and {len(faces2)} face(s) in Image 2")
|
121 |
+
|
122 |
+
if len(faces1) == 0 or len(faces2) == 0:
|
123 |
+
log_messages.append("❌ No faces found in one or both images")
|
124 |
+
return False, log_messages
|
125 |
+
|
126 |
+
for idx1, face1 in enumerate(faces1):
|
127 |
+
for idx2, face2 in enumerate(faces2):
|
128 |
+
# Create temporary files instead of permanent ones (exactly like original)
|
129 |
+
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp1, \
|
130 |
+
tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp2:
|
131 |
+
|
132 |
+
# Convert faces to uint8 and save temporarily (exactly like original)
|
133 |
+
face1_uint8 = (face1 * 255).astype(np.uint8)
|
134 |
+
face2_uint8 = (face2 * 255).astype(np.uint8)
|
135 |
+
|
136 |
+
cv2.imwrite(temp1.name, cv2.cvtColor(face1_uint8, cv2.COLOR_RGB2BGR))
|
137 |
+
cv2.imwrite(temp2.name, cv2.cvtColor(face2_uint8, cv2.COLOR_RGB2BGR))
|
138 |
+
|
139 |
+
try:
|
140 |
+
# Try subprocess approach first to avoid TensorFlow conflicts
|
141 |
+
result = run_deepface_in_subprocess(temp1.name, temp2.name)
|
142 |
+
|
143 |
+
if "error" in result:
|
144 |
+
# If subprocess fails, try direct approach
|
145 |
+
result = DeepFace.verify(img1_path=temp1.name, img2_path=temp2.name)
|
146 |
+
|
147 |
+
similarity = 1 - result['distance']
|
148 |
+
|
149 |
+
log_messages.append(f"Comparing Face1-{idx1} to Face2-{idx2} | Similarity: {similarity:.3f}")
|
150 |
+
|
151 |
+
if result['verified']:
|
152 |
+
log_messages.append(f"✅ Match found between Face1-{idx1} and Face2-{idx2}")
|
153 |
+
match_found = True
|
154 |
+
else:
|
155 |
+
log_messages.append(f"❌ No match between Face1-{idx1} and Face2-{idx2}")
|
156 |
+
|
157 |
+
except Exception as e:
|
158 |
+
log_messages.append(f"❌ Error comparing Face1-{idx1} to Face2-{idx2}: {str(e)}")
|
159 |
+
|
160 |
+
# Clean up temporary files immediately
|
161 |
+
try:
|
162 |
+
os.unlink(temp1.name)
|
163 |
+
os.unlink(temp2.name)
|
164 |
+
except:
|
165 |
+
pass
|
166 |
+
|
167 |
+
if not match_found:
|
168 |
+
log_messages.append("❌ No matching faces found between the two images.")
|
169 |
+
|
170 |
+
return match_found, log_messages
|
171 |
+
|
172 |
+
except Exception as e:
|
173 |
+
log_messages.append(f"Error in face comparison: {str(e)}")
|
174 |
+
return False, log_messages
|
175 |
+
|
176 |
+
def run_face_comparison(self, img1_path: str, img2_path: str) -> Tuple[bool, List[str]]:
|
177 |
+
"""
|
178 |
+
Run face comparison and store results for later use.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
img1_path: Path to first image
|
182 |
+
img2_path: Path to second image
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
Tuple of (faces_match, log_messages)
|
186 |
+
"""
|
187 |
+
faces_match, log_messages = self.compare_all_faces(img1_path, img2_path)
|
188 |
+
|
189 |
+
# Store results for later filtering
|
190 |
+
self.face_match_result = faces_match
|
191 |
+
self.comparison_log = log_messages
|
192 |
+
|
193 |
+
return faces_match, log_messages
|
194 |
+
|
195 |
+
def filter_human_regions_by_face_match(self, masks: Dict[str, np.ndarray]) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
196 |
+
"""
|
197 |
+
Filter human regions based on previously computed face comparison results.
|
198 |
+
This only includes/excludes human regions - fine-grained segmentation happens elsewhere.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
masks: Dictionary of semantic masks
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
Tuple of (filtered_masks, log_messages)
|
205 |
+
"""
|
206 |
+
if not self.available:
|
207 |
+
return masks, ["Face comparison not available - DeepFace not installed"]
|
208 |
+
|
209 |
+
if self.face_match_result is None:
|
210 |
+
return masks, ["No face comparison results available. Run face comparison first."]
|
211 |
+
|
212 |
+
filtered_masks = {}
|
213 |
+
log_messages = []
|
214 |
+
|
215 |
+
# Look for human-specific regions (l3_human, not l2_bio which includes animals)
|
216 |
+
human_labels = [label for label in masks.keys() if 'l3_human' in label.lower()]
|
217 |
+
bio_labels = [label for label in masks.keys() if 'l2_bio' in label.lower()]
|
218 |
+
|
219 |
+
log_messages.append(f"Found human labels: {human_labels}")
|
220 |
+
log_messages.append(f"Found bio labels: {bio_labels}")
|
221 |
+
|
222 |
+
# Include all non-human regions regardless of face matching
|
223 |
+
for label, mask in masks.items():
|
224 |
+
if not any(human_term in label.lower() for human_term in ['l3_human', 'l2_bio']):
|
225 |
+
filtered_masks[label] = mask
|
226 |
+
log_messages.append(f"✅ Including non-human region: {label}")
|
227 |
+
else:
|
228 |
+
log_messages.append(f"🔍 Found human/bio region: {label}")
|
229 |
+
|
230 |
+
# Handle human regions based on face matching results
|
231 |
+
if self.face_match_result:
|
232 |
+
log_messages.append("✅ Faces matched! Including human regions in color matching.")
|
233 |
+
# Include human regions since faces matched
|
234 |
+
for label in human_labels + bio_labels:
|
235 |
+
if label in masks:
|
236 |
+
filtered_masks[label] = masks[label]
|
237 |
+
log_messages.append(f"✅ Including human region (faces matched): {label}")
|
238 |
+
else:
|
239 |
+
log_messages.append("❌ No face match found. Excluding human regions from color matching.")
|
240 |
+
# Don't include human regions since faces didn't match
|
241 |
+
for label in human_labels + bio_labels:
|
242 |
+
log_messages.append(f"❌ Excluding human region (no face match): {label}")
|
243 |
+
|
244 |
+
log_messages.append(f"📊 Final filtered masks: {list(filtered_masks.keys())}")
|
245 |
+
|
246 |
+
return filtered_masks, log_messages
|
folder_paths.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# Simple folder_paths module to replace ComfyUI's folder_paths
|
4 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
models_dir = os.path.join(current_dir, "models")
|
6 |
+
|
7 |
+
# Model folder mappings
|
8 |
+
model_folder_paths = {}
|
9 |
+
|
10 |
+
def add_model_folder_path(name, path):
|
11 |
+
"""Add a model folder path."""
|
12 |
+
model_folder_paths[name] = path
|
13 |
+
os.makedirs(path, exist_ok=True)
|
14 |
+
|
15 |
+
def get_full_path(dirname, filename):
|
16 |
+
"""Get the full path for a model file."""
|
17 |
+
if dirname in model_folder_paths:
|
18 |
+
return os.path.join(model_folder_paths[dirname], filename)
|
19 |
+
return os.path.join(models_dir, dirname, filename)
|
20 |
+
|
21 |
+
# Initialize default paths
|
22 |
+
os.makedirs(models_dir, exist_ok=True)
|
human_parts_segmentation.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image, ImageFilter
|
5 |
+
import cv2
|
6 |
+
import requests
|
7 |
+
from typing import Dict, List, Tuple, Optional
|
8 |
+
import onnxruntime as ort
|
9 |
+
|
10 |
+
# Human parts labels based on CCIHP dataset - consistent with latest repo
|
11 |
+
HUMAN_PARTS_LABELS = {
|
12 |
+
0: ("background", "Background"),
|
13 |
+
1: ("hat", "Hat: Hat, helmet, cap, hood, veil, headscarf, part covering the skull and hair of a hood/balaclava, crown…"),
|
14 |
+
2: ("hair", "Hair"),
|
15 |
+
3: ("glove", "Glove"),
|
16 |
+
4: ("glasses", "Sunglasses/Glasses: Sunglasses, eyewear, protective glasses…"),
|
17 |
+
5: ("upper_clothes", "UpperClothes: T-shirt, shirt, tank top, sweater under a coat, top of a dress…"),
|
18 |
+
6: ("face_mask", "Face Mask: Protective mask, surgical mask, carnival mask, facial part of a balaclava, visor of a helmet…"),
|
19 |
+
7: ("coat", "Coat: Coat, jacket worn without anything on it, vest with nothing on it, a sweater with nothing on it…"),
|
20 |
+
8: ("socks", "Socks"),
|
21 |
+
9: ("pants", "Pants: Pants, shorts, tights, leggings, swimsuit bottoms… (clothing with 2 legs)"),
|
22 |
+
10: ("torso-skin", "Torso-skin"),
|
23 |
+
11: ("scarf", "Scarf: Scarf, bow tie, tie…"),
|
24 |
+
12: ("skirt", "Skirt: Skirt, kilt, bottom of a dress…"),
|
25 |
+
13: ("face", "Face"),
|
26 |
+
14: ("left-arm", "Left-arm (naked part)"),
|
27 |
+
15: ("right-arm", "Right-arm (naked part)"),
|
28 |
+
16: ("left-leg", "Left-leg (naked part)"),
|
29 |
+
17: ("right-leg", "Right-leg (naked part)"),
|
30 |
+
18: ("left-shoe", "Left-shoe"),
|
31 |
+
19: ("right-shoe", "Right-shoe"),
|
32 |
+
20: ("bag", "Bag: Backpack, shoulder bag, fanny pack… (bag carried on oneself"),
|
33 |
+
21: ("", "Others: Jewelry, tags, bibs, belts, ribbons, pins, head decorations, headphones…"),
|
34 |
+
}
|
35 |
+
|
36 |
+
# Model configuration - updated paths consistent with new repos
|
37 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
38 |
+
models_dir = os.path.join(current_dir, "models")
|
39 |
+
models_dir_path = os.path.join(models_dir, "onnx", "human-parts")
|
40 |
+
model_url = "https://huggingface.co/Metal3d/deeplabv3p-resnet50-human/resolve/main/deeplabv3p-resnet50-human.onnx"
|
41 |
+
model_name = "deeplabv3p-resnet50-human.onnx"
|
42 |
+
model_path = os.path.join(models_dir_path, model_name)
|
43 |
+
|
44 |
+
|
45 |
+
def get_class_index(class_name: str) -> int:
|
46 |
+
"""Return the index of the class name in the model."""
|
47 |
+
if class_name == "":
|
48 |
+
return -1
|
49 |
+
|
50 |
+
for key, value in HUMAN_PARTS_LABELS.items():
|
51 |
+
if value[0] == class_name:
|
52 |
+
return key
|
53 |
+
return -1
|
54 |
+
|
55 |
+
|
56 |
+
def download_model(model_url: str, model_path: str) -> bool:
|
57 |
+
"""Download the human parts segmentation model if not present - improved version."""
|
58 |
+
if os.path.exists(model_path):
|
59 |
+
return True
|
60 |
+
|
61 |
+
try:
|
62 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
63 |
+
print(f"Downloading human parts model to {model_path}...")
|
64 |
+
|
65 |
+
response = requests.get(model_url, stream=True)
|
66 |
+
response.raise_for_status()
|
67 |
+
|
68 |
+
total_size = int(response.headers.get('content-length', 0))
|
69 |
+
downloaded = 0
|
70 |
+
|
71 |
+
with open(model_path, 'wb') as f:
|
72 |
+
for chunk in response.iter_content(chunk_size=8192):
|
73 |
+
f.write(chunk)
|
74 |
+
downloaded += len(chunk)
|
75 |
+
if total_size > 0:
|
76 |
+
percent = (downloaded / total_size) * 100
|
77 |
+
print(f"\rDownload progress: {percent:.1f}%", end='', flush=True)
|
78 |
+
|
79 |
+
print("\n✅ Model download completed")
|
80 |
+
return True
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
print(f"\n❌ Error downloading model: {e}")
|
84 |
+
return False
|
85 |
+
|
86 |
+
|
87 |
+
def get_human_parts_mask(image: torch.Tensor, model: ort.InferenceSession, rotation: float = 0, **kwargs) -> Tuple[torch.Tensor, int]:
|
88 |
+
"""
|
89 |
+
Generate human parts mask using the ONNX model - improved version.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
image: Input image tensor
|
93 |
+
model: ONNX inference session
|
94 |
+
rotation: Rotation angle (not used currently)
|
95 |
+
**kwargs: Part-specific enable flags
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Tuple of (mask_tensor, score)
|
99 |
+
"""
|
100 |
+
image = image.squeeze(0)
|
101 |
+
image_np = image.numpy() * 255
|
102 |
+
|
103 |
+
pil_image = Image.fromarray(image_np.astype(np.uint8))
|
104 |
+
original_size = pil_image.size
|
105 |
+
|
106 |
+
# Resize to 512x512 as the model expects
|
107 |
+
pil_image = pil_image.resize((512, 512))
|
108 |
+
center = (256, 256)
|
109 |
+
|
110 |
+
if rotation != 0:
|
111 |
+
pil_image = pil_image.rotate(rotation, center=center)
|
112 |
+
|
113 |
+
# Normalize the image
|
114 |
+
image_np = np.array(pil_image).astype(np.float32) / 127.5 - 1
|
115 |
+
image_np = np.expand_dims(image_np, axis=0)
|
116 |
+
|
117 |
+
# Use the ONNX model to get the segmentation
|
118 |
+
input_name = model.get_inputs()[0].name
|
119 |
+
output_name = model.get_outputs()[0].name
|
120 |
+
result = model.run([output_name], {input_name: image_np})
|
121 |
+
result = np.array(result[0]).argmax(axis=3).squeeze(0)
|
122 |
+
|
123 |
+
# Debug: Check what classes the model actually detected
|
124 |
+
unique_classes = np.unique(result)
|
125 |
+
|
126 |
+
score = 0
|
127 |
+
mask = np.zeros_like(result)
|
128 |
+
|
129 |
+
# Combine masks for enabled classes
|
130 |
+
for class_name, enabled in kwargs.items():
|
131 |
+
class_index = get_class_index(class_name)
|
132 |
+
if enabled and class_index != -1:
|
133 |
+
detected = result == class_index
|
134 |
+
mask[detected] = 255
|
135 |
+
score += mask.sum()
|
136 |
+
|
137 |
+
# Resize back to original size
|
138 |
+
mask_image = Image.fromarray(mask.astype(np.uint8), mode="L")
|
139 |
+
if rotation != 0:
|
140 |
+
mask_image = mask_image.rotate(-rotation, center=center)
|
141 |
+
|
142 |
+
mask_image = mask_image.resize(original_size)
|
143 |
+
|
144 |
+
# Convert back to numpy - improved tensor handling
|
145 |
+
mask = np.array(mask_image).astype(np.float32) / 255.0 # Normalize to 0-1 range
|
146 |
+
|
147 |
+
# Add dimensions for torch tensor - consistent format
|
148 |
+
mask = np.expand_dims(mask, axis=0)
|
149 |
+
mask = np.expand_dims(mask, axis=0)
|
150 |
+
|
151 |
+
return torch.from_numpy(mask), score
|
152 |
+
|
153 |
+
|
154 |
+
def numpy_to_torch_tensor(image_np: np.ndarray) -> torch.Tensor:
|
155 |
+
"""Convert numpy array to torch tensor in the format expected by the models."""
|
156 |
+
if len(image_np.shape) == 3:
|
157 |
+
return torch.from_numpy(image_np.astype(np.float32) / 255.0).unsqueeze(0)
|
158 |
+
return torch.from_numpy(image_np.astype(np.float32) / 255.0)
|
159 |
+
|
160 |
+
|
161 |
+
def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
|
162 |
+
"""Convert torch tensor back to numpy array - improved version."""
|
163 |
+
if len(tensor.shape) == 4:
|
164 |
+
tensor = tensor.squeeze(0)
|
165 |
+
|
166 |
+
# Always handle as float32 tensor in 0-1 range then convert to binary
|
167 |
+
tensor_np = tensor.numpy()
|
168 |
+
if tensor_np.dtype == np.float32 and tensor_np.max() <= 1.0:
|
169 |
+
return (tensor_np > 0.5).astype(np.float32) # Binary threshold
|
170 |
+
else:
|
171 |
+
return tensor_np
|
172 |
+
|
173 |
+
|
174 |
+
class HumanPartsSegmentation:
|
175 |
+
"""
|
176 |
+
Standalone human parts segmentation for face and hair using DeepLabV3+ ResNet50.
|
177 |
+
"""
|
178 |
+
|
179 |
+
def __init__(self):
|
180 |
+
self.model = None
|
181 |
+
|
182 |
+
def check_model_cache(self):
|
183 |
+
"""Check if model file exists in cache - consistent with updated repos."""
|
184 |
+
if not os.path.exists(model_path):
|
185 |
+
return False, "Model file not found"
|
186 |
+
return True, "Model cache verified"
|
187 |
+
|
188 |
+
def clear_model(self):
|
189 |
+
"""Clear model from memory - improved version."""
|
190 |
+
if self.model is not None:
|
191 |
+
del self.model
|
192 |
+
self.model = None
|
193 |
+
|
194 |
+
def load_model(self):
|
195 |
+
"""Load the human parts segmentation model - improved version."""
|
196 |
+
try:
|
197 |
+
# Check and download model if needed
|
198 |
+
cache_status, message = self.check_model_cache()
|
199 |
+
if not cache_status:
|
200 |
+
print(f"Cache check: {message}")
|
201 |
+
if not download_model(model_url, model_path):
|
202 |
+
return False
|
203 |
+
|
204 |
+
# Load model if needed
|
205 |
+
if self.model is None:
|
206 |
+
print("Loading human parts segmentation model...")
|
207 |
+
self.model = ort.InferenceSession(model_path)
|
208 |
+
print("✅ Human parts segmentation model loaded successfully")
|
209 |
+
|
210 |
+
return True
|
211 |
+
|
212 |
+
except Exception as e:
|
213 |
+
print(f"❌ Error loading human parts model: {e}")
|
214 |
+
self.clear_model() # Cleanup on error
|
215 |
+
return False
|
216 |
+
|
217 |
+
def segment_parts(self, image_path: str, parts: List[str], mask_blur: int = 0, mask_offset: int = 0) -> Dict[str, np.ndarray]:
|
218 |
+
"""
|
219 |
+
Segment specific human parts from an image - improved version with filtering.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
image_path: Path to the image file
|
223 |
+
parts: List of part names to segment (e.g., ['face', 'hair'])
|
224 |
+
mask_blur: Blur amount for mask edges
|
225 |
+
mask_offset: Expand/Shrink mask boundary
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
Dictionary mapping part names to binary masks
|
229 |
+
"""
|
230 |
+
if not self.load_model():
|
231 |
+
print("❌ Cannot load human parts segmentation model")
|
232 |
+
return {}
|
233 |
+
|
234 |
+
try:
|
235 |
+
# Load image
|
236 |
+
image = cv2.imread(image_path)
|
237 |
+
if image is None:
|
238 |
+
print(f"❌ Could not load image: {image_path}")
|
239 |
+
return {}
|
240 |
+
|
241 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
242 |
+
|
243 |
+
# Convert to tensor format expected by the model
|
244 |
+
image_tensor = numpy_to_torch_tensor(image_rgb)
|
245 |
+
|
246 |
+
# Prepare kwargs for each part
|
247 |
+
part_kwargs = {part: True for part in parts}
|
248 |
+
|
249 |
+
# Get segmentation mask
|
250 |
+
mask_tensor, score = get_human_parts_mask(image_tensor, self.model, **part_kwargs)
|
251 |
+
|
252 |
+
# Convert back to numpy
|
253 |
+
if len(mask_tensor.shape) == 4:
|
254 |
+
mask_tensor = mask_tensor.squeeze(0).squeeze(0)
|
255 |
+
elif len(mask_tensor.shape) == 3:
|
256 |
+
mask_tensor = mask_tensor.squeeze(0)
|
257 |
+
|
258 |
+
# Get the combined mask for all requested parts
|
259 |
+
combined_mask = mask_tensor.numpy()
|
260 |
+
|
261 |
+
# Generate individual masks for each part if multiple parts requested
|
262 |
+
result_masks = {}
|
263 |
+
if len(parts) == 1:
|
264 |
+
# Single part - return the combined mask
|
265 |
+
part_name = parts[0]
|
266 |
+
final_mask = self._apply_filters(combined_mask, mask_blur, mask_offset)
|
267 |
+
if np.sum(final_mask > 0) > 0:
|
268 |
+
result_masks[part_name] = final_mask
|
269 |
+
else:
|
270 |
+
result_masks[part_name] = final_mask # Return empty mask instead of None
|
271 |
+
else:
|
272 |
+
# Multiple parts - need to segment each individually
|
273 |
+
for part in parts:
|
274 |
+
single_part_kwargs = {part: True}
|
275 |
+
single_mask_tensor, _ = get_human_parts_mask(image_tensor, self.model, **single_part_kwargs)
|
276 |
+
|
277 |
+
if len(single_mask_tensor.shape) == 4:
|
278 |
+
single_mask_tensor = single_mask_tensor.squeeze(0).squeeze(0)
|
279 |
+
elif len(single_mask_tensor.shape) == 3:
|
280 |
+
single_mask_tensor = single_mask_tensor.squeeze(0)
|
281 |
+
|
282 |
+
single_mask = single_mask_tensor.numpy()
|
283 |
+
final_mask = self._apply_filters(single_mask, mask_blur, mask_offset)
|
284 |
+
|
285 |
+
result_masks[part] = final_mask # Always add mask, even if empty
|
286 |
+
|
287 |
+
return result_masks
|
288 |
+
|
289 |
+
except Exception as e:
|
290 |
+
print(f"❌ Error in human parts segmentation: {e}")
|
291 |
+
return {}
|
292 |
+
finally:
|
293 |
+
# Clean up model if not needed
|
294 |
+
self.clear_model()
|
295 |
+
|
296 |
+
def _apply_filters(self, mask: np.ndarray, mask_blur: int = 0, mask_offset: int = 0) -> np.ndarray:
|
297 |
+
"""Apply filtering to mask - new method from updated repo."""
|
298 |
+
if mask_blur == 0 and mask_offset == 0:
|
299 |
+
return mask
|
300 |
+
|
301 |
+
try:
|
302 |
+
# Convert to PIL for filtering
|
303 |
+
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
|
304 |
+
|
305 |
+
# Apply blur if specified
|
306 |
+
if mask_blur > 0:
|
307 |
+
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
308 |
+
|
309 |
+
# Apply offset if specified
|
310 |
+
if mask_offset != 0:
|
311 |
+
if mask_offset > 0:
|
312 |
+
mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1))
|
313 |
+
else:
|
314 |
+
mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1))
|
315 |
+
|
316 |
+
# Convert back to numpy
|
317 |
+
filtered_mask = np.array(mask_image).astype(np.float32) / 255.0
|
318 |
+
return filtered_mask
|
319 |
+
|
320 |
+
except Exception as e:
|
321 |
+
print(f"❌ Error applying filters: {e}")
|
322 |
+
return mask
|
models/RMBG/segformer_clothes/.cache/huggingface/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*
|
models/RMBG/segformer_clothes/.cache/huggingface/download/config.json.lock
ADDED
File without changes
|
models/RMBG/segformer_clothes/.cache/huggingface/download/config.json.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
2634bcc40712620e414ffb0efd5f5e4ea732ec5d
|
2 |
+
8352c4562bb0e1f72767dcb170ad6f3f56007836
|
3 |
+
1748821507.461211
|
models/RMBG/segformer_clothes/.cache/huggingface/download/model.safetensors.lock
ADDED
File without changes
|
models/RMBG/segformer_clothes/.cache/huggingface/download/model.safetensors.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
2634bcc40712620e414ffb0efd5f5e4ea732ec5d
|
2 |
+
f70ae566c5773fb335796ebaa8acc924ac25eb97222c2b2967d44d2fc11568e6
|
3 |
+
1748821512.848557
|
models/RMBG/segformer_clothes/.cache/huggingface/download/preprocessor_config.json.lock
ADDED
File without changes
|
models/RMBG/segformer_clothes/.cache/huggingface/download/preprocessor_config.json.metadata
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
2634bcc40712620e414ffb0efd5f5e4ea732ec5d
|
2 |
+
b2340cf4e53b37fda4f5b92d28f11c0f33c3d0fd
|
3 |
+
1748821513.065366
|
models/RMBG/segformer_clothes/config.json
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "nvidia/mit-b3",
|
3 |
+
"architectures": [
|
4 |
+
"SegformerForSemanticSegmentation"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.0,
|
7 |
+
"classifier_dropout_prob": 0.1,
|
8 |
+
"decoder_hidden_size": 768,
|
9 |
+
"depths": [
|
10 |
+
3,
|
11 |
+
4,
|
12 |
+
18,
|
13 |
+
3
|
14 |
+
],
|
15 |
+
"downsampling_rates": [
|
16 |
+
1,
|
17 |
+
4,
|
18 |
+
8,
|
19 |
+
16
|
20 |
+
],
|
21 |
+
"drop_path_rate": 0.1,
|
22 |
+
"hidden_act": "gelu",
|
23 |
+
"hidden_dropout_prob": 0.0,
|
24 |
+
"hidden_sizes": [
|
25 |
+
64,
|
26 |
+
128,
|
27 |
+
320,
|
28 |
+
512
|
29 |
+
],
|
30 |
+
"id2label": {
|
31 |
+
"0": "Background",
|
32 |
+
"1": "Hat",
|
33 |
+
"10": "Right-shoe",
|
34 |
+
"11": "Face",
|
35 |
+
"12": "Left-leg",
|
36 |
+
"13": "Right-leg",
|
37 |
+
"14": "Left-arm",
|
38 |
+
"15": "Right-arm",
|
39 |
+
"16": "Bag",
|
40 |
+
"17": "Scarf",
|
41 |
+
"2": "Hair",
|
42 |
+
"3": "Sunglasses",
|
43 |
+
"4": "Upper-clothes",
|
44 |
+
"5": "Skirt",
|
45 |
+
"6": "Pants",
|
46 |
+
"7": "Dress",
|
47 |
+
"8": "Belt",
|
48 |
+
"9": "Left-shoe"
|
49 |
+
},
|
50 |
+
"image_size": 224,
|
51 |
+
"initializer_range": 0.02,
|
52 |
+
"label2id": {
|
53 |
+
"Background": "0",
|
54 |
+
"Bag": "16",
|
55 |
+
"Belt": "8",
|
56 |
+
"Dress": "7",
|
57 |
+
"Face": "11",
|
58 |
+
"Hair": "2",
|
59 |
+
"Hat": "1",
|
60 |
+
"Left-arm": "14",
|
61 |
+
"Left-leg": "12",
|
62 |
+
"Left-shoe": "9",
|
63 |
+
"Pants": "6",
|
64 |
+
"Right-arm": "15",
|
65 |
+
"Right-leg": "13",
|
66 |
+
"Right-shoe": "10",
|
67 |
+
"Scarf": "17",
|
68 |
+
"Skirt": "5",
|
69 |
+
"Sunglasses": "3",
|
70 |
+
"Upper-clothes": "4"
|
71 |
+
},
|
72 |
+
"layer_norm_eps": 1e-06,
|
73 |
+
"mlp_ratios": [
|
74 |
+
4,
|
75 |
+
4,
|
76 |
+
4,
|
77 |
+
4
|
78 |
+
],
|
79 |
+
"model_type": "segformer",
|
80 |
+
"num_attention_heads": [
|
81 |
+
1,
|
82 |
+
2,
|
83 |
+
5,
|
84 |
+
8
|
85 |
+
],
|
86 |
+
"num_channels": 3,
|
87 |
+
"num_encoder_blocks": 4,
|
88 |
+
"patch_sizes": [
|
89 |
+
7,
|
90 |
+
3,
|
91 |
+
3,
|
92 |
+
3
|
93 |
+
],
|
94 |
+
"reshape_last_stage": true,
|
95 |
+
"semantic_loss_ignore_index": 255,
|
96 |
+
"sr_ratios": [
|
97 |
+
8,
|
98 |
+
4,
|
99 |
+
2,
|
100 |
+
1
|
101 |
+
],
|
102 |
+
"strides": [
|
103 |
+
4,
|
104 |
+
2,
|
105 |
+
2,
|
106 |
+
2
|
107 |
+
],
|
108 |
+
"torch_dtype": "float32",
|
109 |
+
"transformers_version": "4.38.1"
|
110 |
+
}
|
models/RMBG/segformer_clothes/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f70ae566c5773fb335796ebaa8acc924ac25eb97222c2b2967d44d2fc11568e6
|
3 |
+
size 189029000
|
models/RMBG/segformer_clothes/preprocessor_config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"do_reduce_labels": false,
|
4 |
+
"do_rescale": true,
|
5 |
+
"do_resize": true,
|
6 |
+
"image_mean": [
|
7 |
+
0.485,
|
8 |
+
0.456,
|
9 |
+
0.406
|
10 |
+
],
|
11 |
+
"image_processor_type": "SegformerImageProcessor",
|
12 |
+
"image_std": [
|
13 |
+
0.229,
|
14 |
+
0.224,
|
15 |
+
0.225
|
16 |
+
],
|
17 |
+
"resample": 2,
|
18 |
+
"rescale_factor": 0.00392156862745098,
|
19 |
+
"size": {
|
20 |
+
"height": 512,
|
21 |
+
"width": 512
|
22 |
+
}
|
23 |
+
}
|
models/onnx/human-parts/deeplabv3p-resnet50-human.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a6e823a82da10ba24c29adfb544130684568c46bfac865e215bbace3b4035a71
|
3 |
+
size 47210581
|
requirements.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LaDeco requirements
|
2 |
+
torch==2.3.1
|
3 |
+
torchaudio
|
4 |
+
torchvision
|
5 |
+
tf-keras
|
6 |
+
transformers==4.42.4
|
7 |
+
diffusers
|
8 |
+
opencv-python
|
9 |
+
Pillow
|
10 |
+
numpy
|
11 |
+
matplotlib
|
12 |
+
scipy
|
13 |
+
scikit-learn
|
14 |
+
|
15 |
+
# For Gradio interface
|
16 |
+
gradio
|
17 |
+
|
18 |
+
# Face comparison
|
19 |
+
deepface
|
20 |
+
|
21 |
+
# Human parts segmentation
|
22 |
+
onnxruntime
|
23 |
+
|
24 |
+
# Clothing segmentation
|
25 |
+
huggingface-hub>=0.19.0
|
26 |
+
segment-anything>=1.0
|
27 |
+
|
28 |
+
# Color matching dependencies
|
29 |
+
color-matcher
|
30 |
+
spaces
|
spaces.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import os
|
3 |
+
|
4 |
+
def GPU(func):
|
5 |
+
"""
|
6 |
+
A decorator to indicate that a function should use GPU acceleration if available.
|
7 |
+
This is used specifically for Hugging Face Spaces.
|
8 |
+
"""
|
9 |
+
@functools.wraps(func)
|
10 |
+
def wrapper(*args, **kwargs):
|
11 |
+
return func(*args, **kwargs)
|
12 |
+
return wrapper
|