aleafy commited on
Commit
6dc7213
·
1 Parent(s): 0a63786
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -5,7 +5,7 @@ from enum import Enum
5
  import db_examples
6
  import cv2
7
 
8
- # from demo_utils1 import *
9
 
10
  from misc_utils.train_utils import unit_test_create_model
11
  from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
@@ -28,7 +28,10 @@ from tqdm import tqdm
28
 
29
  # 下载文件
30
  os.makedirs('models', exist_ok=True)
31
- filename = "models/iclight_sd15_fbc.safetensors"
 
 
 
32
 
33
  # if not os.path.exists(filename):
34
  # original_path = os.getcwd()
@@ -157,15 +160,16 @@ config_path = 'configs/instruct_v2v_ic_gradio.yaml'
157
  diffusion_model = unit_test_create_model(config_path).cuda()
158
 
159
  # 加载模型检查点
160
- ckpt_path = 'models/pytorch_model.bin' #! change
161
- ckpt = torch.load(ckpt_path, map_location='cpu')
 
162
  diffusion_model.load_state_dict(ckpt, strict=False)
163
 
164
  # import pdb; pdb.set_trace()
165
 
166
- # # 更改全局临时目录
167
- # new_tmp_dir = "./demo/gradio_bg"
168
- # os.makedirs(new_tmp_dir, exist_ok=True)
169
 
170
  # import pdb; pdb.set_trace()
171
 
 
5
  import db_examples
6
  import cv2
7
 
8
+ from demo_utils1 import *
9
 
10
  from misc_utils.train_utils import unit_test_create_model
11
  from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
 
28
 
29
  # 下载文件
30
  os.makedirs('models', exist_ok=True)
31
+ model_path = "models/relvid_mm_sd15_fbc_unet.pth"
32
+
33
+ if not os.path.exists(filename):
34
+ download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
35
 
36
  # if not os.path.exists(filename):
37
  # original_path = os.getcwd()
 
160
  diffusion_model = unit_test_create_model(config_path).cuda()
161
 
162
  # 加载模型检查点
163
+ # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
164
+ # ckpt_path = 'tmp/pytorch_model.bin'
165
+ ckpt = torch.load(model_path, map_location='cpu')
166
  diffusion_model.load_state_dict(ckpt, strict=False)
167
 
168
  # import pdb; pdb.set_trace()
169
 
170
+ # 更改全局临时目录
171
+ new_tmp_dir = "./demo/gradio_bg"
172
+ os.makedirs(new_tmp_dir, exist_ok=True)
173
 
174
  # import pdb; pdb.set_trace()
175