Spaces:
Runtime error
Runtime error
# inference_engine.py | |
import os | |
import torch | |
import decord | |
import imageio | |
from PIL import Image | |
from models import MTVCrafterPipeline, Encoder, VectorQuantizer, Decoder, SMPL_VQVAE | |
from torchvision.transforms import ToPILImage, transforms, InterpolationMode, functional as F | |
import numpy as np | |
import pickle | |
import copy | |
from huggingface_hub import hf_hub_download | |
from draw_pose import get_pose_images | |
from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width | |
def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, dst_height=512, num_inference_steps=50, guidance_scale=3.0, seed=6666): | |
num_frames = 49 | |
to_pil = ToPILImage() | |
normalize = transforms.Normalize([0.5], [0.5]) | |
pretrained_model_path = "THUDM/CogVideoX-5b" | |
transformer_path = "yanboding/MTVCrafter/MV-DiT/CogVideoX" | |
tokenizer_path = "4DMoT/mp_rank_00_model_states.pt" | |
with open(motion_data_path, 'rb') as f: | |
data_list = pickle.load(f) | |
if not isinstance(data_list, list): | |
data_list = [data_list] | |
pe_mean = np.load('data/mean.npy') | |
pe_std = np.load('data/std.npy') | |
pipe = MTVCrafterPipeline.from_pretrained( | |
model_path=pretrained_model_path, | |
transformer_model_path=transformer_path, | |
torch_dtype=torch.bfloat16, | |
scheduler_type='dpm', | |
).to(device) | |
pipe.vae.enable_tiling() | |
pipe.vae.enable_slicing() | |
# load VQVAE | |
vqvae_model_path = hf_hub_download( | |
repo_id="yanboding/MTVCrafter", | |
filename="4DMoT/mp_rank_00_model_states.pt" | |
) | |
state_dict = torch.load(tokenizer_path, map_location="cpu") | |
motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1]) | |
motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False) | |
motion_decoder = Decoder(in_channels=3072, mid_channels=[512, 128], out_channels=3, upsample_rate=2.0, frame_upsample_rate=[2.0, 2.0], joint_upsample_rate=[1.0, 1.0]) | |
vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device) | |
vqvae.load_state_dict(state_dict['module'], strict=True) | |
# 这里只跑第一个样本 | |
data = data_list[0] | |
new_height, new_width = get_new_height_width(data, dst_height, dst_width) | |
x1 = (new_width - dst_width) // 2 | |
y1 = (new_height - dst_height) // 2 | |
sample_indexes = get_sample_indexes(data['video_length'], num_frames, stride=1) | |
input_images = sample_video(decord.VideoReader(data['video_path']), sample_indexes) | |
input_images = torch.from_numpy(input_images).permute(0, 3, 1, 2).contiguous() | |
input_images = F.resize(input_images, (new_height, new_width), InterpolationMode.BILINEAR) | |
input_images = F.crop(input_images, y1, x1, dst_height, dst_width) | |
if ref_image_path != '': | |
ref_image = Image.open(ref_image_path).convert("RGB") | |
ref_image = torch.from_numpy(np.array(ref_image)).permute(2, 0, 1).contiguous() | |
ref_images = torch.stack([ref_image.clone() for _ in range(num_frames)]) | |
ref_images = F.resize(ref_images, (new_height, new_width), InterpolationMode.BILINEAR) | |
ref_images = F.crop(ref_images, y1, x1, dst_height, dst_width) | |
else: | |
ref_images = copy.deepcopy(input_images) | |
frame0 = input_images[0] | |
ref_images[:, :, :, :] = frame0 | |
try: | |
smpl_poses = np.array([pose[0][0].cpu().numpy() for pose in data['pose']['joints3d_nonparam']]) | |
poses = smpl_poses[sample_indexes] | |
except: | |
poses = data['pose'][sample_indexes] | |
norm_poses = torch.tensor((poses - pe_mean) / pe_std) | |
offset = [data['video_height'], data['video_width'], 0] | |
pose_images_before = get_pose_images(copy.deepcopy(poses), offset) | |
pose_images_before = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_before] | |
input_smpl_joints = norm_poses.unsqueeze(0).to(device) | |
motion_tokens, vq_loss = vqvae(input_smpl_joints, return_vq=True) | |
output_motion, _ = vqvae(input_smpl_joints) | |
pose_images_after = get_pose_images(output_motion[0].cpu().detach() * pe_std + pe_mean, offset) | |
pose_images_after = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_after] | |
# normalize images | |
input_images = input_images / 255.0 | |
ref_images = ref_images / 255.0 | |
input_images = normalize(input_images) | |
ref_images = normalize(ref_images) | |
# infer | |
output_images = pipe( | |
height=dst_height, | |
width=dst_width, | |
num_frames=num_frames, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
ref_images=ref_images, | |
motion_embeds=motion_tokens, | |
joint_mean=pe_mean, | |
joint_std=pe_std, | |
).frames[0] | |
# save result | |
vis_images = [] | |
for k in range(len(output_images)): | |
vis_image = [to_pil(((input_images[k] + 1) * 127.5).clamp(0, 255).to(torch.uint8)), pose_images_before[k], pose_images_after[k], output_images[k]] | |
vis_image = concat_images_grid(vis_image, cols=len(vis_image), pad=2) | |
vis_images.append(vis_image) | |
output_path = "output.mp4" | |
imageio.mimsave(output_path, vis_images, fps=15) | |
return output_path | |