cocktailpeanut commited on
Commit
8081e40
·
1 Parent(s): ce1de29
diffueraser/diffueraser.py CHANGED
@@ -22,6 +22,7 @@ from libs.unet_motion_model import MotionAdapter, UNetMotionModel
22
  from libs.brushnet_CA import BrushNetModel
23
  from libs.unet_2d_condition import UNet2DConditionModel
24
  from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline
 
25
 
26
 
27
  checkpoints = {
@@ -318,7 +319,7 @@ class DiffuEraser:
318
  latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample())
319
  latents = torch.cat(latents, dim=0)
320
  latents = latents * self.vae.config.scaling_factor #[(b f), c1, h, w], c1=4
321
- torch.cuda.empty_cache()
322
  timesteps = torch.tensor([0], device=self.device)
323
  timesteps = timesteps.long()
324
 
@@ -349,7 +350,7 @@ class DiffuEraser:
349
  guidance_scale=guidance_scale_final,
350
  latents=latents_pre,
351
  ).latents
352
- torch.cuda.empty_cache()
353
 
354
  def decode_latents(latents, weight_dtype):
355
  latents = 1 / self.vae.config.scaling_factor * latents
@@ -363,7 +364,7 @@ class DiffuEraser:
363
  with torch.no_grad():
364
  video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16)
365
  images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil")
366
- torch.cuda.empty_cache()
367
 
368
  ## replace input frames with updated frames
369
  black_image = Image.new('L', validation_masks_input[0].size, color=0)
@@ -376,7 +377,7 @@ class DiffuEraser:
376
  latents_pre_out=None
377
  sample_index=None
378
  gc.collect()
379
- torch.cuda.empty_cache()
380
 
381
  ################ Frame-by-frame inference ################
382
  ## add priori
@@ -396,7 +397,7 @@ class DiffuEraser:
396
  images = images[:real_video_length]
397
 
398
  gc.collect()
399
- torch.cuda.empty_cache()
400
 
401
  ################ Compose ################
402
  binary_masks = validation_masks_input_ori
 
22
  from libs.brushnet_CA import BrushNetModel
23
  from libs.unet_2d_condition import UNet2DConditionModel
24
  from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline
25
+ import devicetorch
26
 
27
 
28
  checkpoints = {
 
319
  latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample())
320
  latents = torch.cat(latents, dim=0)
321
  latents = latents * self.vae.config.scaling_factor #[(b f), c1, h, w], c1=4
322
+ devicetorch.empty_cache(torch)
323
  timesteps = torch.tensor([0], device=self.device)
324
  timesteps = timesteps.long()
325
 
 
350
  guidance_scale=guidance_scale_final,
351
  latents=latents_pre,
352
  ).latents
353
+ devicetorch.empty_cache(torch)
354
 
355
  def decode_latents(latents, weight_dtype):
356
  latents = 1 / self.vae.config.scaling_factor * latents
 
364
  with torch.no_grad():
365
  video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16)
366
  images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil")
367
+ devicetorch.empty_cache(torch)
368
 
369
  ## replace input frames with updated frames
370
  black_image = Image.new('L', validation_masks_input[0].size, color=0)
 
377
  latents_pre_out=None
378
  sample_index=None
379
  gc.collect()
380
+ devicetorch.empty_cache(torch)
381
 
382
  ################ Frame-by-frame inference ################
383
  ## add priori
 
397
  images = images[:real_video_length]
398
 
399
  gc.collect()
400
+ devicetorch.empty_cache(torch)
401
 
402
  ################ Compose ################
403
  binary_masks = validation_masks_input_ori
diffueraser/pipeline_diffueraser.py CHANGED
@@ -36,6 +36,7 @@ from diffusers import (
36
  from libs.unet_2d_condition import UNet2DConditionModel
37
  from libs.brushnet_CA import BrushNetModel
38
 
 
39
 
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
 
@@ -1326,7 +1327,7 @@ class StableDiffusionDiffuEraserPipeline(
1326
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1327
  self.unet.to("cpu")
1328
  self.brushnet.to("cpu")
1329
- torch.cuda.empty_cache()
1330
 
1331
  if output_type == "latent":
1332
  image = latents
 
36
  from libs.unet_2d_condition import UNet2DConditionModel
37
  from libs.brushnet_CA import BrushNetModel
38
 
39
+ import devicetorch
40
 
41
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
 
 
1327
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1328
  self.unet.to("cpu")
1329
  self.brushnet.to("cpu")
1330
+ devicetorch.empty_cache(torch)
1331
 
1332
  if output_type == "latent":
1333
  image = latents
gradio_app.py CHANGED
@@ -8,6 +8,7 @@ import gradio as gr
8
 
9
  # Download Weights
10
  from huggingface_hub import snapshot_download
 
11
 
12
  # List of subdirectories to create inside "checkpoints"
13
  subfolders = [
@@ -93,7 +94,7 @@ def infer(input_video, input_mask):
93
  inference_time = end_time - start_time
94
  print(f"DiffuEraser inference time: {inference_time:.4f} s")
95
 
96
- torch.cuda.empty_cache()
97
 
98
  return output_path
99
 
@@ -150,4 +151,4 @@ demo.queue().launch(show_api=False, show_error=True)
150
 
151
 
152
 
153
-
 
8
 
9
  # Download Weights
10
  from huggingface_hub import snapshot_download
11
+ import devicetorch
12
 
13
  # List of subdirectories to create inside "checkpoints"
14
  subfolders = [
 
94
  inference_time = end_time - start_time
95
  print(f"DiffuEraser inference time: {inference_time:.4f} s")
96
 
97
+ devicetorch.empty_cache(torch)
98
 
99
  return output_path
100
 
 
151
 
152
 
153
 
154
+
propainter/inference.py CHANGED
@@ -24,6 +24,7 @@ except:
24
  from propainter.core.utils import to_tensors
25
  from propainter.model.misc import get_device
26
 
 
27
  import warnings
28
  warnings.filterwarnings("ignore")
29
 
@@ -247,15 +248,15 @@ class Propainter:
247
 
248
  gt_flows_f_list.append(flows_f)
249
  gt_flows_b_list.append(flows_b)
250
- torch.cuda.empty_cache()
251
 
252
  gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
253
  gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
254
  gt_flows_bi = (gt_flows_f, gt_flows_b)
255
  else:
256
  gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
257
- torch.cuda.empty_cache()
258
- torch.cuda.empty_cache()
259
  gc.collect()
260
 
261
  if use_half:
@@ -284,7 +285,7 @@ class Propainter:
284
 
285
  pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
286
  pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
287
- torch.cuda.empty_cache()
288
 
289
  pred_flows_f = torch.cat(pred_flows_f, dim=1)
290
  pred_flows_b = torch.cat(pred_flows_b, dim=1)
@@ -292,8 +293,8 @@ class Propainter:
292
  else:
293
  pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
294
  pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
295
- torch.cuda.empty_cache()
296
- torch.cuda.empty_cache()
297
  gc.collect()
298
 
299
 
@@ -321,15 +322,15 @@ class Propainter:
321
 
322
  gt_flows_f_list.append(flows_f)
323
  gt_flows_b_list.append(flows_b)
324
- torch.cuda.empty_cache()
325
 
326
  gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
327
  gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
328
  sample_gt_flows_bi = (gt_flows_f, gt_flows_b)
329
  else:
330
  sample_gt_flows_bi = self.fix_raft(sample_frames, iters=raft_iter)
331
- torch.cuda.empty_cache()
332
- torch.cuda.empty_cache()
333
  gc.collect()
334
 
335
  if use_half:
@@ -356,7 +357,7 @@ class Propainter:
356
 
357
  pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
358
  pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
359
- torch.cuda.empty_cache()
360
 
361
  pred_flows_f = torch.cat(pred_flows_f, dim=1)
362
  pred_flows_b = torch.cat(pred_flows_b, dim=1)
@@ -364,8 +365,8 @@ class Propainter:
364
  else:
365
  sample_pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(sample_gt_flows_bi, sample_flow_masks)
366
  sample_pred_flows_bi = self.fix_flow_complete.combine_flow(sample_gt_flows_bi, sample_pred_flows_bi, sample_flow_masks)
367
- torch.cuda.empty_cache()
368
- torch.cuda.empty_cache()
369
  gc.collect()
370
 
371
  masked_frames = sample_frames * (1 - sample_masks_dilated)
@@ -391,7 +392,7 @@ class Propainter:
391
 
392
  updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
393
  updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
394
- torch.cuda.empty_cache()
395
 
396
  updated_frames = torch.cat(updated_frames, dim=1)
397
  updated_masks = torch.cat(updated_masks, dim=1)
@@ -400,7 +401,7 @@ class Propainter:
400
  prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, sample_pred_flows_bi, sample_masks_dilated, 'nearest')
401
  updated_frames = sample_frames * (1 - sample_masks_dilated) + prop_imgs.view(b, t, 3, h, w) * sample_masks_dilated
402
  updated_masks = updated_local_masks.view(b, t, 1, h, w)
403
- torch.cuda.empty_cache()
404
 
405
  ## replace input frames/masks with updated frames/masks
406
  for i,index in enumerate(index_sample):
@@ -432,7 +433,7 @@ class Propainter:
432
 
433
  updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
434
  updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
435
- torch.cuda.empty_cache()
436
 
437
  updated_frames = torch.cat(updated_frames, dim=1)
438
  updated_masks = torch.cat(updated_masks, dim=1)
@@ -441,7 +442,7 @@ class Propainter:
441
  prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
442
  updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
443
  updated_masks = updated_local_masks.view(b, t, 1, h, w)
444
- torch.cuda.empty_cache()
445
 
446
  comp_frames = [None] * video_length
447
 
@@ -451,7 +452,7 @@ class Propainter:
451
  else:
452
  ref_num = -1
453
 
454
- torch.cuda.empty_cache()
455
  # ---- feature propagation + transformer ----
456
  for f in tqdm(range(0, video_length, neighbor_stride)):
457
  neighbor_ids = [
@@ -488,7 +489,7 @@ class Propainter:
488
 
489
  comp_frames[idx] = comp_frames[idx].astype(np.uint8)
490
 
491
- torch.cuda.empty_cache()
492
 
493
  ##save composed video##
494
  comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
@@ -499,7 +500,7 @@ class Propainter:
499
  writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
500
  writer.release()
501
 
502
- torch.cuda.empty_cache()
503
 
504
  return output_path
505
 
@@ -517,4 +518,4 @@ if __name__ == '__main__':
517
  res = propainter.forward(video, mask, output)
518
 
519
 
520
-
 
24
  from propainter.core.utils import to_tensors
25
  from propainter.model.misc import get_device
26
 
27
+ import devicetorch
28
  import warnings
29
  warnings.filterwarnings("ignore")
30
 
 
248
 
249
  gt_flows_f_list.append(flows_f)
250
  gt_flows_b_list.append(flows_b)
251
+ devicetorch.empty_cache(torch)
252
 
253
  gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
254
  gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
255
  gt_flows_bi = (gt_flows_f, gt_flows_b)
256
  else:
257
  gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
258
+ devicetorch.empty_cache(torch)
259
+ devicetorch.empty_cache(torch)
260
  gc.collect()
261
 
262
  if use_half:
 
285
 
286
  pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
287
  pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
288
+ devicetorch.empty_cache(torch)
289
 
290
  pred_flows_f = torch.cat(pred_flows_f, dim=1)
291
  pred_flows_b = torch.cat(pred_flows_b, dim=1)
 
293
  else:
294
  pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
295
  pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
296
+ devicetorch.empty_cache(torch)
297
+ devicetorch.empty_cache(torch)
298
  gc.collect()
299
 
300
 
 
322
 
323
  gt_flows_f_list.append(flows_f)
324
  gt_flows_b_list.append(flows_b)
325
+ devicetorch.empty_cache(torch)
326
 
327
  gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
328
  gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
329
  sample_gt_flows_bi = (gt_flows_f, gt_flows_b)
330
  else:
331
  sample_gt_flows_bi = self.fix_raft(sample_frames, iters=raft_iter)
332
+ devicetorch.empty_cache(torch)
333
+ devicetorch.empty_cache(torch)
334
  gc.collect()
335
 
336
  if use_half:
 
357
 
358
  pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
359
  pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
360
+ devicetorch.empty_cache(torch)
361
 
362
  pred_flows_f = torch.cat(pred_flows_f, dim=1)
363
  pred_flows_b = torch.cat(pred_flows_b, dim=1)
 
365
  else:
366
  sample_pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(sample_gt_flows_bi, sample_flow_masks)
367
  sample_pred_flows_bi = self.fix_flow_complete.combine_flow(sample_gt_flows_bi, sample_pred_flows_bi, sample_flow_masks)
368
+ devicetorch.empty_cache(torch)
369
+ devicetorch.empty_cache(torch)
370
  gc.collect()
371
 
372
  masked_frames = sample_frames * (1 - sample_masks_dilated)
 
392
 
393
  updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
394
  updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
395
+ devicetorch.empty_cache(torch)
396
 
397
  updated_frames = torch.cat(updated_frames, dim=1)
398
  updated_masks = torch.cat(updated_masks, dim=1)
 
401
  prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, sample_pred_flows_bi, sample_masks_dilated, 'nearest')
402
  updated_frames = sample_frames * (1 - sample_masks_dilated) + prop_imgs.view(b, t, 3, h, w) * sample_masks_dilated
403
  updated_masks = updated_local_masks.view(b, t, 1, h, w)
404
+ devicetorch.empty_cache(torch)
405
 
406
  ## replace input frames/masks with updated frames/masks
407
  for i,index in enumerate(index_sample):
 
433
 
434
  updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
435
  updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
436
+ devicetorch.empty_cache(torch)
437
 
438
  updated_frames = torch.cat(updated_frames, dim=1)
439
  updated_masks = torch.cat(updated_masks, dim=1)
 
442
  prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
443
  updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
444
  updated_masks = updated_local_masks.view(b, t, 1, h, w)
445
+ devicetorch.empty_cache(torch)
446
 
447
  comp_frames = [None] * video_length
448
 
 
452
  else:
453
  ref_num = -1
454
 
455
+ devicetorch.empty_cache(torch)
456
  # ---- feature propagation + transformer ----
457
  for f in tqdm(range(0, video_length, neighbor_stride)):
458
  neighbor_ids = [
 
489
 
490
  comp_frames[idx] = comp_frames[idx].astype(np.uint8)
491
 
492
+ devicetorch.empty_cache(torch)
493
 
494
  ##save composed video##
495
  comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
 
500
  writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
501
  writer.release()
502
 
503
+ devicetorch.empty_cache(torch)
504
 
505
  return output_path
506
 
 
518
  res = propainter.forward(video, mask, output)
519
 
520
 
521
+
propainter/model/misc.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
7
  import logging
8
  import numpy as np
9
  from os import path as osp
 
10
 
11
  def constant_init(module, val, bias=0):
12
  if hasattr(module, 'weight') and module.weight is not None:
@@ -81,8 +82,13 @@ def set_random_seed(seed):
81
  random.seed(seed)
82
  np.random.seed(seed)
83
  torch.manual_seed(seed)
84
- torch.cuda.manual_seed(seed)
85
- torch.cuda.manual_seed_all(seed)
 
 
 
 
 
86
 
87
 
88
  def get_time_str():
@@ -128,4 +134,4 @@ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
128
  else:
129
  continue
130
 
131
- return _scandir(dir_path, suffix=suffix, recursive=recursive)
 
7
  import logging
8
  import numpy as np
9
  from os import path as osp
10
+ import devicetorch
11
 
12
  def constant_init(module, val, bias=0):
13
  if hasattr(module, 'weight') and module.weight is not None:
 
82
  random.seed(seed)
83
  np.random.seed(seed)
84
  torch.manual_seed(seed)
85
+
86
+ if torch.cuda.is_available():
87
+ torch.cuda.manual_seed(seed)
88
+ torch.cuda.manual_seed_all(seed)
89
+
90
+ if torch.backends.mps.is_available():
91
+ torch.mps.manual_seed(seed)
92
 
93
 
94
  def get_time_str():
 
134
  else:
135
  continue
136
 
137
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
run_diffueraser.py CHANGED
@@ -4,6 +4,7 @@ import time
4
  import argparse
5
  from diffueraser.diffueraser import DiffuEraser
6
  from propainter.inference import Propainter, get_device
 
7
 
8
  def main():
9
 
@@ -53,10 +54,10 @@ def main():
53
  inference_time = end_time - start_time
54
  print(f"DiffuEraser inference time: {inference_time:.4f} s")
55
 
56
- torch.cuda.empty_cache()
57
 
58
  if __name__ == '__main__':
59
  main()
60
 
61
 
62
-
 
4
  import argparse
5
  from diffueraser.diffueraser import DiffuEraser
6
  from propainter.inference import Propainter, get_device
7
+ import devicetorch
8
 
9
  def main():
10
 
 
54
  inference_time = end_time - start_time
55
  print(f"DiffuEraser inference time: {inference_time:.4f} s")
56
 
57
+ devicetorch.empty_cache(torch)
58
 
59
  if __name__ == '__main__':
60
  main()
61
 
62
 
63
+