Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
import cv2 | |
from PIL import Image | |
from inference_engine import run_inference | |
from motion_extractor import extract_pkl_from_video | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def full_pipeline(video_file, ref_image=None, width=512, height=512, steps=50, scale=3.0, seed=6666): | |
# 1. ๆๅ motion pkl | |
video_path = video_file.name | |
motion_pkl_path = extract_pkl_from_video(video_path) | |
gr.Info("โณ Extract motion finished and begin animation...", visible=True) | |
# 2. ๅค็ๅ่ๅพๅ๏ผๅฏ้๏ผ | |
if ref_image is not None: | |
ref_path = "temp_ref.png" | |
ref_image.save(ref_path) | |
else: | |
ref_path = "" | |
# 3. ๆจ็ | |
output_path = run_inference( | |
device, | |
motion_pkl_path, | |
ref_path, | |
dst_width=width, | |
dst_height=height, | |
num_inference_steps=steps, | |
guidance_scale=scale, | |
seed=seed, | |
) | |
return output_path | |
def run_pipeline_with_feedback(video_file, ref_image, width, height, steps, scale, seed): | |
try: | |
if video_file is None: | |
raise gr.Error("Please upload a dancing video (.mp4/.mov/.avi).") | |
# ๆทปๅ ่ฟๅบฆๆ็คบ | |
gr.Info("โณ Processing... Please wait several minutes.", visible=True) | |
result = full_pipeline(video_file, ref_image, width, height, steps, scale, seed) | |
gr.Info("โ Inference done, please enjoy it!", visible=True) | |
return result | |
except Exception as e: | |
traceback.print_exc() | |
gr.Warning("โ ๏ธ Inference failed: " + str(e)) | |
return None | |
# ๆๅปบ UI | |
with gr.Blocks(title="MTVCrafter Inference Demo") as demo: | |
gr.Markdown( | |
""" | |
# ๐จ๐ MTVCrafter Inference Demo | |
๐ก **Tip:** Upload a dancing video in **MP4/MOV/AVI** format, and optionally a reference image (e.g., PNG or JPG). | |
This demo will extract human motion from the input video and animate the reference image accordingly. | |
If no reference image is provided, the **first frame** of the video will be used as the reference. | |
๐๏ธ **Note:** The generated output video will contain exactly **49 frames**. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
video_input = gr.File(label="๐น Input Video (Required)", file_types=[".mp4", ".mov", ".avi"]) | |
video_preview = gr.Video(label="๐ Preview of Uploaded Video", height=280) # ๅบๅฎ้ซๅบฆ๏ผ้ฟๅ ๅฏน้ฝ้ไฝ | |
def show_video_preview(video_file): | |
return video_file.name if video_file else None | |
video_input.change(fn=show_video_preview, inputs=video_input, outputs=video_preview) | |
with gr.Column(scale=1): | |
ref_image = gr.Image(type="pil", label="๐ผ๏ธ Reference Image (Optional)", height=538) | |
with gr.Accordion("โ๏ธ Advanced Settings", open=False): | |
with gr.Row(): | |
width = gr.Slider(384, 1024, value=512, step=16, label="Output Width") | |
height = gr.Slider(384, 1024, value=512, step=16, label="Output Height") | |
with gr.Row(): | |
steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps") | |
scale = gr.Slider(0.0, 10.0, value=3.0, step=0.25, label="Guidance Scale") | |
seed = gr.Number(value=6666, label="Random Seed") | |
with gr.Row(scale=1): | |
output_video = gr.Video(label="๐ฌ Generated Video", interactive=False) | |
run_btn = gr.Button("๐ Run MTVCrafter", variant="primary") | |
run_btn.click( | |
fn=run_pipeline_with_feedback, | |
inputs=[video_input, ref_image, width, height, steps, scale, seed], | |
outputs=output_video, | |
) | |
if __name__ == "__main__": | |
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/" | |
os.environ["NO_PROXY"] = "localhost,127.0.0.1/8,::1" | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |