# app.py # Image Upscale and Enhancement with Multiple Models # By FebryEnsz # SDK: Gradio # Hosted on Hugging Face Spaces import gradio as gr import torch from PIL import Image import numpy as np from huggingface_hub import hf_hub_download import cv2 import os # Import os for path handling # --- Dependency Imports (Need to be installed via pip or manual clone) --- # BasicSR related imports (for SwinIR, EDSR, CodeFormer utilities) try: from basicsr.archs.swinir_arch import SwinIR as SwinIR_Arch from basicsr.archs.edsr_arch import EDSR as EDSR_Arch from basicsr.utils import img2tensor, tensor2img BASESR_AVAILABLE = True except ImportError: print("Warning: basicsr not found. SwinIR, EDSR, and CodeFormer (using basicsr utils) will not be available.") BASESR_AVAILABLE = False # RealESRGAN import try: from realesrgan import RealESRGAN REALESRGAN_AVAILABLE = True except ImportError: print("Warning: realesrgan not found. Real-ESRGAN-x4 will not be available.") REALESRGAN_AVAILABLE = False # CodeFormer import (This assumes CodeFormer is installed and importable, # or integrated into basicsr's structure) - often requires manual setup. # We will use basicsr's utilities for CodeFormer if available, and try a direct import if possible. try: # Attempting a common import path if CodeFormer is installed separately from CodeFormer import CodeFormer # Adjust import based on your CodeFormer install CODEFORMER_AVAILABLE = True except ImportError: print("Warning: CodeFormer not found. CodeFormer (Face Enhancement) will not be available.") CODEFORMER_AVAILABLE = False # --- Model Configuration --- # Dictionary of available models and their configuration # format: "UI Name": {"repo_id": "hf_repo_id", "filename": "weight_filename", "type": "upscale" or "face"} MODEL_CONFIGS = { "Real-ESRGAN-x4": {"repo_id": "RealESRGAN/RealESRGAN_x4plus", "filename": "RealESRGAN_x4plus.pth", "type": "upscale", "scale": 4} if REALESRGAN_AVAILABLE else None, "SwinIR-4x": {"repo_id": "SwinIR/SwinIR-Large", "filename": "SwinIR_4x.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None, "EDSR-x4": {"repo_id": "EDSR/edsr_x4", "filename": "edsr_x4.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None, # Note: CodeFormer often requires its own setup. Assuming basicsr utils might help, # but its core logic is in the CodeFormer library. "CodeFormer (Face Enhancement)": {"repo_id": "CodeFormer/codeformer", "filename": "codeformer.pth", "type": "face"} if CODEFORMER_AVAILABLE or BASESR_AVAILABLE else None, # Use CodeFormer if installed, otherwise rely on basicsr utilities being present } # Filter out unavailable models MODEL_CONFIGS = {k: v for k, v in MODEL_CONFIGS.items() if v is not None} # --- Model Loading Cache --- # Use a simple cache to avoid reloading the same model multiple times cached_model = {} cached_model_name = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Function to load the selected model def load_model(model_name): global cached_model, cached_model_name if model_name == cached_model_name and cached_model is not None: print(f"Using cached model: {model_name}") return cached_model, MODEL_CONFIGS[model_name]['type'] print(f"Loading model: {model_name}") config = MODEL_CONFIGS.get(model_name) if config is None: return None, f"Error: Model '{model_name}' not supported or dependencies missing." try: model_type = config['type'] model_path = hf_hub_download(repo_id=config['repo_id'], filename=config['filename']) if model_name == "Real-ESRGAN-x4": if not REALESRGAN_AVAILABLE: raise ImportError("realesrgan not installed.") model = RealESRGAN(device, scale=config['scale']) model.load_weights(model_path) elif model_name == "SwinIR-4x": if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.") # SwinIR requires specific initialization parameters # These match the SwinIR_4x.pth model from the repo model = SwinIR_Arch( upscale=config['scale'], in_chans=3, img_size=64, window_size=8, compress_ratio= -1, dilate_basis=-1, res_range=-1, attn_type='linear' ) # Load weights, handling potential key mismatches if necessary pretrained_dict = torch.load(model_path, map_location=device) model.load_state_dict(pretrained_dict, strict=True) # strict=False if keys might mismatch model.eval() # Set to evaluation mode model.to(device) elif model_name == "EDSR-x4": if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.") # EDSR architecture needs scale, num_feat, num_block # Assuming typical values for EDSR_x4 from the repo model = EDSR_Arch(num_feat=64, num_block=16, upscale=config['scale']) pretrained_dict = torch.load(model_path, map_location=device) model.load_state_dict(pretrained_dict, strict=True) model.eval() model.to(device) elif model_name == "CodeFormer (Face Enhancement)": if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE): raise ImportError("CodeFormer or basicsr not installed.") # CodeFormer loading is more complex, often requiring instantiation with specific args # and potentially related models (like GFPGAN for background). # For simplicity here, we assume a basic CodeFormer instance can be created. # This part might need adjustment based on your CodeFormer installation. if CODEFORMER_AVAILABLE: # This is a simplified instantiation; a real CodeFormer usage might need more args model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9) pretrained_dict = torch.load(model_path, map_location=device)['params_ema'] # CodeFormer often saves params_ema # Need to handle potential DataParallel prefix if saved from DP keys = list(pretrained_dict.keys()) if keys and keys[0].startswith('module.'): pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()} model.load_state_dict(pretrained_dict, strict=True) model.eval() model.to(device) elif BASESR_AVAILABLE: # Fallback: If CodeFormer library isn't directly importable but basicsr is, # we *cannot* instantiate the CodeFormer model itself unless basicsr provides it. # This option is likely only possible if CodeFormer is installed *within* a basicsr environment # or if basicsr provides the architecture. Given the complexity, let's just raise an error # if CODEFORMER_AVAILABLE is False. raise ImportError("CodeFormer library not found. BasicSR utilities alone are not enough to instantiate CodeFormer.") else: raise ValueError(f"Configuration missing for model: {model_name}") # Cache the loaded model cached_model = model cached_model_name = model_name return model, model_type except ImportError as ie: print(f"Dependency missing for {model_name}: {ie}") return None, f"Error: Missing dependency - {ie}. Please ensure model libraries are installed." except Exception as e: print(f"Error loading model {model_name}: {e}") # Clear cache on error cached_model = None cached_model_name = None return None, f"Error loading model: {str(e)}" # Function to preprocess image (PIL RGB to OpenCV BGR numpy) def preprocess_image(image: Image.Image) -> np.ndarray: img = np.array(image) # OpenCV uses BGR, PIL uses RGB. Need conversion. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) return img # Function to postprocess image (OpenCV BGR numpy to PIL RGB) def postprocess_image(img: np.ndarray) -> Image.Image: # Ensure image is in the correct range and type before converting if img.dtype != np.uint8: img = np.clip(img, 0, 255).astype(np.uint8) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return Image.fromarray(img) # Main processing function def enhance_image(image: Image.Image, model_name: str): if image is None: return "Please upload an image.", None status_message = f"Processing image with {model_name}..." # Load the selected model and its type model, model_info = load_model(model_name) if model is None: # model_info contains the error message if loading failed return model_info, None model_type = model_info # model_info is the type string ('upscale' or 'face') try: # Preprocess the image (PIL RGB -> OpenCV BGR) img_np_bgr = preprocess_image(image) # Process based on model type and specific model implementation if model_type == "upscale": print(f"Applying {model_name} upscaling...") if model_name == "Real-ESRGAN-x4": # RealESRGAN works with uint8 BGR numpy directly output_np_bgr = model.predict(img_np_bgr) elif model_name in ["SwinIR-4x", "EDSR-x4"]: if not BASESR_AVAILABLE: raise ImportError(f"basicsr is required for {model_name}") # These models often work with float tensors (0-1 range) # Using basicsr utils: HWC BGR uint8 -> CHW RGB float (0-1) -> send to device img_tensor = img2tensor(img_np_bgr.astype(np.float32) / 255., bgr2rgb=True, float32=True).unsqueeze(0).to(device) with torch.no_grad(): output_tensor = model(img_tensor) # Using basicsr utils: CHW RGB float (0-1) -> HWC RGB uint8 -> Convert to BGR for postprocessing output_np_rgb = tensor2img(output_tensor, rgb2bgr=False, min_max=(0, 1)) output_np_bgr = cv2.cvtColor(output_np_rgb, cv2.COLOR_RGB2BGR) else: raise ValueError(f"Unknown upscale model: {model_name}") status_message = f"Image upscaled successfully with {model_name}!" elif model_type == "face": print(f"Applying {model_name} face enhancement...") if model_name == "CodeFormer (Face Enhancement)": if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE): raise ImportError(f"CodeFormer or basicsr is required for {model_name}") # CodeFormer's enhance method typically expects uint8 BGR numpy # It might return multiple outputs, the first is usually the enhanced image # Example: CodeFormer's inference script might return (restored_img, bboxes) # We assume the image is the first element. # Note: CodeFormer often needs additional setup/parameters for GFPGAN, etc. # This is a simplified call. # Ensure model is on correct device before call if next(model.parameters()).device != device: model.to(device) # A minimal CodeFormer enhancement might look like this, but the actual API # depends on the CodeFormer library version/structure you're using. # The original CodeFormer repo's inference takes numpy BGR. # This is a *placeholder* call assuming such a method exists and works like this: output_np_bgr = model.enhance(img_np_bgr, w=0.5, adain=True)[0] # w and adain are common params else: raise ValueError(f"Unknown face enhancement model: {model_name}") status_message = f"Face enhancement applied successfully with {model_name}!" # Postprocess the output image (OpenCV BGR -> PIL RGB) enhanced_image = postprocess_image(output_np_bgr) return status_message, enhanced_image except ImportError as ie: return f"Error processing image: Missing dependency - {ie}", None except Exception as e: print(f"Error during processing: {e}") import traceback traceback.print_exc() # Print full traceback for debugging return f"Error processing image: {str(e)}", None # Gradio interface with gr.Blocks(title="Image Upscale & Enhancement - By FebryEnsz") as demo: gr.Markdown( """ # Image Upscale & Enhancement **By FebryEnsz** Upload an image and select a model to enhance it. Choose from multiple models for upscaling (to make it 'HD' or higher resolution) or face enhancement (to improve facial details and focus). **Note:** This app requires specific Python libraries (`torch`, `basicsr`, `realesrgan`, `CodeFormer`) to be installed for all models to be available. If a model option is missing, its required library is not installed or found. """ ) with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload Image", type="pil") # Filter available choices based on loaded configs available_models = list(MODEL_CONFIGS.keys()) if not available_models: model_choice = gr.Textbox(label="Select Model", value="No models available. Check dependencies.", interactive=False) enhance_button = gr.Button("Enhance Image", interactive=False) print("No models are available because dependencies are missing.") else: model_choice = gr.Dropdown( choices=available_models, label="Select Model", value=available_models[0] # Default to the first available model ) # Removed scale_slider as models are fixed scale (x4) enhance_button = gr.Button("Enhance Image") with gr.Column(): output_text = gr.Textbox(label="Status", max_lines=2) output_image = gr.Image(label="Enhanced Image") # Connect the button to the processing function if available_models: # Only connect if models are available enhance_button.click( fn=enhance_image, inputs=[image_input, model_choice], outputs=[output_text, output_image] ) # Launch the Gradio app if __name__ == "__main__": # Set torch backend for potentially better performance on some systems if torch.backends.mps.is_available(): os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # Optional: enable fallback for MPS demo.launch()