aleafy commited on
Commit
3dd9098
·
1 Parent(s): 2c33647
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -5,6 +5,8 @@ from enum import Enum
5
  import db_examples
6
  import cv2
7
 
 
 
8
  from demo_utils1 import *
9
 
10
  from misc_utils.train_utils import unit_test_create_model
@@ -27,12 +29,6 @@ import os
27
  from pl_trainer.inference.inference import InferenceIP2PVideo
28
  from tqdm import tqdm
29
 
30
- # 下载文件
31
- os.makedirs('models', exist_ok=True)
32
- model_path = "models/relvid_mm_sd15_fbc_unet.pth"
33
-
34
- if not os.path.exists(model_path):
35
- download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
36
 
37
  # if not os.path.exists(filename):
38
  # original_path = os.getcwd()
@@ -158,7 +154,8 @@ def clear_cache(output_path):
158
  #! 加载模型
159
  # 配置路径和加载模型
160
  config_path = 'configs/instruct_v2v_ic_gradio.yaml'
161
- diffusion_model = unit_test_create_model(config_path).cuda()
 
162
 
163
  # 加载模型检查点
164
  # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
@@ -166,6 +163,13 @@ diffusion_model = unit_test_create_model(config_path).cuda()
166
  ckpt = torch.load(model_path, map_location='cpu')
167
  diffusion_model.load_state_dict(ckpt, strict=False)
168
 
 
 
 
 
 
 
 
169
  # import pdb; pdb.set_trace()
170
 
171
  # 更改全局临时目录
@@ -212,6 +216,7 @@ def save_video_from_frames(image_pred, save_pth, fps=8):
212
 
213
 
214
  # 伪函数占位(生成空白视频)
 
215
  def dummy_process(input_fg, input_bg):
216
  # import pdb; pdb.set_trace()
217
  fg_tensor = load_and_process_video(input_fg).cuda().unsqueeze(0)
 
5
  import db_examples
6
  import cv2
7
 
8
+ import spaces
9
+
10
  from demo_utils1 import *
11
 
12
  from misc_utils.train_utils import unit_test_create_model
 
29
  from pl_trainer.inference.inference import InferenceIP2PVideo
30
  from tqdm import tqdm
31
 
 
 
 
 
 
 
32
 
33
  # if not os.path.exists(filename):
34
  # original_path = os.getcwd()
 
154
  #! 加载模型
155
  # 配置路径和加载模型
156
  config_path = 'configs/instruct_v2v_ic_gradio.yaml'
157
+ diffusion_model = unit_test_create_model(config_path)
158
+ diffusion_model = diffusion_model.to('cuda')
159
 
160
  # 加载模型检查点
161
  # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
 
163
  ckpt = torch.load(model_path, map_location='cpu')
164
  diffusion_model.load_state_dict(ckpt, strict=False)
165
 
166
+ # 下载文件
167
+ os.makedirs('models', exist_ok=True)
168
+ model_path = "models/relvid_mm_sd15_fbc_unet.pth"
169
+
170
+ if not os.path.exists(model_path):
171
+ download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
172
+
173
  # import pdb; pdb.set_trace()
174
 
175
  # 更改全局临时目录
 
216
 
217
 
218
  # 伪函数占位(生成空白视频)
219
+ @spaces.GPU
220
  def dummy_process(input_fg, input_bg):
221
  # import pdb; pdb.set_trace()
222
  fg_tensor = load_and_process_video(input_fg).cuda().unsqueeze(0)