cavargas10 commited on
Commit
faab2d5
·
verified ·
1 Parent(s): 43de49b

Update dataset_toolkits/encode_latent.py

Browse files
Files changed (1) hide show
  1. dataset_toolkits/encode_latent.py +127 -127
dataset_toolkits/encode_latent.py CHANGED
@@ -1,127 +1,127 @@
1
- import os
2
- import sys
3
- sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
4
- import copy
5
- import json
6
- import argparse
7
- import torch
8
- import numpy as np
9
- import pandas as pd
10
- from tqdm import tqdm
11
- from easydict import EasyDict as edict
12
- from concurrent.futures import ThreadPoolExecutor
13
- from queue import Queue
14
-
15
- import trellis.models as models
16
- import trellis.modules.sparse as sp
17
-
18
-
19
- torch.set_grad_enabled(False)
20
-
21
-
22
- if __name__ == '__main__':
23
- parser = argparse.ArgumentParser()
24
- parser.add_argument('--output_dir', type=str, required=True,
25
- help='Directory to save the metadata')
26
- parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
27
- help='Filter objects with aesthetic score lower than this value')
28
- parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg',
29
- help='Feature model')
30
- parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/slat_enc_swin8_B_64l8_fp16',
31
- help='Pretrained encoder model')
32
- parser.add_argument('--model_root', type=str, default='results',
33
- help='Root directory of models')
34
- parser.add_argument('--enc_model', type=str, default=None,
35
- help='Encoder model. if specified, use this model instead of pretrained model')
36
- parser.add_argument('--ckpt', type=str, default=None,
37
- help='Checkpoint to load')
38
- parser.add_argument('--instances', type=str, default=None,
39
- help='Instances to process')
40
- parser.add_argument('--rank', type=int, default=0)
41
- parser.add_argument('--world_size', type=int, default=1)
42
- opt = parser.parse_args()
43
- opt = edict(vars(opt))
44
-
45
- if opt.enc_model is None:
46
- latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}'
47
- encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
48
- else:
49
- latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}'
50
- cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
51
- encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
52
- ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
53
- encoder.load_state_dict(torch.load(ckpt_path), strict=False)
54
- encoder.eval()
55
- print(f'Loaded model from {ckpt_path}')
56
-
57
- os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True)
58
-
59
- # get file list
60
- if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
61
- metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
62
- else:
63
- raise ValueError('metadata.csv not found')
64
- if opt.instances is not None:
65
- with open(opt.instances, 'r') as f:
66
- sha256s = [line.strip() for line in f]
67
- metadata = metadata[metadata['sha256'].isin(sha256s)]
68
- else:
69
- if opt.filter_low_aesthetic_score is not None:
70
- metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
71
- metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True]
72
- if f'latent_{latent_name}' in metadata.columns:
73
- metadata = metadata[metadata[f'latent_{latent_name}'] == False]
74
-
75
- start = len(metadata) * opt.rank // opt.world_size
76
- end = len(metadata) * (opt.rank + 1) // opt.world_size
77
- metadata = metadata[start:end]
78
- records = []
79
-
80
- # filter out objects that are already processed
81
- sha256s = list(metadata['sha256'].values)
82
- for sha256 in copy.copy(sha256s):
83
- if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')):
84
- records.append({'sha256': sha256, f'latent_{latent_name}': True})
85
- sha256s.remove(sha256)
86
-
87
- # encode latents
88
- load_queue = Queue(maxsize=4)
89
- try:
90
- with ThreadPoolExecutor(max_workers=32) as loader_executor, \
91
- ThreadPoolExecutor(max_workers=32) as saver_executor:
92
- def loader(sha256):
93
- try:
94
- feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz'))
95
- load_queue.put((sha256, feats))
96
- except Exception as e:
97
- print(f"Error loading features for {sha256}: {e}")
98
- loader_executor.map(loader, sha256s)
99
-
100
- def saver(sha256, pack):
101
- save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')
102
- np.savez_compressed(save_path, **pack)
103
- records.append({'sha256': sha256, f'latent_{latent_name}': True})
104
-
105
- for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
106
- sha256, feats = load_queue.get()
107
- feats = sp.SparseTensor(
108
- feats = torch.from_numpy(feats['patchtokens']).float(),
109
- coords = torch.cat([
110
- torch.zeros(feats['patchtokens'].shape[0], 1).int(),
111
- torch.from_numpy(feats['indices']).int(),
112
- ], dim=1),
113
- ).cuda()
114
- latent = encoder(feats, sample_posterior=False)
115
- assert torch.isfinite(latent.feats).all(), "Non-finite latent"
116
- pack = {
117
- 'feats': latent.feats.cpu().numpy().astype(np.float32),
118
- 'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8),
119
- }
120
- saver_executor.submit(saver, sha256, pack)
121
-
122
- saver_executor.shutdown(wait=True)
123
- except:
124
- print("Error happened during processing.")
125
-
126
- records = pd.DataFrame.from_records(records)
127
- records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False)
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
4
+ import copy
5
+ import json
6
+ import argparse
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+ from easydict import EasyDict as edict
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from queue import Queue
14
+
15
+ import trellis.models as models
16
+ import trellis.modules.sparse as sp
17
+
18
+
19
+ torch.set_grad_enabled(False)
20
+
21
+
22
+ if __name__ == '__main__':
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument('--output_dir', type=str, required=True,
25
+ help='Directory to save the metadata')
26
+ parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
27
+ help='Filter objects with aesthetic score lower than this value')
28
+ parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg',
29
+ help='Feature model')
30
+ parser.add_argument('--enc_pretrained', type=str, default='cavargas10/TRELLIS/ckpts/slat_enc_swin8_B_64l8_fp16',
31
+ help='Pretrained encoder model')
32
+ parser.add_argument('--model_root', type=str, default='results',
33
+ help='Root directory of models')
34
+ parser.add_argument('--enc_model', type=str, default=None,
35
+ help='Encoder model. if specified, use this model instead of pretrained model')
36
+ parser.add_argument('--ckpt', type=str, default=None,
37
+ help='Checkpoint to load')
38
+ parser.add_argument('--instances', type=str, default=None,
39
+ help='Instances to process')
40
+ parser.add_argument('--rank', type=int, default=0)
41
+ parser.add_argument('--world_size', type=int, default=1)
42
+ opt = parser.parse_args()
43
+ opt = edict(vars(opt))
44
+
45
+ if opt.enc_model is None:
46
+ latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}'
47
+ encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
48
+ else:
49
+ latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}'
50
+ cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
51
+ encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
52
+ ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
53
+ encoder.load_state_dict(torch.load(ckpt_path), strict=False)
54
+ encoder.eval()
55
+ print(f'Loaded model from {ckpt_path}')
56
+
57
+ os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True)
58
+
59
+ # get file list
60
+ if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
61
+ metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
62
+ else:
63
+ raise ValueError('metadata.csv not found')
64
+ if opt.instances is not None:
65
+ with open(opt.instances, 'r') as f:
66
+ sha256s = [line.strip() for line in f]
67
+ metadata = metadata[metadata['sha256'].isin(sha256s)]
68
+ else:
69
+ if opt.filter_low_aesthetic_score is not None:
70
+ metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
71
+ metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True]
72
+ if f'latent_{latent_name}' in metadata.columns:
73
+ metadata = metadata[metadata[f'latent_{latent_name}'] == False]
74
+
75
+ start = len(metadata) * opt.rank // opt.world_size
76
+ end = len(metadata) * (opt.rank + 1) // opt.world_size
77
+ metadata = metadata[start:end]
78
+ records = []
79
+
80
+ # filter out objects that are already processed
81
+ sha256s = list(metadata['sha256'].values)
82
+ for sha256 in copy.copy(sha256s):
83
+ if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')):
84
+ records.append({'sha256': sha256, f'latent_{latent_name}': True})
85
+ sha256s.remove(sha256)
86
+
87
+ # encode latents
88
+ load_queue = Queue(maxsize=4)
89
+ try:
90
+ with ThreadPoolExecutor(max_workers=32) as loader_executor, \
91
+ ThreadPoolExecutor(max_workers=32) as saver_executor:
92
+ def loader(sha256):
93
+ try:
94
+ feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz'))
95
+ load_queue.put((sha256, feats))
96
+ except Exception as e:
97
+ print(f"Error loading features for {sha256}: {e}")
98
+ loader_executor.map(loader, sha256s)
99
+
100
+ def saver(sha256, pack):
101
+ save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')
102
+ np.savez_compressed(save_path, **pack)
103
+ records.append({'sha256': sha256, f'latent_{latent_name}': True})
104
+
105
+ for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
106
+ sha256, feats = load_queue.get()
107
+ feats = sp.SparseTensor(
108
+ feats = torch.from_numpy(feats['patchtokens']).float(),
109
+ coords = torch.cat([
110
+ torch.zeros(feats['patchtokens'].shape[0], 1).int(),
111
+ torch.from_numpy(feats['indices']).int(),
112
+ ], dim=1),
113
+ ).cuda()
114
+ latent = encoder(feats, sample_posterior=False)
115
+ assert torch.isfinite(latent.feats).all(), "Non-finite latent"
116
+ pack = {
117
+ 'feats': latent.feats.cpu().numpy().astype(np.float32),
118
+ 'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8),
119
+ }
120
+ saver_executor.submit(saver, sha256, pack)
121
+
122
+ saver_executor.shutdown(wait=True)
123
+ except:
124
+ print("Error happened during processing.")
125
+
126
+ records = pd.DataFrame.from_records(records)
127
+ records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False)