File size: 5,348 Bytes
30a0a93
 
 
 
 
 
 
 
 
 
 
9d9257a
30a0a93
 
 
 
 
 
 
95b36fc
c3941ce
9d9257a
30a0a93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d9257a
 
 
 
 
30a0a93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# inference_engine.py
import os
import torch
import decord
import imageio
from PIL import Image
from models import MTVCrafterPipeline, Encoder, VectorQuantizer, Decoder, SMPL_VQVAE
from torchvision.transforms import ToPILImage, transforms, InterpolationMode, functional as F
import numpy as np
import pickle
import copy
from huggingface_hub import hf_hub_download
from draw_pose import get_pose_images
from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width

def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, dst_height=512, num_inference_steps=50, guidance_scale=3.0, seed=6666):
    num_frames = 49
    to_pil = ToPILImage()
    normalize = transforms.Normalize([0.5], [0.5])
    pretrained_model_path = "THUDM/CogVideoX-5b"
    transformer_path = "yanboding/MTVCrafter/MV-DiT/CogVideoX"
    tokenizer_path = "4DMoT/mp_rank_00_model_states.pt"
    
    with open(motion_data_path, 'rb') as f:
        data_list = pickle.load(f)
    if not isinstance(data_list, list):
        data_list = [data_list]
    
    pe_mean = np.load('data/mean.npy')
    pe_std = np.load('data/std.npy')

    pipe = MTVCrafterPipeline.from_pretrained(
        model_path=pretrained_model_path,
        transformer_model_path=transformer_path,
        torch_dtype=torch.bfloat16,
        scheduler_type='dpm',
    ).to(device)
    pipe.vae.enable_tiling()
    pipe.vae.enable_slicing()

    # load VQVAE
    
    vqvae_model_path = hf_hub_download(
        repo_id="yanboding/MTVCrafter",
        filename="4DMoT/mp_rank_00_model_states.pt"
    )
    state_dict = torch.load(tokenizer_path, map_location="cpu")
    motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1])
    motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False)
    motion_decoder = Decoder(in_channels=3072, mid_channels=[512, 128], out_channels=3, upsample_rate=2.0, frame_upsample_rate=[2.0, 2.0], joint_upsample_rate=[1.0, 1.0])
    vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device)
    vqvae.load_state_dict(state_dict['module'], strict=True)

    # 这里只跑第一个样本
    data = data_list[0]
    new_height, new_width = get_new_height_width(data, dst_height, dst_width)
    x1 = (new_width - dst_width) // 2
    y1 = (new_height - dst_height) // 2

    sample_indexes = get_sample_indexes(data['video_length'], num_frames, stride=1)
    input_images = sample_video(decord.VideoReader(data['video_path']), sample_indexes)
    input_images = torch.from_numpy(input_images).permute(0, 3, 1, 2).contiguous()
    input_images = F.resize(input_images, (new_height, new_width), InterpolationMode.BILINEAR)
    input_images = F.crop(input_images, y1, x1, dst_height, dst_width)

    if ref_image_path != '':
        ref_image = Image.open(ref_image_path).convert("RGB")
        ref_image = torch.from_numpy(np.array(ref_image)).permute(2, 0, 1).contiguous()
        ref_images = torch.stack([ref_image.clone() for _ in range(num_frames)])
        ref_images = F.resize(ref_images, (new_height, new_width), InterpolationMode.BILINEAR)
        ref_images = F.crop(ref_images, y1, x1, dst_height, dst_width)
    else:
        ref_images = copy.deepcopy(input_images)
        frame0 = input_images[0]
        ref_images[:, :, :, :] = frame0

    try:
        smpl_poses = np.array([pose[0][0].cpu().numpy() for pose in data['pose']['joints3d_nonparam']])
        poses = smpl_poses[sample_indexes]
    except:
        poses = data['pose'][sample_indexes]
    norm_poses = torch.tensor((poses - pe_mean) / pe_std)

    offset = [data['video_height'], data['video_width'], 0]
    pose_images_before = get_pose_images(copy.deepcopy(poses), offset)
    pose_images_before = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_before]
    input_smpl_joints = norm_poses.unsqueeze(0).to(device)
    motion_tokens, vq_loss = vqvae(input_smpl_joints, return_vq=True)
    output_motion, _ =  vqvae(input_smpl_joints)
    pose_images_after = get_pose_images(output_motion[0].cpu().detach() * pe_std + pe_mean, offset)
    pose_images_after = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_after]

    # normalize images
    input_images = input_images / 255.0
    ref_images = ref_images / 255.0
    input_images = normalize(input_images)
    ref_images = normalize(ref_images)

    # infer
    output_images = pipe(
        height=dst_height,
        width=dst_width,
        num_frames=num_frames,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        seed=seed,
        ref_images=ref_images,
        motion_embeds=motion_tokens,
        joint_mean=pe_mean,
        joint_std=pe_std,
    ).frames[0]

    # save result
    vis_images = []
    for k in range(len(output_images)):
        vis_image = [to_pil(((input_images[k] + 1) * 127.5).clamp(0, 255).to(torch.uint8)), pose_images_before[k], pose_images_after[k], output_images[k]]
        vis_image = concat_images_grid(vis_image, cols=len(vis_image), pad=2)
        vis_images.append(vis_image)

    output_path = "output.mp4"
    imageio.mimsave(output_path, vis_images, fps=15)

    return output_path