File size: 15,006 Bytes
f60e836 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
# app.py
# Image Upscale and Enhancement with Multiple Models
# By FebryEnsz
# SDK: Gradio
# Hosted on Hugging Face Spaces
import gradio as gr
import torch
from PIL import Image
import numpy as np
from huggingface_hub import hf_hub_download
import cv2
import os # Import os for path handling
# --- Dependency Imports (Need to be installed via pip or manual clone) ---
# BasicSR related imports (for SwinIR, EDSR, CodeFormer utilities)
try:
from basicsr.archs.swinir_arch import SwinIR as SwinIR_Arch
from basicsr.archs.edsr_arch import EDSR as EDSR_Arch
from basicsr.utils import img2tensor, tensor2img
BASESR_AVAILABLE = True
except ImportError:
print("Warning: basicsr not found. SwinIR, EDSR, and CodeFormer (using basicsr utils) will not be available.")
BASESR_AVAILABLE = False
# RealESRGAN import
try:
from realesrgan import RealESRGAN
REALESRGAN_AVAILABLE = True
except ImportError:
print("Warning: realesrgan not found. Real-ESRGAN-x4 will not be available.")
REALESRGAN_AVAILABLE = False
# CodeFormer import (This assumes CodeFormer is installed and importable,
# or integrated into basicsr's structure) - often requires manual setup.
# We will use basicsr's utilities for CodeFormer if available, and try a direct import if possible.
try:
# Attempting a common import path if CodeFormer is installed separately
from CodeFormer import CodeFormer # Adjust import based on your CodeFormer install
CODEFORMER_AVAILABLE = True
except ImportError:
print("Warning: CodeFormer not found. CodeFormer (Face Enhancement) will not be available.")
CODEFORMER_AVAILABLE = False
# --- Model Configuration ---
# Dictionary of available models and their configuration
# format: "UI Name": {"repo_id": "hf_repo_id", "filename": "weight_filename", "type": "upscale" or "face"}
MODEL_CONFIGS = {
"Real-ESRGAN-x4": {"repo_id": "RealESRGAN/RealESRGAN_x4plus", "filename": "RealESRGAN_x4plus.pth", "type": "upscale", "scale": 4} if REALESRGAN_AVAILABLE else None,
"SwinIR-4x": {"repo_id": "SwinIR/SwinIR-Large", "filename": "SwinIR_4x.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None,
"EDSR-x4": {"repo_id": "EDSR/edsr_x4", "filename": "edsr_x4.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None,
# Note: CodeFormer often requires its own setup. Assuming basicsr utils might help,
# but its core logic is in the CodeFormer library.
"CodeFormer (Face Enhancement)": {"repo_id": "CodeFormer/codeformer", "filename": "codeformer.pth", "type": "face"} if CODEFORMER_AVAILABLE or BASESR_AVAILABLE else None, # Use CodeFormer if installed, otherwise rely on basicsr utilities being present
}
# Filter out unavailable models
MODEL_CONFIGS = {k: v for k, v in MODEL_CONFIGS.items() if v is not None}
# --- Model Loading Cache ---
# Use a simple cache to avoid reloading the same model multiple times
cached_model = {}
cached_model_name = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Function to load the selected model
def load_model(model_name):
global cached_model, cached_model_name
if model_name == cached_model_name and cached_model is not None:
print(f"Using cached model: {model_name}")
return cached_model, MODEL_CONFIGS[model_name]['type']
print(f"Loading model: {model_name}")
config = MODEL_CONFIGS.get(model_name)
if config is None:
return None, f"Error: Model '{model_name}' not supported or dependencies missing."
try:
model_type = config['type']
model_path = hf_hub_download(repo_id=config['repo_id'], filename=config['filename'])
if model_name == "Real-ESRGAN-x4":
if not REALESRGAN_AVAILABLE: raise ImportError("realesrgan not installed.")
model = RealESRGAN(device, scale=config['scale'])
model.load_weights(model_path)
elif model_name == "SwinIR-4x":
if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.")
# SwinIR requires specific initialization parameters
# These match the SwinIR_4x.pth model from the repo
model = SwinIR_Arch(
upscale=config['scale'], in_chans=3, img_size=64, window_size=8,
compress_ratio= -1, dilate_basis=-1, res_range=-1, attn_type='linear'
)
# Load weights, handling potential key mismatches if necessary
pretrained_dict = torch.load(model_path, map_location=device)
model.load_state_dict(pretrained_dict, strict=True) # strict=False if keys might mismatch
model.eval() # Set to evaluation mode
model.to(device)
elif model_name == "EDSR-x4":
if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.")
# EDSR architecture needs scale, num_feat, num_block
# Assuming typical values for EDSR_x4 from the repo
model = EDSR_Arch(num_feat=64, num_block=16, upscale=config['scale'])
pretrained_dict = torch.load(model_path, map_location=device)
model.load_state_dict(pretrained_dict, strict=True)
model.eval()
model.to(device)
elif model_name == "CodeFormer (Face Enhancement)":
if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE): raise ImportError("CodeFormer or basicsr not installed.")
# CodeFormer loading is more complex, often requiring instantiation with specific args
# and potentially related models (like GFPGAN for background).
# For simplicity here, we assume a basic CodeFormer instance can be created.
# This part might need adjustment based on your CodeFormer installation.
if CODEFORMER_AVAILABLE:
# This is a simplified instantiation; a real CodeFormer usage might need more args
model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9)
pretrained_dict = torch.load(model_path, map_location=device)['params_ema'] # CodeFormer often saves params_ema
# Need to handle potential DataParallel prefix if saved from DP
keys = list(pretrained_dict.keys())
if keys and keys[0].startswith('module.'):
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()}
model.load_state_dict(pretrained_dict, strict=True)
model.eval()
model.to(device)
elif BASESR_AVAILABLE:
# Fallback: If CodeFormer library isn't directly importable but basicsr is,
# we *cannot* instantiate the CodeFormer model itself unless basicsr provides it.
# This option is likely only possible if CodeFormer is installed *within* a basicsr environment
# or if basicsr provides the architecture. Given the complexity, let's just raise an error
# if CODEFORMER_AVAILABLE is False.
raise ImportError("CodeFormer library not found. BasicSR utilities alone are not enough to instantiate CodeFormer.")
else:
raise ValueError(f"Configuration missing for model: {model_name}")
# Cache the loaded model
cached_model = model
cached_model_name = model_name
return model, model_type
except ImportError as ie:
print(f"Dependency missing for {model_name}: {ie}")
return None, f"Error: Missing dependency - {ie}. Please ensure model libraries are installed."
except Exception as e:
print(f"Error loading model {model_name}: {e}")
# Clear cache on error
cached_model = None
cached_model_name = None
return None, f"Error loading model: {str(e)}"
# Function to preprocess image (PIL RGB to OpenCV BGR numpy)
def preprocess_image(image: Image.Image) -> np.ndarray:
img = np.array(image)
# OpenCV uses BGR, PIL uses RGB. Need conversion.
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
# Function to postprocess image (OpenCV BGR numpy to PIL RGB)
def postprocess_image(img: np.ndarray) -> Image.Image:
# Ensure image is in the correct range and type before converting
if img.dtype != np.uint8:
img = np.clip(img, 0, 255).astype(np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return Image.fromarray(img)
# Main processing function
def enhance_image(image: Image.Image, model_name: str):
if image is None:
return "Please upload an image.", None
status_message = f"Processing image with {model_name}..."
# Load the selected model and its type
model, model_info = load_model(model_name)
if model is None:
# model_info contains the error message if loading failed
return model_info, None
model_type = model_info # model_info is the type string ('upscale' or 'face')
try:
# Preprocess the image (PIL RGB -> OpenCV BGR)
img_np_bgr = preprocess_image(image)
# Process based on model type and specific model implementation
if model_type == "upscale":
print(f"Applying {model_name} upscaling...")
if model_name == "Real-ESRGAN-x4":
# RealESRGAN works with uint8 BGR numpy directly
output_np_bgr = model.predict(img_np_bgr)
elif model_name in ["SwinIR-4x", "EDSR-x4"]:
if not BASESR_AVAILABLE:
raise ImportError(f"basicsr is required for {model_name}")
# These models often work with float tensors (0-1 range)
# Using basicsr utils: HWC BGR uint8 -> CHW RGB float (0-1) -> send to device
img_tensor = img2tensor(img_np_bgr.astype(np.float32) / 255., bgr2rgb=True, float32=True).unsqueeze(0).to(device)
with torch.no_grad():
output_tensor = model(img_tensor)
# Using basicsr utils: CHW RGB float (0-1) -> HWC RGB uint8 -> Convert to BGR for postprocessing
output_np_rgb = tensor2img(output_tensor, rgb2bgr=False, min_max=(0, 1))
output_np_bgr = cv2.cvtColor(output_np_rgb, cv2.COLOR_RGB2BGR)
else:
raise ValueError(f"Unknown upscale model: {model_name}")
status_message = f"Image upscaled successfully with {model_name}!"
elif model_type == "face":
print(f"Applying {model_name} face enhancement...")
if model_name == "CodeFormer (Face Enhancement)":
if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE):
raise ImportError(f"CodeFormer or basicsr is required for {model_name}")
# CodeFormer's enhance method typically expects uint8 BGR numpy
# It might return multiple outputs, the first is usually the enhanced image
# Example: CodeFormer's inference script might return (restored_img, bboxes)
# We assume the image is the first element.
# Note: CodeFormer often needs additional setup/parameters for GFPGAN, etc.
# This is a simplified call.
# Ensure model is on correct device before call
if next(model.parameters()).device != device:
model.to(device)
# A minimal CodeFormer enhancement might look like this, but the actual API
# depends on the CodeFormer library version/structure you're using.
# The original CodeFormer repo's inference takes numpy BGR.
# This is a *placeholder* call assuming such a method exists and works like this:
output_np_bgr = model.enhance(img_np_bgr, w=0.5, adain=True)[0] # w and adain are common params
else:
raise ValueError(f"Unknown face enhancement model: {model_name}")
status_message = f"Face enhancement applied successfully with {model_name}!"
# Postprocess the output image (OpenCV BGR -> PIL RGB)
enhanced_image = postprocess_image(output_np_bgr)
return status_message, enhanced_image
except ImportError as ie:
return f"Error processing image: Missing dependency - {ie}", None
except Exception as e:
print(f"Error during processing: {e}")
import traceback
traceback.print_exc() # Print full traceback for debugging
return f"Error processing image: {str(e)}", None
# Gradio interface
with gr.Blocks(title="Image Upscale & Enhancement - By FebryEnsz") as demo:
gr.Markdown(
"""
# Image Upscale & Enhancement
**By FebryEnsz**
Upload an image and select a model to enhance it. Choose from multiple models for upscaling (to make it 'HD' or higher resolution) or face enhancement (to improve facial details and focus).
**Note:** This app requires specific Python libraries (`torch`, `basicsr`, `realesrgan`, `CodeFormer`) to be installed for all models to be available. If a model option is missing, its required library is not installed or found.
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image", type="pil")
# Filter available choices based on loaded configs
available_models = list(MODEL_CONFIGS.keys())
if not available_models:
model_choice = gr.Textbox(label="Select Model", value="No models available. Check dependencies.", interactive=False)
enhance_button = gr.Button("Enhance Image", interactive=False)
print("No models are available because dependencies are missing.")
else:
model_choice = gr.Dropdown(
choices=available_models,
label="Select Model",
value=available_models[0] # Default to the first available model
)
# Removed scale_slider as models are fixed scale (x4)
enhance_button = gr.Button("Enhance Image")
with gr.Column():
output_text = gr.Textbox(label="Status", max_lines=2)
output_image = gr.Image(label="Enhanced Image")
# Connect the button to the processing function
if available_models: # Only connect if models are available
enhance_button.click(
fn=enhance_image,
inputs=[image_input, model_choice],
outputs=[output_text, output_image]
)
# Launch the Gradio app
if __name__ == "__main__":
# Set torch backend for potentially better performance on some systems
if torch.backends.mps.is_available():
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # Optional: enable fallback for MPS
demo.launch() |