cavargas10 commited on
Commit
12ea35b
·
verified ·
1 Parent(s): 3ff9d59

Update trellis/utils/general_utils.py

Browse files
Files changed (1) hide show
  1. trellis/utils/general_utils.py +201 -187
trellis/utils/general_utils.py CHANGED
@@ -1,187 +1,201 @@
1
- import numpy as np
2
- import cv2
3
- import torch
4
-
5
-
6
- # Dictionary utils
7
- def _dict_merge(dicta, dictb, prefix=''):
8
- """
9
- Merge two dictionaries.
10
- """
11
- assert isinstance(dicta, dict), 'input must be a dictionary'
12
- assert isinstance(dictb, dict), 'input must be a dictionary'
13
- dict_ = {}
14
- all_keys = set(dicta.keys()).union(set(dictb.keys()))
15
- for key in all_keys:
16
- if key in dicta.keys() and key in dictb.keys():
17
- if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
18
- dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
19
- else:
20
- raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
21
- elif key in dicta.keys():
22
- dict_[key] = dicta[key]
23
- else:
24
- dict_[key] = dictb[key]
25
- return dict_
26
-
27
-
28
- def dict_merge(dicta, dictb):
29
- """
30
- Merge two dictionaries.
31
- """
32
- return _dict_merge(dicta, dictb, prefix='')
33
-
34
-
35
- def dict_foreach(dic, func, special_func={}):
36
- """
37
- Recursively apply a function to all non-dictionary leaf values in a dictionary.
38
- """
39
- assert isinstance(dic, dict), 'input must be a dictionary'
40
- for key in dic.keys():
41
- if isinstance(dic[key], dict):
42
- dic[key] = dict_foreach(dic[key], func)
43
- else:
44
- if key in special_func.keys():
45
- dic[key] = special_func[key](dic[key])
46
- else:
47
- dic[key] = func(dic[key])
48
- return dic
49
-
50
-
51
- def dict_reduce(dicts, func, special_func={}):
52
- """
53
- Reduce a list of dictionaries. Leaf values must be scalars.
54
- """
55
- assert isinstance(dicts, list), 'input must be a list of dictionaries'
56
- assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
57
- assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
58
- all_keys = set([key for dict_ in dicts for key in dict_.keys()])
59
- reduced_dict = {}
60
- for key in all_keys:
61
- vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
62
- if isinstance(vlist[0], dict):
63
- reduced_dict[key] = dict_reduce(vlist, func, special_func)
64
- else:
65
- if key in special_func.keys():
66
- reduced_dict[key] = special_func[key](vlist)
67
- else:
68
- reduced_dict[key] = func(vlist)
69
- return reduced_dict
70
-
71
-
72
- def dict_any(dic, func):
73
- """
74
- Recursively apply a function to all non-dictionary leaf values in a dictionary.
75
- """
76
- assert isinstance(dic, dict), 'input must be a dictionary'
77
- for key in dic.keys():
78
- if isinstance(dic[key], dict):
79
- if dict_any(dic[key], func):
80
- return True
81
- else:
82
- if func(dic[key]):
83
- return True
84
- return False
85
-
86
-
87
- def dict_all(dic, func):
88
- """
89
- Recursively apply a function to all non-dictionary leaf values in a dictionary.
90
- """
91
- assert isinstance(dic, dict), 'input must be a dictionary'
92
- for key in dic.keys():
93
- if isinstance(dic[key], dict):
94
- if not dict_all(dic[key], func):
95
- return False
96
- else:
97
- if not func(dic[key]):
98
- return False
99
- return True
100
-
101
-
102
- def dict_flatten(dic, sep='.'):
103
- """
104
- Flatten a nested dictionary into a dictionary with no nested dictionaries.
105
- """
106
- assert isinstance(dic, dict), 'input must be a dictionary'
107
- flat_dict = {}
108
- for key in dic.keys():
109
- if isinstance(dic[key], dict):
110
- sub_dict = dict_flatten(dic[key], sep=sep)
111
- for sub_key in sub_dict.keys():
112
- flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
113
- else:
114
- flat_dict[key] = dic[key]
115
- return flat_dict
116
-
117
-
118
- def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
119
- num_images = len(images)
120
- if nrow is None and ncol is None:
121
- if aspect_ratio is not None:
122
- nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
123
- else:
124
- nrow = int(np.sqrt(num_images))
125
- ncol = (num_images + nrow - 1) // nrow
126
- elif nrow is None and ncol is not None:
127
- nrow = (num_images + ncol - 1) // ncol
128
- elif nrow is not None and ncol is None:
129
- ncol = (num_images + nrow - 1) // nrow
130
- else:
131
- assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
132
-
133
- grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
134
- for i, img in enumerate(images):
135
- row = i // ncol
136
- col = i % ncol
137
- grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
138
- return grid
139
-
140
-
141
- def notes_on_image(img, notes=None):
142
- img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
143
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
144
- if notes is not None:
145
- img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
146
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
147
- return img
148
-
149
-
150
- def save_image_with_notes(img, path, notes=None):
151
- """
152
- Save an image with notes.
153
- """
154
- if isinstance(img, torch.Tensor):
155
- img = img.cpu().numpy().transpose(1, 2, 0)
156
- if img.dtype == np.float32 or img.dtype == np.float64:
157
- img = np.clip(img * 255, 0, 255).astype(np.uint8)
158
- img = notes_on_image(img, notes)
159
- cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
160
-
161
-
162
- # debug utils
163
-
164
- def atol(x, y):
165
- """
166
- Absolute tolerance.
167
- """
168
- return torch.abs(x - y)
169
-
170
-
171
- def rtol(x, y):
172
- """
173
- Relative tolerance.
174
- """
175
- return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
176
-
177
-
178
- # print utils
179
- def indent(s, n=4):
180
- """
181
- Indent a string.
182
- """
183
- lines = s.split('\n')
184
- for i in range(1, len(lines)):
185
- lines[i] = ' ' * n + lines[i]
186
- return '\n'.join(lines)
187
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ import contextlib
6
+
7
+
8
+ # Dictionary utils
9
+ def _dict_merge(dicta, dictb, prefix=''):
10
+ """
11
+ Merge two dictionaries.
12
+ """
13
+ assert isinstance(dicta, dict), 'input must be a dictionary'
14
+ assert isinstance(dictb, dict), 'input must be a dictionary'
15
+ dict_ = {}
16
+ all_keys = set(dicta.keys()).union(set(dictb.keys()))
17
+ for key in all_keys:
18
+ if key in dicta.keys() and key in dictb.keys():
19
+ if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
20
+ dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
21
+ else:
22
+ raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
23
+ elif key in dicta.keys():
24
+ dict_[key] = dicta[key]
25
+ else:
26
+ dict_[key] = dictb[key]
27
+ return dict_
28
+
29
+
30
+ def dict_merge(dicta, dictb):
31
+ """
32
+ Merge two dictionaries.
33
+ """
34
+ return _dict_merge(dicta, dictb, prefix='')
35
+
36
+
37
+ def dict_foreach(dic, func, special_func={}):
38
+ """
39
+ Recursively apply a function to all non-dictionary leaf values in a dictionary.
40
+ """
41
+ assert isinstance(dic, dict), 'input must be a dictionary'
42
+ for key in dic.keys():
43
+ if isinstance(dic[key], dict):
44
+ dic[key] = dict_foreach(dic[key], func)
45
+ else:
46
+ if key in special_func.keys():
47
+ dic[key] = special_func[key](dic[key])
48
+ else:
49
+ dic[key] = func(dic[key])
50
+ return dic
51
+
52
+
53
+ def dict_reduce(dicts, func, special_func={}):
54
+ """
55
+ Reduce a list of dictionaries. Leaf values must be scalars.
56
+ """
57
+ assert isinstance(dicts, list), 'input must be a list of dictionaries'
58
+ assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
59
+ assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
60
+ all_keys = set([key for dict_ in dicts for key in dict_.keys()])
61
+ reduced_dict = {}
62
+ for key in all_keys:
63
+ vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
64
+ if isinstance(vlist[0], dict):
65
+ reduced_dict[key] = dict_reduce(vlist, func, special_func)
66
+ else:
67
+ if key in special_func.keys():
68
+ reduced_dict[key] = special_func[key](vlist)
69
+ else:
70
+ reduced_dict[key] = func(vlist)
71
+ return reduced_dict
72
+
73
+
74
+ def dict_any(dic, func):
75
+ """
76
+ Recursively apply a function to all non-dictionary leaf values in a dictionary.
77
+ """
78
+ assert isinstance(dic, dict), 'input must be a dictionary'
79
+ for key in dic.keys():
80
+ if isinstance(dic[key], dict):
81
+ if dict_any(dic[key], func):
82
+ return True
83
+ else:
84
+ if func(dic[key]):
85
+ return True
86
+ return False
87
+
88
+
89
+ def dict_all(dic, func):
90
+ """
91
+ Recursively apply a function to all non-dictionary leaf values in a dictionary.
92
+ """
93
+ assert isinstance(dic, dict), 'input must be a dictionary'
94
+ for key in dic.keys():
95
+ if isinstance(dic[key], dict):
96
+ if not dict_all(dic[key], func):
97
+ return False
98
+ else:
99
+ if not func(dic[key]):
100
+ return False
101
+ return True
102
+
103
+
104
+ def dict_flatten(dic, sep='.'):
105
+ """
106
+ Flatten a nested dictionary into a dictionary with no nested dictionaries.
107
+ """
108
+ assert isinstance(dic, dict), 'input must be a dictionary'
109
+ flat_dict = {}
110
+ for key in dic.keys():
111
+ if isinstance(dic[key], dict):
112
+ sub_dict = dict_flatten(dic[key], sep=sep)
113
+ for sub_key in sub_dict.keys():
114
+ flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
115
+ else:
116
+ flat_dict[key] = dic[key]
117
+ return flat_dict
118
+
119
+
120
+ # Context utils
121
+ @contextlib.contextmanager
122
+ def nested_contexts(*contexts):
123
+ with contextlib.ExitStack() as stack:
124
+ for ctx in contexts:
125
+ stack.enter_context(ctx())
126
+ yield
127
+
128
+
129
+ # Image utils
130
+ def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
131
+ num_images = len(images)
132
+ if nrow is None and ncol is None:
133
+ if aspect_ratio is not None:
134
+ nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
135
+ else:
136
+ nrow = int(np.sqrt(num_images))
137
+ ncol = (num_images + nrow - 1) // nrow
138
+ elif nrow is None and ncol is not None:
139
+ nrow = (num_images + ncol - 1) // ncol
140
+ elif nrow is not None and ncol is None:
141
+ ncol = (num_images + nrow - 1) // nrow
142
+ else:
143
+ assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
144
+
145
+ if images[0].ndim == 2:
146
+ grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
147
+ else:
148
+ grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
149
+ for i, img in enumerate(images):
150
+ row = i // ncol
151
+ col = i % ncol
152
+ grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
153
+ return grid
154
+
155
+
156
+ def notes_on_image(img, notes=None):
157
+ img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
158
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
159
+ if notes is not None:
160
+ img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
161
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
162
+ return img
163
+
164
+
165
+ def save_image_with_notes(img, path, notes=None):
166
+ """
167
+ Save an image with notes.
168
+ """
169
+ if isinstance(img, torch.Tensor):
170
+ img = img.cpu().numpy().transpose(1, 2, 0)
171
+ if img.dtype == np.float32 or img.dtype == np.float64:
172
+ img = np.clip(img * 255, 0, 255).astype(np.uint8)
173
+ img = notes_on_image(img, notes)
174
+ cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
175
+
176
+
177
+ # debug utils
178
+
179
+ def atol(x, y):
180
+ """
181
+ Absolute tolerance.
182
+ """
183
+ return torch.abs(x - y)
184
+
185
+
186
+ def rtol(x, y):
187
+ """
188
+ Relative tolerance.
189
+ """
190
+ return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
191
+
192
+
193
+ # print utils
194
+ def indent(s, n=4):
195
+ """
196
+ Indent a string.
197
+ """
198
+ lines = s.split('\n')
199
+ for i in range(1, len(lines)):
200
+ lines[i] = ' ' * n + lines[i]
201
+ return '\n'.join(lines)