MTVCrafter / app.py
yanboding's picture
Upload 32 files
30a0a93 verified
import os
import torch
import gradio as gr
import cv2
from PIL import Image
from inference_engine import run_inference
from motion_extractor import extract_pkl_from_video
device = "cuda" if torch.cuda.is_available() else "cpu"
def full_pipeline(video_file, ref_image=None, width=512, height=512, steps=50, scale=3.0, seed=6666):
# 1. ๆๅ– motion pkl
video_path = video_file.name
motion_pkl_path = extract_pkl_from_video(video_path)
gr.Info("โณ Extract motion finished and begin animation...", visible=True)
# 2. ๅค„็†ๅ‚่€ƒๅ›พๅƒ๏ผˆๅฏ้€‰๏ผ‰
if ref_image is not None:
ref_path = "temp_ref.png"
ref_image.save(ref_path)
else:
ref_path = ""
# 3. ๆŽจ็†
output_path = run_inference(
device,
motion_pkl_path,
ref_path,
dst_width=width,
dst_height=height,
num_inference_steps=steps,
guidance_scale=scale,
seed=seed,
)
return output_path
def run_pipeline_with_feedback(video_file, ref_image, width, height, steps, scale, seed):
try:
if video_file is None:
raise gr.Error("Please upload a dancing video (.mp4/.mov/.avi).")
# ๆทปๅŠ ่ฟ›ๅบฆๆ็คบ
gr.Info("โณ Processing... Please wait several minutes.", visible=True)
result = full_pipeline(video_file, ref_image, width, height, steps, scale, seed)
gr.Info("โœ… Inference done, please enjoy it!", visible=True)
return result
except Exception as e:
traceback.print_exc()
gr.Warning("โš ๏ธ Inference failed: " + str(e))
return None
# ๆž„ๅปบ UI
with gr.Blocks(title="MTVCrafter Inference Demo") as demo:
gr.Markdown(
"""
# ๐ŸŽจ๐Ÿ’ƒ MTVCrafter Inference Demo
๐Ÿ’ก **Tip:** Upload a dancing video in **MP4/MOV/AVI** format, and optionally a reference image (e.g., PNG or JPG).
This demo will extract human motion from the input video and animate the reference image accordingly.
If no reference image is provided, the **first frame** of the video will be used as the reference.
๐ŸŽž๏ธ **Note:** The generated output video will contain exactly **49 frames**.
"""
)
with gr.Row():
with gr.Column(scale=1):
video_input = gr.File(label="๐Ÿ“น Input Video (Required)", file_types=[".mp4", ".mov", ".avi"])
video_preview = gr.Video(label="๐Ÿ‘€ Preview of Uploaded Video", height=280) # ๅ›บๅฎš้ซ˜ๅบฆ๏ผŒ้ฟๅ…ๅฏน้ฝ้”™ไฝ
def show_video_preview(video_file):
return video_file.name if video_file else None
video_input.change(fn=show_video_preview, inputs=video_input, outputs=video_preview)
with gr.Column(scale=1):
ref_image = gr.Image(type="pil", label="๐Ÿ–ผ๏ธ Reference Image (Optional)", height=538)
with gr.Accordion("โš™๏ธ Advanced Settings", open=False):
with gr.Row():
width = gr.Slider(384, 1024, value=512, step=16, label="Output Width")
height = gr.Slider(384, 1024, value=512, step=16, label="Output Height")
with gr.Row():
steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
scale = gr.Slider(0.0, 10.0, value=3.0, step=0.25, label="Guidance Scale")
seed = gr.Number(value=6666, label="Random Seed")
with gr.Row(scale=1):
output_video = gr.Video(label="๐ŸŽฌ Generated Video", interactive=False)
run_btn = gr.Button("๐Ÿš€ Run MTVCrafter", variant="primary")
run_btn.click(
fn=run_pipeline_with_feedback,
inputs=[video_input, ref_image, width, height, steps, scale, seed],
outputs=output_video,
)
if __name__ == "__main__":
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/"
os.environ["NO_PROXY"] = "localhost,127.0.0.1/8,::1"
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)