File size: 15,006 Bytes
f60e836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
# app.py
# Image Upscale and Enhancement with Multiple Models
# By FebryEnsz
# SDK: Gradio
# Hosted on Hugging Face Spaces

import gradio as gr
import torch
from PIL import Image
import numpy as np
from huggingface_hub import hf_hub_download
import cv2
import os # Import os for path handling

# --- Dependency Imports (Need to be installed via pip or manual clone) ---
# BasicSR related imports (for SwinIR, EDSR, CodeFormer utilities)
try:
    from basicsr.archs.swinir_arch import SwinIR as SwinIR_Arch
    from basicsr.archs.edsr_arch import EDSR as EDSR_Arch
    from basicsr.utils import img2tensor, tensor2img
    BASESR_AVAILABLE = True
except ImportError:
    print("Warning: basicsr not found. SwinIR, EDSR, and CodeFormer (using basicsr utils) will not be available.")
    BASESR_AVAILABLE = False

# RealESRGAN import
try:
    from realesrgan import RealESRGAN
    REALESRGAN_AVAILABLE = True
except ImportError:
    print("Warning: realesrgan not found. Real-ESRGAN-x4 will not be available.")
    REALESRGAN_AVAILABLE = False

# CodeFormer import (This assumes CodeFormer is installed and importable,
# or integrated into basicsr's structure) - often requires manual setup.
# We will use basicsr's utilities for CodeFormer if available, and try a direct import if possible.
try:
    # Attempting a common import path if CodeFormer is installed separately
    from CodeFormer import CodeFormer # Adjust import based on your CodeFormer install
    CODEFORMER_AVAILABLE = True
except ImportError:
    print("Warning: CodeFormer not found. CodeFormer (Face Enhancement) will not be available.")
    CODEFORMER_AVAILABLE = False


# --- Model Configuration ---
# Dictionary of available models and their configuration
# format: "UI Name": {"repo_id": "hf_repo_id", "filename": "weight_filename", "type": "upscale" or "face"}
MODEL_CONFIGS = {
    "Real-ESRGAN-x4": {"repo_id": "RealESRGAN/RealESRGAN_x4plus", "filename": "RealESRGAN_x4plus.pth", "type": "upscale", "scale": 4} if REALESRGAN_AVAILABLE else None,
    "SwinIR-4x": {"repo_id": "SwinIR/SwinIR-Large", "filename": "SwinIR_4x.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None,
    "EDSR-x4": {"repo_id": "EDSR/edsr_x4", "filename": "edsr_x4.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None,
    # Note: CodeFormer often requires its own setup. Assuming basicsr utils might help,
    # but its core logic is in the CodeFormer library.
    "CodeFormer (Face Enhancement)": {"repo_id": "CodeFormer/codeformer", "filename": "codeformer.pth", "type": "face"} if CODEFORMER_AVAILABLE or BASESR_AVAILABLE else None, # Use CodeFormer if installed, otherwise rely on basicsr utilities being present
}

# Filter out unavailable models
MODEL_CONFIGS = {k: v for k, v in MODEL_CONFIGS.items() if v is not None}

# --- Model Loading Cache ---
# Use a simple cache to avoid reloading the same model multiple times
cached_model = {}
cached_model_name = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Function to load the selected model
def load_model(model_name):
    global cached_model, cached_model_name

    if model_name == cached_model_name and cached_model is not None:
        print(f"Using cached model: {model_name}")
        return cached_model, MODEL_CONFIGS[model_name]['type']

    print(f"Loading model: {model_name}")
    config = MODEL_CONFIGS.get(model_name)
    if config is None:
        return None, f"Error: Model '{model_name}' not supported or dependencies missing."

    try:
        model_type = config['type']
        model_path = hf_hub_download(repo_id=config['repo_id'], filename=config['filename'])

        if model_name == "Real-ESRGAN-x4":
            if not REALESRGAN_AVAILABLE: raise ImportError("realesrgan not installed.")
            model = RealESRGAN(device, scale=config['scale'])
            model.load_weights(model_path)
        
        elif model_name == "SwinIR-4x":
            if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.")
            # SwinIR requires specific initialization parameters
            # These match the SwinIR_4x.pth model from the repo
            model = SwinIR_Arch(
                upscale=config['scale'], in_chans=3, img_size=64, window_size=8,
                compress_ratio= -1, dilate_basis=-1, res_range=-1, attn_type='linear'
            )
            # Load weights, handling potential key mismatches if necessary
            pretrained_dict = torch.load(model_path, map_location=device)
            model.load_state_dict(pretrained_dict, strict=True) # strict=False if keys might mismatch
            model.eval() # Set to evaluation mode
            model.to(device)
        
        elif model_name == "EDSR-x4":
            if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.")
            # EDSR architecture needs scale, num_feat, num_block
            # Assuming typical values for EDSR_x4 from the repo
            model = EDSR_Arch(num_feat=64, num_block=16, upscale=config['scale'])
            pretrained_dict = torch.load(model_path, map_location=device)
            model.load_state_dict(pretrained_dict, strict=True)
            model.eval()
            model.to(device)
            
        elif model_name == "CodeFormer (Face Enhancement)":
             if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE): raise ImportError("CodeFormer or basicsr not installed.")
             # CodeFormer loading is more complex, often requiring instantiation with specific args
             # and potentially related models (like GFPGAN for background).
             # For simplicity here, we assume a basic CodeFormer instance can be created.
             # This part might need adjustment based on your CodeFormer installation.
             if CODEFORMER_AVAILABLE:
                # This is a simplified instantiation; a real CodeFormer usage might need more args
                model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9)
                pretrained_dict = torch.load(model_path, map_location=device)['params_ema'] # CodeFormer often saves params_ema
                # Need to handle potential DataParallel prefix if saved from DP
                keys = list(pretrained_dict.keys())
                if keys and keys[0].startswith('module.'):
                    pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()}
                model.load_state_dict(pretrained_dict, strict=True)
                model.eval()
                model.to(device)
             elif BASESR_AVAILABLE:
                 # Fallback: If CodeFormer library isn't directly importable but basicsr is,
                 # we *cannot* instantiate the CodeFormer model itself unless basicsr provides it.
                 # This option is likely only possible if CodeFormer is installed *within* a basicsr environment
                 # or if basicsr provides the architecture. Given the complexity, let's just raise an error
                 # if CODEFORMER_AVAILABLE is False.
                 raise ImportError("CodeFormer library not found. BasicSR utilities alone are not enough to instantiate CodeFormer.")


        else:
            raise ValueError(f"Configuration missing for model: {model_name}")

        # Cache the loaded model
        cached_model = model
        cached_model_name = model_name

        return model, model_type
    
    except ImportError as ie:
        print(f"Dependency missing for {model_name}: {ie}")
        return None, f"Error: Missing dependency - {ie}. Please ensure model libraries are installed."
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        # Clear cache on error
        cached_model = None
        cached_model_name = None
        return None, f"Error loading model: {str(e)}"

# Function to preprocess image (PIL RGB to OpenCV BGR numpy)
def preprocess_image(image: Image.Image) -> np.ndarray:
    img = np.array(image)
    # OpenCV uses BGR, PIL uses RGB. Need conversion.
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    return img

# Function to postprocess image (OpenCV BGR numpy to PIL RGB)
def postprocess_image(img: np.ndarray) -> Image.Image:
    # Ensure image is in the correct range and type before converting
    if img.dtype != np.uint8:
         img = np.clip(img, 0, 255).astype(np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return Image.fromarray(img)

# Main processing function
def enhance_image(image: Image.Image, model_name: str):
    if image is None:
        return "Please upload an image.", None
    
    status_message = f"Processing image with {model_name}..."
    
    # Load the selected model and its type
    model, model_info = load_model(model_name)
    
    if model is None:
        # model_info contains the error message if loading failed
        return model_info, None
    
    model_type = model_info # model_info is the type string ('upscale' or 'face')

    try:
        # Preprocess the image (PIL RGB -> OpenCV BGR)
        img_np_bgr = preprocess_image(image)
        
        # Process based on model type and specific model implementation
        if model_type == "upscale":
            print(f"Applying {model_name} upscaling...")
            if model_name == "Real-ESRGAN-x4":
                 # RealESRGAN works with uint8 BGR numpy directly
                output_np_bgr = model.predict(img_np_bgr)
            elif model_name in ["SwinIR-4x", "EDSR-x4"]:
                if not BASESR_AVAILABLE:
                     raise ImportError(f"basicsr is required for {model_name}")
                # These models often work with float tensors (0-1 range)
                # Using basicsr utils: HWC BGR uint8 -> CHW RGB float (0-1) -> send to device
                img_tensor = img2tensor(img_np_bgr.astype(np.float32) / 255., bgr2rgb=True, float32=True).unsqueeze(0).to(device)

                with torch.no_grad():
                    output_tensor = model(img_tensor)

                # Using basicsr utils: CHW RGB float (0-1) -> HWC RGB uint8 -> Convert to BGR for postprocessing
                output_np_rgb = tensor2img(output_tensor, rgb2bgr=False, min_max=(0, 1))
                output_np_bgr = cv2.cvtColor(output_np_rgb, cv2.COLOR_RGB2BGR)
                
            else:
                raise ValueError(f"Unknown upscale model: {model_name}")
                
            status_message = f"Image upscaled successfully with {model_name}!"

        elif model_type == "face":
            print(f"Applying {model_name} face enhancement...")
            if model_name == "CodeFormer (Face Enhancement)":
                if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE):
                    raise ImportError(f"CodeFormer or basicsr is required for {model_name}")
                # CodeFormer's enhance method typically expects uint8 BGR numpy
                # It might return multiple outputs, the first is usually the enhanced image
                # Example: CodeFormer's inference script might return (restored_img, bboxes)
                # We assume the image is the first element.
                # Note: CodeFormer often needs additional setup/parameters for GFPGAN, etc.
                # This is a simplified call.
                # Ensure model is on correct device before call
                if next(model.parameters()).device != device:
                    model.to(device)
                
                # A minimal CodeFormer enhancement might look like this, but the actual API
                # depends on the CodeFormer library version/structure you're using.
                # The original CodeFormer repo's inference takes numpy BGR.
                # This is a *placeholder* call assuming such a method exists and works like this:
                output_np_bgr = model.enhance(img_np_bgr, w=0.5, adain=True)[0] # w and adain are common params


            else:
                 raise ValueError(f"Unknown face enhancement model: {model_name}")

            status_message = f"Face enhancement applied successfully with {model_name}!"

        # Postprocess the output image (OpenCV BGR -> PIL RGB)
        enhanced_image = postprocess_image(output_np_bgr)
        
        return status_message, enhanced_image
    
    except ImportError as ie:
         return f"Error processing image: Missing dependency - {ie}", None
    except Exception as e:
        print(f"Error during processing: {e}")
        import traceback
        traceback.print_exc() # Print full traceback for debugging
        return f"Error processing image: {str(e)}", None

# Gradio interface
with gr.Blocks(title="Image Upscale & Enhancement - By FebryEnsz") as demo:
    gr.Markdown(
        """
        # Image Upscale & Enhancement
        **By FebryEnsz**
        
        Upload an image and select a model to enhance it. Choose from multiple models for upscaling (to make it 'HD' or higher resolution) or face enhancement (to improve facial details and focus).
        
        **Note:** This app requires specific Python libraries (`torch`, `basicsr`, `realesrgan`, `CodeFormer`) to be installed for all models to be available. If a model option is missing, its required library is not installed or found.
        """
    )
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload Image", type="pil")
            
            # Filter available choices based on loaded configs
            available_models = list(MODEL_CONFIGS.keys())
            
            if not available_models:
                 model_choice = gr.Textbox(label="Select Model", value="No models available. Check dependencies.", interactive=False)
                 enhance_button = gr.Button("Enhance Image", interactive=False)
                 print("No models are available because dependencies are missing.")
            else:
                model_choice = gr.Dropdown(
                    choices=available_models,
                    label="Select Model",
                    value=available_models[0] # Default to the first available model
                )
                # Removed scale_slider as models are fixed scale (x4)
                enhance_button = gr.Button("Enhance Image")
            
        with gr.Column():
            output_text = gr.Textbox(label="Status", max_lines=2)
            output_image = gr.Image(label="Enhanced Image")
    
    # Connect the button to the processing function
    if available_models: # Only connect if models are available
        enhance_button.click(
            fn=enhance_image,
            inputs=[image_input, model_choice],
            outputs=[output_text, output_image]
        )

# Launch the Gradio app
if __name__ == "__main__":
    # Set torch backend for potentially better performance on some systems
    if torch.backends.mps.is_available():
        os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # Optional: enable fallback for MPS

    demo.launch()