HairIQ / example_usage.py
DoctorManhattan's picture
Create example_usage.py
430980c verified
#!/usr/bin/env python3
"""
Hair Studio - Example Usage Script
Demonstrates how to use individual AI models for hair analysis
This script shows how to use each model independently for developers
who want to integrate specific functionality into their own applications.
"""
import torch
import numpy as np
from PIL import Image
import cv2
from transformers import (
SegformerFeatureExtractor,
SegformerForSemanticSegmentation,
AutoImageProcessor,
AutoModelForImageClassification,
AutoModel
)
import matplotlib.pyplot as plt
def setup_models():
"""Initialize all AI models"""
print("Loading AI models...")
models = {}
try:
# Hair Segmentation Model
print("- Loading hair segmentation model...")
models['hair_seg_extractor'] = SegformerFeatureExtractor.from_pretrained(
"Allison/segformer-hair-segmentation-10k-steps"
)
models['hair_seg_model'] = SegformerForSemanticSegmentation.from_pretrained(
"Allison/segformer-hair-segmentation-10k-steps"
)
print(" βœ“ Hair segmentation model loaded")
# Hair Type Classification Model
print("- Loading hair classification model...")
models['hair_class_processor'] = AutoImageProcessor.from_pretrained(
"dima806/hair_type_image_detection"
)
models['hair_class_model'] = AutoModelForImageClassification.from_pretrained(
"dima806/hair_type_image_detection"
)
print(" βœ“ Hair classification model loaded")
# Skin Tone Analysis Model
print("- Loading skin analysis model...")
models['skin_processor'] = AutoImageProcessor.from_pretrained(
"google/derm-foundation"
)
models['skin_model'] = AutoModel.from_pretrained(
"google/derm-foundation"
)
print(" βœ“ Skin analysis model loaded")
print("πŸŽ‰ All models loaded successfully!")
return models
except Exception as e:
print(f"❌ Error loading models: {e}")
return None
def segment_hair_example(image_path, models):
"""
Example: Hair Segmentation
Returns a binary mask highlighting hair regions
"""
print(f"\n🎯 Hair Segmentation Example")
print(f"Processing: {image_path}")
# Load image
image = Image.open(image_path).convert("RGB")
# Preprocess
inputs = models['hair_seg_extractor'](images=image, return_tensors="pt")
# Get segmentation
with torch.no_grad():
outputs = models['hair_seg_model'](**inputs)
segmentation = outputs.logits.argmax(dim=1)[0]
# Convert to binary mask (class 1 = hair)
hair_mask = (segmentation == 1).numpy().astype(np.uint8) * 255
# Save result
mask_image = Image.fromarray(hair_mask)
output_path = image_path.replace('.jpg', '_hair_mask.jpg')
mask_image.save(output_path)
print(f"βœ… Hair mask saved to: {output_path}")
print(f"πŸ“Š Hair pixels detected: {np.sum(hair_mask > 0):,}")
return hair_mask
def classify_hair_type_example(image_path, models):
"""
Example: Hair Type Classification
Returns hair type probabilities
"""
print(f"\nπŸ” Hair Type Classification Example")
print(f"Processing: {image_path}")
# Load image
image = Image.open(image_path).convert("RGB")
# Preprocess
inputs = models['hair_class_processor'](images=image, return_tensors="pt")
# Get predictions
with torch.no_grad():
outputs = models['hair_class_model'](**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get labels
if hasattr(models['hair_class_model'].config, 'id2label'):
labels = models['hair_class_model'].config.id2label
else:
labels = {0: "Straight", 1: "Wavy", 2: "Curly", 3: "Coily"}
# Create results
results = {}
for i, prob in enumerate(probabilities[0]):
label = labels.get(i, f"Type_{i}")
results[label] = float(prob)
print("πŸ“Š Hair Type Probabilities:")
for hair_type, probability in sorted(results.items(), key=lambda x: x[1], reverse=True):
print(f" {hair_type:12}: {probability:.1%}")
# Get top prediction
top_type = max(results, key=results.get)
print(f"🎯 Predicted Hair Type: {top_type} ({results[top_type]:.1%} confidence)")
return results
def analyze_skin_tone_example(image_path, models):
"""
Example: Skin Tone Analysis
Returns skin tone characteristics
"""
print(f"\n🌈 Skin Tone Analysis Example")
print(f"Processing: {image_path}")
# Load and preprocess image
image = Image.open(image_path).convert("RGB")
img_array = np.array(image)
# Focus on face region (center area)
h, w = img_array.shape[:2]
face_region = img_array[h//4:3*h//4, w//4:3*w//4]
# Convert to LAB color space for analysis
lab_image = cv2.cvtColor(face_region, cv2.COLOR_RGB2LAB)
# Get average color
avg_lab = np.mean(lab_image.reshape(-1, 3), axis=0)
L, a, b = avg_lab
# Determine undertone
if a > 5 and b > 8:
undertone = "Warm"
confidence = min(0.95, (a + b) / 25)
elif a > 3 and b < -2:
undertone = "Cool"
confidence = min(0.95, a / 15)
else:
undertone = "Neutral"
confidence = 0.8
# Determine depth
if L > 75:
depth = "Very Light"
elif L > 60:
depth = "Light"
elif L > 45:
depth = "Medium"
elif L > 30:
depth = "Deep"
else:
depth = "Very Deep"
results = {
"undertone": undertone,
"confidence": confidence,
"depth": depth,
"lab_values": {
"L": float(L),
"a": float(a),
"b": float(b)
}
}
print("πŸ“Š Skin Tone Analysis Results:")
print(f" Undertone: {undertone} ({confidence:.1%} confidence)")
print(f" Depth: {depth}")
print(f" LAB Values: L={L:.1f}, a={a:.1f}, b={b:.1f}")
return results
def generate_color_recommendations(skin_analysis, hair_type):
"""
Example: Generate color recommendations based on analysis
"""
print(f"\n🎨 Color Recommendation Example")
# Simple color database
color_database = {
"warm": {
"light": ["Honey Blonde", "Golden Brown", "Caramel"],
"medium": ["Chocolate Brown", "Auburn", "Copper"],
"deep": ["Rich Mahogany", "Dark Chocolate", "Burgundy"]
},
"cool": {
"light": ["Ash Blonde", "Platinum", "Cool Brown"],
"medium": ["Ash Brown", "Steel Brown", "Cool Brunette"],
"deep": ["Cool Black", "Dark Ash", "Blue Black"]
},
"neutral": {
"light": ["Natural Blonde", "Sandy Brown", "Light Brown"],
"medium": ["Medium Brown", "Chestnut", "Hazelnut"],
"deep": ["Dark Brown", "Natural Black", "Espresso"]
}
}
undertone = skin_analysis["undertone"].lower()
depth = skin_analysis["depth"].lower().replace(" ", "_").replace("very_", "")
recommendations = color_database.get(undertone, {}).get(depth, ["Natural Brown"])
print(f"πŸ’‘ Recommended Colors for {skin_analysis['undertone']} {skin_analysis['depth']} skin:")
for i, color in enumerate(recommendations, 1):
print(f" {i}. {color}")
return recommendations
def comprehensive_analysis_example(image_path):
"""
Example: Complete hair and skin analysis pipeline
"""
print(f"\nπŸ”¬ Comprehensive Analysis Example")
print(f"=" * 50)
# Load models
models = setup_models()
if not models:
print("❌ Failed to load models")
return
try:
# Perform all analyses
hair_mask = segment_hair_example(image_path, models)
hair_type = classify_hair_type_example(image_path, models)
skin_analysis = analyze_skin_tone_example(image_path, models)
# Generate recommendations
recommendations = generate_color_recommendations(skin_analysis, hair_type)
print(f"\nβœ… Analysis Complete!")
print(f"πŸ“Š Summary:")
print(f" - Hair segmentation: {np.sum(hair_mask > 0):,} pixels")
print(f" - Hair type: {max(hair_type, key=hair_type.get)}")
print(f" - Skin undertone: {skin_analysis['undertone']}")
print(f" - Color recommendations: {len(recommendations)} options")
except Exception as e:
print(f"❌ Analysis failed: {e}")
def virtual_tryon_example(image_path, target_color_hex="#8B4513"):
"""
Example: Simple virtual hair color try-on
"""
print(f"\nπŸͺ„ Virtual Try-On Example")
print(f"Applying color {target_color_hex} to {image_path}")
# Load models for segmentation
models = setup_models()
if not models:
return
# Get hair mask
image = Image.open(image_path).convert("RGB")
inputs = models['hair_seg_extractor'](images=image, return_tensors="pt")
with torch.no_grad():
outputs = models['hair_seg_model'](**inputs)
segmentation = outputs.logits.argmax(dim=1)[0]
hair_mask = (segmentation == 1).numpy().astype(np.uint8)
# Apply color
img_array = np.array(image)
target_color = tuple(int(target_color_hex[i:i+2], 16) for i in (1, 3, 5))
# Simple color blending
mask_3d = np.stack([hair_mask] * 3, axis=-1)
colored_img = img_array.copy()
for i in range(3):
colored_img[:, :, i] = np.where(
hair_mask > 0,
img_array[:, :, i] * 0.6 + target_color[i] * 0.4,
img_array[:, :, i]
)
# Save result
result_image = Image.fromarray(colored_img.astype(np.uint8))
output_path = image_path.replace('.jpg', '_colored.jpg')
result_image.save(output_path)
print(f"βœ… Virtual try-on saved to: {output_path}")
def main():
"""
Main example function - demonstrates all capabilities
"""
print("πŸš€ Hair Studio AI Models - Example Usage")
print("=" * 50)
# Example image path - replace with your image
image_path = "example_image.jpg" # You need to provide this
# Check if image exists
try:
Image.open(image_path)
except FileNotFoundError:
print(f"❌ Image not found: {image_path}")
print("Please provide a valid image path in the script.")
return
# Run comprehensive analysis
comprehensive_analysis_example(image_path)
# Try virtual color application
virtual_tryon_example(image_path)
print(f"\nπŸŽ‰ All examples completed!")
print(f"Check the output files for results.")
if __name__ == "__main__":
main()
# Additional utility functions for developers
def batch_process_images(image_directory):
"""
Example: Process multiple images in batch
"""
import os
models = setup_models()
if not models:
return
results = []
for filename in os.listdir(image_directory):
if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
image_path = os.path.join(image_directory, filename)
try:
# Quick analysis
image = Image.open(image_path)
inputs = models['hair_class_processor'](images=image, return_tensors="pt")
with torch.no_grad():
outputs = models['hair_class_model'](**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Store result
results.append({
'filename': filename,
'hair_type_scores': probabilities[0].numpy()
})
print(f"βœ… Processed: {filename}")
except Exception as e:
print(f"❌ Failed to process {filename}: {e}")
return results
def model_performance_benchmark():
"""
Example: Benchmark model performance
"""
import time
models = setup_models()
if not models:
return
# Create dummy image
dummy_image = Image.new('RGB', (512, 512), color='red')
# Benchmark each model
benchmarks = {}
# Hair segmentation
start_time = time.time()
inputs = models['hair_seg_extractor'](images=dummy_image, return_tensors="pt")
with torch.no_grad():
outputs = models['hair_seg_model'](**inputs)
benchmarks['hair_segmentation'] = time.time() - start_time
# Hair classification
start_time = time.time()
inputs = models['hair_class_processor'](images=dummy_image, return_tensors="pt")
with torch.no_grad():
outputs = models['hair_class_model'](**inputs)
benchmarks['hair_classification'] = time.time() - start_time
print("⚑ Performance Benchmark Results:")
for model_name, inference_time in benchmarks.items():
print(f" {model_name:20}: {inference_time:.3f}s")
return benchmarks