cavargas10 commited on
Commit
fbb5fb5
·
verified ·
1 Parent(s): 5855c62

Update trellis/datasets/sparse_structure_latent.py

Browse files
trellis/datasets/sparse_structure_latent.py CHANGED
@@ -1,189 +1,189 @@
1
- import os
2
- import json
3
- from typing import *
4
- import numpy as np
5
- import torch
6
- import utils3d
7
- from ..representations.octree import DfsOctree as Octree
8
- from ..renderers import OctreeRenderer
9
- from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
10
- from .. import models
11
- from ..utils.dist_utils import read_file_dist
12
-
13
-
14
- class SparseStructureLatentVisMixin:
15
- def __init__(
16
- self,
17
- *args,
18
- pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
19
- ss_dec_path: Optional[str] = None,
20
- ss_dec_ckpt: Optional[str] = None,
21
- **kwargs
22
- ):
23
- super().__init__(*args, **kwargs)
24
- self.ss_dec = None
25
- self.pretrained_ss_dec = pretrained_ss_dec
26
- self.ss_dec_path = ss_dec_path
27
- self.ss_dec_ckpt = ss_dec_ckpt
28
-
29
- def _loading_ss_dec(self):
30
- if self.ss_dec is not None:
31
- return
32
- if self.ss_dec_path is not None:
33
- cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r'))
34
- decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
35
- ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt')
36
- decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True))
37
- else:
38
- decoder = models.from_pretrained(self.pretrained_ss_dec)
39
- self.ss_dec = decoder.cuda().eval()
40
-
41
- def _delete_ss_dec(self):
42
- del self.ss_dec
43
- self.ss_dec = None
44
-
45
- @torch.no_grad()
46
- def decode_latent(self, z, batch_size=4):
47
- self._loading_ss_dec()
48
- ss = []
49
- if self.normalization is not None:
50
- z = z * self.std.to(z.device) + self.mean.to(z.device)
51
- for i in range(0, z.shape[0], batch_size):
52
- ss.append(self.ss_dec(z[i:i+batch_size]))
53
- ss = torch.cat(ss, dim=0)
54
- self._delete_ss_dec()
55
- return ss
56
-
57
- @torch.no_grad()
58
- def visualize_sample(self, x_0: Union[torch.Tensor, dict]):
59
- x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
60
- x_0 = self.decode_latent(x_0.cuda())
61
-
62
- renderer = OctreeRenderer()
63
- renderer.rendering_options.resolution = 512
64
- renderer.rendering_options.near = 0.8
65
- renderer.rendering_options.far = 1.6
66
- renderer.rendering_options.bg_color = (0, 0, 0)
67
- renderer.rendering_options.ssaa = 4
68
- renderer.pipe.primitive = 'voxel'
69
-
70
- # Build camera
71
- yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
72
- yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
73
- yaws = [y + yaws_offset for y in yaws]
74
- pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
75
-
76
- exts = []
77
- ints = []
78
- for yaw, pitch in zip(yaws, pitch):
79
- orig = torch.tensor([
80
- np.sin(yaw) * np.cos(pitch),
81
- np.cos(yaw) * np.cos(pitch),
82
- np.sin(pitch),
83
- ]).float().cuda() * 2
84
- fov = torch.deg2rad(torch.tensor(30)).cuda()
85
- extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
86
- intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
87
- exts.append(extrinsics)
88
- ints.append(intrinsics)
89
-
90
- images = []
91
-
92
- # Build each representation
93
- x_0 = x_0.cuda()
94
- for i in range(x_0.shape[0]):
95
- representation = Octree(
96
- depth=10,
97
- aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
98
- device='cuda',
99
- primitive='voxel',
100
- sh_degree=0,
101
- primitive_config={'solid': True},
102
- )
103
- coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False)
104
- resolution = x_0.shape[-1]
105
- representation.position = coords.float() / resolution
106
- representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(resolution)), dtype=torch.uint8, device='cuda')
107
-
108
- image = torch.zeros(3, 1024, 1024).cuda()
109
- tile = [2, 2]
110
- for j, (ext, intr) in enumerate(zip(exts, ints)):
111
- res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
112
- image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
113
- images.append(image)
114
-
115
- return torch.stack(images)
116
-
117
-
118
- class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase):
119
- """
120
- Sparse structure latent dataset
121
-
122
- Args:
123
- roots (str): path to the dataset
124
- latent_model (str): name of the latent model
125
- min_aesthetic_score (float): minimum aesthetic score
126
- normalization (dict): normalization stats
127
- pretrained_ss_dec (str): name of the pretrained sparse structure decoder
128
- ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
129
- ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
130
- """
131
- def __init__(self,
132
- roots: str,
133
- *,
134
- latent_model: str,
135
- min_aesthetic_score: float = 5.0,
136
- normalization: Optional[dict] = None,
137
- pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
138
- ss_dec_path: Optional[str] = None,
139
- ss_dec_ckpt: Optional[str] = None,
140
- ):
141
- self.latent_model = latent_model
142
- self.min_aesthetic_score = min_aesthetic_score
143
- self.normalization = normalization
144
- self.value_range = (0, 1)
145
-
146
- super().__init__(
147
- roots,
148
- pretrained_ss_dec=pretrained_ss_dec,
149
- ss_dec_path=ss_dec_path,
150
- ss_dec_ckpt=ss_dec_ckpt,
151
- )
152
-
153
- if self.normalization is not None:
154
- self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1)
155
- self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1)
156
-
157
- def filter_metadata(self, metadata):
158
- stats = {}
159
- metadata = metadata[metadata[f'ss_latent_{self.latent_model}']]
160
- stats['With sparse structure latents'] = len(metadata)
161
- metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
162
- stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
163
- return metadata, stats
164
-
165
- def get_instance(self, root, instance):
166
- latent = np.load(os.path.join(root, 'ss_latents', self.latent_model, f'{instance}.npz'))
167
- z = torch.tensor(latent['mean']).float()
168
- if self.normalization is not None:
169
- z = (z - self.mean) / self.std
170
-
171
- pack = {
172
- 'x_0': z,
173
- }
174
- return pack
175
-
176
-
177
- class TextConditionedSparseStructureLatent(TextConditionedMixin, SparseStructureLatent):
178
- """
179
- Text-conditioned sparse structure dataset
180
- """
181
- pass
182
-
183
-
184
- class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent):
185
- """
186
- Image-conditioned sparse structure dataset
187
- """
188
- pass
189
 
 
1
+ import os
2
+ import json
3
+ from typing import *
4
+ import numpy as np
5
+ import torch
6
+ import utils3d
7
+ from ..representations.octree import DfsOctree as Octree
8
+ from ..renderers import OctreeRenderer
9
+ from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
10
+ from .. import models
11
+ from ..utils.dist_utils import read_file_dist
12
+
13
+
14
+ class SparseStructureLatentVisMixin:
15
+ def __init__(
16
+ self,
17
+ *args,
18
+ pretrained_ss_dec: str = 'cavargas10/TRELLIS/ckpts/ss_dec_conv3d_16l8_fp16',
19
+ ss_dec_path: Optional[str] = None,
20
+ ss_dec_ckpt: Optional[str] = None,
21
+ **kwargs
22
+ ):
23
+ super().__init__(*args, **kwargs)
24
+ self.ss_dec = None
25
+ self.pretrained_ss_dec = pretrained_ss_dec
26
+ self.ss_dec_path = ss_dec_path
27
+ self.ss_dec_ckpt = ss_dec_ckpt
28
+
29
+ def _loading_ss_dec(self):
30
+ if self.ss_dec is not None:
31
+ return
32
+ if self.ss_dec_path is not None:
33
+ cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r'))
34
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
35
+ ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt')
36
+ decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True))
37
+ else:
38
+ decoder = models.from_pretrained(self.pretrained_ss_dec)
39
+ self.ss_dec = decoder.cuda().eval()
40
+
41
+ def _delete_ss_dec(self):
42
+ del self.ss_dec
43
+ self.ss_dec = None
44
+
45
+ @torch.no_grad()
46
+ def decode_latent(self, z, batch_size=4):
47
+ self._loading_ss_dec()
48
+ ss = []
49
+ if self.normalization is not None:
50
+ z = z * self.std.to(z.device) + self.mean.to(z.device)
51
+ for i in range(0, z.shape[0], batch_size):
52
+ ss.append(self.ss_dec(z[i:i+batch_size]))
53
+ ss = torch.cat(ss, dim=0)
54
+ self._delete_ss_dec()
55
+ return ss
56
+
57
+ @torch.no_grad()
58
+ def visualize_sample(self, x_0: Union[torch.Tensor, dict]):
59
+ x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
60
+ x_0 = self.decode_latent(x_0.cuda())
61
+
62
+ renderer = OctreeRenderer()
63
+ renderer.rendering_options.resolution = 512
64
+ renderer.rendering_options.near = 0.8
65
+ renderer.rendering_options.far = 1.6
66
+ renderer.rendering_options.bg_color = (0, 0, 0)
67
+ renderer.rendering_options.ssaa = 4
68
+ renderer.pipe.primitive = 'voxel'
69
+
70
+ # Build camera
71
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
72
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
73
+ yaws = [y + yaws_offset for y in yaws]
74
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
75
+
76
+ exts = []
77
+ ints = []
78
+ for yaw, pitch in zip(yaws, pitch):
79
+ orig = torch.tensor([
80
+ np.sin(yaw) * np.cos(pitch),
81
+ np.cos(yaw) * np.cos(pitch),
82
+ np.sin(pitch),
83
+ ]).float().cuda() * 2
84
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
85
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
86
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
87
+ exts.append(extrinsics)
88
+ ints.append(intrinsics)
89
+
90
+ images = []
91
+
92
+ # Build each representation
93
+ x_0 = x_0.cuda()
94
+ for i in range(x_0.shape[0]):
95
+ representation = Octree(
96
+ depth=10,
97
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
98
+ device='cuda',
99
+ primitive='voxel',
100
+ sh_degree=0,
101
+ primitive_config={'solid': True},
102
+ )
103
+ coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False)
104
+ resolution = x_0.shape[-1]
105
+ representation.position = coords.float() / resolution
106
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(resolution)), dtype=torch.uint8, device='cuda')
107
+
108
+ image = torch.zeros(3, 1024, 1024).cuda()
109
+ tile = [2, 2]
110
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
111
+ res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
112
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
113
+ images.append(image)
114
+
115
+ return torch.stack(images)
116
+
117
+
118
+ class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase):
119
+ """
120
+ Sparse structure latent dataset
121
+
122
+ Args:
123
+ roots (str): path to the dataset
124
+ latent_model (str): name of the latent model
125
+ min_aesthetic_score (float): minimum aesthetic score
126
+ normalization (dict): normalization stats
127
+ pretrained_ss_dec (str): name of the pretrained sparse structure decoder
128
+ ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
129
+ ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
130
+ """
131
+ def __init__(self,
132
+ roots: str,
133
+ *,
134
+ latent_model: str,
135
+ min_aesthetic_score: float = 5.0,
136
+ normalization: Optional[dict] = None,
137
+ pretrained_ss_dec: str = 'cavargas10/TRELLIS/ckpts/ss_dec_conv3d_16l8_fp16',
138
+ ss_dec_path: Optional[str] = None,
139
+ ss_dec_ckpt: Optional[str] = None,
140
+ ):
141
+ self.latent_model = latent_model
142
+ self.min_aesthetic_score = min_aesthetic_score
143
+ self.normalization = normalization
144
+ self.value_range = (0, 1)
145
+
146
+ super().__init__(
147
+ roots,
148
+ pretrained_ss_dec=pretrained_ss_dec,
149
+ ss_dec_path=ss_dec_path,
150
+ ss_dec_ckpt=ss_dec_ckpt,
151
+ )
152
+
153
+ if self.normalization is not None:
154
+ self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1)
155
+ self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1)
156
+
157
+ def filter_metadata(self, metadata):
158
+ stats = {}
159
+ metadata = metadata[metadata[f'ss_latent_{self.latent_model}']]
160
+ stats['With sparse structure latents'] = len(metadata)
161
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
162
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
163
+ return metadata, stats
164
+
165
+ def get_instance(self, root, instance):
166
+ latent = np.load(os.path.join(root, 'ss_latents', self.latent_model, f'{instance}.npz'))
167
+ z = torch.tensor(latent['mean']).float()
168
+ if self.normalization is not None:
169
+ z = (z - self.mean) / self.std
170
+
171
+ pack = {
172
+ 'x_0': z,
173
+ }
174
+ return pack
175
+
176
+
177
+ class TextConditionedSparseStructureLatent(TextConditionedMixin, SparseStructureLatent):
178
+ """
179
+ Text-conditioned sparse structure dataset
180
+ """
181
+ pass
182
+
183
+
184
+ class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent):
185
+ """
186
+ Image-conditioned sparse structure dataset
187
+ """
188
+ pass
189