|
import gradio as gr |
|
from core import Ladeco |
|
from matplotlib.figure import Figure |
|
import matplotlib.pyplot as plt |
|
import matplotlib as mpl |
|
import spaces |
|
from PIL import Image |
|
import numpy as np |
|
from color_matching import RegionColorMatcher, create_comparison_figure |
|
from face_comparison import FaceComparison |
|
from cdl_smoothing import cdl_edge_smoothing, get_smoothing_stats, cdl_edge_smoothing_apply_to_source |
|
import tempfile |
|
import os |
|
import cv2 |
|
|
|
|
|
plt.rcParams['figure.facecolor'] = '#0b0f19' |
|
plt.rcParams['text.color'] = '#aab6cc' |
|
ladeco = Ladeco() |
|
|
|
|
|
@spaces.GPU |
|
def infer_two_images(img1: str, img2: str, method: str, enable_face_matching: bool, enable_edge_smoothing: bool) -> tuple[Figure, Figure, Figure, Figure, Figure, Figure, str, str, str]: |
|
""" |
|
Clean 4-step approach: |
|
1. Segment both images identically |
|
2. Determine segment correspondences |
|
3. Match each segment pair in isolation |
|
4. Composite all matched segments |
|
""" |
|
|
|
cdl_display = "" |
|
|
|
|
|
|
|
print("Step 1: Segmenting both images...") |
|
out1 = ladeco.predict(img1) |
|
out2 = ladeco.predict(img2) |
|
|
|
|
|
seg1 = out1.visualize(level=2)[0].image |
|
colormap1 = out1.color_map(level=2) |
|
area1 = out1.area()[0] |
|
|
|
seg2 = out2.visualize(level=2)[0].image |
|
colormap2 = out2.color_map(level=2) |
|
area2 = out2.area()[0] |
|
|
|
|
|
colors1, l2_area1 = [], {} |
|
for labelname, area_ratio in area1.items(): |
|
if labelname.startswith("l2") and area_ratio > 0: |
|
colors1.append(colormap1[labelname]) |
|
labelname = labelname.replace("l2_", "").capitalize() |
|
l2_area1[labelname] = area_ratio |
|
|
|
colors2, l2_area2 = [], {} |
|
for labelname, area_ratio in area2.items(): |
|
if labelname.startswith("l2") and area_ratio > 0: |
|
colors2.append(colormap2[labelname]) |
|
labelname = labelname.replace("l2_", "").capitalize() |
|
l2_area2[labelname] = area_ratio |
|
|
|
pie1 = plot_pie(l2_area1, colors=colors1) |
|
pie2 = plot_pie(l2_area2, colors=colors2) |
|
|
|
|
|
for fig in [seg1, seg2, pie1, pie2]: |
|
fig.set_dpi(96) |
|
fig.set_size_inches(256/96, 256/96) |
|
|
|
|
|
masks1 = extract_semantic_masks(out1) |
|
masks2 = extract_semantic_masks(out2) |
|
|
|
print(f"Extracted {len(masks1)} masks from img1, {len(masks2)} masks from img2") |
|
|
|
|
|
print("Step 2: Determining segment correspondences...") |
|
face_log = ["Step 2: Determining segment correspondences"] |
|
|
|
|
|
common_segments = set(masks1.keys()).intersection(set(masks2.keys())) |
|
face_log.append(f"Found {len(common_segments)} common segments: {sorted(common_segments)}") |
|
|
|
|
|
segments_to_match = determine_segments_to_match(img1, img2, common_segments, enable_face_matching, face_log) |
|
|
|
face_log.append(f"Final segments to match: {sorted(segments_to_match)}") |
|
|
|
|
|
print("Step 3: Matching each segment pair in isolation...") |
|
face_log.append("\nStep 3: Color matching each segment independently") |
|
|
|
matched_regions = {} |
|
segment_masks = {} |
|
|
|
for segment_name in segments_to_match: |
|
if segment_name in masks1 and segment_name in masks2: |
|
face_log.append(f" Processing {segment_name}...") |
|
|
|
|
|
matched_region, final_mask1, final_mask2 = match_single_segment( |
|
img1, img2, |
|
masks1[segment_name], masks2[segment_name], |
|
segment_name, method, face_log |
|
) |
|
|
|
if matched_region is not None: |
|
matched_regions[segment_name] = matched_region |
|
segment_masks[segment_name] = final_mask2 |
|
face_log.append(f" β
{segment_name} matched successfully") |
|
else: |
|
face_log.append(f" β {segment_name} matching failed") |
|
elif segment_name.startswith('l4_'): |
|
|
|
face_log.append(f" Processing fine-grained {segment_name}...") |
|
|
|
matched_region, final_mask1, final_mask2 = match_single_segment( |
|
img1, img2, None, None, segment_name, method, face_log |
|
) |
|
|
|
if matched_region is not None: |
|
matched_regions[segment_name] = matched_region |
|
segment_masks[segment_name] = final_mask2 |
|
face_log.append(f" β
{segment_name} matched successfully") |
|
else: |
|
face_log.append(f" β {segment_name} matching failed") |
|
|
|
|
|
print("Step 4: Compositing all matched segments...") |
|
face_log.append(f"\nStep 4: Compositing {len(matched_regions)} matched segments") |
|
|
|
final_image = composite_matched_segments(img2, matched_regions, segment_masks, face_log) |
|
|
|
|
|
if enable_edge_smoothing: |
|
print("Step 5: Applying CDL-based edge smoothing...") |
|
face_log.append("\nStep 5: CDL edge smoothing - applying CDL transform to image 2 based on composited result") |
|
|
|
try: |
|
|
|
temp_dir = tempfile.gettempdir() |
|
temp_composite_path = os.path.join(temp_dir, "temp_composite_for_cdl.png") |
|
final_image.save(temp_composite_path, "PNG") |
|
|
|
|
|
cdl_stats = get_smoothing_stats(img2, temp_composite_path) |
|
|
|
|
|
slope = cdl_stats['cdl_slope'] |
|
offset = cdl_stats['cdl_offset'] |
|
power = cdl_stats['cdl_power'] |
|
|
|
|
|
cdl_display = f"""π CDL Parameters (Image 2 β Composited Result): |
|
|
|
π§ Method: Simple Mean/Std Matching (basic statistical approach) |
|
|
|
πΈ Slope (Gain): |
|
Red: {slope[0]:.6f} |
|
Green: {slope[1]:.6f} |
|
Blue: {slope[2]:.6f} |
|
|
|
πΈ Offset: |
|
Red: {offset[0]:.6f} |
|
Green: {offset[1]:.6f} |
|
Blue: {offset[2]:.6f} |
|
|
|
πΈ Power (Gamma): |
|
Red: {power[0]:.6f} |
|
Green: {power[1]:.6f} |
|
Blue: {power[2]:.6f} |
|
|
|
These CDL values represent the color transformation needed to convert Image 2 into the composited result. |
|
|
|
The CDL calculation uses the simplest possible approach: matches the mean and standard deviation |
|
of each color channel between the original and composited images, with simple gamma calculation |
|
based on brightness relationships. |
|
""" |
|
|
|
face_log.append(f"π CDL Parameters (image 2 β composited result):") |
|
face_log.append(f" Method: Simple mean/std matching") |
|
face_log.append(f" Slope (R,G,B): [{slope[0]:.4f}, {slope[1]:.4f}, {slope[2]:.4f}]") |
|
face_log.append(f" Offset (R,G,B): [{offset[0]:.4f}, {offset[1]:.4f}, {offset[2]:.4f}]") |
|
face_log.append(f" Power (R,G,B): [{power[0]:.4f}, {power[1]:.4f}, {power[2]:.4f}]") |
|
|
|
|
|
final_image = cdl_edge_smoothing_apply_to_source(img2, temp_composite_path, factor=1.0) |
|
|
|
|
|
if os.path.exists(temp_composite_path): |
|
os.remove(temp_composite_path) |
|
|
|
face_log.append("β
CDL edge smoothing completed - transformed image 2 using calculated CDL parameters") |
|
|
|
except Exception as e: |
|
face_log.append(f"β CDL edge smoothing failed: {e}") |
|
cdl_display = f"β CDL calculation failed: {e}" |
|
else: |
|
face_log.append("\nStep 5: CDL edge smoothing disabled") |
|
cdl_display = "CDL edge smoothing is disabled. Enable it to see CDL parameters." |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
filename = os.path.basename(img2).split('.')[0] |
|
temp_filename = f"color_matched_{method}_{filename}.png" |
|
temp_path = os.path.join(temp_dir, temp_filename) |
|
final_image.save(temp_path, "PNG") |
|
|
|
|
|
|
|
vis_masks1 = {} |
|
vis_masks2 = {} |
|
|
|
for segment_name in segments_to_match: |
|
if segment_name in segment_masks: |
|
if segment_name.startswith('l4_'): |
|
|
|
part_name = segment_name.replace('l4_', '') |
|
if part_name in ['face', 'hair']: |
|
from human_parts_segmentation import HumanPartsSegmentation |
|
segmenter = HumanPartsSegmentation() |
|
masks_dict1 = segmenter.segment_parts(img1, [part_name]) |
|
masks_dict2 = segmenter.segment_parts(img2, [part_name]) |
|
if part_name in masks_dict1 and part_name in masks_dict2: |
|
vis_masks1[segment_name] = masks_dict1[part_name] |
|
vis_masks2[segment_name] = masks_dict2[part_name] |
|
elif part_name == 'upper_clothes': |
|
from clothes_segmentation import ClothesSegmentation |
|
segmenter = ClothesSegmentation() |
|
mask1 = segmenter.segment_clothes(img1, ["Upper-clothes"]) |
|
mask2 = segmenter.segment_clothes(img2, ["Upper-clothes"]) |
|
if mask1 is not None and mask2 is not None: |
|
vis_masks1[segment_name] = mask1 |
|
vis_masks2[segment_name] = mask2 |
|
else: |
|
|
|
if segment_name in masks1 and segment_name in masks2: |
|
vis_masks1[segment_name] = masks1[segment_name] |
|
vis_masks2[segment_name] = masks2[segment_name] |
|
|
|
mask_vis = visualize_matching_masks(img1, img2, vis_masks1, vis_masks2) |
|
|
|
comparison = create_comparison_figure(Image.open(img2), final_image, f"Color Matching Result ({method})") |
|
|
|
face_log_text = "\n".join(face_log) |
|
|
|
return seg1, pie1, seg2, pie2, comparison, mask_vis, temp_path, face_log_text, cdl_display |
|
|
|
|
|
def determine_segments_to_match(img1: str, img2: str, common_segments: set, enable_face_matching: bool, log: list) -> set: |
|
""" |
|
Determine which segments should be matched based on face matching logic. |
|
Returns the set of segment names to process. |
|
""" |
|
if not enable_face_matching: |
|
log.append("Face matching disabled - matching all common segments") |
|
return common_segments |
|
|
|
log.append("Face matching enabled - checking faces...") |
|
|
|
|
|
face_comparator = FaceComparison() |
|
faces_match, face_log = face_comparator.run_face_comparison(img1, img2) |
|
log.extend(face_log) |
|
|
|
if not faces_match: |
|
|
|
log.append("No face match - excluding human/bio segments") |
|
non_human_segments = set() |
|
for segment in common_segments: |
|
if not any(term in segment.lower() for term in ['l3_human', 'l2_bio']): |
|
non_human_segments.add(segment) |
|
else: |
|
log.append(f" Excluding human segment: {segment}") |
|
|
|
log.append(f"Matching {len(non_human_segments)} non-human segments") |
|
return non_human_segments |
|
|
|
else: |
|
|
|
log.append("Faces match - including all segments + fine-grained") |
|
|
|
segments_to_match = common_segments.copy() |
|
|
|
|
|
bio_segments = [s for s in common_segments if 'l2_bio' in s.lower()] |
|
if bio_segments: |
|
fine_grained_segments = add_fine_grained_segments(img1, img2, common_segments, log) |
|
segments_to_match.update(fine_grained_segments) |
|
|
|
return segments_to_match |
|
|
|
|
|
def add_fine_grained_segments(img1: str, img2: str, common_segments: set, log: list) -> set: |
|
""" |
|
Add fine-grained human parts segments when faces match. |
|
Returns set of fine-grained segment names that were successfully added. |
|
""" |
|
fine_grained_segments = set() |
|
|
|
try: |
|
from human_parts_segmentation import HumanPartsSegmentation |
|
from clothes_segmentation import ClothesSegmentation |
|
|
|
log.append(" Adding fine-grained human parts...") |
|
|
|
|
|
human_segmenter = HumanPartsSegmentation() |
|
face_hair_masks1 = human_segmenter.segment_parts(img1, ['face', 'hair']) |
|
face_hair_masks2 = human_segmenter.segment_parts(img2, ['face', 'hair']) |
|
|
|
|
|
clothes_segmenter = ClothesSegmentation() |
|
clothes_mask1 = clothes_segmenter.segment_clothes(img1, ["Upper-clothes"]) |
|
clothes_mask2 = clothes_segmenter.segment_clothes(img2, ["Upper-clothes"]) |
|
|
|
|
|
for part_name, mask1 in face_hair_masks1.items(): |
|
if (mask1 is not None and part_name in face_hair_masks2 and |
|
face_hair_masks2[part_name] is not None): |
|
|
|
if np.sum(mask1 > 0) > 0 and np.sum(face_hair_masks2[part_name] > 0) > 0: |
|
fine_grained_segments.add(f'l4_{part_name}') |
|
log.append(f" Added fine-grained: {part_name}") |
|
|
|
|
|
if (clothes_mask1 is not None and clothes_mask2 is not None and |
|
np.sum(clothes_mask1 > 0) > 0 and np.sum(clothes_mask2 > 0) > 0): |
|
fine_grained_segments.add('l4_upper_clothes') |
|
log.append(f" Added fine-grained: upper_clothes") |
|
|
|
except Exception as e: |
|
log.append(f" Error adding fine-grained segments: {e}") |
|
|
|
return fine_grained_segments |
|
|
|
|
|
def match_single_segment(img1_path: str, img2_path: str, mask1: np.ndarray, mask2: np.ndarray, |
|
segment_name: str, method: str, log: list) -> tuple[Image.Image, np.ndarray, np.ndarray]: |
|
""" |
|
Match colors of a single segment in complete isolation from other segments. |
|
Each segment is processed independently with no knowledge of other segments. |
|
Returns: (matched_image, final_mask1, final_mask2) |
|
""" |
|
try: |
|
|
|
img1 = Image.open(img1_path).convert("RGB") |
|
img2 = Image.open(img2_path).convert("RGB") |
|
|
|
|
|
img1_np = np.array(img1) |
|
img2_np = np.array(img2) |
|
|
|
|
|
if segment_name.startswith('l4_'): |
|
part_name = segment_name.replace('l4_', '') |
|
if part_name in ['face', 'hair']: |
|
from human_parts_segmentation import HumanPartsSegmentation |
|
segmenter = HumanPartsSegmentation() |
|
masks_dict1 = segmenter.segment_parts(img1_path, [part_name]) |
|
masks_dict2 = segmenter.segment_parts(img2_path, [part_name]) |
|
|
|
if part_name in masks_dict1 and part_name in masks_dict2: |
|
mask1 = masks_dict1[part_name] |
|
mask2 = masks_dict2[part_name] |
|
else: |
|
return None, None, None |
|
|
|
elif part_name == 'upper_clothes': |
|
from clothes_segmentation import ClothesSegmentation |
|
segmenter = ClothesSegmentation() |
|
mask1 = segmenter.segment_clothes(img1_path, ["Upper-clothes"]) |
|
mask2 = segmenter.segment_clothes(img2_path, ["Upper-clothes"]) |
|
|
|
if mask1 is None or mask2 is None: |
|
return None, None, None |
|
|
|
|
|
if mask1.shape != img1_np.shape[:2]: |
|
mask1 = cv2.resize(mask1.astype(np.float32), (img1_np.shape[1], img1_np.shape[0]), |
|
interpolation=cv2.INTER_NEAREST) |
|
if mask2.shape != img2_np.shape[:2]: |
|
mask2 = cv2.resize(mask2.astype(np.float32), (img2_np.shape[1], img2_np.shape[0]), |
|
interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
mask1_binary = (mask1 > 0.5).astype(np.float32) |
|
mask2_binary = (mask2 > 0.5).astype(np.float32) |
|
|
|
|
|
pixels1 = np.sum(mask1_binary > 0) |
|
pixels2 = np.sum(mask2_binary > 0) |
|
|
|
if pixels1 == 0 or pixels2 == 0: |
|
log.append(f" No pixels in {segment_name}: img1={pixels1}, img2={pixels2}") |
|
return None, None, None |
|
|
|
log.append(f" {segment_name}: img1={pixels1} pixels, img2={pixels2} pixels") |
|
|
|
|
|
masks1_dict = {segment_name: mask1_binary} |
|
masks2_dict = {segment_name: mask2_binary} |
|
|
|
|
|
color_matcher = RegionColorMatcher(factor=0.8, preserve_colors=True, |
|
preserve_luminance=True, method=method) |
|
|
|
matched_img = color_matcher.match_regions(img1_path, img2_path, masks1_dict, masks2_dict) |
|
|
|
return matched_img, mask1_binary, mask2_binary |
|
|
|
except Exception as e: |
|
log.append(f" Error matching {segment_name}: {e}") |
|
return None, None, None |
|
|
|
|
|
def composite_matched_segments(base_img_path: str, matched_regions: dict, segment_masks: dict, log: list) -> Image.Image: |
|
""" |
|
Composite all matched segments back together using simple alpha compositing. |
|
Each matched segment is completely independent and overlaid on the base image. |
|
""" |
|
|
|
result = Image.open(base_img_path).convert("RGBA") |
|
result_np = np.array(result) |
|
|
|
log.append(f"Compositing {len(matched_regions)} segments onto base image") |
|
|
|
for segment_name, matched_img in matched_regions.items(): |
|
if segment_name in segment_masks: |
|
mask = segment_masks[segment_name] |
|
|
|
|
|
if mask.shape != result_np.shape[:2]: |
|
mask = cv2.resize(mask.astype(np.float32), |
|
(result_np.shape[1], result_np.shape[0]), |
|
interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
matched_np = np.array(matched_img.convert("RGB")) |
|
|
|
|
|
if matched_np.shape[:2] != result_np.shape[:2]: |
|
matched_pil = Image.fromarray(matched_np) |
|
matched_pil = matched_pil.resize((result_np.shape[1], result_np.shape[0]), Image.LANCZOS) |
|
matched_np = np.array(matched_pil) |
|
|
|
|
|
mask_binary = (mask > 0.5).astype(np.float32) |
|
alpha = np.expand_dims(mask_binary, axis=2) |
|
|
|
|
|
result_np[:, :, :3] = (result_np[:, :, :3] * (1 - alpha) + |
|
matched_np * alpha).astype(np.uint8) |
|
|
|
pixels = np.sum(mask_binary > 0) |
|
log.append(f" Composited {segment_name}: {pixels} pixels") |
|
|
|
return Image.fromarray(result_np).convert("RGB") |
|
|
|
|
|
def visualize_matching_masks(img1_path, img2_path, masks1, masks2): |
|
""" |
|
Create a visualization of the masks being matched between two images. |
|
|
|
Args: |
|
img1_path: Path to first image |
|
img2_path: Path to second image |
|
masks1: Dictionary of masks for first image {label: binary_mask} |
|
masks2: Dictionary of masks for second image {label: binary_mask} |
|
|
|
Returns: |
|
A matplotlib Figure showing the matched masks |
|
""" |
|
|
|
img1 = Image.open(img1_path).convert("RGB") |
|
img2 = Image.open(img2_path).convert("RGB") |
|
|
|
|
|
img1_np = np.array(img1) |
|
img2_np = np.array(img2) |
|
|
|
|
|
fine_grained_masks = {} |
|
regular_masks = {} |
|
|
|
for label, mask in masks1.items(): |
|
if label.startswith('l4_'): |
|
fine_grained_masks[label] = mask |
|
else: |
|
regular_masks[label] = mask |
|
|
|
|
|
common_regular = set(regular_masks.keys()).intersection(set(masks2.keys())) |
|
|
|
|
|
common_fine_grained = set() |
|
for label in fine_grained_masks.keys(): |
|
if label.startswith('l4_') and label in masks2: |
|
part_name = label.replace('l4_', '') |
|
common_fine_grained.add(part_name) |
|
|
|
|
|
n_regular_rows = len(common_regular) |
|
n_fine_rows = len(common_fine_grained) |
|
n_rows = n_regular_rows + n_fine_rows |
|
|
|
if n_rows == 0: |
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(10, 5)) |
|
ax.text(0.5, 0.5, "No matching regions found between images", |
|
ha='center', va='center', fontsize=14, color='white') |
|
ax.axis('off') |
|
return fig |
|
|
|
fig, axes = plt.subplots(n_rows, 2, figsize=(12, 3 * n_rows)) |
|
|
|
|
|
if n_rows == 1: |
|
axes = np.array([axes]) |
|
|
|
row_idx = 0 |
|
|
|
|
|
for label in sorted(common_regular): |
|
|
|
display_name = label.replace("l2_", "").capitalize() |
|
|
|
|
|
mask1 = regular_masks[label] |
|
mask2 = masks2[label] |
|
|
|
|
|
masked_img1, masked_img2 = create_mask_overlay(img1_np, img2_np, mask1, mask2, [255, 0, 0]) |
|
|
|
|
|
axes[row_idx, 0].imshow(masked_img1) |
|
axes[row_idx, 0].set_title(f"Image 1: {display_name}") |
|
axes[row_idx, 0].axis('off') |
|
|
|
axes[row_idx, 1].imshow(masked_img2) |
|
axes[row_idx, 1].set_title(f"Image 2: {display_name}") |
|
axes[row_idx, 1].axis('off') |
|
|
|
row_idx += 1 |
|
|
|
|
|
part_colors = { |
|
'face': [255, 0, 0], |
|
'hair': [255, 0, 0], |
|
'upper_clothes': [255, 0, 0] |
|
} |
|
|
|
for part_name in sorted(common_fine_grained): |
|
label = f'l4_{part_name}' |
|
|
|
if label in fine_grained_masks and label in masks2: |
|
mask1 = fine_grained_masks[label] |
|
mask2 = masks2[label] |
|
|
|
color = part_colors.get(part_name, [255, 0, 0]) |
|
|
|
|
|
masked_img1, masked_img2 = create_mask_overlay(img1_np, img2_np, mask1, mask2, color) |
|
|
|
|
|
display_name = part_name.replace('_', ' ').title() |
|
axes[row_idx, 0].imshow(masked_img1) |
|
axes[row_idx, 0].set_title(f"Image 1: {display_name} (Fine-grained)") |
|
axes[row_idx, 0].axis('off') |
|
|
|
axes[row_idx, 1].imshow(masked_img2) |
|
axes[row_idx, 1].set_title(f"Image 2: {display_name} (Fine-grained)") |
|
axes[row_idx, 1].axis('off') |
|
|
|
row_idx += 1 |
|
|
|
plt.suptitle("Matched Regions (highlighted with different colors)", fontsize=16, color='white') |
|
plt.tight_layout() |
|
|
|
return fig |
|
|
|
|
|
def create_mask_overlay(img1_np, img2_np, mask1, mask2, overlay_color): |
|
""" |
|
Create mask overlays on images with the specified color. |
|
|
|
Args: |
|
img1_np: First image as numpy array |
|
img2_np: Second image as numpy array |
|
mask1: Mask for first image |
|
mask2: Mask for second image |
|
overlay_color: RGB color for overlay [R, G, B] |
|
|
|
Returns: |
|
Tuple of (masked_img1, masked_img2) |
|
""" |
|
|
|
if mask1.shape != img1_np.shape[:2]: |
|
mask1_img = Image.fromarray((mask1 * 255).astype(np.uint8)) |
|
mask1_img = mask1_img.resize((img1_np.shape[1], img1_np.shape[0]), Image.NEAREST) |
|
mask1 = np.array(mask1_img).astype(np.float32) / 255.0 |
|
|
|
if mask2.shape != img2_np.shape[:2]: |
|
mask2_img = Image.fromarray((mask2 * 255).astype(np.uint8)) |
|
mask2_img = mask2_img.resize((img2_np.shape[1], img2_np.shape[0]), Image.NEAREST) |
|
mask2 = np.array(mask2_img).astype(np.float32) / 255.0 |
|
|
|
|
|
masked_img1 = img1_np.copy() |
|
masked_img2 = img2_np.copy() |
|
|
|
|
|
overlay_color = np.array(overlay_color, dtype=np.uint8) |
|
|
|
|
|
alpha1 = mask1 * 0.6 |
|
alpha2 = mask2 * 0.6 |
|
|
|
|
|
for c in range(3): |
|
masked_img1[:, :, c] = masked_img1[:, :, c] * (1 - alpha1) + overlay_color[c] * alpha1 |
|
masked_img2[:, :, c] = masked_img2[:, :, c] * (1 - alpha2) + overlay_color[c] * alpha2 |
|
|
|
return masked_img1, masked_img2 |
|
|
|
|
|
def extract_semantic_masks(output): |
|
""" |
|
Extract binary masks for each semantic region from the LadecoOutput. |
|
|
|
Args: |
|
output: LadecoOutput from Ladeco.predict() |
|
|
|
Returns: |
|
Dictionary mapping label names to binary masks |
|
""" |
|
masks = {} |
|
|
|
|
|
seg_mask = output.masks[0].cpu().numpy() |
|
|
|
|
|
for label, indices in output.ladeco2ade.items(): |
|
if label.startswith("l2_"): |
|
|
|
binary_mask = np.zeros_like(seg_mask, dtype=np.float32) |
|
|
|
|
|
for idx in indices: |
|
binary_mask[seg_mask == idx] = 1.0 |
|
|
|
|
|
if np.any(binary_mask): |
|
masks[label] = binary_mask |
|
|
|
return masks |
|
|
|
|
|
def plot_pie(data: dict[str, float], colors=None) -> Figure: |
|
fig, ax = plt.subplots() |
|
|
|
labels = list(data.keys()) |
|
sizes = list(data.values()) |
|
|
|
*_, autotexts = ax.pie(sizes, labels=labels, autopct="%1.1f%%", colors=colors) |
|
|
|
for percent_text in autotexts: |
|
percent_text.set_color("k") |
|
|
|
ax.axis("equal") |
|
|
|
return fig |
|
|
|
|
|
def choose_example(imgpath: str, target_component) -> gr.Image: |
|
img = Image.open(imgpath) |
|
width, height = img.size |
|
ratio = 512 / max(width, height) |
|
img = img.resize((int(width * ratio), int(height * ratio))) |
|
return gr.Image(value=img, label="Input Image (SVG format not supported)", type="filepath") |
|
|
|
|
|
css = """ |
|
.reference { |
|
text-align: center; |
|
font-size: 1.2em; |
|
color: #d1d5db; |
|
margin-bottom: 20px; |
|
} |
|
.reference a { |
|
color: #FB923C; |
|
text-decoration: none; |
|
} |
|
.reference a:hover { |
|
text-decoration: underline; |
|
color: #FB923C; |
|
} |
|
.description { |
|
text-align: center; |
|
font-size: 1.1em; |
|
color: #d1d5db; |
|
margin-bottom: 25px; |
|
} |
|
.footer { |
|
text-align: center; |
|
margin-top: 30px; |
|
padding-top: 20px; |
|
border-top: 1px solid #ddd; |
|
color: #d1d5db; |
|
font-size: 14px; |
|
} |
|
.main-title { |
|
font-size: 24px; |
|
font-weight: bold; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
.selected-image { |
|
height: 756px; |
|
} |
|
.example-image { |
|
height: 220px; |
|
padding: 25px; |
|
} |
|
""".strip() |
|
theme = gr.themes.Base( |
|
primary_hue="orange", |
|
secondary_hue="cyan", |
|
neutral_hue="gray", |
|
).set( |
|
body_text_color='*neutral_100', |
|
body_text_color_subdued='*neutral_600', |
|
background_fill_primary='*neutral_950', |
|
background_fill_secondary='*neutral_600', |
|
border_color_accent='*secondary_800', |
|
color_accent='*primary_50', |
|
color_accent_soft='*secondary_800', |
|
code_background_fill='*neutral_700', |
|
block_background_fill_dark='*body_background_fill', |
|
block_info_text_color='#6b7280', |
|
block_label_text_color='*neutral_300', |
|
block_label_text_weight='700', |
|
block_title_text_color='*block_label_text_color', |
|
block_title_text_weight='300', |
|
panel_background_fill='*neutral_800', |
|
table_text_color_dark='*secondary_800', |
|
checkbox_background_color_selected='*primary_500', |
|
checkbox_label_background_fill='*neutral_500', |
|
checkbox_label_background_fill_hover='*neutral_700', |
|
checkbox_label_text_color='*neutral_200', |
|
input_background_fill='*neutral_700', |
|
input_background_fill_focus='*neutral_600', |
|
slider_color='*primary_500', |
|
table_even_background_fill='*neutral_700', |
|
table_odd_background_fill='*neutral_600', |
|
table_row_focus='*neutral_800' |
|
) |
|
with gr.Blocks(css=css, theme=theme) as demo: |
|
gr.HTML( |
|
""" |
|
<div class="main-title">SegMatch β Zero Shot Segmentation-based color matching</div> |
|
<div class="description"> |
|
Advanced region-based color matching using semantic segmentation and fine-grained human parts detection for precise, contextually-aware color transfer between images. |
|
</div> |
|
""".strip() |
|
) |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
img1 = gr.Image( |
|
label="First Input Image - Color Reference (SVG not supported)", |
|
type="filepath", |
|
height="256px", |
|
) |
|
gr.Label("Example Images for First Input", show_label=False) |
|
with gr.Row(): |
|
ex1_1 = gr.Image( |
|
value="examples/beach.jpg", |
|
show_label=False, |
|
type="filepath", |
|
elem_classes="example-image", |
|
interactive=False, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
show_share_button=False, |
|
) |
|
ex1_2 = gr.Image( |
|
value="examples/field.jpg", |
|
show_label=False, |
|
type="filepath", |
|
elem_classes="example-image", |
|
interactive=False, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
show_share_button=False, |
|
) |
|
|
|
|
|
with gr.Column(): |
|
img2 = gr.Image( |
|
label="Second Input Image - To Be Color Matched (SVG not supported)", |
|
type="filepath", |
|
height="256px", |
|
) |
|
gr.Label("Example Images for Second Input", show_label=False) |
|
with gr.Row(): |
|
ex2_1 = gr.Image( |
|
value="examples/field.jpg", |
|
show_label=False, |
|
type="filepath", |
|
elem_classes="example-image", |
|
interactive=False, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
show_share_button=False, |
|
) |
|
ex2_2 = gr.Image( |
|
value="examples/sky.jpg", |
|
show_label=False, |
|
type="filepath", |
|
elem_classes="example-image", |
|
interactive=False, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
show_share_button=False, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
method = gr.Dropdown( |
|
label="Color Matching Method", |
|
choices=["adain", "mkl", "hm", "reinhard", "mvgd", "hm-mvgd-hm", "hm-mkl-hm", "coral"], |
|
value="adain", |
|
info="Choose the algorithm for color matching between regions" |
|
) |
|
|
|
with gr.Column(): |
|
enable_face_matching = gr.Checkbox( |
|
label="Enable Face Matching for Human Regions", |
|
value=True, |
|
info="Only match human regions if faces are similar (requires DeepFace)" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
enable_edge_smoothing = gr.Checkbox( |
|
label="Enable CDL Edge Smoothing", |
|
value=False, |
|
info="Apply CDL transform to original image using calculated parameters (see log for values)" |
|
) |
|
|
|
start = gr.Button("Start Analysis", variant="primary") |
|
|
|
|
|
download_btn = gr.File( |
|
label="π₯ Download Color-Matched Image", |
|
visible=True, |
|
interactive=False |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Segmentation Results"): |
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
gr.Label("Results for First Image", show_label=True) |
|
seg1 = gr.Plot(label="Semantic Segmentation") |
|
pie1 = gr.Plot(label="Element Area Ratio") |
|
|
|
|
|
with gr.Column(): |
|
gr.Label("Results for Second Image", show_label=True) |
|
seg2 = gr.Plot(label="Semantic Segmentation") |
|
pie2 = gr.Plot(label="Element Area Ratio") |
|
|
|
with gr.TabItem("Color Matching"): |
|
gr.Markdown(""" |
|
### Region-Based Color Matching |
|
|
|
This tab shows the result of matching the colors of the second image to the first image's colors, |
|
but only within corresponding semantic regions. For example, sky areas in the second image are |
|
matched to sky areas in the first image, while vegetation areas are matched separately. |
|
|
|
#### Face Matching Feature: |
|
When enabled, the system will detect faces within human/bio regions and only apply color matching |
|
to human regions where similar faces are found in both images. This ensures that color transfer |
|
only occurs between images of the same person. |
|
|
|
#### CDL Edge Smoothing Feature: |
|
When enabled, calculates Color Decision List (CDL) parameters to transform the original target image |
|
towards the segment-matched result, then applies those CDL parameters to the original image. This creates |
|
a "smoothed" version that maintains the original image's overall characteristics while incorporating the |
|
color relationships found through segment matching. |
|
|
|
The CDL calculation uses the simplest possible approach: matches the mean and standard deviation |
|
of each color channel between the original and composited images, with simple gamma calculation |
|
based on brightness relationships. |
|
|
|
#### Available Methods: |
|
- **adain**: Adaptive Instance Normalization - Matches mean and standard deviation of colors |
|
- **mkl**: Monge-Kantorovich Linearization - Linear transformation of color statistics |
|
- **reinhard**: Reinhard color transfer - Simple statistical approach that matches mean and standard deviation |
|
- **mvgd**: Multi-Variate Gaussian Distribution - Uses color covariance matrices for more accurate matching |
|
- **hm**: Histogram Matching - Matches the full color distribution histograms |
|
- **hm-mvgd-hm**: Histogram + MVGD + Histogram compound method |
|
- **hm-mkl-hm**: Histogram + MKL + Histogram compound method |
|
- **coral**: CORAL (Color Transfer using Correlated Color Temperature) - Advanced covariance-based method for natural color transfer |
|
""") |
|
|
|
|
|
cdl_display = gr.Textbox( |
|
label="π CDL Parameters", |
|
lines=15, |
|
max_lines=20, |
|
interactive=False, |
|
info="Color Decision List parameters calculated when CDL edge smoothing is enabled" |
|
) |
|
|
|
face_log = gr.Textbox( |
|
label="Face Matching Log", |
|
lines=8, |
|
max_lines=15, |
|
interactive=False, |
|
info="Shows details of face detection and matching process" |
|
) |
|
|
|
mask_vis = gr.Plot(label="Matched Regions Visualization") |
|
comparison = gr.Plot(label="Region-Based Color Matching Result") |
|
|
|
gr.HTML( |
|
""" |
|
<div class="footer"> |
|
Β© 2024 SegMatch All Rights Reserved<br> |
|
Developer: Stefan Allen |
|
</div> |
|
""".strip() |
|
) |
|
|
|
|
|
start.click( |
|
fn=infer_two_images, |
|
inputs=[img1, img2, method, enable_face_matching, enable_edge_smoothing], |
|
outputs=[seg1, pie1, seg2, pie2, comparison, mask_vis, download_btn, face_log, cdl_display] |
|
) |
|
|
|
|
|
ex1_1.select(fn=lambda x: choose_example(x, img1), inputs=ex1_1, outputs=img1) |
|
ex1_2.select(fn=lambda x: choose_example(x, img1), inputs=ex1_2, outputs=img1) |
|
ex2_1.select(fn=lambda x: choose_example(x, img2), inputs=ex2_1, outputs=img2) |
|
ex2_2.select(fn=lambda x: choose_example(x, img2), inputs=ex2_2, outputs=img2) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|