import os import torch import torch.utils.checkpoint from PIL import Image import numpy as np from omegaconf import OmegaConf from tqdm import tqdm import cv2 from diffusers import AutoencoderKLTemporalDecoder from diffusers.schedulers import EulerDiscreteScheduler from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor from src.utils.util import save_videos_grid, seed_everything from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters from src.pipelines.pipeline_sonic import SonicPipeline from src.models.audio_adapter.audio_proj import AudioProjModel from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel from src.utils.RIFE.RIFE_HDv3 import RIFEModel from src.dataset.face_align.align import AlignImage BASE_DIR = os.path.dirname(os.path.abspath(__file__)) def test( pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder, width, height, batch, ): """Generate a video tensor for the given batch.""" for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.unsqueeze(0).to(pipe.device).float() ref_img = batch['ref_img'] clip_img = batch['clip_images'] face_mask = batch['face_mask'] image_embeds = image_encoder(clip_img).image_embeds audio_feature = batch['audio_feature'] audio_len = batch['audio_len'] step = int(config.step) window = 3000 audio_prompts = [] last_audio_prompts = [] for i in range(0, audio_feature.shape[-1], window): audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i+window], output_hidden_states=True).hidden_states last_audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i+window]).last_hidden_state last_audio_prompt = last_audio_prompt.unsqueeze(-2) audio_prompt = torch.stack(audio_prompt, dim=2) audio_prompts.append(audio_prompt) last_audio_prompts.append(last_audio_prompt) audio_prompts = torch.cat(audio_prompts, dim=1) audio_prompts = audio_prompts[:, :audio_len*2] audio_prompts = torch.cat([ torch.zeros_like(audio_prompts[:, :4]), audio_prompts, torch.zeros_like(audio_prompts[:, :6]) ], 1) last_audio_prompts = torch.cat(last_audio_prompts, dim=1) last_audio_prompts = last_audio_prompts[:, :audio_len*2] last_audio_prompts = torch.cat([ torch.zeros_like(last_audio_prompts[:, :24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:, :26]) ], 1) ref_tensor_list = [] audio_tensor_list = [] uncond_audio_tensor_list = [] motion_buckets = [] for i in tqdm(range(audio_len//step), ncols=0): audio_clip = audio_prompts[:, i*2*step:i*2*step+10].unsqueeze(0) audio_clip_for_bucket = last_audio_prompts[:, i*2*step:i*2*step+50].unsqueeze(0) motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds) motion_bucket = motion_bucket * 16 + 16 motion_buckets.append(motion_bucket[0]) cond_audio_clip = audio_pe(audio_clip).squeeze(0) uncond_audio_clip = audio_pe(torch.zeros_like(audio_clip)).squeeze(0) ref_tensor_list.append(ref_img[0]) audio_tensor_list.append(cond_audio_clip[0]) uncond_audio_tensor_list.append(uncond_audio_clip[0]) video = pipe( ref_img, clip_img, face_mask, audio_tensor_list, uncond_audio_tensor_list, motion_buckets, height=height, width=width, num_frames=len(audio_tensor_list), decode_chunk_size=config.decode_chunk_size, motion_bucket_scale=config.motion_bucket_scale, fps=config.fps, noise_aug_strength=config.noise_aug_strength, min_guidance_scale1=config.min_appearance_guidance_scale, max_guidance_scale1=config.max_appearance_guidance_scale, min_guidance_scale2=config.audio_guidance_scale, max_guidance_scale2=config.audio_guidance_scale, overlap=config.overlap, shift_offset=config.shift_offset, frames_per_batch=config.n_sample_frames, num_inference_steps=config.num_inference_steps, i2i_noise_strength=config.i2i_noise_strength ).frames video = (video * 0.5 + 0.5).clamp(0, 1) video = torch.cat([video.to(pipe.device)], dim=0).cpu() return video class Sonic: """High-level interface for the Sonic portrait animation pipeline.""" config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml') config = OmegaConf.load(config_file) def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True): config = self.config config.use_interframe = enable_interpolate_frame device = f'cuda:{device_id}' if device_id > -1 else 'cpu' self.device = device config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path) vae = AutoencoderKLTemporalDecoder.from_pretrained( config.pretrained_model_name_or_path, subfolder='vae', variant='fp16') val_noise_scheduler = EulerDiscreteScheduler.from_pretrained( config.pretrained_model_name_or_path, subfolder='scheduler') image_encoder = CLIPVisionModelWithProjection.from_pretrained( config.pretrained_model_name_or_path, subfolder='image_encoder', variant='fp16') unet = UNetSpatioTemporalConditionModel.from_pretrained( config.pretrained_model_name_or_path, subfolder='unet', variant='fp16') add_ip_adapters(unet, [32], [config.ip_audio_scale]) audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=1024, context_tokens=32).to(device) audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, intermediate_dim=1024, output_dim=1, context_tokens=2).to(device) unet.load_state_dict( torch.load(os.path.join(BASE_DIR, config.unet_checkpoint_path), map_location='cpu'), strict=True) audio2token.load_state_dict( torch.load(os.path.join(BASE_DIR, config.audio2token_checkpoint_path), map_location='cpu'), strict=True) audio2bucket.load_state_dict( torch.load(os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path), map_location='cpu'), strict=True) dtype_map = {'fp16': torch.float16, 'fp32': torch.float32, 'bf16': torch.bfloat16} weight_dtype = dtype_map.get(config.weight_dtype) if weight_dtype is None: raise ValueError(f"Unsupported weight dtype: {config.weight_dtype}") whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval() whisper.requires_grad_(False) self.feature_extractor = AutoFeatureExtractor.from_pretrained( os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')) self.face_det = AlignImage(device, det_path=os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt')) if config.use_interframe: self.rife = RIFEModel(device=device) self.rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/')) image_encoder.to(weight_dtype) vae.to(weight_dtype) unet.to(weight_dtype) pipe = SonicPipeline( unet=unet, image_encoder=image_encoder, vae=vae, scheduler=val_noise_scheduler) self.pipe = pipe.to(device=device, dtype=weight_dtype) self.whisper = whisper self.audio2token = audio2token self.audio2bucket = audio2bucket self.image_encoder = image_encoder print('Sonic initialization complete.') def preprocess(self, image_path: str, expand_ratio: float = 1.0): face_image = cv2.imread(image_path) h, w = face_image.shape[:2] _, _, bboxes = self.face_det(face_image, maxface=True) face_num = len(bboxes) bbox_s = [] if face_num > 0: x1, y1, ww, hh = bboxes[0] x2, y2 = x1 + ww, y1 + hh bbox_s = process_bbox((x1, y1, x2, y2), expand_radio=expand_ratio, height=h, width=w) return {'face_num': face_num, 'crop_bbox': bbox_s} def crop_image(self, input_image_path: str, output_image_path: str, crop_bbox): face_image = cv2.imread(input_image_path) crop_img = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]] cv2.imwrite(output_image_path, crop_img) @torch.no_grad() def process(self, image_path, audio_path, output_path, min_resolution=512, inference_steps=25, dynamic_scale=1.0, keep_resolution=False, seed=None): config = self.config device = self.device pipe = self.pipe whisper = self.whisper audio2token = self.audio2token audio2bucket = self.audio2bucket image_encoder = self.image_encoder if seed is not None: config.seed = seed seed_everything(config.seed) config.num_inference_steps = inference_steps config.frame_num = config.fps * 60 config.motion_bucket_scale = dynamic_scale video_path = output_path.replace('.mp4', '_noaudio.mp4') audio_video_path = output_path imSrc_ = Image.open(image_path).convert('RGB') raw_w, raw_h = imSrc_.size test_data = image_audio_to_tensor( self.face_det, self.feature_extractor, image_path, audio_path, limit=config.frame_num, image_size=min_resolution, area=config.area) if test_data is None: return -1 height, width = test_data['ref_img'].shape[-2:] resolution = f"{width}x{height}" if not keep_resolution else f"{raw_w//2*2}x{raw_h//2*2}" video = test(pipe, config, wav_enc=whisper, audio_pe=audio2token, audio2bucket=audio2bucket, image_encoder=image_encoder, width=width, height=height, batch=test_data) if config.use_interframe: out = video.to(device) results = [] for idx in tqdm(range(out.shape[2]-1), ncols=0): I1 = out[:, :, idx] I2 = out[:, :, idx+1] mid = self.rife.inference(I1, I2).clamp(0,1).detach() results.extend([out[:, :, idx], mid]) results.append(out[:, :, -1]) video = torch.stack(results, 2).cpu() save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * (2 if config.use_interframe else 1)) os.system(f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} -vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'") return 0