File size: 3,889 Bytes
30a0a93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import torch
import gradio as gr
import cv2
from PIL import Image
from inference_engine import run_inference
from motion_extractor import extract_pkl_from_video

device = "cuda" if torch.cuda.is_available() else "cpu"

def full_pipeline(video_file, ref_image=None, width=512, height=512, steps=50, scale=3.0, seed=6666):
    # 1. ๆๅ– motion pkl
    video_path = video_file.name
    motion_pkl_path = extract_pkl_from_video(video_path)
    gr.Info("โณ Extract motion finished and begin animation...", visible=True)

    # 2. ๅค„็†ๅ‚่€ƒๅ›พๅƒ๏ผˆๅฏ้€‰๏ผ‰
    if ref_image is not None:
        ref_path = "temp_ref.png"
        ref_image.save(ref_path)
    else:
        ref_path = ""

    # 3. ๆŽจ็†
    output_path = run_inference(
        device, 
        motion_pkl_path, 
        ref_path, 
        dst_width=width, 
        dst_height=height,
        num_inference_steps=steps,
        guidance_scale=scale,
        seed=seed,
        )

    return output_path


def run_pipeline_with_feedback(video_file, ref_image, width, height, steps, scale, seed):
    try:
        if video_file is None:
            raise gr.Error("Please upload a dancing video (.mp4/.mov/.avi).")
        # ๆทปๅŠ ่ฟ›ๅบฆๆ็คบ
        gr.Info("โณ Processing... Please wait several minutes.", visible=True)
        result = full_pipeline(video_file, ref_image, width, height, steps, scale, seed)
        gr.Info("โœ… Inference done, please enjoy it!", visible=True)
        return result
    except Exception as e:
        traceback.print_exc()
        gr.Warning("โš ๏ธ Inference failed: " + str(e))
        return None

# ๆž„ๅปบ UI
with gr.Blocks(title="MTVCrafter Inference Demo") as demo:
    gr.Markdown(
    """
    # ๐ŸŽจ๐Ÿ’ƒ MTVCrafter Inference Demo

    ๐Ÿ’ก **Tip:** Upload a dancing video in **MP4/MOV/AVI** format, and optionally a reference image (e.g., PNG or JPG).  
    This demo will extract human motion from the input video and animate the reference image accordingly.  
    If no reference image is provided, the **first frame** of the video will be used as the reference.

    ๐ŸŽž๏ธ **Note:** The generated output video will contain exactly **49 frames**.
    """
)

    with gr.Row():
        with gr.Column(scale=1):
            video_input = gr.File(label="๐Ÿ“น Input Video (Required)", file_types=[".mp4", ".mov", ".avi"])
            video_preview = gr.Video(label="๐Ÿ‘€ Preview of Uploaded Video", height=280)  # ๅ›บๅฎš้ซ˜ๅบฆ๏ผŒ้ฟๅ…ๅฏน้ฝ้”™ไฝ

            def show_video_preview(video_file):
                return video_file.name if video_file else None

            video_input.change(fn=show_video_preview, inputs=video_input, outputs=video_preview)

        with gr.Column(scale=1):
            ref_image = gr.Image(type="pil", label="๐Ÿ–ผ๏ธ Reference Image (Optional)", height=538)
    
    with gr.Accordion("โš™๏ธ Advanced Settings", open=False):
        with gr.Row():
            width = gr.Slider(384, 1024, value=512, step=16, label="Output Width")
            height = gr.Slider(384, 1024, value=512, step=16, label="Output Height")
        with gr.Row():
            steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
            scale = gr.Slider(0.0, 10.0, value=3.0, step=0.25, label="Guidance Scale")
            seed = gr.Number(value=6666, label="Random Seed")
    
    with gr.Row(scale=1):
        output_video = gr.Video(label="๐ŸŽฌ Generated Video", interactive=False)

    run_btn = gr.Button("๐Ÿš€ Run MTVCrafter", variant="primary")

    run_btn.click(
        fn=run_pipeline_with_feedback,
        inputs=[video_input, ref_image, width, height, steps, scale, seed],
        outputs=output_video,
    )

if __name__ == "__main__":
    os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/"
    os.environ["NO_PROXY"] = "localhost,127.0.0.1/8,::1"
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)