|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
from huggingface_hub import hf_hub_download |
|
import cv2 |
|
import os |
|
|
|
|
|
|
|
try: |
|
from basicsr.archs.swinir_arch import SwinIR as SwinIR_Arch |
|
from basicsr.archs.edsr_arch import EDSR as EDSR_Arch |
|
from basicsr.utils import img2tensor, tensor2img |
|
BASESR_AVAILABLE = True |
|
except ImportError: |
|
print("Warning: basicsr not found. SwinIR, EDSR, and CodeFormer (using basicsr utils) will not be available.") |
|
BASESR_AVAILABLE = False |
|
|
|
|
|
try: |
|
from realesrgan import RealESRGAN |
|
REALESRGAN_AVAILABLE = True |
|
except ImportError: |
|
print("Warning: realesrgan not found. Real-ESRGAN-x4 will not be available.") |
|
REALESRGAN_AVAILABLE = False |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
from CodeFormer import CodeFormer |
|
CODEFORMER_AVAILABLE = True |
|
except ImportError: |
|
print("Warning: CodeFormer not found. CodeFormer (Face Enhancement) will not be available.") |
|
CODEFORMER_AVAILABLE = False |
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIGS = { |
|
"Real-ESRGAN-x4": {"repo_id": "RealESRGAN/RealESRGAN_x4plus", "filename": "RealESRGAN_x4plus.pth", "type": "upscale", "scale": 4} if REALESRGAN_AVAILABLE else None, |
|
"SwinIR-4x": {"repo_id": "SwinIR/SwinIR-Large", "filename": "SwinIR_4x.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None, |
|
"EDSR-x4": {"repo_id": "EDSR/edsr_x4", "filename": "edsr_x4.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None, |
|
|
|
|
|
"CodeFormer (Face Enhancement)": {"repo_id": "CodeFormer/codeformer", "filename": "codeformer.pth", "type": "face"} if CODEFORMER_AVAILABLE or BASESR_AVAILABLE else None, |
|
} |
|
|
|
|
|
MODEL_CONFIGS = {k: v for k, v in MODEL_CONFIGS.items() if v is not None} |
|
|
|
|
|
|
|
cached_model = {} |
|
cached_model_name = None |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
def load_model(model_name): |
|
global cached_model, cached_model_name |
|
|
|
if model_name == cached_model_name and cached_model is not None: |
|
print(f"Using cached model: {model_name}") |
|
return cached_model, MODEL_CONFIGS[model_name]['type'] |
|
|
|
print(f"Loading model: {model_name}") |
|
config = MODEL_CONFIGS.get(model_name) |
|
if config is None: |
|
return None, f"Error: Model '{model_name}' not supported or dependencies missing." |
|
|
|
try: |
|
model_type = config['type'] |
|
model_path = hf_hub_download(repo_id=config['repo_id'], filename=config['filename']) |
|
|
|
if model_name == "Real-ESRGAN-x4": |
|
if not REALESRGAN_AVAILABLE: raise ImportError("realesrgan not installed.") |
|
model = RealESRGAN(device, scale=config['scale']) |
|
model.load_weights(model_path) |
|
|
|
elif model_name == "SwinIR-4x": |
|
if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.") |
|
|
|
|
|
model = SwinIR_Arch( |
|
upscale=config['scale'], in_chans=3, img_size=64, window_size=8, |
|
compress_ratio= -1, dilate_basis=-1, res_range=-1, attn_type='linear' |
|
) |
|
|
|
pretrained_dict = torch.load(model_path, map_location=device) |
|
model.load_state_dict(pretrained_dict, strict=True) |
|
model.eval() |
|
model.to(device) |
|
|
|
elif model_name == "EDSR-x4": |
|
if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.") |
|
|
|
|
|
model = EDSR_Arch(num_feat=64, num_block=16, upscale=config['scale']) |
|
pretrained_dict = torch.load(model_path, map_location=device) |
|
model.load_state_dict(pretrained_dict, strict=True) |
|
model.eval() |
|
model.to(device) |
|
|
|
elif model_name == "CodeFormer (Face Enhancement)": |
|
if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE): raise ImportError("CodeFormer or basicsr not installed.") |
|
|
|
|
|
|
|
|
|
if CODEFORMER_AVAILABLE: |
|
|
|
model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9) |
|
pretrained_dict = torch.load(model_path, map_location=device)['params_ema'] |
|
|
|
keys = list(pretrained_dict.keys()) |
|
if keys and keys[0].startswith('module.'): |
|
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()} |
|
model.load_state_dict(pretrained_dict, strict=True) |
|
model.eval() |
|
model.to(device) |
|
elif BASESR_AVAILABLE: |
|
|
|
|
|
|
|
|
|
|
|
raise ImportError("CodeFormer library not found. BasicSR utilities alone are not enough to instantiate CodeFormer.") |
|
|
|
|
|
else: |
|
raise ValueError(f"Configuration missing for model: {model_name}") |
|
|
|
|
|
cached_model = model |
|
cached_model_name = model_name |
|
|
|
return model, model_type |
|
|
|
except ImportError as ie: |
|
print(f"Dependency missing for {model_name}: {ie}") |
|
return None, f"Error: Missing dependency - {ie}. Please ensure model libraries are installed." |
|
except Exception as e: |
|
print(f"Error loading model {model_name}: {e}") |
|
|
|
cached_model = None |
|
cached_model_name = None |
|
return None, f"Error loading model: {str(e)}" |
|
|
|
|
|
def preprocess_image(image: Image.Image) -> np.ndarray: |
|
img = np.array(image) |
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
return img |
|
|
|
|
|
def postprocess_image(img: np.ndarray) -> Image.Image: |
|
|
|
if img.dtype != np.uint8: |
|
img = np.clip(img, 0, 255).astype(np.uint8) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
return Image.fromarray(img) |
|
|
|
|
|
def enhance_image(image: Image.Image, model_name: str): |
|
if image is None: |
|
return "Please upload an image.", None |
|
|
|
status_message = f"Processing image with {model_name}..." |
|
|
|
|
|
model, model_info = load_model(model_name) |
|
|
|
if model is None: |
|
|
|
return model_info, None |
|
|
|
model_type = model_info |
|
|
|
try: |
|
|
|
img_np_bgr = preprocess_image(image) |
|
|
|
|
|
if model_type == "upscale": |
|
print(f"Applying {model_name} upscaling...") |
|
if model_name == "Real-ESRGAN-x4": |
|
|
|
output_np_bgr = model.predict(img_np_bgr) |
|
elif model_name in ["SwinIR-4x", "EDSR-x4"]: |
|
if not BASESR_AVAILABLE: |
|
raise ImportError(f"basicsr is required for {model_name}") |
|
|
|
|
|
img_tensor = img2tensor(img_np_bgr.astype(np.float32) / 255., bgr2rgb=True, float32=True).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
output_tensor = model(img_tensor) |
|
|
|
|
|
output_np_rgb = tensor2img(output_tensor, rgb2bgr=False, min_max=(0, 1)) |
|
output_np_bgr = cv2.cvtColor(output_np_rgb, cv2.COLOR_RGB2BGR) |
|
|
|
else: |
|
raise ValueError(f"Unknown upscale model: {model_name}") |
|
|
|
status_message = f"Image upscaled successfully with {model_name}!" |
|
|
|
elif model_type == "face": |
|
print(f"Applying {model_name} face enhancement...") |
|
if model_name == "CodeFormer (Face Enhancement)": |
|
if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE): |
|
raise ImportError(f"CodeFormer or basicsr is required for {model_name}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if next(model.parameters()).device != device: |
|
model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
output_np_bgr = model.enhance(img_np_bgr, w=0.5, adain=True)[0] |
|
|
|
|
|
else: |
|
raise ValueError(f"Unknown face enhancement model: {model_name}") |
|
|
|
status_message = f"Face enhancement applied successfully with {model_name}!" |
|
|
|
|
|
enhanced_image = postprocess_image(output_np_bgr) |
|
|
|
return status_message, enhanced_image |
|
|
|
except ImportError as ie: |
|
return f"Error processing image: Missing dependency - {ie}", None |
|
except Exception as e: |
|
print(f"Error during processing: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return f"Error processing image: {str(e)}", None |
|
|
|
|
|
with gr.Blocks(title="Image Upscale & Enhancement - By FebryEnsz") as demo: |
|
gr.Markdown( |
|
""" |
|
# Image Upscale & Enhancement |
|
**By FebryEnsz** |
|
|
|
Upload an image and select a model to enhance it. Choose from multiple models for upscaling (to make it 'HD' or higher resolution) or face enhancement (to improve facial details and focus). |
|
|
|
**Note:** This app requires specific Python libraries (`torch`, `basicsr`, `realesrgan`, `CodeFormer`) to be installed for all models to be available. If a model option is missing, its required library is not installed or found. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="Upload Image", type="pil") |
|
|
|
|
|
available_models = list(MODEL_CONFIGS.keys()) |
|
|
|
if not available_models: |
|
model_choice = gr.Textbox(label="Select Model", value="No models available. Check dependencies.", interactive=False) |
|
enhance_button = gr.Button("Enhance Image", interactive=False) |
|
print("No models are available because dependencies are missing.") |
|
else: |
|
model_choice = gr.Dropdown( |
|
choices=available_models, |
|
label="Select Model", |
|
value=available_models[0] |
|
) |
|
|
|
enhance_button = gr.Button("Enhance Image") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox(label="Status", max_lines=2) |
|
output_image = gr.Image(label="Enhanced Image") |
|
|
|
|
|
if available_models: |
|
enhance_button.click( |
|
fn=enhance_image, |
|
inputs=[image_input, model_choice], |
|
outputs=[output_text, output_image] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if torch.backends.mps.is_available(): |
|
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' |
|
|
|
demo.launch() |