import os import torch import gradio as gr import cv2 from PIL import Image from inference_engine import run_inference from motion_extractor import extract_pkl_from_video device = "cuda" if torch.cuda.is_available() else "cpu" def full_pipeline(video_file, ref_image=None, width=512, height=512, steps=50, scale=3.0, seed=6666): # 1. 提取 motion pkl video_path = video_file.name motion_pkl_path = extract_pkl_from_video(video_path) gr.Info("⏳ Extract motion finished and begin animation...", visible=True) # 2. 处理参考图像(可选) if ref_image is not None: ref_path = "temp_ref.png" ref_image.save(ref_path) else: ref_path = "" # 3. 推理 output_path = run_inference( device, motion_pkl_path, ref_path, dst_width=width, dst_height=height, num_inference_steps=steps, guidance_scale=scale, seed=seed, ) return output_path def run_pipeline_with_feedback(video_file, ref_image, width, height, steps, scale, seed): try: if video_file is None: raise gr.Error("Please upload a dancing video (.mp4/.mov/.avi).") # 添加进度提示 gr.Info("⏳ Processing... Please wait several minutes.", visible=True) result = full_pipeline(video_file, ref_image, width, height, steps, scale, seed) gr.Info("✅ Inference done, please enjoy it!", visible=True) return result except Exception as e: traceback.print_exc() gr.Warning("⚠️ Inference failed: " + str(e)) return None # 构建 UI with gr.Blocks(title="MTVCrafter Inference Demo") as demo: gr.Markdown( """ # 🎨💃 MTVCrafter Inference Demo 💡 **Tip:** Upload a dancing video in **MP4/MOV/AVI** format, and optionally a reference image (e.g., PNG or JPG). This demo will extract human motion from the input video and animate the reference image accordingly. If no reference image is provided, the **first frame** of the video will be used as the reference. 🎞️ **Note:** The generated output video will contain exactly **49 frames**. """ ) with gr.Row(): with gr.Column(scale=1): video_input = gr.File(label="📹 Input Video (Required)", file_types=[".mp4", ".mov", ".avi"]) video_preview = gr.Video(label="👀 Preview of Uploaded Video", height=280) # 固定高度,避免对齐错位 def show_video_preview(video_file): return video_file.name if video_file else None video_input.change(fn=show_video_preview, inputs=video_input, outputs=video_preview) with gr.Column(scale=1): ref_image = gr.Image(type="pil", label="🖼️ Reference Image (Optional)", height=538) with gr.Accordion("⚙️ Advanced Settings", open=False): with gr.Row(): width = gr.Slider(384, 1024, value=512, step=16, label="Output Width") height = gr.Slider(384, 1024, value=512, step=16, label="Output Height") with gr.Row(): steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps") scale = gr.Slider(0.0, 10.0, value=3.0, step=0.25, label="Guidance Scale") seed = gr.Number(value=6666, label="Random Seed") with gr.Row(scale=1): output_video = gr.Video(label="🎬 Generated Video", interactive=False) run_btn = gr.Button("🚀 Run MTVCrafter", variant="primary") run_btn.click( fn=run_pipeline_with_feedback, inputs=[video_input, ref_image, width, height, steps, scale, seed], outputs=output_video, ) if __name__ == "__main__": os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/" os.environ["NO_PROXY"] = "localhost,127.0.0.1/8,::1" demo.launch(server_name="0.0.0.0", server_port=7860, share=True)