fffiloni commited on
Commit
ea799b5
Β·
verified Β·
1 Parent(s): 7b0a8f1

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +69 -77
gradio_app.py CHANGED
@@ -1,16 +1,15 @@
1
  import spaces
2
  import torch
3
- import os
4
  import time
5
- import argparse
6
- from diffueraser.diffueraser import DiffuEraser
7
- from propainter.inference import Propainter, get_device
8
  import gradio as gr
9
 
10
  # Download Weights
11
  from huggingface_hub import snapshot_download
12
 
13
- # List of subdirectories to create inside "checkpoints"
14
  subfolders = [
15
  "diffuEraser",
16
  "stable-diffusion-v1-5",
@@ -18,46 +17,35 @@ subfolders = [
18
  "propainter",
19
  "sd-vae-ft-mse"
20
  ]
21
- # Create each subdirectory
22
  for subfolder in subfolders:
23
- os.makedirs(os.path.join("weigths", subfolder), exist_ok=True)
24
-
25
- snapshot_download(
26
- repo_id = "lixiaowen/diffuEraser",
27
- local_dir = "./weights/diffuEraser"
28
- )
29
-
30
- snapshot_download(
31
- repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5",
32
- local_dir = "./weights/stable-diffusion-v1-5"
33
- )
34
 
35
- snapshot_download(
36
- repo_id = "wangfuyun/PCM_Weights",
37
- local_dir = "./weights/PCM_Weights"
38
- )
 
39
 
40
- snapshot_download(
41
- repo_id = "camenduru/ProPainter",
42
- local_dir = "./weights/propainter"
43
- )
44
-
45
- snapshot_download(
46
- repo_id = "stabilityai/sd-vae-ft-mse",
47
- local_dir = "./weights/sd-vae-ft-mse"
48
- )
49
 
50
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
 
 
 
 
 
 
51
 
52
  @spaces.GPU(duration=120)
53
  def infer(input_video, input_mask):
54
-
55
- video_length = 10 # The maximum length of output video
56
- mask_dilation_iter = 8 # Adjust it to change the degree of mask expansion
57
- max_img_size = 960 # The maximum length of output width and height
58
- save_path = "results" # Path to the output
59
-
60
- ref_stride = 10
61
  neighbor_length = 10
62
  subvideo_length = 50
63
 
@@ -65,50 +53,66 @@ def infer(input_video, input_mask):
65
  vae_path = "weights/sd-vae-ft-mse"
66
  diffueraser_path = "weights/diffuEraser"
67
  propainter_model_dir = "weights/propainter"
68
-
69
  if not os.path.exists(save_path):
70
  os.makedirs(save_path)
71
- priori_path = os.path.join(save_path, "priori.mp4")
72
- output_path = os.path.join(save_path, "diffueraser_result.mp4")
73
-
74
- ## model initialization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  device = get_device()
76
- # PCM params
77
  ckpt = "2-Step"
78
  video_inpainting_sd = DiffuEraser(device, base_model_path, vae_path, diffueraser_path, ckpt=ckpt)
79
  propainter = Propainter(propainter_model_dir, device=device)
80
-
81
- start_time = time.time()
82
 
83
- ## priori
84
- propainter.forward(input_video, input_mask, priori_path, video_length=video_length,
85
- ref_stride=ref_stride, neighbor_length=neighbor_length, subvideo_length = subvideo_length,
86
- mask_dilation = mask_dilation_iter)
87
 
88
- ## diffueraser
89
- guidance_scale = None # The default value is 0.
90
- video_inpainting_sd.forward(input_video, input_mask, priori_path, output_path,
91
- max_img_size = max_img_size, video_length=video_length, mask_dilation_iter=mask_dilation_iter,
 
 
 
 
 
 
 
92
  guidance_scale=guidance_scale)
93
-
94
- end_time = time.time()
95
- inference_time = end_time - start_time
96
- print(f"DiffuEraser inference time: {inference_time:.4f} s")
97
 
98
- torch.cuda.empty_cache()
 
99
 
 
100
  return output_path
101
 
 
102
  with gr.Blocks() as demo:
103
-
104
  with gr.Column():
105
  gr.Markdown("# DiffuEraser: A Diffusion Model for Video Inpainting")
106
- gr.Markdown("DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model Propainter in both content completeness and temporal consistency while maintaining acceptable efficiency.")
 
107
  gr.HTML("""
108
  <div style="display:flex;column-gap:4px;">
109
  <a href="https://github.com/lixiaowen-xw/DiffuEraser">
110
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
111
- </a>
112
  <a href="https://lixiaowen-xw.github.io/DiffuEraser-page">
113
  <img src='https://img.shields.io/badge/Project-Page-green'>
114
  </a>
@@ -122,34 +126,22 @@ with gr.Blocks() as demo:
122
  """)
123
 
124
  with gr.Row():
125
-
126
  with gr.Column():
127
-
128
  input_video = gr.Video(label="Input Video (MP4 ONLY)")
129
  input_mask = gr.Video(label="Input Mask Video (MP4 ONLY)")
130
  submit_btn = gr.Button("Submit")
131
 
132
  with gr.Column():
133
-
134
  video_result = gr.Video(label="Result")
135
  gr.Examples(
136
- examples = [
137
  ["./examples/example1/video.mp4", "./examples/example1/mask.mp4"],
138
  ["./examples/example2/video.mp4", "./examples/example2/mask.mp4"],
139
  ["./examples/example3/video.mp4", "./examples/example3/mask.mp4"],
140
  ],
141
- inputs = [input_video, input_mask]
142
  )
143
 
144
-
145
- submit_btn.click(
146
- fn = infer,
147
- inputs = [input_video, input_mask],
148
- outputs = [video_result]
149
- )
150
 
151
  demo.queue().launch(show_api=False, show_error=True)
152
-
153
-
154
-
155
-
 
1
  import spaces
2
  import torch
3
+ import os
4
  import time
5
+ import datetime
6
+ from moviepy.editor import VideoFileClip
 
7
  import gradio as gr
8
 
9
  # Download Weights
10
  from huggingface_hub import snapshot_download
11
 
12
+ # List of subdirectories to create inside "weights"
13
  subfolders = [
14
  "diffuEraser",
15
  "stable-diffusion-v1-5",
 
17
  "propainter",
18
  "sd-vae-ft-mse"
19
  ]
20
+ # Create directories
21
  for subfolder in subfolders:
22
+ os.makedirs(os.path.join("weights", subfolder), exist_ok=True)
 
 
 
 
 
 
 
 
 
 
23
 
24
+ snapshot_download(repo_id="lixiaowen/diffuEraser", local_dir="./weights/diffuEraser")
25
+ snapshot_download(repo_id="stable-diffusion-v1-5/stable-diffusion-v1-5", local_dir="./weights/stable-diffusion-v1-5")
26
+ snapshot_download(repo_id="wangfuyun/PCM_Weights", local_dir="./weights/PCM_Weights")
27
+ snapshot_download(repo_id="camenduru/ProPainter", local_dir="./weights/propainter")
28
+ snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir="./weights/sd-vae-ft-mse")
29
 
30
+ # Import model classes
31
+ from diffueraser.diffueraser import DiffuEraser
32
+ from propainter.inference import Propainter, get_device
 
 
 
 
 
 
33
 
34
+ # Helper function to trim videos
35
+ def trim_video(input_path, output_path, max_duration=5):
36
+ clip = VideoFileClip(input_path)
37
+ trimmed_clip = clip.subclip(0, min(max_duration, clip.duration))
38
+ trimmed_clip.write_videofile(output_path, codec="libx264", audio_codec="aac")
39
+ clip.close()
40
+ trimmed_clip.close()
41
 
42
  @spaces.GPU(duration=120)
43
  def infer(input_video, input_mask):
44
+ # Setup paths and parameters
45
+ save_path = "results"
46
+ mask_dilation_iter = 8
47
+ max_img_size = 960
48
+ ref_stride = 10
 
 
49
  neighbor_length = 10
50
  subvideo_length = 50
51
 
 
53
  vae_path = "weights/sd-vae-ft-mse"
54
  diffueraser_path = "weights/diffuEraser"
55
  propainter_model_dir = "weights/propainter"
56
+
57
  if not os.path.exists(save_path):
58
  os.makedirs(save_path)
59
+
60
+ # Timestamp for unique filenames
61
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
62
+ trimmed_video_path = os.path.join(save_path, f"trimmed_video_{timestamp}.mp4")
63
+ trimmed_mask_path = os.path.join(save_path, f"trimmed_mask_{timestamp}.mp4")
64
+ priori_path = os.path.join(save_path, f"priori_{timestamp}.mp4")
65
+ output_path = os.path.join(save_path, f"diffueraser_result_{timestamp}.mp4")
66
+
67
+ # Trim input videos
68
+ trim_video(input_video, trimmed_video_path)
69
+ trim_video(input_mask, trimmed_mask_path)
70
+
71
+ # Dynamically compute video_length (in frames) assuming 30 fps
72
+ clip = VideoFileClip(trimmed_video_path)
73
+ video_duration = clip.duration
74
+ clip.close()
75
+ video_length = int(video_duration * 30)
76
+
77
+ # Model setup
78
  device = get_device()
 
79
  ckpt = "2-Step"
80
  video_inpainting_sd = DiffuEraser(device, base_model_path, vae_path, diffueraser_path, ckpt=ckpt)
81
  propainter = Propainter(propainter_model_dir, device=device)
 
 
82
 
83
+ # Run models
84
+ start_time = time.time()
 
 
85
 
86
+ # ProPainter (priori)
87
+ propainter.forward(trimmed_video_path, trimmed_mask_path, priori_path,
88
+ video_length=video_length, ref_stride=ref_stride,
89
+ neighbor_length=neighbor_length, subvideo_length=subvideo_length,
90
+ mask_dilation=mask_dilation_iter)
91
+
92
+ # DiffuEraser
93
+ guidance_scale = None
94
+ video_inpainting_sd.forward(trimmed_video_path, trimmed_mask_path, priori_path, output_path,
95
+ max_img_size=max_img_size, video_length=video_length,
96
+ mask_dilation_iter=mask_dilation_iter,
97
  guidance_scale=guidance_scale)
 
 
 
 
98
 
99
+ end_time = time.time()
100
+ print(f"DiffuEraser inference time: {end_time - start_time:.2f} seconds")
101
 
102
+ torch.cuda.empty_cache()
103
  return output_path
104
 
105
+ # Gradio interface
106
  with gr.Blocks() as demo:
 
107
  with gr.Column():
108
  gr.Markdown("# DiffuEraser: A Diffusion Model for Video Inpainting")
109
+ gr.Markdown("DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model ProPainter in both content completeness and temporal consistency while maintaining acceptable efficiency.")
110
+
111
  gr.HTML("""
112
  <div style="display:flex;column-gap:4px;">
113
  <a href="https://github.com/lixiaowen-xw/DiffuEraser">
114
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
115
+ </a>
116
  <a href="https://lixiaowen-xw.github.io/DiffuEraser-page">
117
  <img src='https://img.shields.io/badge/Project-Page-green'>
118
  </a>
 
126
  """)
127
 
128
  with gr.Row():
 
129
  with gr.Column():
 
130
  input_video = gr.Video(label="Input Video (MP4 ONLY)")
131
  input_mask = gr.Video(label="Input Mask Video (MP4 ONLY)")
132
  submit_btn = gr.Button("Submit")
133
 
134
  with gr.Column():
 
135
  video_result = gr.Video(label="Result")
136
  gr.Examples(
137
+ examples=[
138
  ["./examples/example1/video.mp4", "./examples/example1/mask.mp4"],
139
  ["./examples/example2/video.mp4", "./examples/example2/mask.mp4"],
140
  ["./examples/example3/video.mp4", "./examples/example3/mask.mp4"],
141
  ],
142
+ inputs=[input_video, input_mask]
143
  )
144
 
145
+ submit_btn.click(fn=infer, inputs=[input_video, input_mask], outputs=[video_result])
 
 
 
 
 
146
 
147
  demo.queue().launch(show_api=False, show_error=True)