Spaces:
Running
on
Zero
Running
on
Zero
Upload 5 files
Browse files- trellis/utils/data_utils.py +226 -0
- trellis/utils/dist_utils.py +93 -0
- trellis/utils/elastic_utils.py +228 -0
- trellis/utils/grad_clip_utils.py +81 -0
- trellis/utils/loss_utils.py +92 -0
trellis/utils/data_utils.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
|
9 |
+
def recursive_to_device(
|
10 |
+
data: Any,
|
11 |
+
device: torch.device,
|
12 |
+
non_blocking: bool = False,
|
13 |
+
) -> Any:
|
14 |
+
"""
|
15 |
+
Recursively move all tensors in a data structure to a device.
|
16 |
+
"""
|
17 |
+
if hasattr(data, "to"):
|
18 |
+
return data.to(device, non_blocking=non_blocking)
|
19 |
+
elif isinstance(data, (list, tuple)):
|
20 |
+
return type(data)(recursive_to_device(d, device, non_blocking) for d in data)
|
21 |
+
elif isinstance(data, dict):
|
22 |
+
return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()}
|
23 |
+
else:
|
24 |
+
return data
|
25 |
+
|
26 |
+
|
27 |
+
def load_balanced_group_indices(
|
28 |
+
load: List[int],
|
29 |
+
num_groups: int,
|
30 |
+
equal_size: bool = False,
|
31 |
+
) -> List[List[int]]:
|
32 |
+
"""
|
33 |
+
Split indices into groups with balanced load.
|
34 |
+
"""
|
35 |
+
if equal_size:
|
36 |
+
group_size = len(load) // num_groups
|
37 |
+
indices = np.argsort(load)[::-1]
|
38 |
+
groups = [[] for _ in range(num_groups)]
|
39 |
+
group_load = np.zeros(num_groups)
|
40 |
+
for idx in indices:
|
41 |
+
min_group_idx = np.argmin(group_load)
|
42 |
+
groups[min_group_idx].append(idx)
|
43 |
+
if equal_size and len(groups[min_group_idx]) == group_size:
|
44 |
+
group_load[min_group_idx] = float('inf')
|
45 |
+
else:
|
46 |
+
group_load[min_group_idx] += load[idx]
|
47 |
+
return groups
|
48 |
+
|
49 |
+
|
50 |
+
def cycle(data_loader: DataLoader) -> Iterator:
|
51 |
+
while True:
|
52 |
+
for data in data_loader:
|
53 |
+
if isinstance(data_loader.sampler, ResumableSampler):
|
54 |
+
data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined]
|
55 |
+
yield data
|
56 |
+
if isinstance(data_loader.sampler, DistributedSampler):
|
57 |
+
data_loader.sampler.epoch += 1
|
58 |
+
if isinstance(data_loader.sampler, ResumableSampler):
|
59 |
+
data_loader.sampler.epoch += 1
|
60 |
+
data_loader.sampler.idx = 0
|
61 |
+
|
62 |
+
|
63 |
+
class ResumableSampler(Sampler):
|
64 |
+
"""
|
65 |
+
Distributed sampler that is resumable.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
dataset: Dataset used for sampling.
|
69 |
+
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
|
70 |
+
By default, :attr:`rank` is retrieved from the current distributed
|
71 |
+
group.
|
72 |
+
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
|
73 |
+
indices.
|
74 |
+
seed (int, optional): random seed used to shuffle the sampler if
|
75 |
+
:attr:`shuffle=True`. This number should be identical across all
|
76 |
+
processes in the distributed group. Default: ``0``.
|
77 |
+
drop_last (bool, optional): if ``True``, then the sampler will drop the
|
78 |
+
tail of the data to make it evenly divisible across the number of
|
79 |
+
replicas. If ``False``, the sampler will add extra indices to make
|
80 |
+
the data evenly divisible across the replicas. Default: ``False``.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
dataset: Dataset,
|
86 |
+
shuffle: bool = True,
|
87 |
+
seed: int = 0,
|
88 |
+
drop_last: bool = False,
|
89 |
+
) -> None:
|
90 |
+
self.dataset = dataset
|
91 |
+
self.epoch = 0
|
92 |
+
self.idx = 0
|
93 |
+
self.drop_last = drop_last
|
94 |
+
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
95 |
+
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
96 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
97 |
+
# is no need to drop any data, since the dataset will be split equally.
|
98 |
+
if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type]
|
99 |
+
# Split to nearest available length that is evenly divisible.
|
100 |
+
# This is to ensure each rank receives the same amount of data when
|
101 |
+
# using this Sampler.
|
102 |
+
self.num_samples = math.ceil(
|
103 |
+
(len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type]
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type]
|
107 |
+
self.total_size = self.num_samples * self.world_size
|
108 |
+
self.shuffle = shuffle
|
109 |
+
self.seed = seed
|
110 |
+
|
111 |
+
def __iter__(self) -> Iterator:
|
112 |
+
if self.shuffle:
|
113 |
+
# deterministically shuffle based on epoch and seed
|
114 |
+
g = torch.Generator()
|
115 |
+
g.manual_seed(self.seed + self.epoch)
|
116 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
117 |
+
else:
|
118 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
119 |
+
|
120 |
+
if not self.drop_last:
|
121 |
+
# add extra samples to make it evenly divisible
|
122 |
+
padding_size = self.total_size - len(indices)
|
123 |
+
if padding_size <= len(indices):
|
124 |
+
indices += indices[:padding_size]
|
125 |
+
else:
|
126 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[
|
127 |
+
:padding_size
|
128 |
+
]
|
129 |
+
else:
|
130 |
+
# remove tail of data to make it evenly divisible.
|
131 |
+
indices = indices[: self.total_size]
|
132 |
+
assert len(indices) == self.total_size
|
133 |
+
|
134 |
+
# subsample
|
135 |
+
indices = indices[self.rank : self.total_size : self.world_size]
|
136 |
+
|
137 |
+
# resume from previous state
|
138 |
+
indices = indices[self.idx:]
|
139 |
+
|
140 |
+
return iter(indices)
|
141 |
+
|
142 |
+
def __len__(self) -> int:
|
143 |
+
return self.num_samples
|
144 |
+
|
145 |
+
def state_dict(self) -> dict[str, int]:
|
146 |
+
return {
|
147 |
+
'epoch': self.epoch,
|
148 |
+
'idx': self.idx,
|
149 |
+
}
|
150 |
+
|
151 |
+
def load_state_dict(self, state_dict):
|
152 |
+
self.epoch = state_dict['epoch']
|
153 |
+
self.idx = state_dict['idx']
|
154 |
+
|
155 |
+
|
156 |
+
class BalancedResumableSampler(ResumableSampler):
|
157 |
+
"""
|
158 |
+
Distributed sampler that is resumable and balances the load among the processes.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
dataset: Dataset used for sampling.
|
162 |
+
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
|
163 |
+
By default, :attr:`rank` is retrieved from the current distributed
|
164 |
+
group.
|
165 |
+
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
|
166 |
+
indices.
|
167 |
+
seed (int, optional): random seed used to shuffle the sampler if
|
168 |
+
:attr:`shuffle=True`. This number should be identical across all
|
169 |
+
processes in the distributed group. Default: ``0``.
|
170 |
+
drop_last (bool, optional): if ``True``, then the sampler will drop the
|
171 |
+
tail of the data to make it evenly divisible across the number of
|
172 |
+
replicas. If ``False``, the sampler will add extra indices to make
|
173 |
+
the data evenly divisible across the replicas. Default: ``False``.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
dataset: Dataset,
|
179 |
+
shuffle: bool = True,
|
180 |
+
seed: int = 0,
|
181 |
+
drop_last: bool = False,
|
182 |
+
batch_size: int = 1,
|
183 |
+
) -> None:
|
184 |
+
assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler'
|
185 |
+
super().__init__(dataset, shuffle, seed, drop_last)
|
186 |
+
self.batch_size = batch_size
|
187 |
+
self.loads = dataset.loads
|
188 |
+
|
189 |
+
def __iter__(self) -> Iterator:
|
190 |
+
if self.shuffle:
|
191 |
+
# deterministically shuffle based on epoch and seed
|
192 |
+
g = torch.Generator()
|
193 |
+
g.manual_seed(self.seed + self.epoch)
|
194 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
195 |
+
else:
|
196 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
197 |
+
|
198 |
+
if not self.drop_last:
|
199 |
+
# add extra samples to make it evenly divisible
|
200 |
+
padding_size = self.total_size - len(indices)
|
201 |
+
if padding_size <= len(indices):
|
202 |
+
indices += indices[:padding_size]
|
203 |
+
else:
|
204 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[
|
205 |
+
:padding_size
|
206 |
+
]
|
207 |
+
else:
|
208 |
+
# remove tail of data to make it evenly divisible.
|
209 |
+
indices = indices[: self.total_size]
|
210 |
+
assert len(indices) == self.total_size
|
211 |
+
|
212 |
+
# balance load among processes
|
213 |
+
num_batches = len(indices) // (self.batch_size * self.world_size)
|
214 |
+
balanced_indices = []
|
215 |
+
for i in range(num_batches):
|
216 |
+
start_idx = i * self.batch_size * self.world_size
|
217 |
+
end_idx = (i + 1) * self.batch_size * self.world_size
|
218 |
+
batch_indices = indices[start_idx:end_idx]
|
219 |
+
batch_loads = [self.loads[idx] for idx in batch_indices]
|
220 |
+
groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True)
|
221 |
+
balanced_indices.extend([batch_indices[j] for j in groups[self.rank]])
|
222 |
+
|
223 |
+
# resume from previous state
|
224 |
+
indices = balanced_indices[self.idx:]
|
225 |
+
|
226 |
+
return iter(indices)
|
trellis/utils/dist_utils.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
from contextlib import contextmanager
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
7 |
+
|
8 |
+
|
9 |
+
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
|
10 |
+
os.environ['MASTER_ADDR'] = master_addr
|
11 |
+
os.environ['MASTER_PORT'] = master_port
|
12 |
+
os.environ['WORLD_SIZE'] = str(world_size)
|
13 |
+
os.environ['RANK'] = str(rank)
|
14 |
+
os.environ['LOCAL_RANK'] = str(local_rank)
|
15 |
+
torch.cuda.set_device(local_rank)
|
16 |
+
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
17 |
+
|
18 |
+
|
19 |
+
def read_file_dist(path):
|
20 |
+
"""
|
21 |
+
Read the binary file distributedly.
|
22 |
+
File is only read once by the rank 0 process and broadcasted to other processes.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
data (io.BytesIO): The binary data read from the file.
|
26 |
+
"""
|
27 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
28 |
+
# read file
|
29 |
+
size = torch.LongTensor(1).cuda()
|
30 |
+
if dist.get_rank() == 0:
|
31 |
+
with open(path, 'rb') as f:
|
32 |
+
data = f.read()
|
33 |
+
data = torch.ByteTensor(
|
34 |
+
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
|
35 |
+
).cuda()
|
36 |
+
size[0] = data.shape[0]
|
37 |
+
# broadcast size
|
38 |
+
dist.broadcast(size, src=0)
|
39 |
+
if dist.get_rank() != 0:
|
40 |
+
data = torch.ByteTensor(size[0].item()).cuda()
|
41 |
+
# broadcast data
|
42 |
+
dist.broadcast(data, src=0)
|
43 |
+
# convert to io.BytesIO
|
44 |
+
data = data.cpu().numpy().tobytes()
|
45 |
+
data = io.BytesIO(data)
|
46 |
+
return data
|
47 |
+
else:
|
48 |
+
with open(path, 'rb') as f:
|
49 |
+
data = f.read()
|
50 |
+
data = io.BytesIO(data)
|
51 |
+
return data
|
52 |
+
|
53 |
+
|
54 |
+
def unwrap_dist(model):
|
55 |
+
"""
|
56 |
+
Unwrap the model from distributed training.
|
57 |
+
"""
|
58 |
+
if isinstance(model, DDP):
|
59 |
+
return model.module
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
@contextmanager
|
64 |
+
def master_first():
|
65 |
+
"""
|
66 |
+
A context manager that ensures master process executes first.
|
67 |
+
"""
|
68 |
+
if not dist.is_initialized():
|
69 |
+
yield
|
70 |
+
else:
|
71 |
+
if dist.get_rank() == 0:
|
72 |
+
yield
|
73 |
+
dist.barrier()
|
74 |
+
else:
|
75 |
+
dist.barrier()
|
76 |
+
yield
|
77 |
+
|
78 |
+
|
79 |
+
@contextmanager
|
80 |
+
def local_master_first():
|
81 |
+
"""
|
82 |
+
A context manager that ensures local master process executes first.
|
83 |
+
"""
|
84 |
+
if not dist.is_initialized():
|
85 |
+
yield
|
86 |
+
else:
|
87 |
+
if dist.get_rank() % torch.cuda.device_count() == 0:
|
88 |
+
yield
|
89 |
+
dist.barrier()
|
90 |
+
else:
|
91 |
+
dist.barrier()
|
92 |
+
yield
|
93 |
+
|
trellis/utils/elastic_utils.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from typing import Tuple
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
class MemoryController:
|
10 |
+
"""
|
11 |
+
Base class for memory management during training.
|
12 |
+
"""
|
13 |
+
|
14 |
+
_last_input_size = None
|
15 |
+
_last_mem_ratio = []
|
16 |
+
|
17 |
+
@contextmanager
|
18 |
+
def record(self):
|
19 |
+
pass
|
20 |
+
|
21 |
+
def update_run_states(self, input_size=None, mem_ratio=None):
|
22 |
+
if self._last_input_size is None:
|
23 |
+
self._last_input_size = input_size
|
24 |
+
elif self._last_input_size!= input_size:
|
25 |
+
raise ValueError(f'Input size should not change for different ElasticModules.')
|
26 |
+
self._last_mem_ratio.append(mem_ratio)
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def get_mem_ratio(self, input_size):
|
30 |
+
pass
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def state_dict(self):
|
34 |
+
pass
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def log(self):
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
class LinearMemoryController(MemoryController):
|
42 |
+
"""
|
43 |
+
A simple controller for memory management during training.
|
44 |
+
The memory usage is modeled as a linear function of:
|
45 |
+
- the number of input parameters
|
46 |
+
- the ratio of memory the model use compared to the maximum usage (with no checkpointing)
|
47 |
+
memory_usage = k * input_size * mem_ratio + b
|
48 |
+
The controller keeps track of the memory usage and gives the
|
49 |
+
expected memory ratio to keep the memory usage under a target
|
50 |
+
"""
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
buffer_size=1000,
|
54 |
+
update_every=500,
|
55 |
+
target_ratio=0.8,
|
56 |
+
available_memory=None,
|
57 |
+
max_mem_ratio_start=0.1,
|
58 |
+
params=None,
|
59 |
+
device=None
|
60 |
+
):
|
61 |
+
self.buffer_size = buffer_size
|
62 |
+
self.update_every = update_every
|
63 |
+
self.target_ratio = target_ratio
|
64 |
+
self.device = device or torch.cuda.current_device()
|
65 |
+
self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
|
66 |
+
|
67 |
+
self._memory = np.zeros(buffer_size, dtype=np.float32)
|
68 |
+
self._input_size = np.zeros(buffer_size, dtype=np.float32)
|
69 |
+
self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
|
70 |
+
self._buffer_ptr = 0
|
71 |
+
self._buffer_length = 0
|
72 |
+
self._params = tuple(params) if params is not None else (0.0, 0.0)
|
73 |
+
self._max_mem_ratio = max_mem_ratio_start
|
74 |
+
self.step = 0
|
75 |
+
|
76 |
+
def __repr__(self):
|
77 |
+
return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
|
78 |
+
|
79 |
+
def _add_sample(self, memory, input_size, mem_ratio):
|
80 |
+
self._memory[self._buffer_ptr] = memory
|
81 |
+
self._input_size[self._buffer_ptr] = input_size
|
82 |
+
self._mem_ratio[self._buffer_ptr] = mem_ratio
|
83 |
+
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
84 |
+
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
85 |
+
|
86 |
+
@contextmanager
|
87 |
+
def record(self):
|
88 |
+
torch.cuda.reset_peak_memory_stats(self.device)
|
89 |
+
self._last_input_size = None
|
90 |
+
self._last_mem_ratio = []
|
91 |
+
yield
|
92 |
+
self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
|
93 |
+
self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
|
94 |
+
self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
|
95 |
+
self.step += 1
|
96 |
+
if self.step % self.update_every == 0:
|
97 |
+
self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
|
98 |
+
self._fit_params()
|
99 |
+
|
100 |
+
def _fit_params(self):
|
101 |
+
memory_usage = self._memory[:self._buffer_length]
|
102 |
+
input_size = self._input_size[:self._buffer_length]
|
103 |
+
mem_ratio = self._mem_ratio[:self._buffer_length]
|
104 |
+
|
105 |
+
x = input_size * mem_ratio
|
106 |
+
y = memory_usage
|
107 |
+
k, b = np.polyfit(x, y, 1)
|
108 |
+
self._params = (k, b)
|
109 |
+
# self._visualize()
|
110 |
+
|
111 |
+
def _visualize(self):
|
112 |
+
import matplotlib.pyplot as plt
|
113 |
+
memory_usage = self._memory[:self._buffer_length]
|
114 |
+
input_size = self._input_size[:self._buffer_length]
|
115 |
+
mem_ratio = self._mem_ratio[:self._buffer_length]
|
116 |
+
k, b = self._params
|
117 |
+
|
118 |
+
plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
|
119 |
+
x = np.array([0.0, 20000.0])
|
120 |
+
plt.plot(x, k * x + b, c='r')
|
121 |
+
plt.savefig(f'linear_memory_controller_{self.step}.png')
|
122 |
+
plt.cla()
|
123 |
+
|
124 |
+
def get_mem_ratio(self, input_size):
|
125 |
+
k, b = self._params
|
126 |
+
if k == 0: return np.random.rand() * self._max_mem_ratio
|
127 |
+
pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
|
128 |
+
return min(self._max_mem_ratio, max(0.0, pred))
|
129 |
+
|
130 |
+
def state_dict(self):
|
131 |
+
return {
|
132 |
+
'params': self._params,
|
133 |
+
}
|
134 |
+
|
135 |
+
def load_state_dict(self, state_dict):
|
136 |
+
self._params = tuple(state_dict['params'])
|
137 |
+
|
138 |
+
def log(self):
|
139 |
+
return {
|
140 |
+
'params/k': self._params[0],
|
141 |
+
'params/b': self._params[1],
|
142 |
+
'memory': self._last_memory,
|
143 |
+
'input_size': self._last_input_size,
|
144 |
+
'mem_ratio': self._last_mem_ratio,
|
145 |
+
}
|
146 |
+
|
147 |
+
|
148 |
+
class ElasticModule(nn.Module):
|
149 |
+
"""
|
150 |
+
Module for training with elastic memory management.
|
151 |
+
"""
|
152 |
+
def __init__(self):
|
153 |
+
super().__init__()
|
154 |
+
self._memory_controller: MemoryController = None
|
155 |
+
|
156 |
+
@abstractmethod
|
157 |
+
def _get_input_size(self, *args, **kwargs) -> int:
|
158 |
+
"""
|
159 |
+
Get the size of the input data.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
int: The size of the input data.
|
163 |
+
"""
|
164 |
+
pass
|
165 |
+
|
166 |
+
@abstractmethod
|
167 |
+
def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
|
168 |
+
"""
|
169 |
+
Forward with a given memory ratio.
|
170 |
+
"""
|
171 |
+
pass
|
172 |
+
|
173 |
+
def register_memory_controller(self, memory_controller: MemoryController):
|
174 |
+
self._memory_controller = memory_controller
|
175 |
+
|
176 |
+
def forward(self, *args, **kwargs):
|
177 |
+
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
|
178 |
+
_, ret = self._forward_with_mem_ratio(*args, **kwargs)
|
179 |
+
else:
|
180 |
+
input_size = self._get_input_size(*args, **kwargs)
|
181 |
+
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
|
182 |
+
mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
|
183 |
+
self._memory_controller.update_run_states(input_size, mem_ratio)
|
184 |
+
return ret
|
185 |
+
|
186 |
+
|
187 |
+
class ElasticModuleMixin:
|
188 |
+
"""
|
189 |
+
Mixin for training with elastic memory management.
|
190 |
+
"""
|
191 |
+
def __init__(self, *args, **kwargs):
|
192 |
+
super().__init__(*args, **kwargs)
|
193 |
+
self._memory_controller: MemoryController = None
|
194 |
+
|
195 |
+
@abstractmethod
|
196 |
+
def _get_input_size(self, *args, **kwargs) -> int:
|
197 |
+
"""
|
198 |
+
Get the size of the input data.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
int: The size of the input data.
|
202 |
+
"""
|
203 |
+
pass
|
204 |
+
|
205 |
+
@abstractmethod
|
206 |
+
@contextmanager
|
207 |
+
def with_mem_ratio(self, mem_ratio=1.0) -> float:
|
208 |
+
"""
|
209 |
+
Context manager for training with a reduced memory ratio compared to the full memory usage.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
float: The exact memory ratio used during the forward pass.
|
213 |
+
"""
|
214 |
+
pass
|
215 |
+
|
216 |
+
def register_memory_controller(self, memory_controller: MemoryController):
|
217 |
+
self._memory_controller = memory_controller
|
218 |
+
|
219 |
+
def forward(self, *args, **kwargs):
|
220 |
+
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
|
221 |
+
ret = super().forward(*args, **kwargs)
|
222 |
+
else:
|
223 |
+
input_size = self._get_input_size(*args, **kwargs)
|
224 |
+
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
|
225 |
+
with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
|
226 |
+
ret = super().forward(*args, **kwargs)
|
227 |
+
self._memory_controller.update_run_states(input_size, exact_mem_ratio)
|
228 |
+
return ret
|
trellis/utils/grad_clip_utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.utils
|
5 |
+
|
6 |
+
|
7 |
+
class AdaptiveGradClipper:
|
8 |
+
"""
|
9 |
+
Adaptive gradient clipping for training.
|
10 |
+
"""
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
max_norm=None,
|
14 |
+
clip_percentile=95.0,
|
15 |
+
buffer_size=1000,
|
16 |
+
):
|
17 |
+
self.max_norm = max_norm
|
18 |
+
self.clip_percentile = clip_percentile
|
19 |
+
self.buffer_size = buffer_size
|
20 |
+
|
21 |
+
self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
|
22 |
+
self._max_norm = max_norm
|
23 |
+
self._buffer_ptr = 0
|
24 |
+
self._buffer_length = 0
|
25 |
+
|
26 |
+
def __repr__(self):
|
27 |
+
return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
|
28 |
+
|
29 |
+
def state_dict(self):
|
30 |
+
return {
|
31 |
+
'grad_norm': self._grad_norm,
|
32 |
+
'max_norm': self._max_norm,
|
33 |
+
'buffer_ptr': self._buffer_ptr,
|
34 |
+
'buffer_length': self._buffer_length,
|
35 |
+
}
|
36 |
+
|
37 |
+
def load_state_dict(self, state_dict):
|
38 |
+
self._grad_norm = state_dict['grad_norm']
|
39 |
+
self._max_norm = state_dict['max_norm']
|
40 |
+
self._buffer_ptr = state_dict['buffer_ptr']
|
41 |
+
self._buffer_length = state_dict['buffer_length']
|
42 |
+
|
43 |
+
def log(self):
|
44 |
+
return {
|
45 |
+
'max_norm': self._max_norm,
|
46 |
+
}
|
47 |
+
|
48 |
+
def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
|
49 |
+
"""Clip the gradient norm of an iterable of parameters.
|
50 |
+
|
51 |
+
The norm is computed over all gradients together, as if they were
|
52 |
+
concatenated into a single vector. Gradients are modified in-place.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
56 |
+
single Tensor that will have gradients normalized
|
57 |
+
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
58 |
+
infinity norm.
|
59 |
+
error_if_nonfinite (bool): if True, an error is thrown if the total
|
60 |
+
norm of the gradients from :attr:`parameters` is ``nan``,
|
61 |
+
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
62 |
+
foreach (bool): use the faster foreach-based implementation.
|
63 |
+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
64 |
+
fall back to the slow implementation for other device types.
|
65 |
+
Default: ``None``
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Total norm of the parameter gradients (viewed as a single vector).
|
69 |
+
"""
|
70 |
+
max_norm = self._max_norm if self._max_norm is not None else float('inf')
|
71 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
|
72 |
+
|
73 |
+
if torch.isfinite(grad_norm):
|
74 |
+
self._grad_norm[self._buffer_ptr] = grad_norm
|
75 |
+
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
76 |
+
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
77 |
+
if self._buffer_length == self.buffer_size:
|
78 |
+
self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
|
79 |
+
self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
|
80 |
+
|
81 |
+
return grad_norm
|
trellis/utils/loss_utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from math import exp
|
5 |
+
from lpips import LPIPS
|
6 |
+
|
7 |
+
|
8 |
+
def smooth_l1_loss(pred, target, beta=1.0):
|
9 |
+
diff = torch.abs(pred - target)
|
10 |
+
loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
|
11 |
+
return loss.mean()
|
12 |
+
|
13 |
+
|
14 |
+
def l1_loss(network_output, gt):
|
15 |
+
return torch.abs((network_output - gt)).mean()
|
16 |
+
|
17 |
+
|
18 |
+
def l2_loss(network_output, gt):
|
19 |
+
return ((network_output - gt) ** 2).mean()
|
20 |
+
|
21 |
+
|
22 |
+
def gaussian(window_size, sigma):
|
23 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
24 |
+
return gauss / gauss.sum()
|
25 |
+
|
26 |
+
|
27 |
+
def create_window(window_size, channel):
|
28 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
29 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
30 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
31 |
+
return window
|
32 |
+
|
33 |
+
|
34 |
+
def psnr(img1, img2, max_val=1.0):
|
35 |
+
mse = F.mse_loss(img1, img2)
|
36 |
+
return 20 * torch.log10(max_val / torch.sqrt(mse))
|
37 |
+
|
38 |
+
|
39 |
+
def ssim(img1, img2, window_size=11, size_average=True):
|
40 |
+
channel = img1.size(-3)
|
41 |
+
window = create_window(window_size, channel)
|
42 |
+
|
43 |
+
if img1.is_cuda:
|
44 |
+
window = window.cuda(img1.get_device())
|
45 |
+
window = window.type_as(img1)
|
46 |
+
|
47 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
48 |
+
|
49 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
50 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
51 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
52 |
+
|
53 |
+
mu1_sq = mu1.pow(2)
|
54 |
+
mu2_sq = mu2.pow(2)
|
55 |
+
mu1_mu2 = mu1 * mu2
|
56 |
+
|
57 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
58 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
59 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
60 |
+
|
61 |
+
C1 = 0.01 ** 2
|
62 |
+
C2 = 0.03 ** 2
|
63 |
+
|
64 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
65 |
+
|
66 |
+
if size_average:
|
67 |
+
return ssim_map.mean()
|
68 |
+
else:
|
69 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
70 |
+
|
71 |
+
|
72 |
+
loss_fn_vgg = None
|
73 |
+
def lpips(img1, img2, value_range=(0, 1)):
|
74 |
+
global loss_fn_vgg
|
75 |
+
if loss_fn_vgg is None:
|
76 |
+
loss_fn_vgg = LPIPS(net='vgg').cuda().eval()
|
77 |
+
# normalize to [-1, 1]
|
78 |
+
img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
|
79 |
+
img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
|
80 |
+
return loss_fn_vgg(img1, img2).mean()
|
81 |
+
|
82 |
+
|
83 |
+
def normal_angle(pred, gt):
|
84 |
+
pred = pred * 2.0 - 1.0
|
85 |
+
gt = gt * 2.0 - 1.0
|
86 |
+
norms = pred.norm(dim=-1) * gt.norm(dim=-1)
|
87 |
+
cos_sim = (pred * gt).sum(-1) / (norms + 1e-9)
|
88 |
+
cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
|
89 |
+
ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean()
|
90 |
+
if ang.isnan():
|
91 |
+
return -1
|
92 |
+
return ang
|