Spaces:
Runtime error
Runtime error
File size: 3,889 Bytes
30a0a93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
import torch
import gradio as gr
import cv2
from PIL import Image
from inference_engine import run_inference
from motion_extractor import extract_pkl_from_video
device = "cuda" if torch.cuda.is_available() else "cpu"
def full_pipeline(video_file, ref_image=None, width=512, height=512, steps=50, scale=3.0, seed=6666):
# 1. ๆๅ motion pkl
video_path = video_file.name
motion_pkl_path = extract_pkl_from_video(video_path)
gr.Info("โณ Extract motion finished and begin animation...", visible=True)
# 2. ๅค็ๅ่ๅพๅ๏ผๅฏ้๏ผ
if ref_image is not None:
ref_path = "temp_ref.png"
ref_image.save(ref_path)
else:
ref_path = ""
# 3. ๆจ็
output_path = run_inference(
device,
motion_pkl_path,
ref_path,
dst_width=width,
dst_height=height,
num_inference_steps=steps,
guidance_scale=scale,
seed=seed,
)
return output_path
def run_pipeline_with_feedback(video_file, ref_image, width, height, steps, scale, seed):
try:
if video_file is None:
raise gr.Error("Please upload a dancing video (.mp4/.mov/.avi).")
# ๆทปๅ ่ฟๅบฆๆ็คบ
gr.Info("โณ Processing... Please wait several minutes.", visible=True)
result = full_pipeline(video_file, ref_image, width, height, steps, scale, seed)
gr.Info("โ
Inference done, please enjoy it!", visible=True)
return result
except Exception as e:
traceback.print_exc()
gr.Warning("โ ๏ธ Inference failed: " + str(e))
return None
# ๆๅปบ UI
with gr.Blocks(title="MTVCrafter Inference Demo") as demo:
gr.Markdown(
"""
# ๐จ๐ MTVCrafter Inference Demo
๐ก **Tip:** Upload a dancing video in **MP4/MOV/AVI** format, and optionally a reference image (e.g., PNG or JPG).
This demo will extract human motion from the input video and animate the reference image accordingly.
If no reference image is provided, the **first frame** of the video will be used as the reference.
๐๏ธ **Note:** The generated output video will contain exactly **49 frames**.
"""
)
with gr.Row():
with gr.Column(scale=1):
video_input = gr.File(label="๐น Input Video (Required)", file_types=[".mp4", ".mov", ".avi"])
video_preview = gr.Video(label="๐ Preview of Uploaded Video", height=280) # ๅบๅฎ้ซๅบฆ๏ผ้ฟๅ
ๅฏน้ฝ้ไฝ
def show_video_preview(video_file):
return video_file.name if video_file else None
video_input.change(fn=show_video_preview, inputs=video_input, outputs=video_preview)
with gr.Column(scale=1):
ref_image = gr.Image(type="pil", label="๐ผ๏ธ Reference Image (Optional)", height=538)
with gr.Accordion("โ๏ธ Advanced Settings", open=False):
with gr.Row():
width = gr.Slider(384, 1024, value=512, step=16, label="Output Width")
height = gr.Slider(384, 1024, value=512, step=16, label="Output Height")
with gr.Row():
steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
scale = gr.Slider(0.0, 10.0, value=3.0, step=0.25, label="Guidance Scale")
seed = gr.Number(value=6666, label="Random Seed")
with gr.Row(scale=1):
output_video = gr.Video(label="๐ฌ Generated Video", interactive=False)
run_btn = gr.Button("๐ Run MTVCrafter", variant="primary")
run_btn.click(
fn=run_pipeline_with_feedback,
inputs=[video_input, ref_image, width, height, steps, scale, seed],
outputs=output_video,
)
if __name__ == "__main__":
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/"
os.environ["NO_PROXY"] = "localhost,127.0.0.1/8,::1"
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |