Spaces:
Runtime error
Runtime error
Commit
·
8081e40
1
Parent(s):
ce1de29
update
Browse files- diffueraser/diffueraser.py +6 -5
- diffueraser/pipeline_diffueraser.py +2 -1
- gradio_app.py +3 -2
- propainter/inference.py +21 -20
- propainter/model/misc.py +9 -3
- run_diffueraser.py +3 -2
diffueraser/diffueraser.py
CHANGED
@@ -22,6 +22,7 @@ from libs.unet_motion_model import MotionAdapter, UNetMotionModel
|
|
22 |
from libs.brushnet_CA import BrushNetModel
|
23 |
from libs.unet_2d_condition import UNet2DConditionModel
|
24 |
from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline
|
|
|
25 |
|
26 |
|
27 |
checkpoints = {
|
@@ -318,7 +319,7 @@ class DiffuEraser:
|
|
318 |
latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample())
|
319 |
latents = torch.cat(latents, dim=0)
|
320 |
latents = latents * self.vae.config.scaling_factor #[(b f), c1, h, w], c1=4
|
321 |
-
|
322 |
timesteps = torch.tensor([0], device=self.device)
|
323 |
timesteps = timesteps.long()
|
324 |
|
@@ -349,7 +350,7 @@ class DiffuEraser:
|
|
349 |
guidance_scale=guidance_scale_final,
|
350 |
latents=latents_pre,
|
351 |
).latents
|
352 |
-
|
353 |
|
354 |
def decode_latents(latents, weight_dtype):
|
355 |
latents = 1 / self.vae.config.scaling_factor * latents
|
@@ -363,7 +364,7 @@ class DiffuEraser:
|
|
363 |
with torch.no_grad():
|
364 |
video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16)
|
365 |
images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil")
|
366 |
-
|
367 |
|
368 |
## replace input frames with updated frames
|
369 |
black_image = Image.new('L', validation_masks_input[0].size, color=0)
|
@@ -376,7 +377,7 @@ class DiffuEraser:
|
|
376 |
latents_pre_out=None
|
377 |
sample_index=None
|
378 |
gc.collect()
|
379 |
-
|
380 |
|
381 |
################ Frame-by-frame inference ################
|
382 |
## add priori
|
@@ -396,7 +397,7 @@ class DiffuEraser:
|
|
396 |
images = images[:real_video_length]
|
397 |
|
398 |
gc.collect()
|
399 |
-
|
400 |
|
401 |
################ Compose ################
|
402 |
binary_masks = validation_masks_input_ori
|
|
|
22 |
from libs.brushnet_CA import BrushNetModel
|
23 |
from libs.unet_2d_condition import UNet2DConditionModel
|
24 |
from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline
|
25 |
+
import devicetorch
|
26 |
|
27 |
|
28 |
checkpoints = {
|
|
|
319 |
latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample())
|
320 |
latents = torch.cat(latents, dim=0)
|
321 |
latents = latents * self.vae.config.scaling_factor #[(b f), c1, h, w], c1=4
|
322 |
+
devicetorch.empty_cache(torch)
|
323 |
timesteps = torch.tensor([0], device=self.device)
|
324 |
timesteps = timesteps.long()
|
325 |
|
|
|
350 |
guidance_scale=guidance_scale_final,
|
351 |
latents=latents_pre,
|
352 |
).latents
|
353 |
+
devicetorch.empty_cache(torch)
|
354 |
|
355 |
def decode_latents(latents, weight_dtype):
|
356 |
latents = 1 / self.vae.config.scaling_factor * latents
|
|
|
364 |
with torch.no_grad():
|
365 |
video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16)
|
366 |
images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil")
|
367 |
+
devicetorch.empty_cache(torch)
|
368 |
|
369 |
## replace input frames with updated frames
|
370 |
black_image = Image.new('L', validation_masks_input[0].size, color=0)
|
|
|
377 |
latents_pre_out=None
|
378 |
sample_index=None
|
379 |
gc.collect()
|
380 |
+
devicetorch.empty_cache(torch)
|
381 |
|
382 |
################ Frame-by-frame inference ################
|
383 |
## add priori
|
|
|
397 |
images = images[:real_video_length]
|
398 |
|
399 |
gc.collect()
|
400 |
+
devicetorch.empty_cache(torch)
|
401 |
|
402 |
################ Compose ################
|
403 |
binary_masks = validation_masks_input_ori
|
diffueraser/pipeline_diffueraser.py
CHANGED
@@ -36,6 +36,7 @@ from diffusers import (
|
|
36 |
from libs.unet_2d_condition import UNet2DConditionModel
|
37 |
from libs.brushnet_CA import BrushNetModel
|
38 |
|
|
|
39 |
|
40 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
|
@@ -1326,7 +1327,7 @@ class StableDiffusionDiffuEraserPipeline(
|
|
1326 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1327 |
self.unet.to("cpu")
|
1328 |
self.brushnet.to("cpu")
|
1329 |
-
|
1330 |
|
1331 |
if output_type == "latent":
|
1332 |
image = latents
|
|
|
36 |
from libs.unet_2d_condition import UNet2DConditionModel
|
37 |
from libs.brushnet_CA import BrushNetModel
|
38 |
|
39 |
+
import devicetorch
|
40 |
|
41 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
42 |
|
|
|
1327 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1328 |
self.unet.to("cpu")
|
1329 |
self.brushnet.to("cpu")
|
1330 |
+
devicetorch.empty_cache(torch)
|
1331 |
|
1332 |
if output_type == "latent":
|
1333 |
image = latents
|
gradio_app.py
CHANGED
@@ -8,6 +8,7 @@ import gradio as gr
|
|
8 |
|
9 |
# Download Weights
|
10 |
from huggingface_hub import snapshot_download
|
|
|
11 |
|
12 |
# List of subdirectories to create inside "checkpoints"
|
13 |
subfolders = [
|
@@ -93,7 +94,7 @@ def infer(input_video, input_mask):
|
|
93 |
inference_time = end_time - start_time
|
94 |
print(f"DiffuEraser inference time: {inference_time:.4f} s")
|
95 |
|
96 |
-
|
97 |
|
98 |
return output_path
|
99 |
|
@@ -150,4 +151,4 @@ demo.queue().launch(show_api=False, show_error=True)
|
|
150 |
|
151 |
|
152 |
|
153 |
-
|
|
|
8 |
|
9 |
# Download Weights
|
10 |
from huggingface_hub import snapshot_download
|
11 |
+
import devicetorch
|
12 |
|
13 |
# List of subdirectories to create inside "checkpoints"
|
14 |
subfolders = [
|
|
|
94 |
inference_time = end_time - start_time
|
95 |
print(f"DiffuEraser inference time: {inference_time:.4f} s")
|
96 |
|
97 |
+
devicetorch.empty_cache(torch)
|
98 |
|
99 |
return output_path
|
100 |
|
|
|
151 |
|
152 |
|
153 |
|
154 |
+
|
propainter/inference.py
CHANGED
@@ -24,6 +24,7 @@ except:
|
|
24 |
from propainter.core.utils import to_tensors
|
25 |
from propainter.model.misc import get_device
|
26 |
|
|
|
27 |
import warnings
|
28 |
warnings.filterwarnings("ignore")
|
29 |
|
@@ -247,15 +248,15 @@ class Propainter:
|
|
247 |
|
248 |
gt_flows_f_list.append(flows_f)
|
249 |
gt_flows_b_list.append(flows_b)
|
250 |
-
|
251 |
|
252 |
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
253 |
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
254 |
gt_flows_bi = (gt_flows_f, gt_flows_b)
|
255 |
else:
|
256 |
gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
|
257 |
-
|
258 |
-
|
259 |
gc.collect()
|
260 |
|
261 |
if use_half:
|
@@ -284,7 +285,7 @@ class Propainter:
|
|
284 |
|
285 |
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
286 |
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
287 |
-
|
288 |
|
289 |
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
290 |
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
@@ -292,8 +293,8 @@ class Propainter:
|
|
292 |
else:
|
293 |
pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
|
294 |
pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
|
295 |
-
|
296 |
-
|
297 |
gc.collect()
|
298 |
|
299 |
|
@@ -321,15 +322,15 @@ class Propainter:
|
|
321 |
|
322 |
gt_flows_f_list.append(flows_f)
|
323 |
gt_flows_b_list.append(flows_b)
|
324 |
-
|
325 |
|
326 |
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
327 |
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
328 |
sample_gt_flows_bi = (gt_flows_f, gt_flows_b)
|
329 |
else:
|
330 |
sample_gt_flows_bi = self.fix_raft(sample_frames, iters=raft_iter)
|
331 |
-
|
332 |
-
|
333 |
gc.collect()
|
334 |
|
335 |
if use_half:
|
@@ -356,7 +357,7 @@ class Propainter:
|
|
356 |
|
357 |
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
358 |
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
359 |
-
|
360 |
|
361 |
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
362 |
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
@@ -364,8 +365,8 @@ class Propainter:
|
|
364 |
else:
|
365 |
sample_pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(sample_gt_flows_bi, sample_flow_masks)
|
366 |
sample_pred_flows_bi = self.fix_flow_complete.combine_flow(sample_gt_flows_bi, sample_pred_flows_bi, sample_flow_masks)
|
367 |
-
|
368 |
-
|
369 |
gc.collect()
|
370 |
|
371 |
masked_frames = sample_frames * (1 - sample_masks_dilated)
|
@@ -391,7 +392,7 @@ class Propainter:
|
|
391 |
|
392 |
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
393 |
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
394 |
-
|
395 |
|
396 |
updated_frames = torch.cat(updated_frames, dim=1)
|
397 |
updated_masks = torch.cat(updated_masks, dim=1)
|
@@ -400,7 +401,7 @@ class Propainter:
|
|
400 |
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, sample_pred_flows_bi, sample_masks_dilated, 'nearest')
|
401 |
updated_frames = sample_frames * (1 - sample_masks_dilated) + prop_imgs.view(b, t, 3, h, w) * sample_masks_dilated
|
402 |
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
403 |
-
|
404 |
|
405 |
## replace input frames/masks with updated frames/masks
|
406 |
for i,index in enumerate(index_sample):
|
@@ -432,7 +433,7 @@ class Propainter:
|
|
432 |
|
433 |
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
434 |
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
435 |
-
|
436 |
|
437 |
updated_frames = torch.cat(updated_frames, dim=1)
|
438 |
updated_masks = torch.cat(updated_masks, dim=1)
|
@@ -441,7 +442,7 @@ class Propainter:
|
|
441 |
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
|
442 |
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
|
443 |
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
444 |
-
|
445 |
|
446 |
comp_frames = [None] * video_length
|
447 |
|
@@ -451,7 +452,7 @@ class Propainter:
|
|
451 |
else:
|
452 |
ref_num = -1
|
453 |
|
454 |
-
|
455 |
# ---- feature propagation + transformer ----
|
456 |
for f in tqdm(range(0, video_length, neighbor_stride)):
|
457 |
neighbor_ids = [
|
@@ -488,7 +489,7 @@ class Propainter:
|
|
488 |
|
489 |
comp_frames[idx] = comp_frames[idx].astype(np.uint8)
|
490 |
|
491 |
-
|
492 |
|
493 |
##save composed video##
|
494 |
comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
|
@@ -499,7 +500,7 @@ class Propainter:
|
|
499 |
writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
500 |
writer.release()
|
501 |
|
502 |
-
|
503 |
|
504 |
return output_path
|
505 |
|
@@ -517,4 +518,4 @@ if __name__ == '__main__':
|
|
517 |
res = propainter.forward(video, mask, output)
|
518 |
|
519 |
|
520 |
-
|
|
|
24 |
from propainter.core.utils import to_tensors
|
25 |
from propainter.model.misc import get_device
|
26 |
|
27 |
+
import devicetorch
|
28 |
import warnings
|
29 |
warnings.filterwarnings("ignore")
|
30 |
|
|
|
248 |
|
249 |
gt_flows_f_list.append(flows_f)
|
250 |
gt_flows_b_list.append(flows_b)
|
251 |
+
devicetorch.empty_cache(torch)
|
252 |
|
253 |
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
254 |
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
255 |
gt_flows_bi = (gt_flows_f, gt_flows_b)
|
256 |
else:
|
257 |
gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
|
258 |
+
devicetorch.empty_cache(torch)
|
259 |
+
devicetorch.empty_cache(torch)
|
260 |
gc.collect()
|
261 |
|
262 |
if use_half:
|
|
|
285 |
|
286 |
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
287 |
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
288 |
+
devicetorch.empty_cache(torch)
|
289 |
|
290 |
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
291 |
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
|
|
293 |
else:
|
294 |
pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
|
295 |
pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
|
296 |
+
devicetorch.empty_cache(torch)
|
297 |
+
devicetorch.empty_cache(torch)
|
298 |
gc.collect()
|
299 |
|
300 |
|
|
|
322 |
|
323 |
gt_flows_f_list.append(flows_f)
|
324 |
gt_flows_b_list.append(flows_b)
|
325 |
+
devicetorch.empty_cache(torch)
|
326 |
|
327 |
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
328 |
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
329 |
sample_gt_flows_bi = (gt_flows_f, gt_flows_b)
|
330 |
else:
|
331 |
sample_gt_flows_bi = self.fix_raft(sample_frames, iters=raft_iter)
|
332 |
+
devicetorch.empty_cache(torch)
|
333 |
+
devicetorch.empty_cache(torch)
|
334 |
gc.collect()
|
335 |
|
336 |
if use_half:
|
|
|
357 |
|
358 |
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
359 |
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
360 |
+
devicetorch.empty_cache(torch)
|
361 |
|
362 |
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
363 |
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
|
|
365 |
else:
|
366 |
sample_pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(sample_gt_flows_bi, sample_flow_masks)
|
367 |
sample_pred_flows_bi = self.fix_flow_complete.combine_flow(sample_gt_flows_bi, sample_pred_flows_bi, sample_flow_masks)
|
368 |
+
devicetorch.empty_cache(torch)
|
369 |
+
devicetorch.empty_cache(torch)
|
370 |
gc.collect()
|
371 |
|
372 |
masked_frames = sample_frames * (1 - sample_masks_dilated)
|
|
|
392 |
|
393 |
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
394 |
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
395 |
+
devicetorch.empty_cache(torch)
|
396 |
|
397 |
updated_frames = torch.cat(updated_frames, dim=1)
|
398 |
updated_masks = torch.cat(updated_masks, dim=1)
|
|
|
401 |
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, sample_pred_flows_bi, sample_masks_dilated, 'nearest')
|
402 |
updated_frames = sample_frames * (1 - sample_masks_dilated) + prop_imgs.view(b, t, 3, h, w) * sample_masks_dilated
|
403 |
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
404 |
+
devicetorch.empty_cache(torch)
|
405 |
|
406 |
## replace input frames/masks with updated frames/masks
|
407 |
for i,index in enumerate(index_sample):
|
|
|
433 |
|
434 |
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
435 |
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
436 |
+
devicetorch.empty_cache(torch)
|
437 |
|
438 |
updated_frames = torch.cat(updated_frames, dim=1)
|
439 |
updated_masks = torch.cat(updated_masks, dim=1)
|
|
|
442 |
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
|
443 |
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
|
444 |
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
445 |
+
devicetorch.empty_cache(torch)
|
446 |
|
447 |
comp_frames = [None] * video_length
|
448 |
|
|
|
452 |
else:
|
453 |
ref_num = -1
|
454 |
|
455 |
+
devicetorch.empty_cache(torch)
|
456 |
# ---- feature propagation + transformer ----
|
457 |
for f in tqdm(range(0, video_length, neighbor_stride)):
|
458 |
neighbor_ids = [
|
|
|
489 |
|
490 |
comp_frames[idx] = comp_frames[idx].astype(np.uint8)
|
491 |
|
492 |
+
devicetorch.empty_cache(torch)
|
493 |
|
494 |
##save composed video##
|
495 |
comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
|
|
|
500 |
writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
501 |
writer.release()
|
502 |
|
503 |
+
devicetorch.empty_cache(torch)
|
504 |
|
505 |
return output_path
|
506 |
|
|
|
518 |
res = propainter.forward(video, mask, output)
|
519 |
|
520 |
|
521 |
+
|
propainter/model/misc.py
CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|
7 |
import logging
|
8 |
import numpy as np
|
9 |
from os import path as osp
|
|
|
10 |
|
11 |
def constant_init(module, val, bias=0):
|
12 |
if hasattr(module, 'weight') and module.weight is not None:
|
@@ -81,8 +82,13 @@ def set_random_seed(seed):
|
|
81 |
random.seed(seed)
|
82 |
np.random.seed(seed)
|
83 |
torch.manual_seed(seed)
|
84 |
-
|
85 |
-
torch.cuda.
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
|
88 |
def get_time_str():
|
@@ -128,4 +134,4 @@ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
|
128 |
else:
|
129 |
continue
|
130 |
|
131 |
-
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
|
|
7 |
import logging
|
8 |
import numpy as np
|
9 |
from os import path as osp
|
10 |
+
import devicetorch
|
11 |
|
12 |
def constant_init(module, val, bias=0):
|
13 |
if hasattr(module, 'weight') and module.weight is not None:
|
|
|
82 |
random.seed(seed)
|
83 |
np.random.seed(seed)
|
84 |
torch.manual_seed(seed)
|
85 |
+
|
86 |
+
if torch.cuda.is_available():
|
87 |
+
torch.cuda.manual_seed(seed)
|
88 |
+
torch.cuda.manual_seed_all(seed)
|
89 |
+
|
90 |
+
if torch.backends.mps.is_available():
|
91 |
+
torch.mps.manual_seed(seed)
|
92 |
|
93 |
|
94 |
def get_time_str():
|
|
|
134 |
else:
|
135 |
continue
|
136 |
|
137 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
run_diffueraser.py
CHANGED
@@ -4,6 +4,7 @@ import time
|
|
4 |
import argparse
|
5 |
from diffueraser.diffueraser import DiffuEraser
|
6 |
from propainter.inference import Propainter, get_device
|
|
|
7 |
|
8 |
def main():
|
9 |
|
@@ -53,10 +54,10 @@ def main():
|
|
53 |
inference_time = end_time - start_time
|
54 |
print(f"DiffuEraser inference time: {inference_time:.4f} s")
|
55 |
|
56 |
-
|
57 |
|
58 |
if __name__ == '__main__':
|
59 |
main()
|
60 |
|
61 |
|
62 |
-
|
|
|
4 |
import argparse
|
5 |
from diffueraser.diffueraser import DiffuEraser
|
6 |
from propainter.inference import Propainter, get_device
|
7 |
+
import devicetorch
|
8 |
|
9 |
def main():
|
10 |
|
|
|
54 |
inference_time = end_time - start_time
|
55 |
print(f"DiffuEraser inference time: {inference_time:.4f} s")
|
56 |
|
57 |
+
devicetorch.empty_cache(torch)
|
58 |
|
59 |
if __name__ == '__main__':
|
60 |
main()
|
61 |
|
62 |
|
63 |
+
|