aleafy commited on
Commit
b49b3d1
·
1 Parent(s): 7699d8b
Files changed (1) hide show
  1. app.py +19 -19
app.py CHANGED
@@ -152,24 +152,24 @@ def clear_cache(output_path):
152
 
153
  #! 加载模型
154
  # 配置路径和加载模型
155
- # config_path = 'configs/instruct_v2v_ic_gradio.yaml'
156
- # diffusion_model = unit_test_create_model(config_path)
157
- # diffusion_model = diffusion_model.to('cuda')
158
 
159
- # # 加载模型检查点
160
- # # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
161
- # # ckpt_path = 'tmp/pytorch_model.bin'
162
- # # 下载文件
163
 
164
- # os.makedirs('models', exist_ok=True)
165
- # model_path = "models/relvid_mm_sd15_fbc_unet.pth"
166
 
167
- # if not os.path.exists(model_path):
168
- # download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
169
 
170
 
171
- # ckpt = torch.load(model_path, map_location='cpu')
172
- # diffusion_model.load_state_dict(ckpt, strict=False)
173
 
174
 
175
  # import pdb; pdb.set_trace()
@@ -216,12 +216,12 @@ def save_video_from_frames(image_pred, save_pth, fps=8):
216
  out.release()
217
  print(f"视频已保存至 {save_pth}")
218
 
219
- #! model
220
- # inf_pipe = InferenceIP2PVideo(
221
- # diffusion_model.unet,
222
- # scheduler='ddpm',
223
- # num_ddim_steps=20
224
- # )
225
 
226
 
227
  def process_example(*args):
 
152
 
153
  #! 加载模型
154
  # 配置路径和加载模型
155
+ config_path = 'configs/instruct_v2v_ic_gradio.yaml'
156
+ diffusion_model = unit_test_create_model(config_path)
157
+ diffusion_model = diffusion_model.to('cuda')
158
 
159
+ # 加载模型检查点
160
+ # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
161
+ # ckpt_path = 'tmp/pytorch_model.bin'
162
+ # 下载文件
163
 
164
+ os.makedirs('models', exist_ok=True)
165
+ model_path = "models/relvid_mm_sd15_fbc_unet.pth"
166
 
167
+ if not os.path.exists(model_path):
168
+ download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
169
 
170
 
171
+ ckpt = torch.load(model_path, map_location='cpu')
172
+ diffusion_model.load_state_dict(ckpt, strict=False)
173
 
174
 
175
  # import pdb; pdb.set_trace()
 
216
  out.release()
217
  print(f"视频已保存至 {save_pth}")
218
 
219
+
220
+ inf_pipe = InferenceIP2PVideo(
221
+ diffusion_model.unet,
222
+ scheduler='ddpm',
223
+ num_ddim_steps=20
224
+ )
225
 
226
 
227
  def process_example(*args):