DiffuEraser-demo / gradio_app.py
fffiloni's picture
Update gradio_app.py
e0180e9 verified
import spaces
import torch
import os
import time
import datetime
from moviepy.editor import VideoFileClip
import gradio as gr
# Download Weights
from huggingface_hub import snapshot_download
# List of subdirectories to create inside "weights"
subfolders = [
"diffuEraser",
"stable-diffusion-v1-5",
"PCM_Weights",
"propainter",
"sd-vae-ft-mse"
]
# Create directories
for subfolder in subfolders:
os.makedirs(os.path.join("weights", subfolder), exist_ok=True)
snapshot_download(repo_id="lixiaowen/diffuEraser", local_dir="./weights/diffuEraser")
snapshot_download(repo_id="stable-diffusion-v1-5/stable-diffusion-v1-5", local_dir="./weights/stable-diffusion-v1-5")
snapshot_download(repo_id="wangfuyun/PCM_Weights", local_dir="./weights/PCM_Weights")
snapshot_download(repo_id="camenduru/ProPainter", local_dir="./weights/propainter")
snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir="./weights/sd-vae-ft-mse")
# Import model classes
from diffueraser.diffueraser import DiffuEraser
from propainter.inference import Propainter, get_device
base_model_path = "weights/stable-diffusion-v1-5"
vae_path = "weights/sd-vae-ft-mse"
diffueraser_path = "weights/diffuEraser"
propainter_model_dir = "weights/propainter"
# Model setup
device = get_device()
ckpt = "2-Step"
video_inpainting_sd = DiffuEraser(device, base_model_path, vae_path, diffueraser_path, ckpt=ckpt)
propainter = Propainter(propainter_model_dir, device=device)
# Helper function to trim videos
def trim_video(input_path, output_path, max_duration=5):
clip = VideoFileClip(input_path)
trimmed_clip = clip.subclip(0, min(max_duration, clip.duration))
trimmed_clip.write_videofile(output_path, codec="libx264", audio_codec="aac")
clip.close()
trimmed_clip.close()
@spaces.GPU(duration=100)
def infer(input_video, input_mask):
# Setup paths and parameters
save_path = "results"
mask_dilation_iter = 8
max_img_size = 960
ref_stride = 10
neighbor_length = 10
subvideo_length = 50
if not os.path.exists(save_path):
os.makedirs(save_path)
# Timestamp for unique filenames
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
trimmed_video_path = os.path.join(save_path, f"trimmed_video_{timestamp}.mp4")
trimmed_mask_path = os.path.join(save_path, f"trimmed_mask_{timestamp}.mp4")
priori_path = os.path.join(save_path, f"priori_{timestamp}.mp4")
output_path = os.path.join(save_path, f"diffueraser_result_{timestamp}.mp4")
# Trim input videos
trim_video(input_video, trimmed_video_path)
trim_video(input_mask, trimmed_mask_path)
# Dynamically compute video_length (in frames) assuming 30 fps
clip = VideoFileClip(trimmed_video_path)
video_duration = clip.duration
clip.close()
video_length = int(video_duration * 30)
# Run models
start_time = time.time()
# ProPainter (priori)
propainter.forward(trimmed_video_path, trimmed_mask_path, priori_path,
video_length=video_length, ref_stride=ref_stride,
neighbor_length=neighbor_length, subvideo_length=subvideo_length,
mask_dilation=mask_dilation_iter)
# DiffuEraser
guidance_scale = None
video_inpainting_sd.forward(trimmed_video_path, trimmed_mask_path, priori_path, output_path,
max_img_size=max_img_size, video_length=video_length,
mask_dilation_iter=mask_dilation_iter,
guidance_scale=guidance_scale)
end_time = time.time()
print(f"DiffuEraser inference time: {end_time - start_time:.2f} seconds")
torch.cuda.empty_cache()
return output_path
# Gradio interface
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# DiffuEraser: A Diffusion Model for Video Inpainting")
gr.Markdown("DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model ProPainter in both content completeness and temporal consistency while maintaining acceptable efficiency.")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/lixiaowen-xw/DiffuEraser">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page">
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
<a href="https://huggingface.co/spaces/fffiloni/DiffuEraser-demo?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
</div>
""")
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Input Video (MP4 ONLY)")
input_mask = gr.Video(label="Input Mask Video (MP4 ONLY)")
submit_btn = gr.Button("Submit")
with gr.Column():
video_result = gr.Video(label="Result")
gr.Examples(
examples=[
["./examples/example1/video.mp4", "./examples/example1/mask.mp4"],
["./examples/example2/video.mp4", "./examples/example2/mask.mp4"],
["./examples/example3/video.mp4", "./examples/example3/mask.mp4"],
],
inputs=[input_video, input_mask]
)
submit_btn.click(fn=infer, inputs=[input_video, input_mask], outputs=[video_result])
demo.queue().launch(show_api=True, show_error=True, ssr_mode=False)