import os import cv2 import argparse import glob import torch import numpy as np from tqdm import tqdm from torchvision.transforms.functional import normalize from basicsr.utils import imwrite, img2tensor, tensor2img from basicsr.utils.download_util import load_file_from_url from basicsr.utils.misc import gpu_is_available, get_device from scipy.ndimage import gaussian_filter1d from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.misc import is_gray from basicsr.utils.video_util import VideoReader, VideoWriter from basicsr.utils.registry import ARCH_REGISTRY import gradio as gr from torch.hub import download_url_to_file title = r"""

KEEP: Kalman-Inspired Feature Propagation for Video Face Super-Resolution

""" description = r""" Official Gradio demo for Kalman-Inspired FEaturE Propagation for Video Face Super-Resolution (ECCV 2024).
🔥 KEEP is a robust video face super-resolution algorithm.
🤗 Try to drop your own face video, and get the restored results!
""" post_article = r""" If you found KEEP helpful, please consider ⭐ the Github Repo. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/jnjaby/KEEP)](https://github.com/jnjaby/KEEP) --- 📝 **Citation**
If our work is useful for your research, please consider citing: ```bibtex @InProceedings{feng2024keep, title = {Kalman-Inspired FEaturE Propagation for Video Face Super-Resolution}, author = {Feng, Ruicheng and Li, Chongyi and Loy, Chen Change}, booktitle = {European Conference on Computer Vision (ECCV)}, year = {2024} } ``` 📋 **License**
This project is licensed under S-Lab License 1.0. Redistribution and use for non-commercial purposes should follow this license.

📧 **Contact**
If you have any questions, please feel free to reach out via ruicheng002@ntu.edu.sg. """ def interpolate_sequence(sequence): interpolated_sequence = np.copy(sequence) missing_indices = np.isnan(sequence) if np.any(missing_indices): valid_indices = ~missing_indices x = np.arange(len(sequence)) interpolated_sequence[missing_indices] = np.interp(x[missing_indices], x[valid_indices], sequence[valid_indices]) return interpolated_sequence def set_realesrgan(): from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.realesrgan_utils import RealESRGANer use_half = False if torch.cuda.is_available(): no_half_gpu_list = ['1650', '1660'] if not any(gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list): use_half = True model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) upsampler = RealESRGANer(scale=2, model_path="https://github.com/jnjaby/KEEP/releases/download/v1.0.0/RealESRGAN_x2plus.pth", model=model, tile=400, tile_pad=40, pre_pad=0, half=use_half) if not gpu_is_available(): import warnings warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA. The unoptimized RealESRGAN is slow on CPU.', category=RuntimeWarning) return upsampler def process_video(input_video, draw_box, bg_enhancement): device = get_device() args = argparse.Namespace( input_path=input_video, upscale=1, max_length=20, has_aligned=False, only_center_face=True, draw_box=draw_box, detection_model='retinaface_resnet50', bg_enhancement=bg_enhancement, face_upsample=False, bg_tile=400, suffix=None, save_video_fps=None, model_type='KEEP' ) output_dir = './results/' os.makedirs(output_dir, exist_ok=True) model_configs = { 'KEEP': { 'architecture': { 'img_size': 512, 'emb_dim': 256, 'dim_embd': 512, 'n_head': 8, 'n_layers': 9, 'codebook_size': 1024, 'cft_list': ['16', '32', '64'], 'kalman_attn_head_dim': 48, 'num_uncertainty_layers': 3, 'cfa_list': ['16', '32'], 'cfa_nhead': 4, 'cfa_dim': 256, 'cond': 1 }, 'checkpoint_dir': '/home/user/app/weights/KEEP', 'checkpoint_url': 'https://github.com/jnjaby/KEEP/releases/download/v1.0.0/KEEP-b76feb75.pth' }, } if args.bg_enhancement: bg_upsampler = set_realesrgan() else: bg_upsampler = None if args.face_upsample: face_upsampler = bg_upsampler if bg_upsampler is not None else set_realesrgan() else: face_upsampler = None if args.model_type not in model_configs: raise ValueError(f"Unknown model type: {args.model_type}. Available options: {list(model_configs.keys())}") config = model_configs[args.model_type] net = ARCH_REGISTRY.get('KEEP')(**config['architecture']).to(device) ckpt_path = load_file_from_url(url=config['checkpoint_url'], model_dir=config['checkpoint_dir'], progress=True, file_name=None) checkpoint = torch.load(ckpt_path, weights_only=True) net.load_state_dict(checkpoint['params_ema']) net.eval() if not args.has_aligned: print(f'Face detection model: {args.detection_model}') if bg_upsampler is not None: print(f'Background upsampling: True, Face upsampling: {args.face_upsample}') else: print(f'Background upsampling: False, Face upsampling: {args.face_upsample}') face_helper = FaceRestoreHelper(args.upscale, face_size=512, crop_ratio=(1, 1), det_model=args.detection_model, save_ext='png', use_parse=True, device=device) # Reading the input video. input_img_list = [] if args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): vidreader = VideoReader(args.input_path) image = vidreader.get_frame() while image is not None: input_img_list.append(image) image = vidreader.get_frame() fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps vidreader.close() clip_name = os.path.basename(args.input_path)[:-4] else: raise TypeError(f'Unrecognized type of input video {args.input_path}.') if len(input_img_list) == 0: raise FileNotFoundError('No input image/video is found...') print('Detecting keypoints and smooth alignment ...') if not args.has_aligned: raw_landmarks = [] for i, img in enumerate(input_img_list): face_helper.clean_all() face_helper.read_image(img) num_det_faces = face_helper.get_face_landmarks_5(only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5, only_keep_largest=True) if num_det_faces == 1: raw_landmarks.append(face_helper.all_landmarks_5[0].reshape((10,))) elif num_det_faces == 0: raw_landmarks.append(np.array([np.nan]*10)) raw_landmarks = np.array(raw_landmarks) for i in range(10): raw_landmarks[:, i] = interpolate_sequence(raw_landmarks[:, i]) video_length = len(input_img_list) avg_landmarks = gaussian_filter1d(raw_landmarks, 5, axis=0).reshape(video_length, 5, 2) cropped_faces = [] for i, img in enumerate(input_img_list): face_helper.clean_all() face_helper.read_image(img) face_helper.all_landmarks_5 = [avg_landmarks[i]] face_helper.align_warp_face() cropped_face_t = img2tensor(face_helper.cropped_faces[0] / 255., bgr2rgb=True, float32=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_faces.append(cropped_face_t) cropped_faces = torch.stack(cropped_faces, dim=0).unsqueeze(0).to(device) print('Restoring faces ...') with torch.no_grad(): video_length = cropped_faces.shape[1] output = [] for start_idx in range(0, video_length, args.max_length): end_idx = min(start_idx + args.max_length, video_length) if end_idx - start_idx == 1: output.append(net(cropped_faces[:, [start_idx, start_idx], ...], need_upscale=False)[:, 0:1, ...]) else: output.append(net(cropped_faces[:, start_idx:end_idx, ...], need_upscale=False)) output = torch.cat(output, dim=1).squeeze(0) assert output.shape[0] == video_length, "Different number of frames" restored_faces = [tensor2img(x, rgb2bgr=True, min_max=(-1, 1)) for x in output] del output torch.cuda.empty_cache() print('Pasting faces back ...') restored_frames = [] for i, img in enumerate(input_img_list): face_helper.clean_all() if args.has_aligned: img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) face_helper.is_gray = is_gray(img, threshold=10) if face_helper.is_gray: print('Grayscale input: True') face_helper.cropped_faces = [img] else: face_helper.read_image(img) face_helper.all_landmarks_5 = [avg_landmarks[i]] face_helper.align_warp_face() face_helper.add_restored_face(restored_faces[i].astype('uint8')) if not args.has_aligned: if bg_upsampler is not None: bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0] else: bg_img = None face_helper.get_inverse_affine(None) if args.face_upsample and face_upsampler is not None: restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler) else: restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box) restored_frames.append(restored_img) # Saving the output video. print('Saving video ...') height, width = restored_frames[0].shape[:2] save_restore_path = os.path.join(output_dir, f'{clip_name}.mp4') vidwriter = VideoWriter(save_restore_path, height, width, fps) for f in restored_frames: vidwriter.write_frame(f) vidwriter.close() print(f'All results are saved in {save_restore_path}.') return save_restore_path # Downloading necessary models and sample videos. sample_videos_dir = os.path.join("/home/user/app/hugging_face/", "test_sample/") os.makedirs(sample_videos_dir, exist_ok=True) download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_1.mp4", os.path.join(sample_videos_dir, "real_1.mp4")) download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_2.mp4", os.path.join(sample_videos_dir, "real_2.mp4")) download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_3.mp4", os.path.join(sample_videos_dir, "real_3.mp4")) download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_4.mp4", os.path.join(sample_videos_dir, "real_4.mp4")) model_dir = "/home/user/app/weights/" model_url = "https://github.com/jnjaby/KEEP/releases/download/v1.0.0/" _ = load_file_from_url(url=os.path.join(model_url, 'KEEP-b76feb75.pth'), model_dir=os.path.join(model_dir, "KEEP"), progress=True, file_name=None) _ = load_file_from_url(url=os.path.join(model_url, 'detection_Resnet50_Final.pth'), model_dir=os.path.join(model_dir, "facelib"), progress=True, file_name=None) _ = load_file_from_url(url=os.path.join(model_url, 'detection_mobilenet0.25_Final.pth'), model_dir=os.path.join(model_dir, "facelib"), progress=True, file_name=None) _ = load_file_from_url(url=os.path.join(model_url, 'yolov5n-face.pth'), model_dir=os.path.join(model_dir, "facelib"), progress=True, file_name=None) _ = load_file_from_url(url=os.path.join(model_url, 'yolov5l-face.pth'), model_dir=os.path.join(model_dir, "facelib"), progress=True, file_name=None) _ = load_file_from_url(url=os.path.join(model_url, 'parsing_parsenet.pth'), model_dir=os.path.join(model_dir, "facelib"), progress=True, file_name=None) _ = load_file_from_url(url=os.path.join(model_url, 'RealESRGAN_x2plus.pth'), model_dir=os.path.join(model_dir, "realesrgan"), progress=True, file_name=None) # Launching the Gradio interface. demo = gr.Interface( fn=process_video, title=title, description=description, inputs=[ gr.Video(label="Input Video"), gr.Checkbox(label="Draw Box", value=False), gr.Checkbox(label="Background Enhancement", value=False), ], outputs=gr.Video(label="Processed Video"), examples=[ [os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_1.mp4"), True, False], [os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_2.mp4"), True, False], [os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_3.mp4"), True, False], [os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_4.mp4"), True, False], ], article=post_article ) demo.launch(share=True)