Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Video Upscaler for Blissful Tuner Extension | |
License: Apache 2.0 | |
Created on Wed Apr 23 10:19:19 2025 | |
@author: blyss | |
""" | |
from typing import List | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from rich.traceback import install as install_rich_tracebacks | |
from swinir.network_swinir import SwinIR | |
from spandrel import ImageModelDescriptor, ModelLoader | |
from video_processing_common import BlissfulVideoProcessor, set_seed, setup_parser_video_common | |
from utils import setup_compute_context, load_torch_file, BlissfulLogger | |
logger = BlissfulLogger(__name__, "#8e00ed") | |
install_rich_tracebacks() | |
def upscale_frames_swin( | |
model: torch.nn.Module, | |
frames: List[np.ndarray], | |
VideoProcessor: BlissfulVideoProcessor | |
) -> List[np.ndarray]: | |
""" | |
Upscale a list of RGB frames using a compiled SwinIR model. | |
Args: | |
model: Loaded SwinIR upsampler. | |
frames: List of H×W×3 float32 RGB arrays in [0,1]. | |
device: torch device (cpu or cuda). | |
dtype: torch.dtype to use for computation. | |
Returns: | |
List of upscaled H'×W'×3 uint8 BGR frames. | |
""" | |
window_size = 8 | |
for img in tqdm(frames, desc="Upscaling SwinIR"): | |
# Mark step for CUDA graph capture if enabled | |
torch.compiler.cudagraph_mark_step_begin() | |
# Convert HWC RGB → CHW tensor | |
tensor = VideoProcessor.np_image_to_tensor(img) | |
# Pad to window multiple | |
_, _, h, w = tensor.shape | |
h_pad = ((h + window_size - 1) // window_size) * window_size - h | |
w_pad = ((w + window_size - 1) // window_size) * window_size - w | |
tensor = torch.cat([tensor, torch.flip(tensor, [2])], 2)[:, :, : h + h_pad, :] | |
tensor = torch.cat([tensor, torch.flip(tensor, [3])], 3)[:, :, :, : w + w_pad] | |
# Inference | |
with torch.no_grad(): | |
out = model(tensor) | |
# Post-process: NCHW → HWC BGR uint8 | |
VideoProcessor.write_np_or_tensor_to_png(out) | |
def load_swin_model( | |
model_path: str, | |
device: torch.device, | |
dtype: torch.dtype, | |
) -> torch.nn.Module: | |
""" | |
Instantiate and load weights into a SwinIR model. | |
Args: | |
model_path: Path to checkpoint (.pth or safetensors). | |
device: torch device. | |
dtype: torch dtype. | |
Returns: | |
SwinIR model in eval() on device and dtype. | |
""" | |
logger.info(f"Loading SwinIR model ({dtype})…") | |
model = SwinIR( | |
upscale=4, | |
in_chans=3, | |
img_size=64, | |
window_size=8, | |
img_range=1.0, | |
depths=[6] * 9, | |
embed_dim=240, | |
num_heads=[8] * 9, | |
mlp_ratio=2, | |
upsampler='nearest+conv', | |
resi_connection='3conv', | |
) | |
ckpt = load_torch_file(model_path) | |
key = 'params_ema' if 'params_ema' in ckpt else None | |
model.load_state_dict(ckpt[key] if key else ckpt, strict=True) | |
model.to(device, dtype).eval() | |
return model | |
def load_esrgan_model( | |
model_path: str, | |
device: torch.device, | |
dtype: torch.dtype, | |
) -> torch.nn.Module: | |
""" | |
Load an ESRGAN (or RRDBNet) style model via Spandrel loader. | |
Args: | |
model_path: Path to ESRGAN checkpoint. | |
device: torch device. | |
dtype: torch dtype. | |
Returns: | |
Model ready for inference. | |
""" | |
logger.info(f"Loading ESRGAN model ({dtype})…") | |
descriptor = ModelLoader().load_from_file(model_path) | |
assert isinstance(descriptor, ImageModelDescriptor) | |
model = descriptor.model.eval().to(device, dtype) | |
return model | |
def main() -> None: | |
""" | |
Parse CLI args, load input, model, and run upscaling pipeline. | |
""" | |
parser = setup_parser_video_common(description="Video upscaling using SwinIR or ESRGAN models") | |
parser.add_argument( | |
"--scale", type=float, default=2, | |
help="Final scale multiplier for output resolution" | |
) | |
parser.add_argument( | |
"--mode", choices=["swinir", "esrgan"], default="swinir", | |
help="Model architecture to use" | |
) | |
args = parser.parse_args() | |
args.mode = args.mode.lower() | |
# Map string → torch.dtype | |
device, dtype = setup_compute_context(None, args.dtype) | |
VideoProcessor = BlissfulVideoProcessor(device, dtype) | |
VideoProcessor.prepare_files_and_path(args.input, args.output, args.mode.upper()) | |
frames, fps, w, h = VideoProcessor.load_frames(make_rgb=True) | |
set_seed(args.seed) | |
# Load and run model | |
if args.mode == "swinir": | |
model = load_swin_model(args.model, device, dtype) | |
upscale_frames_swin(model, frames, VideoProcessor) | |
else: | |
model = load_esrgan_model(args.model, device, dtype) | |
logger.info("Processing with ESRGAN...") | |
for frame in tqdm(frames, desc="Upscaling ESRGAN"): | |
inp = VideoProcessor.np_image_to_tensor(frame) | |
with torch.no_grad(): | |
sr = model(inp) | |
VideoProcessor.write_np_or_tensor_to_png(sr) | |
# Write video | |
logger.info("Encoding output video...") | |
out_w, out_h = int(w * args.scale), int(h * args.scale) | |
VideoProcessor.write_buffered_frames_to_output(fps, args.keep_pngs, (out_w, out_h)) | |
logger.info("Done!") | |
if __name__ == "__main__": | |
main() | |