Spaces:
Running
Running
| """ | |
| A copy from https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/dataset.py | |
| """ | |
| import queue as Queue | |
| import threading | |
| import torch | |
| from torch.utils.data import DataLoader | |
| class BackgroundGenerator(threading.Thread): | |
| def __init__(self, generator, local_rank, max_prefetch=6): | |
| super(BackgroundGenerator, self).__init__() | |
| self.queue = Queue.Queue(max_prefetch) | |
| self.generator = generator | |
| self.local_rank = local_rank | |
| self.daemon = True | |
| self.start() | |
| def run(self): | |
| torch.cuda.set_device(self.local_rank) | |
| for item in self.generator: | |
| self.queue.put(item) | |
| self.queue.put(None) | |
| def next(self): | |
| next_item = self.queue.get() | |
| if next_item is None: | |
| raise StopIteration | |
| return next_item | |
| def __next__(self): | |
| return self.next() | |
| def __iter__(self): | |
| return self | |
| class DataLoaderX(DataLoader): | |
| def __init__(self, local_rank, **kwargs): | |
| super(DataLoaderX, self).__init__(**kwargs) | |
| self.stream = torch.cuda.Stream(local_rank) | |
| self.local_rank = local_rank | |
| def __iter__(self): | |
| self.iter = super(DataLoaderX, self).__iter__() | |
| self.iter = BackgroundGenerator(self.iter, self.local_rank) | |
| self.preload() | |
| return self | |
| def preload(self): | |
| self.batch = next(self.iter, None) | |
| if self.batch is None: | |
| return None | |
| with torch.cuda.stream(self.stream): | |
| for k in range(len(self.batch)): | |
| self.batch[k] = self.batch[k].to(device=self.local_rank, | |
| non_blocking=True) | |
| def __next__(self): | |
| torch.cuda.current_stream().wait_stream(self.stream) | |
| batch = self.batch | |
| if batch is None: | |
| raise StopIteration | |
| self.preload() | |
| return batch |