# inference_engine.py import os import torch import decord import imageio from PIL import Image from models import MTVCrafterPipeline, Encoder, VectorQuantizer, Decoder, SMPL_VQVAE from torchvision.transforms import ToPILImage, transforms, InterpolationMode, functional as F import numpy as np import pickle import copy from huggingface_hub import hf_hub_download from draw_pose import get_pose_images from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, dst_height=512, num_inference_steps=50, guidance_scale=3.0, seed=6666): num_frames = 49 to_pil = ToPILImage() normalize = transforms.Normalize([0.5], [0.5]) pretrained_model_path = "THUDM/CogVideoX-5b" transformer_path = "yanboding/MTVCrafter/MV-DiT/CogVideoX" tokenizer_path = "4DMoT/mp_rank_00_model_states.pt" with open(motion_data_path, 'rb') as f: data_list = pickle.load(f) if not isinstance(data_list, list): data_list = [data_list] pe_mean = np.load('data/mean.npy') pe_std = np.load('data/std.npy') pipe = MTVCrafterPipeline.from_pretrained( model_path=pretrained_model_path, transformer_model_path=transformer_path, torch_dtype=torch.bfloat16, scheduler_type='dpm', ).to(device) pipe.vae.enable_tiling() pipe.vae.enable_slicing() # load VQVAE vqvae_model_path = hf_hub_download( repo_id="yanboding/MTVCrafter", filename="4DMoT/mp_rank_00_model_states.pt" ) state_dict = torch.load(tokenizer_path, map_location="cpu") motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1]) motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False) motion_decoder = Decoder(in_channels=3072, mid_channels=[512, 128], out_channels=3, upsample_rate=2.0, frame_upsample_rate=[2.0, 2.0], joint_upsample_rate=[1.0, 1.0]) vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device) vqvae.load_state_dict(state_dict['module'], strict=True) # 这里只跑第一个样本 data = data_list[0] new_height, new_width = get_new_height_width(data, dst_height, dst_width) x1 = (new_width - dst_width) // 2 y1 = (new_height - dst_height) // 2 sample_indexes = get_sample_indexes(data['video_length'], num_frames, stride=1) input_images = sample_video(decord.VideoReader(data['video_path']), sample_indexes) input_images = torch.from_numpy(input_images).permute(0, 3, 1, 2).contiguous() input_images = F.resize(input_images, (new_height, new_width), InterpolationMode.BILINEAR) input_images = F.crop(input_images, y1, x1, dst_height, dst_width) if ref_image_path != '': ref_image = Image.open(ref_image_path).convert("RGB") ref_image = torch.from_numpy(np.array(ref_image)).permute(2, 0, 1).contiguous() ref_images = torch.stack([ref_image.clone() for _ in range(num_frames)]) ref_images = F.resize(ref_images, (new_height, new_width), InterpolationMode.BILINEAR) ref_images = F.crop(ref_images, y1, x1, dst_height, dst_width) else: ref_images = copy.deepcopy(input_images) frame0 = input_images[0] ref_images[:, :, :, :] = frame0 try: smpl_poses = np.array([pose[0][0].cpu().numpy() for pose in data['pose']['joints3d_nonparam']]) poses = smpl_poses[sample_indexes] except: poses = data['pose'][sample_indexes] norm_poses = torch.tensor((poses - pe_mean) / pe_std) offset = [data['video_height'], data['video_width'], 0] pose_images_before = get_pose_images(copy.deepcopy(poses), offset) pose_images_before = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_before] input_smpl_joints = norm_poses.unsqueeze(0).to(device) motion_tokens, vq_loss = vqvae(input_smpl_joints, return_vq=True) output_motion, _ = vqvae(input_smpl_joints) pose_images_after = get_pose_images(output_motion[0].cpu().detach() * pe_std + pe_mean, offset) pose_images_after = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_after] # normalize images input_images = input_images / 255.0 ref_images = ref_images / 255.0 input_images = normalize(input_images) ref_images = normalize(ref_images) # infer output_images = pipe( height=dst_height, width=dst_width, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, seed=seed, ref_images=ref_images, motion_embeds=motion_tokens, joint_mean=pe_mean, joint_std=pe_std, ).frames[0] # save result vis_images = [] for k in range(len(output_images)): vis_image = [to_pil(((input_images[k] + 1) * 127.5).clamp(0, 255).to(torch.uint8)), pose_images_before[k], pose_images_after[k], output_images[k]] vis_image = concat_images_grid(vis_image, cols=len(vis_image), pad=2) vis_images.append(vis_image) output_path = "output.mp4" imageio.mimsave(output_path, vis_images, fps=15) return output_path