cavargas10 commited on
Commit
43de49b
·
verified ·
1 Parent(s): fbb5fb5

Update trellis/datasets/structured_latent.py

Browse files
Files changed (1) hide show
  1. trellis/datasets/structured_latent.py +218 -218
trellis/datasets/structured_latent.py CHANGED
@@ -1,218 +1,218 @@
1
- import json
2
- import os
3
- from typing import *
4
- import numpy as np
5
- import torch
6
- import utils3d.torch
7
- from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
8
- from ..modules.sparse.basic import SparseTensor
9
- from .. import models
10
- from ..utils.render_utils import get_renderer
11
- from ..utils.dist_utils import read_file_dist
12
- from ..utils.data_utils import load_balanced_group_indices
13
-
14
-
15
- class SLatVisMixin:
16
- def __init__(
17
- self,
18
- *args,
19
- pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
20
- slat_dec_path: Optional[str] = None,
21
- slat_dec_ckpt: Optional[str] = None,
22
- **kwargs
23
- ):
24
- super().__init__(*args, **kwargs)
25
- self.slat_dec = None
26
- self.pretrained_slat_dec = pretrained_slat_dec
27
- self.slat_dec_path = slat_dec_path
28
- self.slat_dec_ckpt = slat_dec_ckpt
29
-
30
- def _loading_slat_dec(self):
31
- if self.slat_dec is not None:
32
- return
33
- if self.slat_dec_path is not None:
34
- cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
35
- decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
36
- ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
37
- decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True))
38
- else:
39
- decoder = models.from_pretrained(self.pretrained_slat_dec)
40
- self.slat_dec = decoder.cuda().eval()
41
-
42
- def _delete_slat_dec(self):
43
- del self.slat_dec
44
- self.slat_dec = None
45
-
46
- @torch.no_grad()
47
- def decode_latent(self, z, batch_size=4):
48
- self._loading_slat_dec()
49
- reps = []
50
- if self.normalization is not None:
51
- z = z * self.std.to(z.device) + self.mean.to(z.device)
52
- for i in range(0, z.shape[0], batch_size):
53
- reps.append(self.slat_dec(z[i:i+batch_size]))
54
- reps = sum(reps, [])
55
- self._delete_slat_dec()
56
- return reps
57
-
58
- @torch.no_grad()
59
- def visualize_sample(self, x_0: Union[SparseTensor, dict]):
60
- x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
61
- reps = self.decode_latent(x_0.cuda())
62
-
63
- # Build camera
64
- yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
65
- yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
66
- yaws = [y + yaws_offset for y in yaws]
67
- pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
68
-
69
- exts = []
70
- ints = []
71
- for yaw, pitch in zip(yaws, pitch):
72
- orig = torch.tensor([
73
- np.sin(yaw) * np.cos(pitch),
74
- np.cos(yaw) * np.cos(pitch),
75
- np.sin(pitch),
76
- ]).float().cuda() * 2
77
- fov = torch.deg2rad(torch.tensor(40)).cuda()
78
- extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
79
- intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
80
- exts.append(extrinsics)
81
- ints.append(intrinsics)
82
-
83
- renderer = get_renderer(reps[0])
84
- images = []
85
- for representation in reps:
86
- image = torch.zeros(3, 1024, 1024).cuda()
87
- tile = [2, 2]
88
- for j, (ext, intr) in enumerate(zip(exts, ints)):
89
- res = renderer.render(representation, ext, intr)
90
- image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
91
- images.append(image)
92
- images = torch.stack(images)
93
-
94
- return images
95
-
96
-
97
- class SLat(SLatVisMixin, StandardDatasetBase):
98
- """
99
- structured latent dataset
100
-
101
- Args:
102
- roots (str): path to the dataset
103
- latent_model (str): name of the latent model
104
- min_aesthetic_score (float): minimum aesthetic score
105
- max_num_voxels (int): maximum number of voxels
106
- normalization (dict): normalization stats
107
- pretrained_slat_dec (str): name of the pretrained slat decoder
108
- slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
109
- slat_dec_ckpt (str): name of the slat decoder checkpoint
110
- """
111
- def __init__(self,
112
- roots: str,
113
- *,
114
- latent_model: str,
115
- min_aesthetic_score: float = 5.0,
116
- max_num_voxels: int = 32768,
117
- normalization: Optional[dict] = None,
118
- pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
119
- slat_dec_path: Optional[str] = None,
120
- slat_dec_ckpt: Optional[str] = None,
121
- ):
122
- self.normalization = normalization
123
- self.latent_model = latent_model
124
- self.min_aesthetic_score = min_aesthetic_score
125
- self.max_num_voxels = max_num_voxels
126
- self.value_range = (0, 1)
127
-
128
- super().__init__(
129
- roots,
130
- pretrained_slat_dec=pretrained_slat_dec,
131
- slat_dec_path=slat_dec_path,
132
- slat_dec_ckpt=slat_dec_ckpt,
133
- )
134
-
135
- self.loads = [self.metadata.loc[sha256, 'num_voxels'] for _, sha256 in self.instances]
136
-
137
- if self.normalization is not None:
138
- self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
139
- self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
140
-
141
- def filter_metadata(self, metadata):
142
- stats = {}
143
- metadata = metadata[metadata[f'latent_{self.latent_model}']]
144
- stats['With latent'] = len(metadata)
145
- metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
146
- stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
147
- metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
148
- stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
149
- return metadata, stats
150
-
151
- def get_instance(self, root, instance):
152
- data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
153
- coords = torch.tensor(data['coords']).int()
154
- feats = torch.tensor(data['feats']).float()
155
- if self.normalization is not None:
156
- feats = (feats - self.mean) / self.std
157
- return {
158
- 'coords': coords,
159
- 'feats': feats,
160
- }
161
-
162
- @staticmethod
163
- def collate_fn(batch, split_size=None):
164
- if split_size is None:
165
- group_idx = [list(range(len(batch)))]
166
- else:
167
- group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
168
- packs = []
169
- for group in group_idx:
170
- sub_batch = [batch[i] for i in group]
171
- pack = {}
172
- coords = []
173
- feats = []
174
- layout = []
175
- start = 0
176
- for i, b in enumerate(sub_batch):
177
- coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
178
- feats.append(b['feats'])
179
- layout.append(slice(start, start + b['coords'].shape[0]))
180
- start += b['coords'].shape[0]
181
- coords = torch.cat(coords)
182
- feats = torch.cat(feats)
183
- pack['x_0'] = SparseTensor(
184
- coords=coords,
185
- feats=feats,
186
- )
187
- pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
188
- pack['x_0'].register_spatial_cache('layout', layout)
189
-
190
- # collate other data
191
- keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
192
- for k in keys:
193
- if isinstance(sub_batch[0][k], torch.Tensor):
194
- pack[k] = torch.stack([b[k] for b in sub_batch])
195
- elif isinstance(sub_batch[0][k], list):
196
- pack[k] = sum([b[k] for b in sub_batch], [])
197
- else:
198
- pack[k] = [b[k] for b in sub_batch]
199
-
200
- packs.append(pack)
201
-
202
- if split_size is None:
203
- return packs[0]
204
- return packs
205
-
206
-
207
- class TextConditionedSLat(TextConditionedMixin, SLat):
208
- """
209
- Text conditioned structured latent dataset
210
- """
211
- pass
212
-
213
-
214
- class ImageConditionedSLat(ImageConditionedMixin, SLat):
215
- """
216
- Image conditioned structured latent dataset
217
- """
218
- pass
 
1
+ import json
2
+ import os
3
+ from typing import *
4
+ import numpy as np
5
+ import torch
6
+ import utils3d.torch
7
+ from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
8
+ from ..modules.sparse.basic import SparseTensor
9
+ from .. import models
10
+ from ..utils.render_utils import get_renderer
11
+ from ..utils.dist_utils import read_file_dist
12
+ from ..utils.data_utils import load_balanced_group_indices
13
+
14
+
15
+ class SLatVisMixin:
16
+ def __init__(
17
+ self,
18
+ *args,
19
+ pretrained_slat_dec: str = 'cavargas10/TRELLIS/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
20
+ slat_dec_path: Optional[str] = None,
21
+ slat_dec_ckpt: Optional[str] = None,
22
+ **kwargs
23
+ ):
24
+ super().__init__(*args, **kwargs)
25
+ self.slat_dec = None
26
+ self.pretrained_slat_dec = pretrained_slat_dec
27
+ self.slat_dec_path = slat_dec_path
28
+ self.slat_dec_ckpt = slat_dec_ckpt
29
+
30
+ def _loading_slat_dec(self):
31
+ if self.slat_dec is not None:
32
+ return
33
+ if self.slat_dec_path is not None:
34
+ cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
35
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
36
+ ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
37
+ decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True))
38
+ else:
39
+ decoder = models.from_pretrained(self.pretrained_slat_dec)
40
+ self.slat_dec = decoder.cuda().eval()
41
+
42
+ def _delete_slat_dec(self):
43
+ del self.slat_dec
44
+ self.slat_dec = None
45
+
46
+ @torch.no_grad()
47
+ def decode_latent(self, z, batch_size=4):
48
+ self._loading_slat_dec()
49
+ reps = []
50
+ if self.normalization is not None:
51
+ z = z * self.std.to(z.device) + self.mean.to(z.device)
52
+ for i in range(0, z.shape[0], batch_size):
53
+ reps.append(self.slat_dec(z[i:i+batch_size]))
54
+ reps = sum(reps, [])
55
+ self._delete_slat_dec()
56
+ return reps
57
+
58
+ @torch.no_grad()
59
+ def visualize_sample(self, x_0: Union[SparseTensor, dict]):
60
+ x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
61
+ reps = self.decode_latent(x_0.cuda())
62
+
63
+ # Build camera
64
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
65
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
66
+ yaws = [y + yaws_offset for y in yaws]
67
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
68
+
69
+ exts = []
70
+ ints = []
71
+ for yaw, pitch in zip(yaws, pitch):
72
+ orig = torch.tensor([
73
+ np.sin(yaw) * np.cos(pitch),
74
+ np.cos(yaw) * np.cos(pitch),
75
+ np.sin(pitch),
76
+ ]).float().cuda() * 2
77
+ fov = torch.deg2rad(torch.tensor(40)).cuda()
78
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
79
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
80
+ exts.append(extrinsics)
81
+ ints.append(intrinsics)
82
+
83
+ renderer = get_renderer(reps[0])
84
+ images = []
85
+ for representation in reps:
86
+ image = torch.zeros(3, 1024, 1024).cuda()
87
+ tile = [2, 2]
88
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
89
+ res = renderer.render(representation, ext, intr)
90
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
91
+ images.append(image)
92
+ images = torch.stack(images)
93
+
94
+ return images
95
+
96
+
97
+ class SLat(SLatVisMixin, StandardDatasetBase):
98
+ """
99
+ structured latent dataset
100
+
101
+ Args:
102
+ roots (str): path to the dataset
103
+ latent_model (str): name of the latent model
104
+ min_aesthetic_score (float): minimum aesthetic score
105
+ max_num_voxels (int): maximum number of voxels
106
+ normalization (dict): normalization stats
107
+ pretrained_slat_dec (str): name of the pretrained slat decoder
108
+ slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
109
+ slat_dec_ckpt (str): name of the slat decoder checkpoint
110
+ """
111
+ def __init__(self,
112
+ roots: str,
113
+ *,
114
+ latent_model: str,
115
+ min_aesthetic_score: float = 5.0,
116
+ max_num_voxels: int = 32768,
117
+ normalization: Optional[dict] = None,
118
+ pretrained_slat_dec: str = 'cavargas10/TRELLIS/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
119
+ slat_dec_path: Optional[str] = None,
120
+ slat_dec_ckpt: Optional[str] = None,
121
+ ):
122
+ self.normalization = normalization
123
+ self.latent_model = latent_model
124
+ self.min_aesthetic_score = min_aesthetic_score
125
+ self.max_num_voxels = max_num_voxels
126
+ self.value_range = (0, 1)
127
+
128
+ super().__init__(
129
+ roots,
130
+ pretrained_slat_dec=pretrained_slat_dec,
131
+ slat_dec_path=slat_dec_path,
132
+ slat_dec_ckpt=slat_dec_ckpt,
133
+ )
134
+
135
+ self.loads = [self.metadata.loc[sha256, 'num_voxels'] for _, sha256 in self.instances]
136
+
137
+ if self.normalization is not None:
138
+ self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
139
+ self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
140
+
141
+ def filter_metadata(self, metadata):
142
+ stats = {}
143
+ metadata = metadata[metadata[f'latent_{self.latent_model}']]
144
+ stats['With latent'] = len(metadata)
145
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
146
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
147
+ metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
148
+ stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
149
+ return metadata, stats
150
+
151
+ def get_instance(self, root, instance):
152
+ data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
153
+ coords = torch.tensor(data['coords']).int()
154
+ feats = torch.tensor(data['feats']).float()
155
+ if self.normalization is not None:
156
+ feats = (feats - self.mean) / self.std
157
+ return {
158
+ 'coords': coords,
159
+ 'feats': feats,
160
+ }
161
+
162
+ @staticmethod
163
+ def collate_fn(batch, split_size=None):
164
+ if split_size is None:
165
+ group_idx = [list(range(len(batch)))]
166
+ else:
167
+ group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
168
+ packs = []
169
+ for group in group_idx:
170
+ sub_batch = [batch[i] for i in group]
171
+ pack = {}
172
+ coords = []
173
+ feats = []
174
+ layout = []
175
+ start = 0
176
+ for i, b in enumerate(sub_batch):
177
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
178
+ feats.append(b['feats'])
179
+ layout.append(slice(start, start + b['coords'].shape[0]))
180
+ start += b['coords'].shape[0]
181
+ coords = torch.cat(coords)
182
+ feats = torch.cat(feats)
183
+ pack['x_0'] = SparseTensor(
184
+ coords=coords,
185
+ feats=feats,
186
+ )
187
+ pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
188
+ pack['x_0'].register_spatial_cache('layout', layout)
189
+
190
+ # collate other data
191
+ keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
192
+ for k in keys:
193
+ if isinstance(sub_batch[0][k], torch.Tensor):
194
+ pack[k] = torch.stack([b[k] for b in sub_batch])
195
+ elif isinstance(sub_batch[0][k], list):
196
+ pack[k] = sum([b[k] for b in sub_batch], [])
197
+ else:
198
+ pack[k] = [b[k] for b in sub_batch]
199
+
200
+ packs.append(pack)
201
+
202
+ if split_size is None:
203
+ return packs[0]
204
+ return packs
205
+
206
+
207
+ class TextConditionedSLat(TextConditionedMixin, SLat):
208
+ """
209
+ Text conditioned structured latent dataset
210
+ """
211
+ pass
212
+
213
+
214
+ class ImageConditionedSLat(ImageConditionedMixin, SLat):
215
+ """
216
+ Image conditioned structured latent dataset
217
+ """
218
+ pass