Image-HDR / app.py
ApaCu's picture
Create app.py
f60e836 verified
raw
history blame
15 kB
# app.py
# Image Upscale and Enhancement with Multiple Models
# By FebryEnsz
# SDK: Gradio
# Hosted on Hugging Face Spaces
import gradio as gr
import torch
from PIL import Image
import numpy as np
from huggingface_hub import hf_hub_download
import cv2
import os # Import os for path handling
# --- Dependency Imports (Need to be installed via pip or manual clone) ---
# BasicSR related imports (for SwinIR, EDSR, CodeFormer utilities)
try:
from basicsr.archs.swinir_arch import SwinIR as SwinIR_Arch
from basicsr.archs.edsr_arch import EDSR as EDSR_Arch
from basicsr.utils import img2tensor, tensor2img
BASESR_AVAILABLE = True
except ImportError:
print("Warning: basicsr not found. SwinIR, EDSR, and CodeFormer (using basicsr utils) will not be available.")
BASESR_AVAILABLE = False
# RealESRGAN import
try:
from realesrgan import RealESRGAN
REALESRGAN_AVAILABLE = True
except ImportError:
print("Warning: realesrgan not found. Real-ESRGAN-x4 will not be available.")
REALESRGAN_AVAILABLE = False
# CodeFormer import (This assumes CodeFormer is installed and importable,
# or integrated into basicsr's structure) - often requires manual setup.
# We will use basicsr's utilities for CodeFormer if available, and try a direct import if possible.
try:
# Attempting a common import path if CodeFormer is installed separately
from CodeFormer import CodeFormer # Adjust import based on your CodeFormer install
CODEFORMER_AVAILABLE = True
except ImportError:
print("Warning: CodeFormer not found. CodeFormer (Face Enhancement) will not be available.")
CODEFORMER_AVAILABLE = False
# --- Model Configuration ---
# Dictionary of available models and their configuration
# format: "UI Name": {"repo_id": "hf_repo_id", "filename": "weight_filename", "type": "upscale" or "face"}
MODEL_CONFIGS = {
"Real-ESRGAN-x4": {"repo_id": "RealESRGAN/RealESRGAN_x4plus", "filename": "RealESRGAN_x4plus.pth", "type": "upscale", "scale": 4} if REALESRGAN_AVAILABLE else None,
"SwinIR-4x": {"repo_id": "SwinIR/SwinIR-Large", "filename": "SwinIR_4x.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None,
"EDSR-x4": {"repo_id": "EDSR/edsr_x4", "filename": "edsr_x4.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None,
# Note: CodeFormer often requires its own setup. Assuming basicsr utils might help,
# but its core logic is in the CodeFormer library.
"CodeFormer (Face Enhancement)": {"repo_id": "CodeFormer/codeformer", "filename": "codeformer.pth", "type": "face"} if CODEFORMER_AVAILABLE or BASESR_AVAILABLE else None, # Use CodeFormer if installed, otherwise rely on basicsr utilities being present
}
# Filter out unavailable models
MODEL_CONFIGS = {k: v for k, v in MODEL_CONFIGS.items() if v is not None}
# --- Model Loading Cache ---
# Use a simple cache to avoid reloading the same model multiple times
cached_model = {}
cached_model_name = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Function to load the selected model
def load_model(model_name):
global cached_model, cached_model_name
if model_name == cached_model_name and cached_model is not None:
print(f"Using cached model: {model_name}")
return cached_model, MODEL_CONFIGS[model_name]['type']
print(f"Loading model: {model_name}")
config = MODEL_CONFIGS.get(model_name)
if config is None:
return None, f"Error: Model '{model_name}' not supported or dependencies missing."
try:
model_type = config['type']
model_path = hf_hub_download(repo_id=config['repo_id'], filename=config['filename'])
if model_name == "Real-ESRGAN-x4":
if not REALESRGAN_AVAILABLE: raise ImportError("realesrgan not installed.")
model = RealESRGAN(device, scale=config['scale'])
model.load_weights(model_path)
elif model_name == "SwinIR-4x":
if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.")
# SwinIR requires specific initialization parameters
# These match the SwinIR_4x.pth model from the repo
model = SwinIR_Arch(
upscale=config['scale'], in_chans=3, img_size=64, window_size=8,
compress_ratio= -1, dilate_basis=-1, res_range=-1, attn_type='linear'
)
# Load weights, handling potential key mismatches if necessary
pretrained_dict = torch.load(model_path, map_location=device)
model.load_state_dict(pretrained_dict, strict=True) # strict=False if keys might mismatch
model.eval() # Set to evaluation mode
model.to(device)
elif model_name == "EDSR-x4":
if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.")
# EDSR architecture needs scale, num_feat, num_block
# Assuming typical values for EDSR_x4 from the repo
model = EDSR_Arch(num_feat=64, num_block=16, upscale=config['scale'])
pretrained_dict = torch.load(model_path, map_location=device)
model.load_state_dict(pretrained_dict, strict=True)
model.eval()
model.to(device)
elif model_name == "CodeFormer (Face Enhancement)":
if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE): raise ImportError("CodeFormer or basicsr not installed.")
# CodeFormer loading is more complex, often requiring instantiation with specific args
# and potentially related models (like GFPGAN for background).
# For simplicity here, we assume a basic CodeFormer instance can be created.
# This part might need adjustment based on your CodeFormer installation.
if CODEFORMER_AVAILABLE:
# This is a simplified instantiation; a real CodeFormer usage might need more args
model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9)
pretrained_dict = torch.load(model_path, map_location=device)['params_ema'] # CodeFormer often saves params_ema
# Need to handle potential DataParallel prefix if saved from DP
keys = list(pretrained_dict.keys())
if keys and keys[0].startswith('module.'):
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()}
model.load_state_dict(pretrained_dict, strict=True)
model.eval()
model.to(device)
elif BASESR_AVAILABLE:
# Fallback: If CodeFormer library isn't directly importable but basicsr is,
# we *cannot* instantiate the CodeFormer model itself unless basicsr provides it.
# This option is likely only possible if CodeFormer is installed *within* a basicsr environment
# or if basicsr provides the architecture. Given the complexity, let's just raise an error
# if CODEFORMER_AVAILABLE is False.
raise ImportError("CodeFormer library not found. BasicSR utilities alone are not enough to instantiate CodeFormer.")
else:
raise ValueError(f"Configuration missing for model: {model_name}")
# Cache the loaded model
cached_model = model
cached_model_name = model_name
return model, model_type
except ImportError as ie:
print(f"Dependency missing for {model_name}: {ie}")
return None, f"Error: Missing dependency - {ie}. Please ensure model libraries are installed."
except Exception as e:
print(f"Error loading model {model_name}: {e}")
# Clear cache on error
cached_model = None
cached_model_name = None
return None, f"Error loading model: {str(e)}"
# Function to preprocess image (PIL RGB to OpenCV BGR numpy)
def preprocess_image(image: Image.Image) -> np.ndarray:
img = np.array(image)
# OpenCV uses BGR, PIL uses RGB. Need conversion.
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
# Function to postprocess image (OpenCV BGR numpy to PIL RGB)
def postprocess_image(img: np.ndarray) -> Image.Image:
# Ensure image is in the correct range and type before converting
if img.dtype != np.uint8:
img = np.clip(img, 0, 255).astype(np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return Image.fromarray(img)
# Main processing function
def enhance_image(image: Image.Image, model_name: str):
if image is None:
return "Please upload an image.", None
status_message = f"Processing image with {model_name}..."
# Load the selected model and its type
model, model_info = load_model(model_name)
if model is None:
# model_info contains the error message if loading failed
return model_info, None
model_type = model_info # model_info is the type string ('upscale' or 'face')
try:
# Preprocess the image (PIL RGB -> OpenCV BGR)
img_np_bgr = preprocess_image(image)
# Process based on model type and specific model implementation
if model_type == "upscale":
print(f"Applying {model_name} upscaling...")
if model_name == "Real-ESRGAN-x4":
# RealESRGAN works with uint8 BGR numpy directly
output_np_bgr = model.predict(img_np_bgr)
elif model_name in ["SwinIR-4x", "EDSR-x4"]:
if not BASESR_AVAILABLE:
raise ImportError(f"basicsr is required for {model_name}")
# These models often work with float tensors (0-1 range)
# Using basicsr utils: HWC BGR uint8 -> CHW RGB float (0-1) -> send to device
img_tensor = img2tensor(img_np_bgr.astype(np.float32) / 255., bgr2rgb=True, float32=True).unsqueeze(0).to(device)
with torch.no_grad():
output_tensor = model(img_tensor)
# Using basicsr utils: CHW RGB float (0-1) -> HWC RGB uint8 -> Convert to BGR for postprocessing
output_np_rgb = tensor2img(output_tensor, rgb2bgr=False, min_max=(0, 1))
output_np_bgr = cv2.cvtColor(output_np_rgb, cv2.COLOR_RGB2BGR)
else:
raise ValueError(f"Unknown upscale model: {model_name}")
status_message = f"Image upscaled successfully with {model_name}!"
elif model_type == "face":
print(f"Applying {model_name} face enhancement...")
if model_name == "CodeFormer (Face Enhancement)":
if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE):
raise ImportError(f"CodeFormer or basicsr is required for {model_name}")
# CodeFormer's enhance method typically expects uint8 BGR numpy
# It might return multiple outputs, the first is usually the enhanced image
# Example: CodeFormer's inference script might return (restored_img, bboxes)
# We assume the image is the first element.
# Note: CodeFormer often needs additional setup/parameters for GFPGAN, etc.
# This is a simplified call.
# Ensure model is on correct device before call
if next(model.parameters()).device != device:
model.to(device)
# A minimal CodeFormer enhancement might look like this, but the actual API
# depends on the CodeFormer library version/structure you're using.
# The original CodeFormer repo's inference takes numpy BGR.
# This is a *placeholder* call assuming such a method exists and works like this:
output_np_bgr = model.enhance(img_np_bgr, w=0.5, adain=True)[0] # w and adain are common params
else:
raise ValueError(f"Unknown face enhancement model: {model_name}")
status_message = f"Face enhancement applied successfully with {model_name}!"
# Postprocess the output image (OpenCV BGR -> PIL RGB)
enhanced_image = postprocess_image(output_np_bgr)
return status_message, enhanced_image
except ImportError as ie:
return f"Error processing image: Missing dependency - {ie}", None
except Exception as e:
print(f"Error during processing: {e}")
import traceback
traceback.print_exc() # Print full traceback for debugging
return f"Error processing image: {str(e)}", None
# Gradio interface
with gr.Blocks(title="Image Upscale & Enhancement - By FebryEnsz") as demo:
gr.Markdown(
"""
# Image Upscale & Enhancement
**By FebryEnsz**
Upload an image and select a model to enhance it. Choose from multiple models for upscaling (to make it 'HD' or higher resolution) or face enhancement (to improve facial details and focus).
**Note:** This app requires specific Python libraries (`torch`, `basicsr`, `realesrgan`, `CodeFormer`) to be installed for all models to be available. If a model option is missing, its required library is not installed or found.
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image", type="pil")
# Filter available choices based on loaded configs
available_models = list(MODEL_CONFIGS.keys())
if not available_models:
model_choice = gr.Textbox(label="Select Model", value="No models available. Check dependencies.", interactive=False)
enhance_button = gr.Button("Enhance Image", interactive=False)
print("No models are available because dependencies are missing.")
else:
model_choice = gr.Dropdown(
choices=available_models,
label="Select Model",
value=available_models[0] # Default to the first available model
)
# Removed scale_slider as models are fixed scale (x4)
enhance_button = gr.Button("Enhance Image")
with gr.Column():
output_text = gr.Textbox(label="Status", max_lines=2)
output_image = gr.Image(label="Enhanced Image")
# Connect the button to the processing function
if available_models: # Only connect if models are available
enhance_button.click(
fn=enhance_image,
inputs=[image_input, model_choice],
outputs=[output_text, output_image]
)
# Launch the Gradio app
if __name__ == "__main__":
# Set torch backend for potentially better performance on some systems
if torch.backends.mps.is_available():
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # Optional: enable fallback for MPS
demo.launch()