Framepack-H111 / utils /device_utils.py
rahul7star's picture
Upload 303 files
e0336bc verified
raw
history blame contribute delete
472 Bytes
import torch
def clean_memory_on_device(device):
if device.type == "cuda":
torch.cuda.empty_cache()
elif device.type == "cpu":
pass
elif device.type == "mps": # not tested
torch.mps.empty_cache()
def synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()