cavargas10 commited on
Commit
d3296f5
·
verified ·
1 Parent(s): 69b99a0

Upload 5 files

Browse files
trellis/utils/data_utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler
6
+ import torch.distributed as dist
7
+
8
+
9
+ def recursive_to_device(
10
+ data: Any,
11
+ device: torch.device,
12
+ non_blocking: bool = False,
13
+ ) -> Any:
14
+ """
15
+ Recursively move all tensors in a data structure to a device.
16
+ """
17
+ if hasattr(data, "to"):
18
+ return data.to(device, non_blocking=non_blocking)
19
+ elif isinstance(data, (list, tuple)):
20
+ return type(data)(recursive_to_device(d, device, non_blocking) for d in data)
21
+ elif isinstance(data, dict):
22
+ return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()}
23
+ else:
24
+ return data
25
+
26
+
27
+ def load_balanced_group_indices(
28
+ load: List[int],
29
+ num_groups: int,
30
+ equal_size: bool = False,
31
+ ) -> List[List[int]]:
32
+ """
33
+ Split indices into groups with balanced load.
34
+ """
35
+ if equal_size:
36
+ group_size = len(load) // num_groups
37
+ indices = np.argsort(load)[::-1]
38
+ groups = [[] for _ in range(num_groups)]
39
+ group_load = np.zeros(num_groups)
40
+ for idx in indices:
41
+ min_group_idx = np.argmin(group_load)
42
+ groups[min_group_idx].append(idx)
43
+ if equal_size and len(groups[min_group_idx]) == group_size:
44
+ group_load[min_group_idx] = float('inf')
45
+ else:
46
+ group_load[min_group_idx] += load[idx]
47
+ return groups
48
+
49
+
50
+ def cycle(data_loader: DataLoader) -> Iterator:
51
+ while True:
52
+ for data in data_loader:
53
+ if isinstance(data_loader.sampler, ResumableSampler):
54
+ data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined]
55
+ yield data
56
+ if isinstance(data_loader.sampler, DistributedSampler):
57
+ data_loader.sampler.epoch += 1
58
+ if isinstance(data_loader.sampler, ResumableSampler):
59
+ data_loader.sampler.epoch += 1
60
+ data_loader.sampler.idx = 0
61
+
62
+
63
+ class ResumableSampler(Sampler):
64
+ """
65
+ Distributed sampler that is resumable.
66
+
67
+ Args:
68
+ dataset: Dataset used for sampling.
69
+ rank (int, optional): Rank of the current process within :attr:`num_replicas`.
70
+ By default, :attr:`rank` is retrieved from the current distributed
71
+ group.
72
+ shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
73
+ indices.
74
+ seed (int, optional): random seed used to shuffle the sampler if
75
+ :attr:`shuffle=True`. This number should be identical across all
76
+ processes in the distributed group. Default: ``0``.
77
+ drop_last (bool, optional): if ``True``, then the sampler will drop the
78
+ tail of the data to make it evenly divisible across the number of
79
+ replicas. If ``False``, the sampler will add extra indices to make
80
+ the data evenly divisible across the replicas. Default: ``False``.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ dataset: Dataset,
86
+ shuffle: bool = True,
87
+ seed: int = 0,
88
+ drop_last: bool = False,
89
+ ) -> None:
90
+ self.dataset = dataset
91
+ self.epoch = 0
92
+ self.idx = 0
93
+ self.drop_last = drop_last
94
+ self.world_size = dist.get_world_size() if dist.is_initialized() else 1
95
+ self.rank = dist.get_rank() if dist.is_initialized() else 0
96
+ # If the dataset length is evenly divisible by # of replicas, then there
97
+ # is no need to drop any data, since the dataset will be split equally.
98
+ if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type]
99
+ # Split to nearest available length that is evenly divisible.
100
+ # This is to ensure each rank receives the same amount of data when
101
+ # using this Sampler.
102
+ self.num_samples = math.ceil(
103
+ (len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type]
104
+ )
105
+ else:
106
+ self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type]
107
+ self.total_size = self.num_samples * self.world_size
108
+ self.shuffle = shuffle
109
+ self.seed = seed
110
+
111
+ def __iter__(self) -> Iterator:
112
+ if self.shuffle:
113
+ # deterministically shuffle based on epoch and seed
114
+ g = torch.Generator()
115
+ g.manual_seed(self.seed + self.epoch)
116
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
117
+ else:
118
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
119
+
120
+ if not self.drop_last:
121
+ # add extra samples to make it evenly divisible
122
+ padding_size = self.total_size - len(indices)
123
+ if padding_size <= len(indices):
124
+ indices += indices[:padding_size]
125
+ else:
126
+ indices += (indices * math.ceil(padding_size / len(indices)))[
127
+ :padding_size
128
+ ]
129
+ else:
130
+ # remove tail of data to make it evenly divisible.
131
+ indices = indices[: self.total_size]
132
+ assert len(indices) == self.total_size
133
+
134
+ # subsample
135
+ indices = indices[self.rank : self.total_size : self.world_size]
136
+
137
+ # resume from previous state
138
+ indices = indices[self.idx:]
139
+
140
+ return iter(indices)
141
+
142
+ def __len__(self) -> int:
143
+ return self.num_samples
144
+
145
+ def state_dict(self) -> dict[str, int]:
146
+ return {
147
+ 'epoch': self.epoch,
148
+ 'idx': self.idx,
149
+ }
150
+
151
+ def load_state_dict(self, state_dict):
152
+ self.epoch = state_dict['epoch']
153
+ self.idx = state_dict['idx']
154
+
155
+
156
+ class BalancedResumableSampler(ResumableSampler):
157
+ """
158
+ Distributed sampler that is resumable and balances the load among the processes.
159
+
160
+ Args:
161
+ dataset: Dataset used for sampling.
162
+ rank (int, optional): Rank of the current process within :attr:`num_replicas`.
163
+ By default, :attr:`rank` is retrieved from the current distributed
164
+ group.
165
+ shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
166
+ indices.
167
+ seed (int, optional): random seed used to shuffle the sampler if
168
+ :attr:`shuffle=True`. This number should be identical across all
169
+ processes in the distributed group. Default: ``0``.
170
+ drop_last (bool, optional): if ``True``, then the sampler will drop the
171
+ tail of the data to make it evenly divisible across the number of
172
+ replicas. If ``False``, the sampler will add extra indices to make
173
+ the data evenly divisible across the replicas. Default: ``False``.
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ dataset: Dataset,
179
+ shuffle: bool = True,
180
+ seed: int = 0,
181
+ drop_last: bool = False,
182
+ batch_size: int = 1,
183
+ ) -> None:
184
+ assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler'
185
+ super().__init__(dataset, shuffle, seed, drop_last)
186
+ self.batch_size = batch_size
187
+ self.loads = dataset.loads
188
+
189
+ def __iter__(self) -> Iterator:
190
+ if self.shuffle:
191
+ # deterministically shuffle based on epoch and seed
192
+ g = torch.Generator()
193
+ g.manual_seed(self.seed + self.epoch)
194
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
195
+ else:
196
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
197
+
198
+ if not self.drop_last:
199
+ # add extra samples to make it evenly divisible
200
+ padding_size = self.total_size - len(indices)
201
+ if padding_size <= len(indices):
202
+ indices += indices[:padding_size]
203
+ else:
204
+ indices += (indices * math.ceil(padding_size / len(indices)))[
205
+ :padding_size
206
+ ]
207
+ else:
208
+ # remove tail of data to make it evenly divisible.
209
+ indices = indices[: self.total_size]
210
+ assert len(indices) == self.total_size
211
+
212
+ # balance load among processes
213
+ num_batches = len(indices) // (self.batch_size * self.world_size)
214
+ balanced_indices = []
215
+ for i in range(num_batches):
216
+ start_idx = i * self.batch_size * self.world_size
217
+ end_idx = (i + 1) * self.batch_size * self.world_size
218
+ batch_indices = indices[start_idx:end_idx]
219
+ batch_loads = [self.loads[idx] for idx in batch_indices]
220
+ groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True)
221
+ balanced_indices.extend([batch_indices[j] for j in groups[self.rank]])
222
+
223
+ # resume from previous state
224
+ indices = balanced_indices[self.idx:]
225
+
226
+ return iter(indices)
trellis/utils/dist_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ from contextlib import contextmanager
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+
8
+
9
+ def setup_dist(rank, local_rank, world_size, master_addr, master_port):
10
+ os.environ['MASTER_ADDR'] = master_addr
11
+ os.environ['MASTER_PORT'] = master_port
12
+ os.environ['WORLD_SIZE'] = str(world_size)
13
+ os.environ['RANK'] = str(rank)
14
+ os.environ['LOCAL_RANK'] = str(local_rank)
15
+ torch.cuda.set_device(local_rank)
16
+ dist.init_process_group('nccl', rank=rank, world_size=world_size)
17
+
18
+
19
+ def read_file_dist(path):
20
+ """
21
+ Read the binary file distributedly.
22
+ File is only read once by the rank 0 process and broadcasted to other processes.
23
+
24
+ Returns:
25
+ data (io.BytesIO): The binary data read from the file.
26
+ """
27
+ if dist.is_initialized() and dist.get_world_size() > 1:
28
+ # read file
29
+ size = torch.LongTensor(1).cuda()
30
+ if dist.get_rank() == 0:
31
+ with open(path, 'rb') as f:
32
+ data = f.read()
33
+ data = torch.ByteTensor(
34
+ torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
35
+ ).cuda()
36
+ size[0] = data.shape[0]
37
+ # broadcast size
38
+ dist.broadcast(size, src=0)
39
+ if dist.get_rank() != 0:
40
+ data = torch.ByteTensor(size[0].item()).cuda()
41
+ # broadcast data
42
+ dist.broadcast(data, src=0)
43
+ # convert to io.BytesIO
44
+ data = data.cpu().numpy().tobytes()
45
+ data = io.BytesIO(data)
46
+ return data
47
+ else:
48
+ with open(path, 'rb') as f:
49
+ data = f.read()
50
+ data = io.BytesIO(data)
51
+ return data
52
+
53
+
54
+ def unwrap_dist(model):
55
+ """
56
+ Unwrap the model from distributed training.
57
+ """
58
+ if isinstance(model, DDP):
59
+ return model.module
60
+ return model
61
+
62
+
63
+ @contextmanager
64
+ def master_first():
65
+ """
66
+ A context manager that ensures master process executes first.
67
+ """
68
+ if not dist.is_initialized():
69
+ yield
70
+ else:
71
+ if dist.get_rank() == 0:
72
+ yield
73
+ dist.barrier()
74
+ else:
75
+ dist.barrier()
76
+ yield
77
+
78
+
79
+ @contextmanager
80
+ def local_master_first():
81
+ """
82
+ A context manager that ensures local master process executes first.
83
+ """
84
+ if not dist.is_initialized():
85
+ yield
86
+ else:
87
+ if dist.get_rank() % torch.cuda.device_count() == 0:
88
+ yield
89
+ dist.barrier()
90
+ else:
91
+ dist.barrier()
92
+ yield
93
+
trellis/utils/elastic_utils.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from contextlib import contextmanager
3
+ from typing import Tuple
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+
9
+ class MemoryController:
10
+ """
11
+ Base class for memory management during training.
12
+ """
13
+
14
+ _last_input_size = None
15
+ _last_mem_ratio = []
16
+
17
+ @contextmanager
18
+ def record(self):
19
+ pass
20
+
21
+ def update_run_states(self, input_size=None, mem_ratio=None):
22
+ if self._last_input_size is None:
23
+ self._last_input_size = input_size
24
+ elif self._last_input_size!= input_size:
25
+ raise ValueError(f'Input size should not change for different ElasticModules.')
26
+ self._last_mem_ratio.append(mem_ratio)
27
+
28
+ @abstractmethod
29
+ def get_mem_ratio(self, input_size):
30
+ pass
31
+
32
+ @abstractmethod
33
+ def state_dict(self):
34
+ pass
35
+
36
+ @abstractmethod
37
+ def log(self):
38
+ pass
39
+
40
+
41
+ class LinearMemoryController(MemoryController):
42
+ """
43
+ A simple controller for memory management during training.
44
+ The memory usage is modeled as a linear function of:
45
+ - the number of input parameters
46
+ - the ratio of memory the model use compared to the maximum usage (with no checkpointing)
47
+ memory_usage = k * input_size * mem_ratio + b
48
+ The controller keeps track of the memory usage and gives the
49
+ expected memory ratio to keep the memory usage under a target
50
+ """
51
+ def __init__(
52
+ self,
53
+ buffer_size=1000,
54
+ update_every=500,
55
+ target_ratio=0.8,
56
+ available_memory=None,
57
+ max_mem_ratio_start=0.1,
58
+ params=None,
59
+ device=None
60
+ ):
61
+ self.buffer_size = buffer_size
62
+ self.update_every = update_every
63
+ self.target_ratio = target_ratio
64
+ self.device = device or torch.cuda.current_device()
65
+ self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
66
+
67
+ self._memory = np.zeros(buffer_size, dtype=np.float32)
68
+ self._input_size = np.zeros(buffer_size, dtype=np.float32)
69
+ self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
70
+ self._buffer_ptr = 0
71
+ self._buffer_length = 0
72
+ self._params = tuple(params) if params is not None else (0.0, 0.0)
73
+ self._max_mem_ratio = max_mem_ratio_start
74
+ self.step = 0
75
+
76
+ def __repr__(self):
77
+ return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
78
+
79
+ def _add_sample(self, memory, input_size, mem_ratio):
80
+ self._memory[self._buffer_ptr] = memory
81
+ self._input_size[self._buffer_ptr] = input_size
82
+ self._mem_ratio[self._buffer_ptr] = mem_ratio
83
+ self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
84
+ self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
85
+
86
+ @contextmanager
87
+ def record(self):
88
+ torch.cuda.reset_peak_memory_stats(self.device)
89
+ self._last_input_size = None
90
+ self._last_mem_ratio = []
91
+ yield
92
+ self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
93
+ self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
94
+ self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
95
+ self.step += 1
96
+ if self.step % self.update_every == 0:
97
+ self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
98
+ self._fit_params()
99
+
100
+ def _fit_params(self):
101
+ memory_usage = self._memory[:self._buffer_length]
102
+ input_size = self._input_size[:self._buffer_length]
103
+ mem_ratio = self._mem_ratio[:self._buffer_length]
104
+
105
+ x = input_size * mem_ratio
106
+ y = memory_usage
107
+ k, b = np.polyfit(x, y, 1)
108
+ self._params = (k, b)
109
+ # self._visualize()
110
+
111
+ def _visualize(self):
112
+ import matplotlib.pyplot as plt
113
+ memory_usage = self._memory[:self._buffer_length]
114
+ input_size = self._input_size[:self._buffer_length]
115
+ mem_ratio = self._mem_ratio[:self._buffer_length]
116
+ k, b = self._params
117
+
118
+ plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
119
+ x = np.array([0.0, 20000.0])
120
+ plt.plot(x, k * x + b, c='r')
121
+ plt.savefig(f'linear_memory_controller_{self.step}.png')
122
+ plt.cla()
123
+
124
+ def get_mem_ratio(self, input_size):
125
+ k, b = self._params
126
+ if k == 0: return np.random.rand() * self._max_mem_ratio
127
+ pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
128
+ return min(self._max_mem_ratio, max(0.0, pred))
129
+
130
+ def state_dict(self):
131
+ return {
132
+ 'params': self._params,
133
+ }
134
+
135
+ def load_state_dict(self, state_dict):
136
+ self._params = tuple(state_dict['params'])
137
+
138
+ def log(self):
139
+ return {
140
+ 'params/k': self._params[0],
141
+ 'params/b': self._params[1],
142
+ 'memory': self._last_memory,
143
+ 'input_size': self._last_input_size,
144
+ 'mem_ratio': self._last_mem_ratio,
145
+ }
146
+
147
+
148
+ class ElasticModule(nn.Module):
149
+ """
150
+ Module for training with elastic memory management.
151
+ """
152
+ def __init__(self):
153
+ super().__init__()
154
+ self._memory_controller: MemoryController = None
155
+
156
+ @abstractmethod
157
+ def _get_input_size(self, *args, **kwargs) -> int:
158
+ """
159
+ Get the size of the input data.
160
+
161
+ Returns:
162
+ int: The size of the input data.
163
+ """
164
+ pass
165
+
166
+ @abstractmethod
167
+ def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
168
+ """
169
+ Forward with a given memory ratio.
170
+ """
171
+ pass
172
+
173
+ def register_memory_controller(self, memory_controller: MemoryController):
174
+ self._memory_controller = memory_controller
175
+
176
+ def forward(self, *args, **kwargs):
177
+ if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
178
+ _, ret = self._forward_with_mem_ratio(*args, **kwargs)
179
+ else:
180
+ input_size = self._get_input_size(*args, **kwargs)
181
+ mem_ratio = self._memory_controller.get_mem_ratio(input_size)
182
+ mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
183
+ self._memory_controller.update_run_states(input_size, mem_ratio)
184
+ return ret
185
+
186
+
187
+ class ElasticModuleMixin:
188
+ """
189
+ Mixin for training with elastic memory management.
190
+ """
191
+ def __init__(self, *args, **kwargs):
192
+ super().__init__(*args, **kwargs)
193
+ self._memory_controller: MemoryController = None
194
+
195
+ @abstractmethod
196
+ def _get_input_size(self, *args, **kwargs) -> int:
197
+ """
198
+ Get the size of the input data.
199
+
200
+ Returns:
201
+ int: The size of the input data.
202
+ """
203
+ pass
204
+
205
+ @abstractmethod
206
+ @contextmanager
207
+ def with_mem_ratio(self, mem_ratio=1.0) -> float:
208
+ """
209
+ Context manager for training with a reduced memory ratio compared to the full memory usage.
210
+
211
+ Returns:
212
+ float: The exact memory ratio used during the forward pass.
213
+ """
214
+ pass
215
+
216
+ def register_memory_controller(self, memory_controller: MemoryController):
217
+ self._memory_controller = memory_controller
218
+
219
+ def forward(self, *args, **kwargs):
220
+ if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
221
+ ret = super().forward(*args, **kwargs)
222
+ else:
223
+ input_size = self._get_input_size(*args, **kwargs)
224
+ mem_ratio = self._memory_controller.get_mem_ratio(input_size)
225
+ with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
226
+ ret = super().forward(*args, **kwargs)
227
+ self._memory_controller.update_run_states(input_size, exact_mem_ratio)
228
+ return ret
trellis/utils/grad_clip_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import numpy as np
4
+ import torch.utils
5
+
6
+
7
+ class AdaptiveGradClipper:
8
+ """
9
+ Adaptive gradient clipping for training.
10
+ """
11
+ def __init__(
12
+ self,
13
+ max_norm=None,
14
+ clip_percentile=95.0,
15
+ buffer_size=1000,
16
+ ):
17
+ self.max_norm = max_norm
18
+ self.clip_percentile = clip_percentile
19
+ self.buffer_size = buffer_size
20
+
21
+ self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
22
+ self._max_norm = max_norm
23
+ self._buffer_ptr = 0
24
+ self._buffer_length = 0
25
+
26
+ def __repr__(self):
27
+ return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
28
+
29
+ def state_dict(self):
30
+ return {
31
+ 'grad_norm': self._grad_norm,
32
+ 'max_norm': self._max_norm,
33
+ 'buffer_ptr': self._buffer_ptr,
34
+ 'buffer_length': self._buffer_length,
35
+ }
36
+
37
+ def load_state_dict(self, state_dict):
38
+ self._grad_norm = state_dict['grad_norm']
39
+ self._max_norm = state_dict['max_norm']
40
+ self._buffer_ptr = state_dict['buffer_ptr']
41
+ self._buffer_length = state_dict['buffer_length']
42
+
43
+ def log(self):
44
+ return {
45
+ 'max_norm': self._max_norm,
46
+ }
47
+
48
+ def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
49
+ """Clip the gradient norm of an iterable of parameters.
50
+
51
+ The norm is computed over all gradients together, as if they were
52
+ concatenated into a single vector. Gradients are modified in-place.
53
+
54
+ Args:
55
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
56
+ single Tensor that will have gradients normalized
57
+ norm_type (float): type of the used p-norm. Can be ``'inf'`` for
58
+ infinity norm.
59
+ error_if_nonfinite (bool): if True, an error is thrown if the total
60
+ norm of the gradients from :attr:`parameters` is ``nan``,
61
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
62
+ foreach (bool): use the faster foreach-based implementation.
63
+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
64
+ fall back to the slow implementation for other device types.
65
+ Default: ``None``
66
+
67
+ Returns:
68
+ Total norm of the parameter gradients (viewed as a single vector).
69
+ """
70
+ max_norm = self._max_norm if self._max_norm is not None else float('inf')
71
+ grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
72
+
73
+ if torch.isfinite(grad_norm):
74
+ self._grad_norm[self._buffer_ptr] = grad_norm
75
+ self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
76
+ self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
77
+ if self._buffer_length == self.buffer_size:
78
+ self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
79
+ self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
80
+
81
+ return grad_norm
trellis/utils/loss_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.autograd import Variable
4
+ from math import exp
5
+ from lpips import LPIPS
6
+
7
+
8
+ def smooth_l1_loss(pred, target, beta=1.0):
9
+ diff = torch.abs(pred - target)
10
+ loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
11
+ return loss.mean()
12
+
13
+
14
+ def l1_loss(network_output, gt):
15
+ return torch.abs((network_output - gt)).mean()
16
+
17
+
18
+ def l2_loss(network_output, gt):
19
+ return ((network_output - gt) ** 2).mean()
20
+
21
+
22
+ def gaussian(window_size, sigma):
23
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
24
+ return gauss / gauss.sum()
25
+
26
+
27
+ def create_window(window_size, channel):
28
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
29
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
30
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
31
+ return window
32
+
33
+
34
+ def psnr(img1, img2, max_val=1.0):
35
+ mse = F.mse_loss(img1, img2)
36
+ return 20 * torch.log10(max_val / torch.sqrt(mse))
37
+
38
+
39
+ def ssim(img1, img2, window_size=11, size_average=True):
40
+ channel = img1.size(-3)
41
+ window = create_window(window_size, channel)
42
+
43
+ if img1.is_cuda:
44
+ window = window.cuda(img1.get_device())
45
+ window = window.type_as(img1)
46
+
47
+ return _ssim(img1, img2, window, window_size, channel, size_average)
48
+
49
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
50
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
51
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
52
+
53
+ mu1_sq = mu1.pow(2)
54
+ mu2_sq = mu2.pow(2)
55
+ mu1_mu2 = mu1 * mu2
56
+
57
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
58
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
59
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
60
+
61
+ C1 = 0.01 ** 2
62
+ C2 = 0.03 ** 2
63
+
64
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
65
+
66
+ if size_average:
67
+ return ssim_map.mean()
68
+ else:
69
+ return ssim_map.mean(1).mean(1).mean(1)
70
+
71
+
72
+ loss_fn_vgg = None
73
+ def lpips(img1, img2, value_range=(0, 1)):
74
+ global loss_fn_vgg
75
+ if loss_fn_vgg is None:
76
+ loss_fn_vgg = LPIPS(net='vgg').cuda().eval()
77
+ # normalize to [-1, 1]
78
+ img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
79
+ img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
80
+ return loss_fn_vgg(img1, img2).mean()
81
+
82
+
83
+ def normal_angle(pred, gt):
84
+ pred = pred * 2.0 - 1.0
85
+ gt = gt * 2.0 - 1.0
86
+ norms = pred.norm(dim=-1) * gt.norm(dim=-1)
87
+ cos_sim = (pred * gt).sum(-1) / (norms + 1e-9)
88
+ cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
89
+ ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean()
90
+ if ang.isnan():
91
+ return -1
92
+ return ang