staswrs commited on
Commit
b2a27a7
·
0 Parent(s):

🔥 Clean redeploy with updated app.py

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +35 -0
  3. README.md +13 -0
  4. app.py +249 -0
  5. briarmbg.py +464 -0
  6. get-pip.py +0 -0
  7. image_process.py +149 -0
  8. inference_triposg.py +137 -0
  9. requirements.txt +18 -0
  10. scripts/__pycache__/briarmbg.cpython-310.pyc +0 -0
  11. scripts/__pycache__/image_process.cpython-310.pyc +0 -0
  12. scripts/__pycache__/inference_triposg.cpython-310.pyc +0 -0
  13. scripts/inference_triposg_scribble.py +78 -0
  14. scripts/inference_vae.py +61 -0
  15. triposg/.DS_Store +0 -0
  16. triposg/LICENSE +80 -0
  17. triposg/__pycache__/inference_utils.cpython-310.pyc +0 -0
  18. triposg/diso/__init__.py +0 -0
  19. triposg/diso/__pycache__/__init__.cpython-310.pyc +0 -0
  20. triposg/diso/__pycache__/diff_dmc.cpython-310.pyc +0 -0
  21. triposg/diso/diff_dmc.py +15 -0
  22. triposg/diso/diffusion_utils.py +5 -0
  23. triposg/inference_utils.py +491 -0
  24. triposg/models/.DS_Store +0 -0
  25. triposg/models/__pycache__/attention_processor.cpython-310.pyc +0 -0
  26. triposg/models/__pycache__/embeddings.cpython-310.pyc +0 -0
  27. triposg/models/attention_processor.py +420 -0
  28. triposg/models/autoencoders/__init__.py +1 -0
  29. triposg/models/autoencoders/__pycache__/__init__.cpython-310.pyc +0 -0
  30. triposg/models/autoencoders/__pycache__/autoencoder_kl_triposg.cpython-310.pyc +0 -0
  31. triposg/models/autoencoders/__pycache__/vae.cpython-310.pyc +0 -0
  32. triposg/models/autoencoders/autoencoder_kl_triposg.py +536 -0
  33. triposg/models/autoencoders/vae.py +69 -0
  34. triposg/models/embeddings.py +96 -0
  35. triposg/models/transformers/__init__.py +61 -0
  36. triposg/models/transformers/__pycache__/__init__.cpython-310.pyc +0 -0
  37. triposg/models/transformers/__pycache__/modeling_outputs.cpython-310.pyc +0 -0
  38. triposg/models/transformers/__pycache__/triposg_transformer.cpython-310.pyc +0 -0
  39. triposg/models/transformers/modeling_outputs.py +8 -0
  40. triposg/models/transformers/triposg_transformer.py +775 -0
  41. triposg/pipelines/__pycache__/pipeline_triposg.cpython-310.pyc +0 -0
  42. triposg/pipelines/__pycache__/pipeline_triposg_output.cpython-310.pyc +0 -0
  43. triposg/pipelines/__pycache__/pipeline_utils.cpython-310.pyc +0 -0
  44. triposg/pipelines/pipeline_triposg.py +324 -0
  45. triposg/pipelines/pipeline_triposg_output.py +17 -0
  46. triposg/pipelines/pipeline_triposg_scribble.py +338 -0
  47. triposg/pipelines/pipeline_utils.py +96 -0
  48. triposg/schedulers/__init__.py +5 -0
  49. triposg/schedulers/scheduling_rectified_flow.py +327 -0
  50. triposg/utils/__pycache__/typing.cpython-310.pyc +0 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: EtTripoSg Api Gradio
3
+ emoji: 🐢
4
+ colorFrom: pink
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.50.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import subprocess
3
+
4
+ # # 🧹 Убираем pyenv, если вдруг остался .python-version
5
+ # os.environ.pop("PYENV_VERSION", None)
6
+
7
+ # # ⚙️ Устанавливаем torch и diso
8
+ # subprocess.run(["pip", "install", "torch", "wheel"], check=True)
9
+
10
+ # subprocess.run([
11
+ # "pip", "install", "--no-build-isolation",
12
+ # "diso@git+https://github.com/SarahWeiii/diso.git"
13
+ # ], check=True)
14
+
15
+ # # ✅ Только теперь импортируем всё остальное
16
+ # import gradio as gr
17
+ # import uuid
18
+ # import torch
19
+ # import zipfile
20
+ # import requests
21
+
22
+ # from inference_triposg import run_triposg
23
+ # from triposg.pipelines.pipeline_triposg import TripoSGPipeline
24
+ # from briarmbg import BriaRMBG
25
+
26
+ # # === Настройки устройства ===
27
+ # # device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ # # dtype = torch.float16 if torch.cuda.is_available() else torch.float32
29
+ # # dtype = torch.float32
30
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ # dtype = torch.float16 if device == "cuda" else torch.float32
32
+
33
+ # # === Проверка и загрузка весов ===
34
+ # weights_dir = "pretrained_weights"
35
+ # triposg_path = os.path.join(weights_dir, "TripoSG")
36
+ # rmbg_path = os.path.join(weights_dir, "RMBG-1.4")
37
+
38
+ # if not (os.path.exists(triposg_path) and os.path.exists(rmbg_path)):
39
+ # print("📦 Downloading pretrained weights from Hugging Face Dataset...")
40
+ # url = "https://huggingface.co/datasets/endlesstools/pretrained-assets/resolve/main/pretrained_models.zip"
41
+ # zip_path = "pretrained_models.zip"
42
+
43
+ # with requests.get(url, stream=True) as r:
44
+ # r.raise_for_status()
45
+ # with open(zip_path, "wb") as f:
46
+ # for chunk in r.iter_content(chunk_size=8192):
47
+ # f.write(chunk)
48
+
49
+ # print("📦 Extracting weights...")
50
+ # with zipfile.ZipFile(zip_path, "r") as zip_ref:
51
+ # zip_ref.extractall(weights_dir)
52
+
53
+ # os.remove(zip_path)
54
+ # print("✅ Weights ready.")
55
+
56
+ # # === Загрузка моделей ===
57
+ # pipe = TripoSGPipeline.from_pretrained(triposg_path).to(device, dtype)
58
+ # rmbg_net = BriaRMBG.from_pretrained(rmbg_path).to(device)
59
+ # rmbg_net.eval()
60
+
61
+ # # === Функция генерации ===
62
+ # def generate(file):
63
+ # temp_id = str(uuid.uuid4())
64
+ # input_path = f"/tmp/{temp_id}.png"
65
+ # output_path = f"/tmp/{temp_id}.glb"
66
+
67
+ # with open(input_path, "wb") as f:
68
+ # f.write(file)
69
+
70
+ # print("[DEBUG] Generating mesh...")
71
+ # try:
72
+ # mesh = run_triposg(
73
+ # pipe=pipe,
74
+ # image_input=input_path,
75
+ # rmbg_net=rmbg_net,
76
+ # seed=42,
77
+ # num_inference_steps=25,
78
+ # guidance_scale=5.0,
79
+ # faces=-1,
80
+ # )
81
+ # # mesh.export(output_path)
82
+ # if mesh is None:
83
+ # raise ValueError("Mesh generation failed")
84
+ # mesh.export(output_path)
85
+ # print(f"[DEBUG] Mesh saved to {output_path}")
86
+ # # return output_path
87
+ # if os.path.exists(output_path):
88
+ # return output_path
89
+ # else:
90
+ # return "Error: mesh export failed or file not found"
91
+ # except Exception as e:
92
+ # print("[ERROR]", e)
93
+ # return f"Error: {e}"
94
+
95
+ # # === Gradio-интерфейс ===
96
+ # demo = gr.Interface(
97
+ # fn=generate,
98
+ # inputs=gr.File(type="binary", label="Upload image"),
99
+ # outputs=gr.File(label="Generated .glb model"),
100
+ # title="TripoSG Image-to-3D",
101
+ # description="Upload an image and get back a 3D GLB model.",
102
+ # )
103
+
104
+ # # # === ВАЖНО: переменная должна называться `app` ===
105
+ # # app = demo.launch(inline=True, share=False, prevent_thread_lock=True)
106
+ # demo.launch()
107
+
108
+
109
+
110
+
111
+ import os
112
+ import subprocess
113
+
114
+ # Убираем pyenv, если вдруг остался .python-version
115
+ os.environ.pop("PYENV_VERSION", None)
116
+
117
+ # Установка зависимостей
118
+ subprocess.run(["pip", "install", "torch", "wheel"], check=True)
119
+ subprocess.run([
120
+ "pip", "install", "--no-build-isolation",
121
+ "diso@git+https://github.com/SarahWeiii/diso.git"
122
+ ], check=True)
123
+
124
+ # Импорты
125
+ import gradio as gr
126
+ import uuid
127
+ import torch
128
+ import zipfile
129
+ import requests
130
+ import traceback
131
+
132
+ from inference_triposg import run_triposg
133
+ from triposg.pipelines.pipeline_triposg import TripoSGPipeline
134
+ from briarmbg import BriaRMBG
135
+
136
+ # Настройки устройства
137
+ device = "cuda" if torch.cuda.is_available() else "cpu"
138
+ dtype = torch.float16 if device == "cuda" else torch.float32
139
+
140
+ # Загрузка весов
141
+ weights_dir = "pretrained_weights"
142
+ triposg_path = os.path.join(weights_dir, "TripoSG")
143
+ rmbg_path = os.path.join(weights_dir, "RMBG-1.4")
144
+
145
+ if not (os.path.exists(triposg_path) and os.path.exists(rmbg_path)):
146
+ print("📦 Downloading pretrained weights...")
147
+ url = "https://huggingface.co/datasets/endlesstools/pretrained-assets/resolve/main/pretrained_models.zip"
148
+ zip_path = "pretrained_models.zip"
149
+
150
+ with requests.get(url, stream=True) as r:
151
+ r.raise_for_status()
152
+ with open(zip_path, "wb") as f:
153
+ for chunk in r.iter_content(chunk_size=8192):
154
+ f.write(chunk)
155
+
156
+ print("📦 Extracting weights...")
157
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
158
+ zip_ref.extractall(weights_dir)
159
+
160
+ os.remove(zip_path)
161
+ print("✅ Weights ready.")
162
+
163
+ # Загрузка моделей
164
+ pipe = TripoSGPipeline.from_pretrained(triposg_path).to(device, dtype)
165
+ rmbg_net = BriaRMBG.from_pretrained(rmbg_path).to(device)
166
+ rmbg_net.eval()
167
+
168
+ # Генерация .glb
169
+ def generate(image_path):
170
+ print("[API CALL] image_path received:", image_path)
171
+ print("[API CALL] File exists:", os.path.exists(image_path))
172
+
173
+ temp_id = str(uuid.uuid4())
174
+ output_path = f"/tmp/{temp_id}.glb"
175
+
176
+ print("[DEBUG] Generating mesh from:", image_path)
177
+
178
+ try:
179
+ mesh = run_triposg(
180
+ pipe=pipe,
181
+ image_input=image_path,
182
+ rmbg_net=rmbg_net,
183
+ seed=42,
184
+ num_inference_steps=25,
185
+ guidance_scale=5.0,
186
+ faces=-1,
187
+ )
188
+
189
+ if mesh is None:
190
+ raise ValueError("Mesh generation failed")
191
+
192
+ mesh.export(output_path)
193
+ print(f"[DEBUG] Mesh saved to {output_path}")
194
+
195
+ return output_path if os.path.exists(output_path) else "Error: output file not found"
196
+ # except Exception as e:
197
+ # print("[ERROR]", e)
198
+ # return f"Error: {e}"
199
+ except Exception as e:
200
+ import traceback
201
+ print("[ERROR]", e)
202
+ traceback.print_exc() # ← выведет полную трассировку в логи
203
+ return f"Error: {e}"
204
+
205
+ # Интерфейс Gradio
206
+ demo = gr.Interface(
207
+ fn=generate,
208
+ inputs=gr.Image(type="filepath", label="Upload image"),
209
+ outputs=gr.File(label="Download .glb"),
210
+ title="TripoSG Image to 3D",
211
+ description="Upload an image to generate a 3D model (.glb)",
212
+ )
213
+
214
+ # Запуск
215
+ demo.launch()
216
+
217
+
218
+
219
+
220
+
221
+
222
+ # import gradio as gr
223
+ # import uuid
224
+ # import os
225
+ # import traceback
226
+
227
+ # def generate(image_path):
228
+ # try:
229
+ # print("[DEBUG] got image path:", image_path)
230
+ # print("[DEBUG] file exists:", os.path.exists(image_path))
231
+
232
+ # out_path = f"/tmp/{uuid.uuid4()}.txt"
233
+ # with open(out_path, "w") as f:
234
+ # f.write(f"Received: {image_path}")
235
+
236
+ # return out_path
237
+
238
+ # except Exception as e:
239
+ # print("[ERROR]", e)
240
+ # traceback.print_exc()
241
+ # return f"Error: {e}"
242
+
243
+ # demo = gr.Interface(
244
+ # fn=generate,
245
+ # inputs=gr.Image(type="filepath"),
246
+ # outputs=gr.File()
247
+ # )
248
+
249
+ # demo.launch()
briarmbg.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source and Copyright Notice:
3
+ This code is from briaai/RMBG-1.4
4
+ Original repository: https://huggingface.co/briaai/RMBG-1.4
5
+ Copyright belongs to briaai
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+ class REBNCONV(nn.Module):
14
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
15
+ super(REBNCONV,self).__init__()
16
+
17
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
18
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
19
+ self.relu_s1 = nn.ReLU(inplace=True)
20
+
21
+ def forward(self,x):
22
+
23
+ hx = x
24
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
25
+
26
+ return xout
27
+
28
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
29
+ def _upsample_like(src,tar):
30
+
31
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
32
+
33
+ return src
34
+
35
+
36
+ ### RSU-7 ###
37
+ class RSU7(nn.Module):
38
+
39
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
40
+ super(RSU7,self).__init__()
41
+
42
+ self.in_ch = in_ch
43
+ self.mid_ch = mid_ch
44
+ self.out_ch = out_ch
45
+
46
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
47
+
48
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
49
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
53
+
54
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
55
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
56
+
57
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
58
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
59
+
60
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
61
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
62
+
63
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
64
+
65
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
66
+
67
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
68
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
69
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
70
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
71
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
72
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
73
+
74
+ def forward(self,x):
75
+ b, c, h, w = x.shape
76
+
77
+ hx = x
78
+ hxin = self.rebnconvin(hx)
79
+
80
+ hx1 = self.rebnconv1(hxin)
81
+ hx = self.pool1(hx1)
82
+
83
+ hx2 = self.rebnconv2(hx)
84
+ hx = self.pool2(hx2)
85
+
86
+ hx3 = self.rebnconv3(hx)
87
+ hx = self.pool3(hx3)
88
+
89
+ hx4 = self.rebnconv4(hx)
90
+ hx = self.pool4(hx4)
91
+
92
+ hx5 = self.rebnconv5(hx)
93
+ hx = self.pool5(hx5)
94
+
95
+ hx6 = self.rebnconv6(hx)
96
+
97
+ hx7 = self.rebnconv7(hx6)
98
+
99
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
100
+ hx6dup = _upsample_like(hx6d,hx5)
101
+
102
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
103
+ hx5dup = _upsample_like(hx5d,hx4)
104
+
105
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
106
+ hx4dup = _upsample_like(hx4d,hx3)
107
+
108
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
109
+ hx3dup = _upsample_like(hx3d,hx2)
110
+
111
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
112
+ hx2dup = _upsample_like(hx2d,hx1)
113
+
114
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
115
+
116
+ return hx1d + hxin
117
+
118
+
119
+ ### RSU-6 ###
120
+ class RSU6(nn.Module):
121
+
122
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
123
+ super(RSU6,self).__init__()
124
+
125
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
126
+
127
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
128
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
129
+
130
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
131
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
132
+
133
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
134
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
135
+
136
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
137
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
138
+
139
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
140
+
141
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
142
+
143
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
144
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
145
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
146
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
147
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
148
+
149
+ def forward(self,x):
150
+
151
+ hx = x
152
+
153
+ hxin = self.rebnconvin(hx)
154
+
155
+ hx1 = self.rebnconv1(hxin)
156
+ hx = self.pool1(hx1)
157
+
158
+ hx2 = self.rebnconv2(hx)
159
+ hx = self.pool2(hx2)
160
+
161
+ hx3 = self.rebnconv3(hx)
162
+ hx = self.pool3(hx3)
163
+
164
+ hx4 = self.rebnconv4(hx)
165
+ hx = self.pool4(hx4)
166
+
167
+ hx5 = self.rebnconv5(hx)
168
+
169
+ hx6 = self.rebnconv6(hx5)
170
+
171
+
172
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
173
+ hx5dup = _upsample_like(hx5d,hx4)
174
+
175
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
176
+ hx4dup = _upsample_like(hx4d,hx3)
177
+
178
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
179
+ hx3dup = _upsample_like(hx3d,hx2)
180
+
181
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
182
+ hx2dup = _upsample_like(hx2d,hx1)
183
+
184
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
185
+
186
+ return hx1d + hxin
187
+
188
+ ### RSU-5 ###
189
+ class RSU5(nn.Module):
190
+
191
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
192
+ super(RSU5,self).__init__()
193
+
194
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
195
+
196
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
197
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
198
+
199
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
200
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
201
+
202
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
203
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
204
+
205
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
206
+
207
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
208
+
209
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
210
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
211
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
212
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
213
+
214
+ def forward(self,x):
215
+
216
+ hx = x
217
+
218
+ hxin = self.rebnconvin(hx)
219
+
220
+ hx1 = self.rebnconv1(hxin)
221
+ hx = self.pool1(hx1)
222
+
223
+ hx2 = self.rebnconv2(hx)
224
+ hx = self.pool2(hx2)
225
+
226
+ hx3 = self.rebnconv3(hx)
227
+ hx = self.pool3(hx3)
228
+
229
+ hx4 = self.rebnconv4(hx)
230
+
231
+ hx5 = self.rebnconv5(hx4)
232
+
233
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
234
+ hx4dup = _upsample_like(hx4d,hx3)
235
+
236
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
237
+ hx3dup = _upsample_like(hx3d,hx2)
238
+
239
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
240
+ hx2dup = _upsample_like(hx2d,hx1)
241
+
242
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
243
+
244
+ return hx1d + hxin
245
+
246
+ ### RSU-4 ###
247
+ class RSU4(nn.Module):
248
+
249
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
250
+ super(RSU4,self).__init__()
251
+
252
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
253
+
254
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
255
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
256
+
257
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
258
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
259
+
260
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
261
+
262
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
263
+
264
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
265
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
266
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
267
+
268
+ def forward(self,x):
269
+
270
+ hx = x
271
+
272
+ hxin = self.rebnconvin(hx)
273
+
274
+ hx1 = self.rebnconv1(hxin)
275
+ hx = self.pool1(hx1)
276
+
277
+ hx2 = self.rebnconv2(hx)
278
+ hx = self.pool2(hx2)
279
+
280
+ hx3 = self.rebnconv3(hx)
281
+
282
+ hx4 = self.rebnconv4(hx3)
283
+
284
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
285
+ hx3dup = _upsample_like(hx3d,hx2)
286
+
287
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
288
+ hx2dup = _upsample_like(hx2d,hx1)
289
+
290
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
291
+
292
+ return hx1d + hxin
293
+
294
+ ### RSU-4F ###
295
+ class RSU4F(nn.Module):
296
+
297
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
298
+ super(RSU4F,self).__init__()
299
+
300
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
301
+
302
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
303
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
304
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
305
+
306
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
307
+
308
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
309
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
310
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
311
+
312
+ def forward(self,x):
313
+
314
+ hx = x
315
+
316
+ hxin = self.rebnconvin(hx)
317
+
318
+ hx1 = self.rebnconv1(hxin)
319
+ hx2 = self.rebnconv2(hx1)
320
+ hx3 = self.rebnconv3(hx2)
321
+
322
+ hx4 = self.rebnconv4(hx3)
323
+
324
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
325
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
326
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
327
+
328
+ return hx1d + hxin
329
+
330
+
331
+ class myrebnconv(nn.Module):
332
+ def __init__(self, in_ch=3,
333
+ out_ch=1,
334
+ kernel_size=3,
335
+ stride=1,
336
+ padding=1,
337
+ dilation=1,
338
+ groups=1):
339
+ super(myrebnconv,self).__init__()
340
+
341
+ self.conv = nn.Conv2d(in_ch,
342
+ out_ch,
343
+ kernel_size=kernel_size,
344
+ stride=stride,
345
+ padding=padding,
346
+ dilation=dilation,
347
+ groups=groups)
348
+ self.bn = nn.BatchNorm2d(out_ch)
349
+ self.rl = nn.ReLU(inplace=True)
350
+
351
+ def forward(self,x):
352
+ return self.rl(self.bn(self.conv(x)))
353
+
354
+
355
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
356
+
357
+ def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
358
+ super(BriaRMBG,self).__init__()
359
+ in_ch=config["in_ch"]
360
+ out_ch=config["out_ch"]
361
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
362
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
363
+
364
+ self.stage1 = RSU7(64,32,64)
365
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
366
+
367
+ self.stage2 = RSU6(64,32,128)
368
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
369
+
370
+ self.stage3 = RSU5(128,64,256)
371
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
372
+
373
+ self.stage4 = RSU4(256,128,512)
374
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
375
+
376
+ self.stage5 = RSU4F(512,256,512)
377
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
378
+
379
+ self.stage6 = RSU4F(512,256,512)
380
+
381
+ # decoder
382
+ self.stage5d = RSU4F(1024,256,512)
383
+ self.stage4d = RSU4(1024,128,256)
384
+ self.stage3d = RSU5(512,64,128)
385
+ self.stage2d = RSU6(256,32,64)
386
+ self.stage1d = RSU7(128,16,64)
387
+
388
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
389
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
390
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
391
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
392
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
393
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
394
+
395
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
396
+
397
+ def forward(self,x):
398
+
399
+ hx = x
400
+
401
+ hxin = self.conv_in(hx)
402
+ #hx = self.pool_in(hxin)
403
+
404
+ #stage 1
405
+ hx1 = self.stage1(hxin)
406
+ hx = self.pool12(hx1)
407
+
408
+ #stage 2
409
+ hx2 = self.stage2(hx)
410
+ hx = self.pool23(hx2)
411
+
412
+ #stage 3
413
+ hx3 = self.stage3(hx)
414
+ hx = self.pool34(hx3)
415
+
416
+ #stage 4
417
+ hx4 = self.stage4(hx)
418
+ hx = self.pool45(hx4)
419
+
420
+ #stage 5
421
+ hx5 = self.stage5(hx)
422
+ hx = self.pool56(hx5)
423
+
424
+ #stage 6
425
+ hx6 = self.stage6(hx)
426
+ hx6up = _upsample_like(hx6,hx5)
427
+
428
+ #-------------------- decoder --------------------
429
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
430
+ hx5dup = _upsample_like(hx5d,hx4)
431
+
432
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
433
+ hx4dup = _upsample_like(hx4d,hx3)
434
+
435
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
436
+ hx3dup = _upsample_like(hx3d,hx2)
437
+
438
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
439
+ hx2dup = _upsample_like(hx2d,hx1)
440
+
441
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
442
+
443
+
444
+ #side output
445
+ d1 = self.side1(hx1d)
446
+ d1 = _upsample_like(d1,x)
447
+
448
+ d2 = self.side2(hx2d)
449
+ d2 = _upsample_like(d2,x)
450
+
451
+ d3 = self.side3(hx3d)
452
+ d3 = _upsample_like(d3,x)
453
+
454
+ d4 = self.side4(hx4d)
455
+ d4 = _upsample_like(d4,x)
456
+
457
+ d5 = self.side5(hx5d)
458
+ d5 = _upsample_like(d5,x)
459
+
460
+ d6 = self.side6(hx6)
461
+ d6 = _upsample_like(d6,x)
462
+
463
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
464
+
get-pip.py ADDED
The diff for this file is too large to render. See raw diff
 
image_process.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ from skimage.morphology import remove_small_objects
4
+ from skimage.measure import label
5
+ import numpy as np
6
+ from PIL import Image
7
+ import cv2
8
+ from torchvision import transforms
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms.functional as TF
12
+
13
+ def find_bounding_box(gray_image):
14
+ _, binary_image = cv2.threshold(gray_image, 1, 255, cv2.THRESH_BINARY)
15
+ contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
16
+ max_contour = max(contours, key=cv2.contourArea)
17
+ x, y, w, h = cv2.boundingRect(max_contour)
18
+ return x, y, w, h
19
+
20
+ def load_image(img_path, bg_color=None, rmbg_net=None, padding_ratio=0.1):
21
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
22
+ if img is None:
23
+ return f"invalid image path {img_path}"
24
+
25
+ def is_valid_alpha(alpha, min_ratio = 0.01):
26
+ bins = 20
27
+ if isinstance(alpha, np.ndarray):
28
+ hist = cv2.calcHist([alpha], [0], None, [bins], [0, 256])
29
+ else:
30
+ hist = torch.histc(alpha, bins=bins, min=0, max=1)
31
+ min_hist_val = alpha.shape[0] * alpha.shape[1] * min_ratio
32
+ return hist[0] >= min_hist_val and hist[-1] >= min_hist_val
33
+
34
+ def rmbg(image: torch.Tensor) -> torch.Tensor:
35
+ image = TF.normalize(image, [0.5,0.5,0.5], [1.0,1.0,1.0]).unsqueeze(0)
36
+ result=rmbg_net(image)
37
+ return result[0][0]
38
+
39
+ if len(img.shape) == 2:
40
+ num_channels = 1
41
+ else:
42
+ num_channels = img.shape[2]
43
+
44
+ # check if too large
45
+ height, width = img.shape[:2]
46
+ if height > width:
47
+ scale = 2000 / height
48
+ else:
49
+ scale = 2000 / width
50
+ if scale < 1:
51
+ new_size = (int(width * scale), int(height * scale))
52
+ img = cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)
53
+
54
+ if img.dtype != 'uint8':
55
+ img = (img * (255. / np.iinfo(img.dtype).max)).astype(np.uint8)
56
+
57
+ rgb_image = None
58
+ alpha = None
59
+
60
+ if num_channels == 1:
61
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
62
+ elif num_channels == 3:
63
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
64
+ elif num_channels == 4:
65
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
66
+
67
+ b, g, r, alpha = cv2.split(img)
68
+ if not is_valid_alpha(alpha):
69
+ alpha = None
70
+ else:
71
+ alpha_gpu = torch.from_numpy(alpha).unsqueeze(0).cuda().float() / 255.
72
+ else:
73
+ return f"invalid image: channels {num_channels}"
74
+
75
+ rgb_image_gpu = torch.from_numpy(rgb_image).cuda().float().permute(2, 0, 1) / 255.
76
+ if alpha is None:
77
+ resize_transform = transforms.Resize((384, 384), antialias=True)
78
+ rgb_image_resized = resize_transform(rgb_image_gpu)
79
+ normalize_image = rgb_image_resized * 2 - 1
80
+
81
+ mean_color = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).cuda()
82
+ resize_transform = transforms.Resize((1024, 1024), antialias=True)
83
+ rgb_image_resized = resize_transform(rgb_image_gpu)
84
+ max_value = rgb_image_resized.flatten().max()
85
+ if max_value < 1e-3:
86
+ return "invalid image: pure black image"
87
+ normalize_image = rgb_image_resized / max_value - mean_color
88
+ normalize_image = normalize_image.unsqueeze(0)
89
+ resize_transform = transforms.Resize((rgb_image_gpu.shape[1], rgb_image_gpu.shape[2]), antialias=True)
90
+
91
+ # seg from rmbg
92
+ alpha_gpu_rmbg = rmbg(rgb_image_resized)
93
+ alpha_gpu_rmbg = alpha_gpu_rmbg.squeeze(0)
94
+ alpha_gpu_rmbg = resize_transform(alpha_gpu_rmbg)
95
+ ma, mi = alpha_gpu_rmbg.max(), alpha_gpu_rmbg.min()
96
+ alpha_gpu_rmbg = (alpha_gpu_rmbg - mi) / (ma - mi)
97
+
98
+ alpha_gpu = alpha_gpu_rmbg
99
+
100
+ alpha_gpu_tmp = alpha_gpu * 255
101
+ alpha = alpha_gpu_tmp.to(torch.uint8).squeeze().cpu().numpy()
102
+
103
+ _, alpha = cv2.threshold(alpha, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
104
+ labeled_alpha = label(alpha)
105
+ cleaned_alpha = remove_small_objects(labeled_alpha, min_size=200)
106
+ cleaned_alpha = (cleaned_alpha > 0).astype(np.uint8)
107
+ alpha = cleaned_alpha * 255
108
+ alpha_gpu = torch.from_numpy(cleaned_alpha).cuda().float().unsqueeze(0)
109
+ x, y, w, h = find_bounding_box(alpha)
110
+
111
+ # If alpha is provided, the bounds of all foreground are used
112
+ else:
113
+ rows, cols = np.where(alpha > 0)
114
+ if rows.size > 0 and cols.size > 0:
115
+ x_min = np.min(cols)
116
+ y_min = np.min(rows)
117
+ x_max = np.max(cols)
118
+ y_max = np.max(rows)
119
+
120
+ width = x_max - x_min + 1
121
+ height = y_max - y_min + 1
122
+ x, y, w, h = x_min, y_min, width, height
123
+
124
+ if np.all(alpha==0):
125
+ raise ValueError(f"input image too small")
126
+
127
+ bg_gray = bg_color[0]
128
+ bg_color = torch.from_numpy(bg_color).float().cuda().repeat(alpha_gpu.shape[1], alpha_gpu.shape[2], 1).permute(2, 0, 1)
129
+ rgb_image_gpu = rgb_image_gpu * alpha_gpu + bg_color * (1 - alpha_gpu)
130
+ padding_size = [0] * 6
131
+ if w > h:
132
+ padding_size[0] = int(w * padding_ratio)
133
+ padding_size[2] = int(padding_size[0] + (w - h) / 2)
134
+ else:
135
+ padding_size[2] = int(h * padding_ratio)
136
+ padding_size[0] = int(padding_size[2] + (h - w) / 2)
137
+ padding_size[1] = padding_size[0]
138
+ padding_size[3] = padding_size[2]
139
+ padded_tensor = F.pad(rgb_image_gpu[:, y:(y+h), x:(x+w)], pad=tuple(padding_size), mode='constant', value=bg_gray)
140
+
141
+ return padded_tensor
142
+
143
+ def prepare_image(image_path, bg_color, rmbg_net=None):
144
+ if os.path.isfile(image_path):
145
+ img_tensor = load_image(image_path, bg_color=bg_color, rmbg_net=rmbg_net)
146
+ img_np = img_tensor.permute(1,2,0).cpu().numpy()
147
+ img_pil = Image.fromarray((img_np*255).astype(np.uint8))
148
+
149
+ return img_pil
inference_triposg.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from glob import glob
5
+ from typing import Any, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import trimesh
10
+ from huggingface_hub import snapshot_download
11
+ from PIL import Image
12
+
13
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ from triposg.pipelines.pipeline_triposg import TripoSGPipeline
16
+ from image_process import prepare_image
17
+ from briarmbg import BriaRMBG
18
+
19
+ import pymeshlab
20
+
21
+
22
+ # @torch.no_grad()
23
+ # def run_triposg(
24
+ # pipe: Any,
25
+ # image_input: Union[str, Image.Image],
26
+ # rmbg_net: Any,
27
+ # seed: int,
28
+ # num_inference_steps: int = 50,
29
+ # guidance_scale: float = 7.0,
30
+ # faces: int = -1,
31
+ # ) -> trimesh.Scene:
32
+
33
+ # img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
34
+
35
+ # outputs = pipe(
36
+ # image=img_pil,
37
+ # generator=torch.Generator(device=pipe.device).manual_seed(seed),
38
+ # num_inference_steps=num_inference_steps,
39
+ # guidance_scale=guidance_scale,
40
+ # ).samples[0]
41
+ # mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
42
+
43
+ # if faces > 0:
44
+ # mesh = simplify_mesh(mesh, faces)
45
+
46
+ # return mesh
47
+
48
+ @torch.no_grad()
49
+ def run_triposg(
50
+ pipe: Any,
51
+ image_input: Union[str, Image.Image],
52
+ rmbg_net: Any,
53
+ seed: int,
54
+ num_inference_steps: int = 50,
55
+ guidance_scale: float = 7.0,
56
+ faces: int = -1,
57
+ ) -> trimesh.Scene:
58
+ print("[DEBUG] Preparing image...")
59
+ img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
60
+
61
+ print("[DEBUG] Running TripoSG pipeline...")
62
+ outputs = pipe(
63
+ image=img_pil,
64
+ generator=torch.Generator(device=pipe.device).manual_seed(seed),
65
+ num_inference_steps=num_inference_steps,
66
+ guidance_scale=guidance_scale,
67
+ ).samples[0]
68
+
69
+ print("[DEBUG] TripoSG output keys:", type(outputs), outputs[0].shape, outputs[1].shape)
70
+
71
+ mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
72
+ print(f"[DEBUG] Mesh created: {mesh.vertices.shape[0]} verts / {mesh.faces.shape[0]} faces")
73
+
74
+ if faces > 0:
75
+ print(f"[DEBUG] Simplifying mesh to {faces} faces")
76
+ mesh = simplify_mesh(mesh, faces)
77
+
78
+ return mesh
79
+
80
+
81
+ def mesh_to_pymesh(vertices, faces):
82
+ mesh = pymeshlab.Mesh(vertex_matrix=vertices, face_matrix=faces)
83
+ ms = pymeshlab.MeshSet()
84
+ ms.add_mesh(mesh)
85
+ return ms
86
+
87
+ def pymesh_to_trimesh(mesh):
88
+ verts = mesh.vertex_matrix()#.tolist()
89
+ faces = mesh.face_matrix()#.tolist()
90
+ return trimesh.Trimesh(vertices=verts, faces=faces) #, vID, fID
91
+
92
+ def simplify_mesh(mesh: trimesh.Trimesh, n_faces):
93
+ if mesh.faces.shape[0] > n_faces:
94
+ ms = mesh_to_pymesh(mesh.vertices, mesh.faces)
95
+ ms.meshing_merge_close_vertices()
96
+ ms.meshing_decimation_quadric_edge_collapse(targetfacenum = n_faces)
97
+ return pymesh_to_trimesh(ms.current_mesh())
98
+ else:
99
+ return mesh
100
+
101
+ if __name__ == "__main__":
102
+ device = "cuda"
103
+ dtype = torch.float16
104
+
105
+ parser = argparse.ArgumentParser()
106
+ parser.add_argument("--image-input", type=str, required=True)
107
+ parser.add_argument("--output-path", type=str, default="./output.glb")
108
+ parser.add_argument("--seed", type=int, default=42)
109
+ parser.add_argument("--num-inference-steps", type=int, default=50)
110
+ parser.add_argument("--guidance-scale", type=float, default=7.0)
111
+ parser.add_argument("--faces", type=int, default=-1)
112
+ args = parser.parse_args()
113
+
114
+ # download pretrained weights
115
+ triposg_weights_dir = "pretrained_weights/TripoSG"
116
+ rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
117
+ snapshot_download(repo_id="VAST-AI/TripoSG", local_dir=triposg_weights_dir)
118
+ snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
119
+
120
+ # init rmbg model for background removal
121
+ rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
122
+ rmbg_net.eval()
123
+
124
+ # init tripoSG pipeline
125
+ pipe: TripoSGPipeline = TripoSGPipeline.from_pretrained(triposg_weights_dir).to(device, dtype)
126
+
127
+ # run inference
128
+ run_triposg(
129
+ pipe,
130
+ image_input=args.image_input,
131
+ rmbg_net=rmbg_net,
132
+ seed=args.seed,
133
+ num_inference_steps=args.num_inference_steps,
134
+ guidance_scale=args.guidance_scale,
135
+ faces=args.faces,
136
+ ).export(args.output_path)
137
+ print(f"Mesh saved to {args.output_path}")
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ trimesh
6
+ pillow
7
+ omegaconf
8
+ diffusers
9
+ transformers
10
+ opencv-python
11
+ scikit-image
12
+ python-multipart
13
+ peft
14
+ jaxtyping
15
+ typeguard
16
+ pymeshlab
17
+ einops
18
+
scripts/__pycache__/briarmbg.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
scripts/__pycache__/image_process.cpython-310.pyc ADDED
Binary file (4.49 kB). View file
 
scripts/__pycache__/inference_triposg.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
scripts/inference_triposg_scribble.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from glob import glob
5
+ from typing import Any, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import trimesh
10
+ from huggingface_hub import snapshot_download
11
+ from PIL import Image
12
+
13
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ from triposg.pipelines.pipeline_triposg_scribble import TripoSGScribblePipeline
16
+
17
+
18
+
19
+ @torch.no_grad()
20
+ def run_triposg_scribble(
21
+ pipe: Any,
22
+ image_input: Union[str, Image.Image],
23
+ prompt: str,
24
+ seed: int,
25
+ num_inference_steps: int = 16,
26
+ scribble_confidence: float = 0.4,
27
+ prompt_confidence: float = 1.0
28
+ ) -> trimesh.Scene:
29
+
30
+ img_pil = Image.open(image_input).convert("RGB")
31
+
32
+ outputs = pipe(
33
+ image=img_pil,
34
+ prompt=prompt,
35
+ generator=torch.Generator(device=pipe.device).manual_seed(seed),
36
+ num_inference_steps=num_inference_steps,
37
+ guidance_scale=0, # this is a CFG-distilled model
38
+ attention_kwargs={"cross_attention_scale": prompt_confidence, "cross_attention_2_scale": scribble_confidence},
39
+ use_flash_decoder=False, # there're some boundary problems when using flash decoder with this model
40
+ dense_octree_depth=8, hierarchical_octree_depth=8 # 256 resolution for faster inference
41
+ ).samples[0]
42
+ mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
43
+ return mesh
44
+
45
+
46
+ if __name__ == "__main__":
47
+ device = "cuda"
48
+ dtype = torch.float16
49
+
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument("--image-input", type=str, required=True)
52
+ parser.add_argument("--prompt", type=str, required=True)
53
+ parser.add_argument("--output-path", type=str, default="./output.glb")
54
+ parser.add_argument("--seed", type=int, default=42)
55
+ parser.add_argument("--num-inference-steps", type=int, default=16)
56
+ # feel free to tune the scribble confidence, 0.3-0.5 often gives good results for hand-drawn sketches
57
+ parser.add_argument("--scribble-conf", type=float, default=0.4)
58
+ parser.add_argument("--prompt-conf", type=float, default=1.0)
59
+ args = parser.parse_args()
60
+
61
+ # download pretrained weights
62
+ triposg_scribble_weights_dir = "pretrained_weights/TripoSG-scribble"
63
+ snapshot_download(repo_id="VAST-AI/TripoSG-scribble", local_dir=triposg_scribble_weights_dir)
64
+
65
+ # init tripoSG pipeline
66
+ pipe: TripoSGScribblePipeline = TripoSGScribblePipeline.from_pretrained(triposg_scribble_weights_dir).to(device, dtype)
67
+
68
+ # run inference
69
+ run_triposg_scribble(
70
+ pipe,
71
+ image_input=args.image_input,
72
+ prompt=args.prompt,
73
+ seed=args.seed,
74
+ num_inference_steps=args.num_inference_steps,
75
+ scribble_confidence=args.scribble_conf,
76
+ prompt_confidence=args.prompt_conf,
77
+ ).export(args.output_path)
78
+ print("Mesh saved to", args.output_path)
scripts/inference_vae.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+ import trimesh
5
+
6
+ from triposg.inference_utils import hierarchical_extract_geometry
7
+ from triposg.models.autoencoders import TripoSGVAEModel
8
+ from huggingface_hub import snapshot_download
9
+
10
+
11
+
12
+ def load_surface(data_path, num_pc=204800):
13
+ data = np.load(data_path, allow_pickle=True).tolist()
14
+ surface = data["surface_points"] # Nx3
15
+ normal = data["surface_normals"] # Nx3
16
+
17
+ rng = np.random.default_rng()
18
+ ind = rng.choice(surface.shape[0], num_pc, replace=False)
19
+ surface = torch.FloatTensor(surface[ind])
20
+ normal = torch.FloatTensor(normal[ind])
21
+ surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
22
+
23
+ return surface
24
+
25
+
26
+ if __name__ == "__main__":
27
+ device = "cuda"
28
+ dtype = torch.float16
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--surface-input", type=str, required=True)
31
+ args = parser.parse_args()
32
+
33
+ # download pretrained weights
34
+ triposg_weights_dir = "pretrained_weights/TripoSG"
35
+ snapshot_download(repo_id="VAST-AI/TripoSG", local_dir=triposg_weights_dir)
36
+
37
+ vae: TripoSGVAEModel = TripoSGVAEModel.from_pretrained(
38
+ triposg_weights_dir,
39
+ subfolder="vae",
40
+ ).to(device, dtype=dtype)
41
+
42
+ # load surface from sdf and encode
43
+ surface = load_surface(
44
+ args.surface_input, num_pc=204800
45
+ ).to(device, dtype=dtype)
46
+ sample = vae.encode(surface).latent_dist.sample()
47
+
48
+ # vae infer
49
+ with torch.no_grad():
50
+ geometric_func = lambda x: vae.decode(sample, sampled_points=x).sample
51
+ output = hierarchical_extract_geometry(
52
+ geometric_func,
53
+ device,
54
+ bounds=(-1.005, -1.005, -1.005, 1.005, 1.005, 1.005),
55
+ dense_octree_depth=8,
56
+ hierarchical_octree_depth=9,
57
+ )
58
+ meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output]
59
+
60
+ meshes[0].export("test_vae.glb")
61
+
triposg/.DS_Store ADDED
Binary file (6.15 kB). View file
 
triposg/LICENSE ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TENCENT HUNYUAN FLASHVDM COMMUNITY LICENSE AGREEMENT
2
+ Tencent Hunyuan FlashVDM Release Date: March 19, 2025
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KORE AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan FlashVDM Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5
+ 1. DEFINITIONS.
6
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan FlashVDM Works or any portion or element thereof set forth herein.
8
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan FlashVDM made publicly available by Tencent.
9
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan FlashVDM Works for any purpose and in any field of use.
11
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan FlashVDM and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan FlashVDM or any Model Derivative of Tencent Hunyuan FlashVDM; (ii) works based on Tencent Hunyuan FlashVDM or any Model Derivative of Tencent Hunyuan FlashVDM; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan FlashVDM or any Model Derivative of Tencent Hunyuan FlashVDM, to that model in order to cause that model to perform similarly to Tencent Hunyuan FlashVDM or a Model Derivative of Tencent Hunyuan FlashVDM, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan FlashVDM or a Model Derivative of Tencent Hunyuan FlashVDM for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13
+ h. “Output” shall mean the information and/or content output of Tencent Hunyuan FlashVDM or a Model Derivative that results from operating or otherwise using Tencent Hunyuan FlashVDM or a Model Derivative, including via a Hosted Service.
14
+ i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
15
+ j. “Tencent Hunyuan FlashVDM” shall mean the 3D generation models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us at [ https://github.com/Tencent/FlashVDM ].
16
+ k. “Tencent Hunyuan FlashVDM Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
17
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
18
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
19
+ n. “including” shall mean including but not limited to.
20
+ 2. GRANT OF RIGHTS.
21
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
22
+ 3. DISTRIBUTION.
23
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan FlashVDM Works, exclusively in the Territory, provided that You meet all of the following conditions:
24
+ a. You must provide all such Third Party recipients of the Tencent Hunyuan FlashVDM Works or products or services using them a copy of this Agreement;
25
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
26
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan FlashVDM Works; and (ii) mark the products or services developed by using the Tencent Hunyuan FlashVDM Works to indicate that the product/service is “Powered by Tencent Hunyuan ”; and
27
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan FlashVDM is licensed under the Tencent Hunyuan FlashVDM Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
28
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan FlashVDM Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
29
+ 4. ADDITIONAL COMMERCIAL TERMS.
30
+ If, on the Tencent Hunyuan FlashVDM version release date, the monthly active users of all products or services made available by or for Licensee is greater than 1 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
31
+ Subject to Tencent's written approval, you may request a license for the use of Tencent Hunyuan FlashVDM by submitting the following information to hunyuan3d@tencent.com:
32
+ a. Your company’s name and associated business sector that plans to use Tencent Hunyuan FlashVDM.
33
+ b. Your intended use case and the purpose of using Tencent Hunyuan FlashVDM.
34
+ 5. RULES OF USE.
35
+ a. Your use of the Tencent Hunyuan FlashVDM Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan FlashVDM Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan FlashVDM Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan FlashVDM Works are subject to the use restrictions in these Sections 5(a) and 5(b).
36
+ b. You must not use the Tencent Hunyuan FlashVDM Works or any Output or results of the Tencent Hunyuan FlashVDM Works to improve any other AI model (other than Tencent Hunyuan FlashVDM or Model Derivatives thereof).
37
+ c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan FlashVDM Works, Output or results of the Tencent Hunyuan FlashVDM Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
38
+ 6. INTELLECTUAL PROPERTY.
39
+ a. Subject to Tencent’s ownership of Tencent Hunyuan FlashVDM Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
40
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan FlashVDM Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan FlashVDM Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
41
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan FlashVDM Works.
42
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
43
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
44
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan FlashVDM Works or to grant any license thereto.
45
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN FLASHVDM WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN FLASHVDM WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN FLASHVDM WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
46
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN FLASHVDM WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
47
+ 8. SURVIVAL AND TERMINATION.
48
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
49
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan FlashVDM Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
50
+ 9. GOVERNING LAW AND JURISDICTION.
51
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
52
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
53
+
54
+ EXHIBIT A
55
+ ACCEPTABLE USE POLICY
56
+
57
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
58
+ Last modified: November 5, 2024
59
+
60
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan FlashVDM. You agree not to use Tencent Hunyuan FlashVDM or Model Derivatives:
61
+ 1. Outside the Territory;
62
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
63
+ 3. To harm Yourself or others;
64
+ 4. To repurpose or distribute output from Tencent Hunyuan FlashVDM or any Model Derivatives to harm Yourself or others;
65
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
66
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
67
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
68
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
69
+ 9. To intentionally defame, disparage or otherwise harass others;
70
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
71
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
72
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
73
+ 13. To impersonate another individual without consent, authorization, or legal right;
74
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
75
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
76
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
77
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
78
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
79
+ 19. For military purposes;
80
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
triposg/__pycache__/inference_utils.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
triposg/diso/__init__.py ADDED
File without changes
triposg/diso/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (156 Bytes). View file
 
triposg/diso/__pycache__/diff_dmc.cpython-310.pyc ADDED
Binary file (611 Bytes). View file
 
triposg/diso/diff_dmc.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class DiffDMC(nn.Module):
6
+ def __init__(self, latent_dim=256):
7
+ super().__init__()
8
+ self.linear = nn.Sequential(
9
+ nn.Linear(latent_dim, latent_dim),
10
+ nn.ReLU(),
11
+ nn.Linear(latent_dim, latent_dim)
12
+ )
13
+
14
+ def forward(self, x):
15
+ return self.linear(x)
triposg/diso/diffusion_utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ # dummy utility functions for diffusion corrections
3
+
4
+ def post_process_latents(latents):
5
+ return latents
triposg/inference_utils.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import scipy.ndimage
5
+ from skimage import measure
6
+ import torch.nn.functional as F
7
+ from diso import DiffDMC
8
+ from einops import repeat
9
+ from triposg.utils.typing import *
10
+
11
+ def generate_dense_grid_points_gpu(bbox_min: torch.Tensor,
12
+ bbox_max: torch.Tensor,
13
+ octree_depth: int,
14
+ indexing: str = "ij"):
15
+ length = bbox_max - bbox_min
16
+ num_cells = 2 ** octree_depth
17
+ device = bbox_min.device
18
+
19
+ x = torch.linspace(bbox_min[0], bbox_max[0], int(num_cells), dtype=torch.float16, device=device)
20
+ y = torch.linspace(bbox_min[1], bbox_max[1], int(num_cells), dtype=torch.float16, device=device)
21
+ z = torch.linspace(bbox_min[2], bbox_max[2], int(num_cells), dtype=torch.float16, device=device)
22
+
23
+ xs, ys, zs = torch.meshgrid(x, y, z, indexing=indexing)
24
+ xyz = torch.stack((xs, ys, zs), dim=-1)
25
+ xyz = xyz.view(-1, 3)
26
+ grid_size = [int(num_cells), int(num_cells), int(num_cells)]
27
+
28
+ return xyz, grid_size, length
29
+
30
+ def find_mesh_grid_coordinates_fast_gpu(occupancy_grid, n_limits=-1):
31
+ core_grid = occupancy_grid[1:-1, 1:-1, 1:-1]
32
+ occupied = core_grid > 0
33
+
34
+ neighbors_unoccupied = (
35
+ (occupancy_grid[:-2, :-2, :-2] < 0)
36
+ | (occupancy_grid[:-2, :-2, 1:-1] < 0)
37
+ | (occupancy_grid[:-2, :-2, 2:] < 0) # x-1, y-1, z-1/0/1
38
+ | (occupancy_grid[:-2, 1:-1, :-2] < 0)
39
+ | (occupancy_grid[:-2, 1:-1, 1:-1] < 0)
40
+ | (occupancy_grid[:-2, 1:-1, 2:] < 0) # x-1, y0, z-1/0/1
41
+ | (occupancy_grid[:-2, 2:, :-2] < 0)
42
+ | (occupancy_grid[:-2, 2:, 1:-1] < 0)
43
+ | (occupancy_grid[:-2, 2:, 2:] < 0) # x-1, y+1, z-1/0/1
44
+ | (occupancy_grid[1:-1, :-2, :-2] < 0)
45
+ | (occupancy_grid[1:-1, :-2, 1:-1] < 0)
46
+ | (occupancy_grid[1:-1, :-2, 2:] < 0) # x0, y-1, z-1/0/1
47
+ | (occupancy_grid[1:-1, 1:-1, :-2] < 0)
48
+ | (occupancy_grid[1:-1, 1:-1, 2:] < 0) # x0, y0, z-1/1
49
+ | (occupancy_grid[1:-1, 2:, :-2] < 0)
50
+ | (occupancy_grid[1:-1, 2:, 1:-1] < 0)
51
+ | (occupancy_grid[1:-1, 2:, 2:] < 0) # x0, y+1, z-1/0/1
52
+ | (occupancy_grid[2:, :-2, :-2] < 0)
53
+ | (occupancy_grid[2:, :-2, 1:-1] < 0)
54
+ | (occupancy_grid[2:, :-2, 2:] < 0) # x+1, y-1, z-1/0/1
55
+ | (occupancy_grid[2:, 1:-1, :-2] < 0)
56
+ | (occupancy_grid[2:, 1:-1, 1:-1] < 0)
57
+ | (occupancy_grid[2:, 1:-1, 2:] < 0) # x+1, y0, z-1/0/1
58
+ | (occupancy_grid[2:, 2:, :-2] < 0)
59
+ | (occupancy_grid[2:, 2:, 1:-1] < 0)
60
+ | (occupancy_grid[2:, 2:, 2:] < 0) # x+1, y+1, z-1/0/1
61
+ )
62
+ core_mesh_coords = torch.nonzero(occupied & neighbors_unoccupied, as_tuple=False) + 1
63
+
64
+ if n_limits != -1 and core_mesh_coords.shape[0] > n_limits:
65
+ print(f"core mesh coords {core_mesh_coords.shape[0]} is too large, limited to {n_limits}")
66
+ ind = np.random.choice(core_mesh_coords.shape[0], n_limits, True)
67
+ core_mesh_coords = core_mesh_coords[ind]
68
+
69
+ return core_mesh_coords
70
+
71
+ def find_candidates_band(occupancy_grid: torch.Tensor, band_threshold: float, n_limits: int = -1) -> torch.Tensor:
72
+ """
73
+ Returns the coordinates of all voxels in the occupancy_grid where |value| < band_threshold.
74
+
75
+ Args:
76
+ occupancy_grid (torch.Tensor): A 3D tensor of SDF values.
77
+ band_threshold (float): The threshold below which |SDF| must be to include the voxel.
78
+ n_limits (int): Maximum number of points to return (-1 for no limit)
79
+
80
+ Returns:
81
+ torch.Tensor: A 2D tensor of coordinates (N x 3) where each row is [x, y, z].
82
+ """
83
+ core_grid = occupancy_grid[1:-1, 1:-1, 1:-1]
84
+ # logits to sdf
85
+ core_grid = torch.sigmoid(core_grid) * 2 - 1
86
+ # Create a boolean mask for all cells in the band
87
+ in_band = torch.abs(core_grid) < band_threshold
88
+
89
+ # Get coordinates of all voxels in the band
90
+ core_mesh_coords = torch.nonzero(in_band, as_tuple=False) + 1
91
+
92
+ if n_limits != -1 and core_mesh_coords.shape[0] > n_limits:
93
+ print(f"core mesh coords {core_mesh_coords.shape[0]} is too large, limited to {n_limits}")
94
+ ind = np.random.choice(core_mesh_coords.shape[0], n_limits, True)
95
+ core_mesh_coords = core_mesh_coords[ind]
96
+
97
+ return core_mesh_coords
98
+
99
+ def expand_edge_region_fast(edge_coords, grid_size):
100
+ expanded_tensor = torch.zeros(grid_size, grid_size, grid_size, device='cuda', dtype=torch.float16, requires_grad=False)
101
+ expanded_tensor[edge_coords[:, 0], edge_coords[:, 1], edge_coords[:, 2]] = 1
102
+ if grid_size < 512:
103
+ kernel_size = 5
104
+ pooled_tensor = torch.nn.functional.max_pool3d(expanded_tensor.unsqueeze(0).unsqueeze(0), kernel_size=kernel_size, stride=1, padding=2).squeeze()
105
+ else:
106
+ kernel_size = 3
107
+ pooled_tensor = torch.nn.functional.max_pool3d(expanded_tensor.unsqueeze(0).unsqueeze(0), kernel_size=kernel_size, stride=1, padding=1).squeeze()
108
+ expanded_coords_low_res = torch.nonzero(pooled_tensor, as_tuple=False).to(torch.int16)
109
+
110
+ expanded_coords_high_res = torch.stack([
111
+ torch.cat((expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1)),
112
+ torch.cat((expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2+1, expanded_coords_low_res[:, 1] * 2 + 1, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2 + 1, expanded_coords_low_res[:, 1] * 2 + 1)),
113
+ torch.cat((expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2+1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2 + 1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2+1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2 + 1))
114
+ ], dim=1)
115
+
116
+ return expanded_coords_high_res
117
+
118
+ def zoom_block(block, scale_factor, order=3):
119
+ block = block.astype(np.float32)
120
+ return scipy.ndimage.zoom(block, scale_factor, order=order)
121
+
122
+ def parallel_zoom(occupancy_grid, scale_factor):
123
+ result = torch.nn.functional.interpolate(occupancy_grid.unsqueeze(0).unsqueeze(0), scale_factor=scale_factor)
124
+ return result.squeeze(0).squeeze(0)
125
+
126
+
127
+ @torch.no_grad()
128
+ def hierarchical_extract_geometry(geometric_func: Callable,
129
+ device: torch.device,
130
+ bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
131
+ dense_octree_depth: int = 8,
132
+ hierarchical_octree_depth: int = 9,
133
+ ):
134
+ """
135
+
136
+ Args:
137
+ geometric_func:
138
+ device:
139
+ bounds:
140
+ dense_octree_depth:
141
+ hierarchical_octree_depth:
142
+ Returns:
143
+
144
+ """
145
+ if isinstance(bounds, float):
146
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
147
+
148
+ bbox_min = torch.tensor(bounds[0:3]).to(device)
149
+ bbox_max = torch.tensor(bounds[3:6]).to(device)
150
+ bbox_size = bbox_max - bbox_min
151
+
152
+ xyz_samples, grid_size, length = generate_dense_grid_points_gpu(
153
+ bbox_min=bbox_min,
154
+ bbox_max=bbox_max,
155
+ octree_depth=dense_octree_depth,
156
+ indexing="ij"
157
+ )
158
+
159
+ print(f'step 1 query num: {xyz_samples.shape[0]}')
160
+ grid_logits = geometric_func(xyz_samples.unsqueeze(0)).to(torch.float16).view(grid_size[0], grid_size[1], grid_size[2])
161
+ # print(f'step 1 grid_logits shape: {grid_logits.shape}')
162
+ for i in range(hierarchical_octree_depth - dense_octree_depth):
163
+ curr_octree_depth = dense_octree_depth + i + 1
164
+ # upsample
165
+ grid_size = 2**curr_octree_depth
166
+ normalize_offset = grid_size / 2
167
+ high_res_occupancy = parallel_zoom(grid_logits, 2)
168
+
169
+ band_threshold = 1.0
170
+ edge_coords = find_candidates_band(grid_logits, band_threshold)
171
+ expanded_coords = expand_edge_region_fast(edge_coords, grid_size=int(grid_size/2)).to(torch.float16)
172
+ print(f'step {i+2} query num: {len(expanded_coords)}')
173
+ expanded_coords_norm = (expanded_coords - normalize_offset) * (abs(bounds[0]) / normalize_offset)
174
+
175
+ all_logits = None
176
+
177
+ all_logits = geometric_func(expanded_coords_norm.unsqueeze(0)).to(torch.float16)
178
+ all_logits = torch.cat([expanded_coords_norm, all_logits[0]], dim=1)
179
+ # print("all logits shape = ", all_logits.shape)
180
+
181
+ indices = all_logits[..., :3]
182
+ indices = indices * (normalize_offset / abs(bounds[0])) + normalize_offset
183
+ indices = indices.type(torch.IntTensor)
184
+ values = all_logits[:, 3]
185
+ # breakpoint()
186
+ high_res_occupancy[indices[:, 0], indices[:, 1], indices[:, 2]] = values
187
+ grid_logits = high_res_occupancy
188
+ torch.cuda.empty_cache()
189
+ mesh_v_f = []
190
+ try:
191
+ print("final grids shape = ", grid_logits.shape)
192
+ vertices, faces, normals, _ = measure.marching_cubes(grid_logits.float().cpu().numpy(), 0, method="lewiner")
193
+ vertices = vertices / (2**hierarchical_octree_depth) * bbox_size.cpu().numpy() + bbox_min.cpu().numpy()
194
+ mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
195
+ except Exception as e:
196
+ print(e)
197
+ torch.cuda.empty_cache()
198
+ mesh_v_f = (None, None)
199
+
200
+ return [mesh_v_f]
201
+
202
+ def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
203
+ """
204
+ Args:
205
+ input_tensor: shape [D, D, D], torch.float16
206
+ alpha: isosurface offset
207
+ Returns:
208
+ mask: shape [D, D, D], torch.int32
209
+ """
210
+ device = input_tensor.device
211
+ D = input_tensor.shape[0]
212
+ signed_val = 0.0
213
+
214
+ # add isosurface offset and exclude invalid value
215
+ val = input_tensor + alpha
216
+ valid_mask = val > -9000
217
+
218
+ # obtain neighbors
219
+ def get_neighbor(t, shift, axis):
220
+ if shift == 0:
221
+ return t.clone()
222
+
223
+ pad_dims = [0, 0, 0, 0, 0, 0] # [x_front,x_back,y_front,y_back,z_front,z_back]
224
+
225
+ if axis == 0: # x axis
226
+ pad_idx = 0 if shift > 0 else 1
227
+ pad_dims[pad_idx] = abs(shift)
228
+ elif axis == 1: # y axis
229
+ pad_idx = 2 if shift > 0 else 3
230
+ pad_dims[pad_idx] = abs(shift)
231
+ elif axis == 2: # z axis
232
+ pad_idx = 4 if shift > 0 else 5
233
+ pad_dims[pad_idx] = abs(shift)
234
+
235
+ # Apply padding with replication at boundaries
236
+ padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate')
237
+
238
+ # Create dynamic slicing indices
239
+ slice_dims = [slice(None)] * 3
240
+ if axis == 0: # x axis
241
+ if shift > 0:
242
+ slice_dims[0] = slice(shift, None)
243
+ else:
244
+ slice_dims[0] = slice(None, shift)
245
+ elif axis == 1: # y axis
246
+ if shift > 0:
247
+ slice_dims[1] = slice(shift, None)
248
+ else:
249
+ slice_dims[1] = slice(None, shift)
250
+ elif axis == 2: # z axis
251
+ if shift > 0:
252
+ slice_dims[2] = slice(shift, None)
253
+ else:
254
+ slice_dims[2] = slice(None, shift)
255
+
256
+ # Apply slicing and restore dimensions
257
+ padded = padded.squeeze(0).squeeze(0)
258
+ sliced = padded[slice_dims]
259
+ return sliced
260
+
261
+ # Get neighbors in all directions
262
+ left = get_neighbor(val, 1, axis=0) # x axis
263
+ right = get_neighbor(val, -1, axis=0)
264
+ back = get_neighbor(val, 1, axis=1) # y axis
265
+ front = get_neighbor(val, -1, axis=1)
266
+ down = get_neighbor(val, 1, axis=2) # z axis
267
+ up = get_neighbor(val, -1, axis=2)
268
+
269
+ # Handle invalid boundary values
270
+ def safe_where(neighbor):
271
+ return torch.where(neighbor > -9000, neighbor, val)
272
+
273
+ left = safe_where(left)
274
+ right = safe_where(right)
275
+ back = safe_where(back)
276
+ front = safe_where(front)
277
+ down = safe_where(down)
278
+ up = safe_where(up)
279
+
280
+ # Calculate sign consistency
281
+ sign = torch.sign(val.to(torch.float32))
282
+ neighbors_sign = torch.stack([
283
+ torch.sign(left.to(torch.float32)),
284
+ torch.sign(right.to(torch.float32)),
285
+ torch.sign(back.to(torch.float32)),
286
+ torch.sign(front.to(torch.float32)),
287
+ torch.sign(down.to(torch.float32)),
288
+ torch.sign(up.to(torch.float32))
289
+ ], dim=0)
290
+
291
+ # Check if all signs are consistent
292
+ same_sign = torch.all(neighbors_sign == sign, dim=0)
293
+
294
+ # Generate final mask
295
+ mask = (~same_sign).to(torch.int32)
296
+ return mask * valid_mask.to(torch.int32)
297
+
298
+
299
+ def generate_dense_grid_points_2(
300
+ bbox_min: np.ndarray,
301
+ bbox_max: np.ndarray,
302
+ octree_resolution: int,
303
+ indexing: str = "ij",
304
+ ):
305
+ length = bbox_max - bbox_min
306
+ num_cells = octree_resolution
307
+
308
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
309
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
310
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
311
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
312
+ xyz = np.stack((xs, ys, zs), axis=-1)
313
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
314
+
315
+ return xyz, grid_size, length
316
+
317
+ @torch.no_grad()
318
+ def flash_extract_geometry(
319
+ latents: torch.FloatTensor,
320
+ vae: Callable,
321
+ bounds: Union[Tuple[float], List[float], float] = 1.01,
322
+ num_chunks: int = 10000,
323
+ mc_level: float = 0.0,
324
+ octree_depth: int = 8,
325
+ min_resolution: int = 63,
326
+ mini_grid_num: int = 4,
327
+ **kwargs,
328
+ ):
329
+ geo_decoder = vae.decoder
330
+ device = latents.device
331
+ dtype = latents.dtype
332
+ # resolution to depth
333
+ octree_resolution = 2 ** octree_depth
334
+ resolutions = []
335
+ if octree_resolution < min_resolution:
336
+ resolutions.append(octree_resolution)
337
+ while octree_resolution >= min_resolution:
338
+ resolutions.append(octree_resolution)
339
+ octree_resolution = octree_resolution // 2
340
+ resolutions.reverse()
341
+ resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
342
+ for i, resolution in enumerate(resolutions[1:]):
343
+ resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
344
+
345
+
346
+ # 1. generate query points
347
+ if isinstance(bounds, float):
348
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
349
+ bbox_min = np.array(bounds[0:3])
350
+ bbox_max = np.array(bounds[3:6])
351
+ bbox_size = bbox_max - bbox_min
352
+
353
+ xyz_samples, grid_size, length = generate_dense_grid_points_2(
354
+ bbox_min=bbox_min,
355
+ bbox_max=bbox_max,
356
+ octree_resolution=resolutions[0],
357
+ indexing="ij"
358
+ )
359
+
360
+ dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
361
+ dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
362
+
363
+ grid_size = np.array(grid_size)
364
+
365
+ # 2. latents to 3d volume
366
+ xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
367
+ batch_size = latents.shape[0]
368
+ mini_grid_size = xyz_samples.shape[0] // mini_grid_num
369
+ xyz_samples = xyz_samples.view(
370
+ mini_grid_num, mini_grid_size,
371
+ mini_grid_num, mini_grid_size,
372
+ mini_grid_num, mini_grid_size, 3
373
+ ).permute(
374
+ 0, 2, 4, 1, 3, 5, 6
375
+ ).reshape(
376
+ -1, mini_grid_size * mini_grid_size * mini_grid_size, 3
377
+ )
378
+ batch_logits = []
379
+ num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
380
+ for start in range(0, xyz_samples.shape[0], num_batchs):
381
+ queries = xyz_samples[start: start + num_batchs, :]
382
+ batch = queries.shape[0]
383
+ batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
384
+ # geo_decoder.set_topk(True)
385
+ geo_decoder.set_topk(False)
386
+ logits = vae.decode(batch_latents, queries).sample
387
+ batch_logits.append(logits)
388
+ grid_logits = torch.cat(batch_logits, dim=0).reshape(
389
+ mini_grid_num, mini_grid_num, mini_grid_num,
390
+ mini_grid_size, mini_grid_size,
391
+ mini_grid_size
392
+ ).permute(0, 3, 1, 4, 2, 5).contiguous().view(
393
+ (batch_size, grid_size[0], grid_size[1], grid_size[2])
394
+ )
395
+
396
+ for octree_depth_now in resolutions[1:]:
397
+ grid_size = np.array([octree_depth_now + 1] * 3)
398
+ resolution = bbox_size / octree_depth_now
399
+ next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
400
+ next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
401
+ curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
402
+ curr_points += grid_logits.squeeze(0).abs() < 0.95
403
+
404
+ if octree_depth_now == resolutions[-1]:
405
+ expand_num = 0
406
+ else:
407
+ expand_num = 1
408
+ for i in range(expand_num):
409
+ curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
410
+ curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
411
+ (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
412
+
413
+ next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
414
+ for i in range(2 - expand_num):
415
+ next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
416
+ nidx = torch.where(next_index > 0)
417
+
418
+ next_points = torch.stack(nidx, dim=1)
419
+ next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
420
+ torch.tensor(bbox_min, dtype=torch.float32, device=device))
421
+
422
+ query_grid_num = 6
423
+ min_val = next_points.min(axis=0).values
424
+ max_val = next_points.max(axis=0).values
425
+ vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
426
+ index = torch.floor(vol_queries_index).long()
427
+ index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
428
+ index = index.sort()
429
+ next_points = next_points[index.indices].unsqueeze(0).contiguous()
430
+ unique_values = torch.unique(index.values, return_counts=True)
431
+ grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
432
+ input_grid = [[], []]
433
+ logits_grid_list = []
434
+ start_num = 0
435
+ sum_num = 0
436
+ for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
437
+ if sum_num + count < num_chunks or sum_num == 0:
438
+ sum_num += count
439
+ input_grid[0].append(grid_index)
440
+ input_grid[1].append(count)
441
+ else:
442
+ # geo_decoder.set_topk(input_grid)
443
+ geo_decoder.set_topk(False)
444
+ logits_grid = vae.decode(latents,next_points[:, start_num:start_num + sum_num]).sample
445
+ start_num = start_num + sum_num
446
+ logits_grid_list.append(logits_grid)
447
+ input_grid = [[grid_index], [count]]
448
+ sum_num = count
449
+ if sum_num > 0:
450
+ # geo_decoder.set_topk(input_grid)
451
+ geo_decoder.set_topk(False)
452
+ logits_grid = vae.decode(latents,next_points[:, start_num:start_num + sum_num]).sample
453
+ logits_grid_list.append(logits_grid)
454
+ logits_grid = torch.cat(logits_grid_list, dim=1)
455
+ grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
456
+ next_logits[nidx] = grid_logits
457
+ grid_logits = next_logits.unsqueeze(0)
458
+
459
+ grid_logits[grid_logits == -10000.] = float('nan')
460
+ torch.cuda.empty_cache()
461
+ mesh_v_f = []
462
+ grid_logits = grid_logits[0]
463
+ try:
464
+ print("final grids shape = ", grid_logits.shape)
465
+
466
+ dmc = DiffDMC(dtype=torch.float32).to(grid_logits.device)
467
+ sdf = -grid_logits / octree_resolution
468
+ sdf = sdf.to(torch.float32).contiguous()
469
+
470
+ vertices, faces = dmc(sdf)
471
+ vertices = vertices.detach().cpu().numpy()
472
+ faces = faces.detach().cpu().numpy()[:, ::-1]
473
+ vertices = vertices / (2 ** octree_depth) * bbox_size + bbox_min
474
+ # Центрируем
475
+ vertices = vertices - vertices.mean(axis=0)
476
+ # Масштабируем (например, ×100 — сантиметры)
477
+ target_scale = 1.0
478
+ max_extent = np.max(np.linalg.norm(vertices, axis=1))
479
+ scale_factor = target_scale / max_extent
480
+ vertices = vertices * scale_factor
481
+
482
+ mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
483
+
484
+
485
+
486
+ except Exception as e:
487
+ print(e)
488
+ torch.cuda.empty_cache()
489
+ mesh_v_f = (None, None)
490
+
491
+ return [mesh_v_f]
triposg/models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
triposg/models/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (8.09 kB). View file
 
triposg/models/__pycache__/embeddings.cpython-310.pyc ADDED
Binary file (3.65 kB). View file
 
triposg/models/attention_processor.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention_processor import Attention
6
+ from diffusers.models.embeddings import apply_rotary_emb
7
+ from diffusers.utils import logging
8
+ from einops import rearrange
9
+
10
+
11
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
12
+
13
+
14
+ class FlashTripoSGAttnProcessor2_0:
15
+ r"""
16
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
17
+ used in the Tripo2DiT model. It applies a s normalization layer and rotary embedding on query and key vector.
18
+ """
19
+
20
+ def __init__(self, topk=True):
21
+ if not hasattr(F, "scaled_dot_product_attention"):
22
+ raise ImportError(
23
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
24
+ )
25
+ self.topk = topk
26
+
27
+ def qkv(self, attn, q, k, v, attn_mask, dropout_p, is_causal):
28
+ if k.shape[-2] == 3072:
29
+ topk = 1024
30
+ elif k.shape[-2] == 512:
31
+ topk = 256
32
+ else:
33
+ topk = k.shape[-2] // 3
34
+
35
+ if self.topk is True:
36
+ q1 = q[:, :, ::100, :]
37
+ sim = q1 @ k.transpose(-1, -2)
38
+ sim = torch.mean(sim, -2)
39
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
40
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
41
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
42
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
43
+ out = F.scaled_dot_product_attention(q, k0, v0)
44
+ elif self.topk is False:
45
+ out = F.scaled_dot_product_attention(q, k, v)
46
+ else:
47
+ idx, counts = self.topk
48
+ start = 0
49
+ outs = []
50
+ for grid_coord, count in zip(idx, counts):
51
+ end = start + count
52
+ q_chunk = q[:, :, start:end, :]
53
+ q1 = q_chunk[:, :, ::50, :]
54
+ sim = q1 @ k.transpose(-1, -2)
55
+ sim = torch.mean(sim, -2)
56
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
57
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
58
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
59
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
60
+ out = F.scaled_dot_product_attention(q_chunk, k0, v0)
61
+ outs.append(out)
62
+ start += count
63
+ out = torch.cat(outs, dim=-2)
64
+ self.topk = False
65
+ return out
66
+
67
+ def __call__(
68
+ self,
69
+ attn: Attention,
70
+ hidden_states: torch.Tensor,
71
+ encoder_hidden_states: Optional[torch.Tensor] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ temb: Optional[torch.Tensor] = None,
74
+ image_rotary_emb: Optional[torch.Tensor] = None,
75
+ **kwargs,
76
+ ) -> torch.Tensor:
77
+ residual = hidden_states
78
+ if attn.spatial_norm is not None:
79
+ hidden_states = attn.spatial_norm(hidden_states, temb)
80
+
81
+ input_ndim = hidden_states.ndim
82
+
83
+ if input_ndim == 4:
84
+ batch_size, channel, height, width = hidden_states.shape
85
+ hidden_states = hidden_states.view(
86
+ batch_size, channel, height * width
87
+ ).transpose(1, 2)
88
+
89
+ batch_size, sequence_length, _ = (
90
+ hidden_states.shape
91
+ if encoder_hidden_states is None
92
+ else encoder_hidden_states.shape
93
+ )
94
+
95
+ if attention_mask is not None:
96
+ attention_mask = attn.prepare_attention_mask(
97
+ attention_mask, sequence_length, batch_size
98
+ )
99
+ # scaled_dot_product_attention expects attention_mask shape to be
100
+ # (batch, heads, source_length, target_length)
101
+ attention_mask = attention_mask.view(
102
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
103
+ )
104
+
105
+ if attn.group_norm is not None:
106
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
107
+ 1, 2
108
+ )
109
+
110
+ query = attn.to_q(hidden_states)
111
+
112
+ if encoder_hidden_states is None:
113
+ encoder_hidden_states = hidden_states
114
+ elif attn.norm_cross:
115
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
116
+ encoder_hidden_states
117
+ )
118
+
119
+ key = attn.to_k(encoder_hidden_states)
120
+ value = attn.to_v(encoder_hidden_states)
121
+
122
+ # NOTE that tripo2 split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
123
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
124
+ if not attn.is_cross_attention:
125
+ qkv = torch.cat((query, key, value), dim=-1)
126
+ split_size = qkv.shape[-1] // attn.heads // 3
127
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
128
+ query, key, value = torch.split(qkv, split_size, dim=-1)
129
+ else:
130
+ kv = torch.cat((key, value), dim=-1)
131
+ split_size = kv.shape[-1] // attn.heads // 2
132
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
133
+ key, value = torch.split(kv, split_size, dim=-1)
134
+
135
+ head_dim = key.shape[-1]
136
+
137
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
138
+
139
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
140
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
141
+
142
+ if attn.norm_q is not None:
143
+ query = attn.norm_q(query)
144
+ if attn.norm_k is not None:
145
+ key = attn.norm_k(key)
146
+
147
+ # Apply RoPE if needed
148
+ if image_rotary_emb is not None:
149
+ query = apply_rotary_emb(query, image_rotary_emb)
150
+ if not attn.is_cross_attention:
151
+ key = apply_rotary_emb(key, image_rotary_emb)
152
+
153
+ # flashvdm topk
154
+ hidden_states = self.qkv(attn, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
155
+
156
+ hidden_states = hidden_states.transpose(1, 2).reshape(
157
+ batch_size, -1, attn.heads * head_dim
158
+ )
159
+ hidden_states = hidden_states.to(query.dtype)
160
+
161
+ # linear proj
162
+ hidden_states = attn.to_out[0](hidden_states)
163
+ # dropout
164
+ hidden_states = attn.to_out[1](hidden_states)
165
+
166
+ if input_ndim == 4:
167
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
168
+ batch_size, channel, height, width
169
+ )
170
+
171
+ if attn.residual_connection:
172
+ hidden_states = hidden_states + residual
173
+
174
+ hidden_states = hidden_states / attn.rescale_output_factor
175
+
176
+ return hidden_states
177
+
178
+ class TripoSGAttnProcessor2_0:
179
+ r"""
180
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
181
+ used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector.
182
+ """
183
+
184
+ def __init__(self):
185
+ if not hasattr(F, "scaled_dot_product_attention"):
186
+ raise ImportError(
187
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
188
+ )
189
+
190
+ def __call__(
191
+ self,
192
+ attn: Attention,
193
+ hidden_states: torch.Tensor,
194
+ encoder_hidden_states: Optional[torch.Tensor] = None,
195
+ attention_mask: Optional[torch.Tensor] = None,
196
+ temb: Optional[torch.Tensor] = None,
197
+ image_rotary_emb: Optional[torch.Tensor] = None,
198
+ ) -> torch.Tensor:
199
+ residual = hidden_states
200
+ if attn.spatial_norm is not None:
201
+ hidden_states = attn.spatial_norm(hidden_states, temb)
202
+
203
+ input_ndim = hidden_states.ndim
204
+
205
+ if input_ndim == 4:
206
+ batch_size, channel, height, width = hidden_states.shape
207
+ hidden_states = hidden_states.view(
208
+ batch_size, channel, height * width
209
+ ).transpose(1, 2)
210
+
211
+ batch_size, sequence_length, _ = (
212
+ hidden_states.shape
213
+ if encoder_hidden_states is None
214
+ else encoder_hidden_states.shape
215
+ )
216
+
217
+ if attention_mask is not None:
218
+ attention_mask = attn.prepare_attention_mask(
219
+ attention_mask, sequence_length, batch_size
220
+ )
221
+ # scaled_dot_product_attention expects attention_mask shape to be
222
+ # (batch, heads, source_length, target_length)
223
+ attention_mask = attention_mask.view(
224
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
225
+ )
226
+
227
+ if attn.group_norm is not None:
228
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
229
+ 1, 2
230
+ )
231
+
232
+ query = attn.to_q(hidden_states)
233
+
234
+ if encoder_hidden_states is None:
235
+ encoder_hidden_states = hidden_states
236
+ elif attn.norm_cross:
237
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
238
+ encoder_hidden_states
239
+ )
240
+
241
+ key = attn.to_k(encoder_hidden_states)
242
+ value = attn.to_v(encoder_hidden_states)
243
+
244
+ # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
245
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
246
+ if not attn.is_cross_attention:
247
+ qkv = torch.cat((query, key, value), dim=-1)
248
+ split_size = qkv.shape[-1] // attn.heads // 3
249
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
250
+ query, key, value = torch.split(qkv, split_size, dim=-1)
251
+ else:
252
+ kv = torch.cat((key, value), dim=-1)
253
+ split_size = kv.shape[-1] // attn.heads // 2
254
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
255
+ key, value = torch.split(kv, split_size, dim=-1)
256
+
257
+ head_dim = key.shape[-1]
258
+
259
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
260
+
261
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
262
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
263
+
264
+ if attn.norm_q is not None:
265
+ query = attn.norm_q(query)
266
+ if attn.norm_k is not None:
267
+ key = attn.norm_k(key)
268
+
269
+ # Apply RoPE if needed
270
+ if image_rotary_emb is not None:
271
+ query = apply_rotary_emb(query, image_rotary_emb)
272
+ if not attn.is_cross_attention:
273
+ key = apply_rotary_emb(key, image_rotary_emb)
274
+
275
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
276
+ # TODO: add support for attn.scale when we move to Torch 2.1
277
+ hidden_states = F.scaled_dot_product_attention(
278
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
279
+ )
280
+
281
+ hidden_states = hidden_states.transpose(1, 2).reshape(
282
+ batch_size, -1, attn.heads * head_dim
283
+ )
284
+ hidden_states = hidden_states.to(query.dtype)
285
+
286
+ # linear proj
287
+ hidden_states = attn.to_out[0](hidden_states)
288
+ # dropout
289
+ hidden_states = attn.to_out[1](hidden_states)
290
+
291
+ if input_ndim == 4:
292
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
293
+ batch_size, channel, height, width
294
+ )
295
+
296
+ if attn.residual_connection:
297
+ hidden_states = hidden_states + residual
298
+
299
+ hidden_states = hidden_states / attn.rescale_output_factor
300
+
301
+ return hidden_states
302
+
303
+
304
+ class FusedTripoSGAttnProcessor2_0:
305
+ r"""
306
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
307
+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
308
+ query and key vector.
309
+ """
310
+
311
+ def __init__(self):
312
+ if not hasattr(F, "scaled_dot_product_attention"):
313
+ raise ImportError(
314
+ "FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
315
+ )
316
+
317
+ def __call__(
318
+ self,
319
+ attn: Attention,
320
+ hidden_states: torch.Tensor,
321
+ encoder_hidden_states: Optional[torch.Tensor] = None,
322
+ attention_mask: Optional[torch.Tensor] = None,
323
+ temb: Optional[torch.Tensor] = None,
324
+ image_rotary_emb: Optional[torch.Tensor] = None,
325
+ ) -> torch.Tensor:
326
+ residual = hidden_states
327
+ if attn.spatial_norm is not None:
328
+ hidden_states = attn.spatial_norm(hidden_states, temb)
329
+
330
+ input_ndim = hidden_states.ndim
331
+
332
+ if input_ndim == 4:
333
+ batch_size, channel, height, width = hidden_states.shape
334
+ hidden_states = hidden_states.view(
335
+ batch_size, channel, height * width
336
+ ).transpose(1, 2)
337
+
338
+ batch_size, sequence_length, _ = (
339
+ hidden_states.shape
340
+ if encoder_hidden_states is None
341
+ else encoder_hidden_states.shape
342
+ )
343
+
344
+ if attention_mask is not None:
345
+ attention_mask = attn.prepare_attention_mask(
346
+ attention_mask, sequence_length, batch_size
347
+ )
348
+ # scaled_dot_product_attention expects attention_mask shape to be
349
+ # (batch, heads, source_length, target_length)
350
+ attention_mask = attention_mask.view(
351
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
352
+ )
353
+
354
+ if attn.group_norm is not None:
355
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
356
+ 1, 2
357
+ )
358
+
359
+ # NOTE that pre-trained split heads first, then split qkv
360
+ if encoder_hidden_states is None:
361
+ qkv = attn.to_qkv(hidden_states)
362
+ split_size = qkv.shape[-1] // attn.heads // 3
363
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
364
+ query, key, value = torch.split(qkv, split_size, dim=-1)
365
+ else:
366
+ if attn.norm_cross:
367
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
368
+ encoder_hidden_states
369
+ )
370
+ query = attn.to_q(hidden_states)
371
+
372
+ kv = attn.to_kv(encoder_hidden_states)
373
+ split_size = kv.shape[-1] // attn.heads // 2
374
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
375
+ key, value = torch.split(kv, split_size, dim=-1)
376
+
377
+ head_dim = key.shape[-1]
378
+
379
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
380
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
381
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
382
+
383
+ if attn.norm_q is not None:
384
+ query = attn.norm_q(query)
385
+ if attn.norm_k is not None:
386
+ key = attn.norm_k(key)
387
+
388
+ # Apply RoPE if needed
389
+ if image_rotary_emb is not None:
390
+ query = apply_rotary_emb(query, image_rotary_emb)
391
+ if not attn.is_cross_attention:
392
+ key = apply_rotary_emb(key, image_rotary_emb)
393
+
394
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
395
+ # TODO: add support for attn.scale when we move to Torch 2.1
396
+ hidden_states = F.scaled_dot_product_attention(
397
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
398
+ )
399
+
400
+ hidden_states = hidden_states.transpose(1, 2).reshape(
401
+ batch_size, -1, attn.heads * head_dim
402
+ )
403
+ hidden_states = hidden_states.to(query.dtype)
404
+
405
+ # linear proj
406
+ hidden_states = attn.to_out[0](hidden_states)
407
+ # dropout
408
+ hidden_states = attn.to_out[1](hidden_states)
409
+
410
+ if input_ndim == 4:
411
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
412
+ batch_size, channel, height, width
413
+ )
414
+
415
+ if attn.residual_connection:
416
+ hidden_states = hidden_states + residual
417
+
418
+ hidden_states = hidden_states / attn.rescale_output_factor
419
+
420
+ return hidden_states
triposg/models/autoencoders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder_kl_triposg import TripoSGVAEModel
triposg/models/autoencoders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (233 Bytes). View file
 
triposg/models/autoencoders/__pycache__/autoencoder_kl_triposg.cpython-310.pyc ADDED
Binary file (15.7 kB). View file
 
triposg/models/autoencoders/__pycache__/vae.cpython-310.pyc ADDED
Binary file (2.34 kB). View file
 
triposg/models/autoencoders/autoencoder_kl_triposg.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
8
+ from diffusers.models.autoencoders.vae import DecoderOutput
9
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.models.normalization import FP32LayerNorm, LayerNorm
12
+ from diffusers.utils import logging
13
+ from diffusers.utils.accelerate_utils import apply_forward_hook
14
+ from einops import repeat
15
+ # from torch_cluster import fps
16
+ from tqdm import tqdm
17
+
18
+ from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0, FlashTripoSGAttnProcessor2_0
19
+ from ..embeddings import FrequencyPositionalEmbedding
20
+ from ..transformers.triposg_transformer import DiTBlock
21
+ from .vae import DiagonalGaussianDistribution
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+
26
+ class TripoSGEncoder(nn.Module):
27
+ def __init__(
28
+ self,
29
+ in_channels: int = 3,
30
+ dim: int = 512,
31
+ num_attention_heads: int = 8,
32
+ num_layers: int = 8,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.proj_in = nn.Linear(in_channels, dim, bias=True)
37
+
38
+ self.blocks = nn.ModuleList(
39
+ [
40
+ DiTBlock(
41
+ dim=dim,
42
+ num_attention_heads=num_attention_heads,
43
+ use_self_attention=False,
44
+ use_cross_attention=True,
45
+ cross_attention_dim=dim,
46
+ cross_attention_norm_type="layer_norm",
47
+ activation_fn="gelu",
48
+ norm_type="fp32_layer_norm",
49
+ norm_eps=1e-5,
50
+ qk_norm=False,
51
+ qkv_bias=False,
52
+ ) # cross attention
53
+ ]
54
+ + [
55
+ DiTBlock(
56
+ dim=dim,
57
+ num_attention_heads=num_attention_heads,
58
+ use_self_attention=True,
59
+ self_attention_norm_type="fp32_layer_norm",
60
+ use_cross_attention=False,
61
+ activation_fn="gelu",
62
+ norm_type="fp32_layer_norm",
63
+ norm_eps=1e-5,
64
+ qk_norm=False,
65
+ qkv_bias=False,
66
+ )
67
+ for _ in range(num_layers) # self attention
68
+ ]
69
+ )
70
+
71
+ self.norm_out = LayerNorm(dim)
72
+
73
+ def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor):
74
+ hidden_states = self.proj_in(sample_1)
75
+ encoder_hidden_states = self.proj_in(sample_2)
76
+
77
+ for layer, block in enumerate(self.blocks):
78
+ if layer == 0:
79
+ hidden_states = block(
80
+ hidden_states, encoder_hidden_states=encoder_hidden_states
81
+ )
82
+ else:
83
+ hidden_states = block(hidden_states)
84
+
85
+ hidden_states = self.norm_out(hidden_states)
86
+
87
+ return hidden_states
88
+
89
+
90
+ class TripoSGDecoder(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels: int = 3,
94
+ out_channels: int = 1,
95
+ dim: int = 512,
96
+ num_attention_heads: int = 8,
97
+ num_layers: int = 16,
98
+ grad_type: str = "analytical",
99
+ grad_interval: float = 0.001,
100
+ ):
101
+ super().__init__()
102
+
103
+ if grad_type not in ["numerical", "analytical"]:
104
+ raise ValueError(f"grad_type must be one of ['numerical', 'analytical']")
105
+ self.grad_type = grad_type
106
+ self.grad_interval = grad_interval
107
+
108
+ self.blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim=dim,
112
+ num_attention_heads=num_attention_heads,
113
+ use_self_attention=True,
114
+ self_attention_norm_type="fp32_layer_norm",
115
+ use_cross_attention=False,
116
+ activation_fn="gelu",
117
+ norm_type="fp32_layer_norm",
118
+ norm_eps=1e-5,
119
+ qk_norm=False,
120
+ qkv_bias=False,
121
+ )
122
+ for _ in range(num_layers) # self attention
123
+ ]
124
+ + [
125
+ DiTBlock(
126
+ dim=dim,
127
+ num_attention_heads=num_attention_heads,
128
+ use_self_attention=False,
129
+ use_cross_attention=True,
130
+ cross_attention_dim=dim,
131
+ cross_attention_norm_type="layer_norm",
132
+ activation_fn="gelu",
133
+ norm_type="fp32_layer_norm",
134
+ norm_eps=1e-5,
135
+ qk_norm=False,
136
+ qkv_bias=False,
137
+ ) # cross attention
138
+ ]
139
+ )
140
+
141
+ self.proj_query = nn.Linear(in_channels, dim, bias=True)
142
+
143
+ self.norm_out = LayerNorm(dim)
144
+ self.proj_out = nn.Linear(dim, out_channels, bias=True)
145
+
146
+ def set_topk(self, topk):
147
+ self.blocks[-1].set_topk(topk)
148
+
149
+ def set_flash_processor(self, processor):
150
+ self.blocks[-1].set_flash_processor(processor)
151
+
152
+ def query_geometry(
153
+ self,
154
+ model_fn: callable,
155
+ queries: torch.Tensor,
156
+ sample: torch.Tensor,
157
+ grad: bool = False,
158
+ ):
159
+ logits = model_fn(queries, sample)
160
+ if grad:
161
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
162
+ if self.grad_type == "numerical":
163
+ interval = self.grad_interval
164
+ grad_value = []
165
+ for offset in [
166
+ (interval, 0, 0),
167
+ (0, interval, 0),
168
+ (0, 0, interval),
169
+ ]:
170
+ offset_tensor = torch.tensor(offset, device=queries.device)[
171
+ None, :
172
+ ]
173
+ res_p = model_fn(queries + offset_tensor, sample)[..., 0]
174
+ res_n = model_fn(queries - offset_tensor, sample)[..., 0]
175
+ grad_value.append((res_p - res_n) / (2 * interval))
176
+ grad_value = torch.stack(grad_value, dim=-1)
177
+ else:
178
+ queries_d = torch.clone(queries)
179
+ queries_d.requires_grad = True
180
+ with torch.enable_grad():
181
+ res_d = model_fn(queries_d, sample)
182
+ grad_value = torch.autograd.grad(
183
+ res_d,
184
+ [queries_d],
185
+ grad_outputs=torch.ones_like(res_d),
186
+ create_graph=self.training,
187
+ )[0]
188
+ else:
189
+ grad_value = None
190
+
191
+ return logits, grad_value
192
+
193
+ def forward(
194
+ self,
195
+ sample: torch.Tensor,
196
+ queries: torch.Tensor,
197
+ kv_cache: Optional[torch.Tensor] = None,
198
+ ):
199
+ if kv_cache is None:
200
+ hidden_states = sample
201
+ for _, block in enumerate(self.blocks[:-1]):
202
+ hidden_states = block(hidden_states)
203
+ kv_cache = hidden_states
204
+
205
+ # query grid logits by cross attention
206
+ def query_fn(q, kv):
207
+ q = self.proj_query(q)
208
+ l = self.blocks[-1](q, encoder_hidden_states=kv)
209
+ return self.proj_out(self.norm_out(l))
210
+
211
+ logits, grad = self.query_geometry(
212
+ query_fn, queries, kv_cache, grad=self.training
213
+ )
214
+ logits = logits * -1 if not isinstance(logits, Tuple) else logits[0] * -1
215
+
216
+ return logits, kv_cache
217
+
218
+
219
+ class TripoSGVAEModel(ModelMixin, ConfigMixin):
220
+ @register_to_config
221
+ def __init__(
222
+ self,
223
+ in_channels: int = 3, # NOTE xyz instead of feature dim
224
+ latent_channels: int = 64,
225
+ num_attention_heads: int = 8,
226
+ width_encoder: int = 512,
227
+ width_decoder: int = 1024,
228
+ num_layers_encoder: int = 8,
229
+ num_layers_decoder: int = 16,
230
+ embedding_type: str = "frequency",
231
+ embed_frequency: int = 8,
232
+ embed_include_pi: bool = False,
233
+ ):
234
+ super().__init__()
235
+
236
+ self.out_channels = 1
237
+
238
+ if embedding_type == "frequency":
239
+ self.embedder = FrequencyPositionalEmbedding(
240
+ num_freqs=embed_frequency,
241
+ logspace=True,
242
+ input_dim=in_channels,
243
+ include_pi=embed_include_pi,
244
+ )
245
+ else:
246
+ raise NotImplementedError(
247
+ f"Embedding type {embedding_type} is not supported."
248
+ )
249
+
250
+ self.encoder = TripoSGEncoder(
251
+ in_channels=in_channels + self.embedder.out_dim,
252
+ dim=width_encoder,
253
+ num_attention_heads=num_attention_heads,
254
+ num_layers=num_layers_encoder,
255
+ )
256
+ self.decoder = TripoSGDecoder(
257
+ in_channels=self.embedder.out_dim,
258
+ out_channels=self.out_channels,
259
+ dim=width_decoder,
260
+ num_attention_heads=num_attention_heads,
261
+ num_layers=num_layers_decoder,
262
+ )
263
+
264
+ self.quant = nn.Linear(width_encoder, latent_channels * 2, bias=True)
265
+ self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True)
266
+
267
+ self.use_slicing = False
268
+ self.slicing_length = 1
269
+
270
+ def set_flash_decoder(self):
271
+ self.decoder.set_flash_processor(FlashTripoSGAttnProcessor2_0())
272
+
273
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
274
+ def fuse_qkv_projections(self):
275
+ """
276
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
277
+ are fused. For cross-attention modules, key and value projection matrices are fused.
278
+
279
+ <Tip warning={true}>
280
+
281
+ This API is 🧪 experimental.
282
+
283
+ </Tip>
284
+ """
285
+ self.original_attn_processors = None
286
+
287
+ for _, attn_processor in self.attn_processors.items():
288
+ if "Added" in str(attn_processor.__class__.__name__):
289
+ raise ValueError(
290
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
291
+ )
292
+
293
+ self.original_attn_processors = self.attn_processors
294
+
295
+ for module in self.modules():
296
+ if isinstance(module, Attention):
297
+ module.fuse_projections(fuse=True)
298
+
299
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
300
+
301
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
302
+ def unfuse_qkv_projections(self):
303
+ """Disables the fused QKV projection if enabled.
304
+
305
+ <Tip warning={true}>
306
+
307
+ This API is 🧪 experimental.
308
+
309
+ </Tip>
310
+
311
+ """
312
+ if self.original_attn_processors is not None:
313
+ self.set_attn_processor(self.original_attn_processors)
314
+
315
+ @property
316
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
317
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
318
+ r"""
319
+ Returns:
320
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
321
+ indexed by its weight name.
322
+ """
323
+ # set recursively
324
+ processors = {}
325
+
326
+ def fn_recursive_add_processors(
327
+ name: str,
328
+ module: torch.nn.Module,
329
+ processors: Dict[str, AttentionProcessor],
330
+ ):
331
+ if hasattr(module, "get_processor"):
332
+ processors[f"{name}.processor"] = module.get_processor()
333
+
334
+ for sub_name, child in module.named_children():
335
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
336
+
337
+ return processors
338
+
339
+ for name, module in self.named_children():
340
+ fn_recursive_add_processors(name, module, processors)
341
+
342
+ return processors
343
+
344
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
345
+ def set_attn_processor(
346
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
347
+ ):
348
+ r"""
349
+ Sets the attention processor to use to compute attention.
350
+
351
+ Parameters:
352
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
353
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
354
+ for **all** `Attention` layers.
355
+
356
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
357
+ processor. This is strongly recommended when setting trainable attention processors.
358
+
359
+ """
360
+ count = len(self.attn_processors.keys())
361
+
362
+ if isinstance(processor, dict) and len(processor) != count:
363
+ raise ValueError(
364
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
365
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
366
+ )
367
+
368
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
369
+ if hasattr(module, "set_processor"):
370
+ if not isinstance(processor, dict):
371
+ module.set_processor(processor)
372
+ else:
373
+ module.set_processor(processor.pop(f"{name}.processor"))
374
+
375
+ for sub_name, child in module.named_children():
376
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
377
+
378
+ for name, module in self.named_children():
379
+ fn_recursive_attn_processor(name, module, processor)
380
+
381
+ def set_default_attn_processor(self):
382
+ """
383
+ Disables custom attention processors and sets the default attention implementation.
384
+ """
385
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
386
+
387
+ def enable_slicing(self, slicing_length: int = 1) -> None:
388
+ r"""
389
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
390
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
391
+ """
392
+ self.use_slicing = True
393
+ self.slicing_length = slicing_length
394
+
395
+ def disable_slicing(self) -> None:
396
+ r"""
397
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
398
+ decoding in one step.
399
+ """
400
+ self.use_slicing = False
401
+
402
+ def _sample_features(
403
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
404
+ ):
405
+ """
406
+ Sample points from features of the input point cloud.
407
+
408
+ Args:
409
+ x (torch.Tensor): The input point cloud. shape: (B, N, C)
410
+ num_tokens (int, optional): The number of points to sample. Defaults to 2048.
411
+ seed (Optional[int], optional): The random seed. Defaults to None.
412
+ """
413
+ rng = np.random.default_rng(seed)
414
+ indices = rng.choice(
415
+ x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
416
+ )
417
+ selected_points = x[:, indices]
418
+
419
+ batch_size, num_points, num_channels = selected_points.shape
420
+ flattened_points = selected_points.view(batch_size * num_points, num_channels)
421
+ batch_indices = (
422
+ torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
423
+ )
424
+
425
+ # fps sampling
426
+ sampling_ratio = 1.0 / 4
427
+ sampled_indices = fps(
428
+ flattened_points[:, :3],
429
+ batch_indices,
430
+ ratio=sampling_ratio,
431
+ random_start=self.training,
432
+ )
433
+ sampled_points = flattened_points[sampled_indices].view(
434
+ batch_size, -1, num_channels
435
+ )
436
+
437
+ return sampled_points
438
+
439
+ def _encode(
440
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
441
+ ):
442
+ position_channels = self.config.in_channels
443
+ positions, features = x[..., :position_channels], x[..., position_channels:]
444
+ x_kv = torch.cat([self.embedder(positions), features], dim=-1)
445
+
446
+ sampled_x = self._sample_features(x, num_tokens, seed)
447
+ positions, features = (
448
+ sampled_x[..., :position_channels],
449
+ sampled_x[..., position_channels:],
450
+ )
451
+ x_q = torch.cat([self.embedder(positions), features], dim=-1)
452
+
453
+ x = self.encoder(x_q, x_kv)
454
+
455
+ x = self.quant(x)
456
+
457
+ return x
458
+
459
+ @apply_forward_hook
460
+ def encode(
461
+ self, x: torch.Tensor, return_dict: bool = True, **kwargs
462
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
463
+ """
464
+ Encode a batch of point features into latents.
465
+ """
466
+ if self.use_slicing and x.shape[0] > 1:
467
+ encoded_slices = [
468
+ self._encode(x_slice, **kwargs)
469
+ for x_slice in x.split(self.slicing_length)
470
+ ]
471
+ h = torch.cat(encoded_slices)
472
+ else:
473
+ h = self._encode(x, **kwargs)
474
+
475
+ posterior = DiagonalGaussianDistribution(h, feature_dim=-1)
476
+
477
+ if not return_dict:
478
+ return (posterior,)
479
+ return AutoencoderKLOutput(latent_dist=posterior)
480
+
481
+ def _decode(
482
+ self,
483
+ z: torch.Tensor,
484
+ sampled_points: torch.Tensor,
485
+ num_chunks: int = 50000,
486
+ to_cpu: bool = False,
487
+ return_dict: bool = True,
488
+ ) -> Union[DecoderOutput, torch.Tensor]:
489
+ xyz_samples = sampled_points
490
+
491
+ z = self.post_quant(z)
492
+
493
+ num_points = xyz_samples.shape[1]
494
+ kv_cache = None
495
+ dec = []
496
+
497
+ for i in range(0, num_points, num_chunks):
498
+ queries = xyz_samples[:, i : i + num_chunks, :].to(z.device, dtype=z.dtype)
499
+ queries = self.embedder(queries)
500
+
501
+ z_, kv_cache = self.decoder(z, queries, kv_cache)
502
+ dec.append(z_ if not to_cpu else z_.cpu())
503
+
504
+ z = torch.cat(dec, dim=1)
505
+
506
+ if not return_dict:
507
+ return (z,)
508
+
509
+ return DecoderOutput(sample=z)
510
+
511
+ @apply_forward_hook
512
+ def decode(
513
+ self,
514
+ z: torch.Tensor,
515
+ sampled_points: torch.Tensor,
516
+ return_dict: bool = True,
517
+ **kwargs,
518
+ ) -> Union[DecoderOutput, torch.Tensor]:
519
+ if self.use_slicing and z.shape[0] > 1:
520
+ decoded_slices = [
521
+ self._decode(z_slice, p_slice, **kwargs).sample
522
+ for z_slice, p_slice in zip(
523
+ z.split(self.slicing_length),
524
+ sampled_points.split(self.slicing_length),
525
+ )
526
+ ]
527
+ decoded = torch.cat(decoded_slices)
528
+ else:
529
+ decoded = self._decode(z, sampled_points, **kwargs).sample
530
+
531
+ if not return_dict:
532
+ return (decoded,)
533
+ return DecoderOutput(sample=decoded)
534
+
535
+ def forward(self, x: torch.Tensor):
536
+ pass
triposg/models/autoencoders/vae.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers.utils.torch_utils import randn_tensor
6
+
7
+
8
+ class DiagonalGaussianDistribution(object):
9
+ def __init__(
10
+ self,
11
+ parameters: torch.Tensor,
12
+ deterministic: bool = False,
13
+ feature_dim: int = 1,
14
+ ):
15
+ self.parameters = parameters
16
+ self.feature_dim = feature_dim
17
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim)
18
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
19
+ self.deterministic = deterministic
20
+ self.std = torch.exp(0.5 * self.logvar)
21
+ self.var = torch.exp(self.logvar)
22
+ if self.deterministic:
23
+ self.var = self.std = torch.zeros_like(
24
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
25
+ )
26
+
27
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
28
+ # make sure sample is on the same device as the parameters and has same dtype
29
+ sample = randn_tensor(
30
+ self.mean.shape,
31
+ generator=generator,
32
+ device=self.parameters.device,
33
+ dtype=self.parameters.dtype,
34
+ )
35
+ x = self.mean + self.std * sample
36
+ return x
37
+
38
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
39
+ if self.deterministic:
40
+ return torch.Tensor([0.0])
41
+ else:
42
+ if other is None:
43
+ return 0.5 * torch.sum(
44
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
45
+ dim=[1, 2, 3],
46
+ )
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var
51
+ - 1.0
52
+ - self.logvar
53
+ + other.logvar,
54
+ dim=[1, 2, 3],
55
+ )
56
+
57
+ def nll(
58
+ self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
59
+ ) -> torch.Tensor:
60
+ if self.deterministic:
61
+ return torch.Tensor([0.0])
62
+ logtwopi = np.log(2.0 * np.pi)
63
+ return 0.5 * torch.sum(
64
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
65
+ dim=dims,
66
+ )
67
+
68
+ def mode(self) -> torch.Tensor:
69
+ return self.mean
triposg/models/embeddings.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class FrequencyPositionalEmbedding(nn.Module):
6
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
7
+ each feature dimension of `x[..., i]` into:
8
+ [
9
+ sin(x[..., i]),
10
+ sin(f_1*x[..., i]),
11
+ sin(f_2*x[..., i]),
12
+ ...
13
+ sin(f_N * x[..., i]),
14
+ cos(x[..., i]),
15
+ cos(f_1*x[..., i]),
16
+ cos(f_2*x[..., i]),
17
+ ...
18
+ cos(f_N * x[..., i]),
19
+ x[..., i] # only present if include_input is True.
20
+ ], here f_i is the frequency.
21
+
22
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
23
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
24
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
25
+
26
+ Args:
27
+ num_freqs (int): the number of frequencies, default is 6;
28
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
29
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
30
+ input_dim (int): the input dimension, default is 3;
31
+ include_input (bool): include the input tensor or not, default is True.
32
+
33
+ Attributes:
34
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
36
+
37
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
38
+ otherwise, it is input_dim * num_freqs * 2.
39
+
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ num_freqs: int = 6,
45
+ logspace: bool = True,
46
+ input_dim: int = 3,
47
+ include_input: bool = True,
48
+ include_pi: bool = True,
49
+ ) -> None:
50
+ """The initialization"""
51
+
52
+ super().__init__()
53
+
54
+ if logspace:
55
+ frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
56
+ else:
57
+ frequencies = torch.linspace(
58
+ 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
59
+ )
60
+
61
+ if include_pi:
62
+ frequencies *= torch.pi
63
+
64
+ self.register_buffer("frequencies", frequencies, persistent=False)
65
+ self.include_input = include_input
66
+ self.num_freqs = num_freqs
67
+
68
+ self.out_dim = self.get_dims(input_dim)
69
+
70
+ def get_dims(self, input_dim):
71
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
72
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
73
+
74
+ return out_dim
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ """Forward process.
78
+
79
+ Args:
80
+ x: tensor of shape [..., dim]
81
+
82
+ Returns:
83
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
84
+ where temp is 1 if include_input is True and 0 otherwise.
85
+ """
86
+
87
+ if self.num_freqs > 0:
88
+ embed = (x[..., None].contiguous() * self.frequencies).view(
89
+ *x.shape[:-1], -1
90
+ )
91
+ if self.include_input:
92
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
93
+ else:
94
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
95
+ else:
96
+ return x
triposg/models/transformers/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ from .triposg_transformer import TripoSGDiTModel
4
+
5
+
6
+ def default_set_attn_proc_func(
7
+ name: str,
8
+ hidden_size: int,
9
+ cross_attention_dim: Optional[int],
10
+ ori_attn_proc: object,
11
+ ) -> object:
12
+ return ori_attn_proc
13
+
14
+
15
+ def set_transformer_attn_processor(
16
+ transformer: TripoSGDiTModel,
17
+ set_self_attn_proc_func: Callable = default_set_attn_proc_func,
18
+ set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func,
19
+ set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func,
20
+ set_self_attn_module_names: Optional[list[str]] = None,
21
+ set_cross_attn_1_module_names: Optional[list[str]] = None,
22
+ set_cross_attn_2_module_names: Optional[list[str]] = None,
23
+ ) -> None:
24
+ do_set_processor = lambda name, module_names: (
25
+ any([name.startswith(module_name) for module_name in module_names])
26
+ if module_names is not None
27
+ else True
28
+ ) # prefix match
29
+
30
+ attn_procs = {}
31
+ for name, attn_processor in transformer.attn_processors.items():
32
+ hidden_size = transformer.config.width
33
+ if name.endswith("attn1.processor"):
34
+ # self attention
35
+ attn_procs[name] = (
36
+ set_self_attn_proc_func(name, hidden_size, None, attn_processor)
37
+ if do_set_processor(name, set_self_attn_module_names)
38
+ else attn_processor
39
+ )
40
+ elif name.endswith("attn2.processor"):
41
+ # cross attention
42
+ cross_attention_dim = transformer.config.cross_attention_dim
43
+ attn_procs[name] = (
44
+ set_cross_attn_1_proc_func(
45
+ name, hidden_size, cross_attention_dim, attn_processor
46
+ )
47
+ if do_set_processor(name, set_cross_attn_1_module_names)
48
+ else attn_processor
49
+ )
50
+ elif name.endswith("attn2_2.processor"):
51
+ # cross attention 2
52
+ cross_attention_dim = transformer.config.cross_attention_2_dim
53
+ attn_procs[name] = (
54
+ set_cross_attn_2_proc_func(
55
+ name, hidden_size, cross_attention_dim, attn_processor
56
+ )
57
+ if do_set_processor(name, set_cross_attn_2_module_names)
58
+ else attn_processor
59
+ )
60
+
61
+ transformer.set_attn_processor(attn_procs)
triposg/models/transformers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.86 kB). View file
 
triposg/models/transformers/__pycache__/modeling_outputs.cpython-310.pyc ADDED
Binary file (475 Bytes). View file
 
triposg/models/transformers/__pycache__/triposg_transformer.cpython-310.pyc ADDED
Binary file (19.6 kB). View file
 
triposg/models/transformers/modeling_outputs.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass
7
+ class Transformer1DModelOutput:
8
+ sample: torch.FloatTensor
triposg/models/transformers/triposg_transformer.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 VAST-AI-Research and contributors
2
+
3
+ # This code is based on Tencent HunyuanDiT (https://huggingface.co/Tencent-Hunyuan/HunyuanDiT),
4
+ # which is licensed under the Tencent Hunyuan Community License Agreement.
5
+ # Portions of this code are copied or adapted from HunyuanDiT.
6
+ # See the original license below:
7
+
8
+ # ---- Start of Tencent Hunyuan Community License Agreement ----
9
+
10
+ # TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
11
+ # Tencent Hunyuan DiT Release Date: 14 May 2024
12
+ # THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
13
+ # By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
14
+ # 1. DEFINITIONS.
15
+ # a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
16
+ # b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
17
+ # c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
18
+ # d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
19
+ # e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
20
+ # f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
21
+ # g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
22
+ # h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
23
+ # i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
24
+ # j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan DiT released at https://huggingface.co/Tencent-Hunyuan/HunyuanDiT.
25
+ # k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
26
+ # l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union.
27
+ # m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
28
+ # n. “including” shall mean including but not limited to.
29
+ # 2. GRANT OF RIGHTS.
30
+ # We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
31
+ # 3. DISTRIBUTION.
32
+ # You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
33
+ # a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
34
+ # b. You must cause any modified files to carry prominent notices stating that You changed the files;
35
+ # c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
36
+ # d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
37
+ # You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
38
+ # 4. ADDITIONAL COMMERCIAL TERMS.
39
+ # If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
40
+ # 5. RULES OF USE.
41
+ # a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
42
+ # b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other large language model (other than Tencent Hunyuan or Model Derivatives thereof).
43
+ # c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
44
+ # 6. INTELLECTUAL PROPERTY.
45
+ # a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
46
+ # b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
47
+ # c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
48
+ # d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
49
+ # 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
50
+ # a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
51
+ # b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
52
+ # c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
53
+ # 8. SURVIVAL AND TERMINATION.
54
+ # a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
55
+ # b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
56
+ # 9. GOVERNING LAW AND JURISDICTION.
57
+ # a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
58
+ # b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
59
+ #
60
+ # EXHIBIT A
61
+ # ACCEPTABLE USE POLICY
62
+
63
+ # Tencent reserves the right to update this Acceptable Use Policy from time to time.
64
+ # Last modified: [insert date]
65
+
66
+ # Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
67
+ # 1. Outside the Territory;
68
+ # 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
69
+ # 3. To harm Yourself or others;
70
+ # 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
71
+ # 5. To override or circumvent the safety guardrails and safeguards We have put in place;
72
+ # 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
73
+ # 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
74
+ # 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
75
+ # 9. To intentionally defame, disparage or otherwise harass others;
76
+ # 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
77
+ # 11. To generate or disseminate personal identifiable information with the purpose of harming others;
78
+ # 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
79
+ # 13. To impersonate another individual without consent, authorization, or legal right;
80
+ # 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
81
+ # 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
82
+ # 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
83
+ # 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
84
+ # 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
85
+ # 19. For military purposes;
86
+ # 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
87
+
88
+ # ---- End of Tencent Hunyuan Community License Agreement ----
89
+
90
+ # Please note that the use of this code is subject to the terms and conditions
91
+ # of the Tencent Hunyuan Community License Agreement, including the Acceptable Use Policy.
92
+
93
+ from typing import Any, Dict, Optional, Tuple, Union
94
+
95
+ import torch
96
+ import torch.utils.checkpoint
97
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
98
+ from diffusers.loaders import PeftAdapterMixin
99
+ from diffusers.models.attention import FeedForward
100
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
101
+ from diffusers.models.embeddings import (
102
+ GaussianFourierProjection,
103
+ TimestepEmbedding,
104
+ Timesteps,
105
+ )
106
+ from diffusers.models.modeling_utils import ModelMixin
107
+ from diffusers.models.normalization import (
108
+ AdaLayerNormContinuous,
109
+ FP32LayerNorm,
110
+ LayerNorm,
111
+ )
112
+ from diffusers.utils import (
113
+ USE_PEFT_BACKEND,
114
+ is_torch_version,
115
+ logging,
116
+ scale_lora_layers,
117
+ unscale_lora_layers,
118
+ )
119
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
120
+ from torch import nn
121
+
122
+ from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0
123
+ from .modeling_outputs import Transformer1DModelOutput
124
+
125
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
126
+
127
+
128
+ @maybe_allow_in_graph
129
+ class DiTBlock(nn.Module):
130
+ r"""
131
+ Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
132
+ QKNorm
133
+
134
+ Parameters:
135
+ dim (`int`):
136
+ The number of channels in the input and output.
137
+ num_attention_heads (`int`):
138
+ The number of headsto use for multi-head attention.
139
+ cross_attention_dim (`int`,*optional*):
140
+ The size of the encoder_hidden_states vector for cross attention.
141
+ dropout(`float`, *optional*, defaults to 0.0):
142
+ The dropout probability to use.
143
+ activation_fn (`str`,*optional*, defaults to `"geglu"`):
144
+ Activation function to be used in feed-forward. .
145
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
146
+ Whether to use learnable elementwise affine parameters for normalization.
147
+ norm_eps (`float`, *optional*, defaults to 1e-6):
148
+ A small constant added to the denominator in normalization layers to prevent division by zero.
149
+ final_dropout (`bool` *optional*, defaults to False):
150
+ Whether to apply a final dropout after the last feed-forward layer.
151
+ ff_inner_dim (`int`, *optional*):
152
+ The size of the hidden layer in the feed-forward block. Defaults to `None`.
153
+ ff_bias (`bool`, *optional*, defaults to `True`):
154
+ Whether to use bias in the feed-forward block.
155
+ skip (`bool`, *optional*, defaults to `False`):
156
+ Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
157
+ qk_norm (`bool`, *optional*, defaults to `True`):
158
+ Whether to use normalization in QK calculation. Defaults to `True`.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ dim: int,
164
+ num_attention_heads: int,
165
+ use_self_attention: bool = True,
166
+ self_attention_norm_type: Optional[str] = None,
167
+ use_cross_attention: bool = True, # ada layer norm
168
+ cross_attention_dim: Optional[int] = None,
169
+ cross_attention_norm_type: Optional[str] = "fp32_layer_norm",
170
+ use_cross_attention_2: bool = False,
171
+ cross_attention_2_dim: Optional[int] = None,
172
+ cross_attention_2_norm_type: Optional[str] = None,
173
+ dropout=0.0,
174
+ activation_fn: str = "gelu",
175
+ norm_type: str = "fp32_layer_norm", # TODO
176
+ norm_elementwise_affine: bool = True,
177
+ norm_eps: float = 1e-5,
178
+ final_dropout: bool = False,
179
+ ff_inner_dim: Optional[int] = None, # int(dim * 4) if None
180
+ ff_bias: bool = True,
181
+ skip: bool = False,
182
+ skip_concat_front: bool = False, # [x, skip] or [skip, x]
183
+ skip_norm_last: bool = False, # this is an error
184
+ qk_norm: bool = True,
185
+ qkv_bias: bool = True,
186
+ ):
187
+ super().__init__()
188
+
189
+ self.use_self_attention = use_self_attention
190
+ self.use_cross_attention = use_cross_attention
191
+ self.use_cross_attention_2 = use_cross_attention_2
192
+ self.skip_concat_front = skip_concat_front
193
+ self.skip_norm_last = skip_norm_last
194
+ # Define 3 blocks. Each block has its own normalization layer.
195
+ # NOTE: when new version comes, check norm2 and norm 3
196
+ # 1. Self-Attn
197
+ if use_self_attention:
198
+ if (
199
+ self_attention_norm_type == "fp32_layer_norm"
200
+ or self_attention_norm_type is None
201
+ ):
202
+ self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
203
+ else:
204
+ raise NotImplementedError
205
+
206
+ self.attn1 = Attention(
207
+ query_dim=dim,
208
+ cross_attention_dim=None,
209
+ dim_head=dim // num_attention_heads,
210
+ heads=num_attention_heads,
211
+ qk_norm="rms_norm" if qk_norm else None,
212
+ eps=1e-6,
213
+ bias=qkv_bias,
214
+ processor=TripoSGAttnProcessor2_0(),
215
+ )
216
+
217
+ # 2. Cross-Attn
218
+ if use_cross_attention:
219
+ assert cross_attention_dim is not None
220
+
221
+ self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
222
+
223
+ self.attn2 = Attention(
224
+ query_dim=dim,
225
+ cross_attention_dim=cross_attention_dim,
226
+ dim_head=dim // num_attention_heads,
227
+ heads=num_attention_heads,
228
+ qk_norm="rms_norm" if qk_norm else None,
229
+ cross_attention_norm=cross_attention_norm_type,
230
+ eps=1e-6,
231
+ bias=qkv_bias,
232
+ processor=TripoSGAttnProcessor2_0(),
233
+ )
234
+
235
+ if use_cross_attention_2:
236
+ assert cross_attention_2_dim is not None
237
+
238
+ self.norm2_2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
239
+
240
+ self.attn2_2 = Attention(
241
+ query_dim=dim,
242
+ cross_attention_dim=cross_attention_2_dim,
243
+ dim_head=dim // num_attention_heads,
244
+ heads=num_attention_heads,
245
+ qk_norm="rms_norm" if qk_norm else None,
246
+ cross_attention_norm=cross_attention_2_norm_type,
247
+ eps=1e-6,
248
+ bias=qkv_bias,
249
+ processor=TripoSGAttnProcessor2_0(),
250
+ )
251
+
252
+ # 3. Feed-forward
253
+ self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
254
+
255
+ self.ff = FeedForward(
256
+ dim,
257
+ dropout=dropout, ### 0.0
258
+ activation_fn=activation_fn, ### approx GeLU
259
+ final_dropout=final_dropout, ### 0.0
260
+ inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
261
+ bias=ff_bias,
262
+ )
263
+
264
+ # 4. Skip Connection
265
+ if skip:
266
+ self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True)
267
+ self.skip_linear = nn.Linear(2 * dim, dim)
268
+ else:
269
+ self.skip_linear = None
270
+
271
+ # let chunk size default to None
272
+ self._chunk_size = None
273
+ self._chunk_dim = 0
274
+
275
+ def set_topk(self, topk):
276
+ self.flash_processor.topk = topk
277
+
278
+ def set_flash_processor(self, flash_processor):
279
+ self.flash_processor = flash_processor
280
+ self.attn2.processor = self.flash_processor
281
+
282
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
283
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
284
+ # Sets chunk feed-forward
285
+ self._chunk_size = chunk_size
286
+ self._chunk_dim = dim
287
+
288
+ def forward(
289
+ self,
290
+ hidden_states: torch.Tensor,
291
+ encoder_hidden_states: Optional[torch.Tensor] = None,
292
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
293
+ temb: Optional[torch.Tensor] = None,
294
+ image_rotary_emb: Optional[torch.Tensor] = None,
295
+ skip: Optional[torch.Tensor] = None,
296
+ attention_kwargs: Optional[Dict[str, Any]] = None,
297
+ ) -> torch.Tensor:
298
+ # Prepare attention kwargs
299
+ attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
300
+ cross_attention_scale = attention_kwargs.pop("cross_attention_scale", 1.0)
301
+ cross_attention_2_scale = attention_kwargs.pop("cross_attention_2_scale", 1.0)
302
+
303
+ # Notice that normalization is always applied before the real computation in the following blocks.
304
+ # 0. Long Skip Connection
305
+ if self.skip_linear is not None:
306
+ cat = torch.cat(
307
+ (
308
+ [skip, hidden_states]
309
+ if self.skip_concat_front
310
+ else [hidden_states, skip]
311
+ ),
312
+ dim=-1,
313
+ )
314
+ if self.skip_norm_last:
315
+ # don't do this
316
+ hidden_states = self.skip_linear(cat)
317
+ hidden_states = self.skip_norm(hidden_states)
318
+ else:
319
+ cat = self.skip_norm(cat)
320
+ hidden_states = self.skip_linear(cat)
321
+
322
+ # 1. Self-Attention
323
+ if self.use_self_attention:
324
+ norm_hidden_states = self.norm1(hidden_states)
325
+ attn_output = self.attn1(
326
+ norm_hidden_states,
327
+ image_rotary_emb=image_rotary_emb,
328
+ **attention_kwargs,
329
+ )
330
+ hidden_states = hidden_states + attn_output
331
+
332
+ # 2. Cross-Attention
333
+ if self.use_cross_attention:
334
+ if self.use_cross_attention_2:
335
+ hidden_states = (
336
+ hidden_states
337
+ + self.attn2(
338
+ self.norm2(hidden_states),
339
+ encoder_hidden_states=encoder_hidden_states,
340
+ image_rotary_emb=image_rotary_emb,
341
+ **attention_kwargs,
342
+ ) * cross_attention_scale
343
+ + self.attn2_2(
344
+ self.norm2_2(hidden_states),
345
+ encoder_hidden_states=encoder_hidden_states_2,
346
+ image_rotary_emb=image_rotary_emb,
347
+ **attention_kwargs,
348
+ ) * cross_attention_2_scale
349
+ )
350
+ else:
351
+ hidden_states = hidden_states + self.attn2(
352
+ self.norm2(hidden_states),
353
+ encoder_hidden_states=encoder_hidden_states,
354
+ image_rotary_emb=image_rotary_emb,
355
+ **attention_kwargs,
356
+ ) * cross_attention_scale
357
+
358
+ # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
359
+ mlp_inputs = self.norm3(hidden_states)
360
+ hidden_states = hidden_states + self.ff(mlp_inputs)
361
+
362
+ return hidden_states
363
+
364
+
365
+ class TripoSGDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
366
+ """
367
+ TripoSG: Diffusion model with a Transformer backbone.
368
+
369
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
370
+
371
+ Parameters:
372
+ num_attention_heads (`int`, *optional*, defaults to 16):
373
+ The number of heads to use for multi-head attention.
374
+ attention_head_dim (`int`, *optional*, defaults to 88):
375
+ The number of channels in each head.
376
+ in_channels (`int`, *optional*):
377
+ The number of channels in the input and output (specify if the input is **continuous**).
378
+ patch_size (`int`, *optional*):
379
+ The size of the patch to use for the input.
380
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
381
+ Activation function to use in feed-forward.
382
+ sample_size (`int`, *optional*):
383
+ The width of the latent images. This is fixed during training since it is used to learn a number of
384
+ position embeddings.
385
+ dropout (`float`, *optional*, defaults to 0.0):
386
+ The dropout probability to use.
387
+ cross_attention_dim (`int`, *optional*):
388
+ The number of dimension in the clip text embedding.
389
+ hidden_size (`int`, *optional*):
390
+ The size of hidden layer in the conditioning embedding layers.
391
+ num_layers (`int`, *optional*, defaults to 1):
392
+ The number of layers of Transformer blocks to use.
393
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
394
+ The ratio of the hidden layer size to the input size.
395
+ learn_sigma (`bool`, *optional*, defaults to `True`):
396
+ Whether to predict variance.
397
+ cross_attention_dim_t5 (`int`, *optional*):
398
+ The number dimensions in t5 text embedding.
399
+ pooled_projection_dim (`int`, *optional*):
400
+ The size of the pooled projection.
401
+ text_len (`int`, *optional*):
402
+ The length of the clip text embedding.
403
+ text_len_t5 (`int`, *optional*):
404
+ The length of the T5 text embedding.
405
+ use_style_cond_and_image_meta_size (`bool`, *optional*):
406
+ Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
407
+ """
408
+
409
+ _supports_gradient_checkpointing = True
410
+
411
+ @register_to_config
412
+ def __init__(
413
+ self,
414
+ num_attention_heads: int = 16,
415
+ width: int = 2048,
416
+ in_channels: int = 64,
417
+ num_layers: int = 21,
418
+ cross_attention_dim: int = 1024,
419
+ use_cross_attention_2: bool = False,
420
+ cross_attention_2_dim: Optional[int] = None
421
+ ):
422
+ super().__init__()
423
+ self.out_channels = in_channels
424
+ self.num_heads = num_attention_heads
425
+ self.inner_dim = width
426
+ self.mlp_ratio = 4.0
427
+
428
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
429
+ "positional",
430
+ inner_dim=self.inner_dim,
431
+ flip_sin_to_cos=False,
432
+ freq_shift=0,
433
+ time_embedding_dim=None,
434
+ )
435
+ self.time_proj = TimestepEmbedding(
436
+ timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
437
+ )
438
+ self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
439
+
440
+ self.blocks = nn.ModuleList(
441
+ [
442
+ DiTBlock(
443
+ dim=self.inner_dim,
444
+ num_attention_heads=self.config.num_attention_heads,
445
+ use_self_attention=True,
446
+ self_attention_norm_type="fp32_layer_norm",
447
+ use_cross_attention=True,
448
+ cross_attention_dim=cross_attention_dim,
449
+ cross_attention_norm_type=None,
450
+ use_cross_attention_2=use_cross_attention_2,
451
+ cross_attention_2_dim=cross_attention_2_dim,
452
+ cross_attention_2_norm_type=None,
453
+ activation_fn="gelu",
454
+ norm_type="fp32_layer_norm", # TODO
455
+ norm_eps=1e-5,
456
+ ff_inner_dim=int(self.inner_dim * self.mlp_ratio),
457
+ skip=layer > num_layers // 2,
458
+ skip_concat_front=True,
459
+ skip_norm_last=True, # this is an error
460
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
461
+ qkv_bias=False,
462
+ )
463
+ for layer in range(num_layers)
464
+ ]
465
+ )
466
+
467
+ self.norm_out = LayerNorm(self.inner_dim)
468
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
469
+
470
+ self.gradient_checkpointing = False
471
+
472
+ def _set_gradient_checkpointing(self, module, value=False):
473
+ self.gradient_checkpointing = value
474
+
475
+ def _set_time_proj(
476
+ self,
477
+ time_embedding_type: str,
478
+ inner_dim: int,
479
+ flip_sin_to_cos: bool,
480
+ freq_shift: float,
481
+ time_embedding_dim: int,
482
+ ) -> Tuple[int, int]:
483
+ if time_embedding_type == "fourier":
484
+ time_embed_dim = time_embedding_dim or inner_dim * 2
485
+ if time_embed_dim % 2 != 0:
486
+ raise ValueError(
487
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
488
+ )
489
+ self.time_embed = GaussianFourierProjection(
490
+ time_embed_dim // 2,
491
+ set_W_to_weight=False,
492
+ log=False,
493
+ flip_sin_to_cos=flip_sin_to_cos,
494
+ )
495
+ timestep_input_dim = time_embed_dim
496
+ elif time_embedding_type == "positional":
497
+ time_embed_dim = time_embedding_dim or inner_dim * 4
498
+
499
+ self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
500
+ timestep_input_dim = inner_dim
501
+ else:
502
+ raise ValueError(
503
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
504
+ )
505
+
506
+ return time_embed_dim, timestep_input_dim
507
+
508
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
509
+ def fuse_qkv_projections(self):
510
+ """
511
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
512
+ are fused. For cross-attention modules, key and value projection matrices are fused.
513
+
514
+ <Tip warning={true}>
515
+
516
+ This API is 🧪 experimental.
517
+
518
+ </Tip>
519
+ """
520
+ self.original_attn_processors = None
521
+
522
+ for _, attn_processor in self.attn_processors.items():
523
+ if "Added" in str(attn_processor.__class__.__name__):
524
+ raise ValueError(
525
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
526
+ )
527
+
528
+ self.original_attn_processors = self.attn_processors
529
+
530
+ for module in self.modules():
531
+ if isinstance(module, Attention):
532
+ module.fuse_projections(fuse=True)
533
+
534
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
535
+
536
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
537
+ def unfuse_qkv_projections(self):
538
+ """Disables the fused QKV projection if enabled.
539
+
540
+ <Tip warning={true}>
541
+
542
+ This API is 🧪 experimental.
543
+
544
+ </Tip>
545
+
546
+ """
547
+ if self.original_attn_processors is not None:
548
+ self.set_attn_processor(self.original_attn_processors)
549
+
550
+ @property
551
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
552
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
553
+ r"""
554
+ Returns:
555
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
556
+ indexed by its weight name.
557
+ """
558
+ # set recursively
559
+ processors = {}
560
+
561
+ def fn_recursive_add_processors(
562
+ name: str,
563
+ module: torch.nn.Module,
564
+ processors: Dict[str, AttentionProcessor],
565
+ ):
566
+ if hasattr(module, "get_processor"):
567
+ processors[f"{name}.processor"] = module.get_processor()
568
+
569
+ for sub_name, child in module.named_children():
570
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
571
+
572
+ return processors
573
+
574
+ for name, module in self.named_children():
575
+ fn_recursive_add_processors(name, module, processors)
576
+
577
+ return processors
578
+
579
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
580
+ def set_attn_processor(
581
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
582
+ ):
583
+ r"""
584
+ Sets the attention processor to use to compute attention.
585
+
586
+ Parameters:
587
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
588
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
589
+ for **all** `Attention` layers.
590
+
591
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
592
+ processor. This is strongly recommended when setting trainable attention processors.
593
+
594
+ """
595
+ count = len(self.attn_processors.keys())
596
+
597
+ if isinstance(processor, dict) and len(processor) != count:
598
+ raise ValueError(
599
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
600
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
601
+ )
602
+
603
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
604
+ if hasattr(module, "set_processor"):
605
+ if not isinstance(processor, dict):
606
+ module.set_processor(processor)
607
+ else:
608
+ module.set_processor(processor.pop(f"{name}.processor"))
609
+
610
+ for sub_name, child in module.named_children():
611
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
612
+
613
+ for name, module in self.named_children():
614
+ fn_recursive_attn_processor(name, module, processor)
615
+
616
+ def set_default_attn_processor(self):
617
+ """
618
+ Disables custom attention processors and sets the default attention implementation.
619
+ """
620
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
621
+
622
+ def forward(
623
+ self,
624
+ hidden_states: Optional[torch.Tensor],
625
+ timestep: Union[int, float, torch.LongTensor],
626
+ encoder_hidden_states: Optional[torch.Tensor] = None,
627
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
628
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
629
+ attention_kwargs: Optional[Dict[str, Any]] = None,
630
+ return_dict: bool = True,
631
+ ):
632
+ """
633
+ The [`HunyuanDiT2DModel`] forward method.
634
+
635
+ Args:
636
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
637
+ The input tensor.
638
+ timestep ( `torch.LongTensor`, *optional*):
639
+ Used to indicate denoising step.
640
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
641
+ Conditional embeddings for cross attention layer.
642
+ return_dict: bool
643
+ Whether to return a dictionary.
644
+ """
645
+
646
+ if attention_kwargs is not None:
647
+ attention_kwargs = attention_kwargs.copy()
648
+ lora_scale = attention_kwargs.pop("scale", 1.0)
649
+ else:
650
+ lora_scale = 1.0
651
+
652
+ if USE_PEFT_BACKEND:
653
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
654
+ scale_lora_layers(self, lora_scale)
655
+ else:
656
+ if (
657
+ attention_kwargs is not None
658
+ and attention_kwargs.get("scale", None) is not None
659
+ ):
660
+ logger.warning(
661
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
662
+ )
663
+
664
+ _, N, _ = hidden_states.shape
665
+
666
+ temb = self.time_embed(timestep).to(hidden_states.dtype)
667
+ temb = self.time_proj(temb)
668
+ temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states
669
+
670
+ hidden_states = self.proj_in(hidden_states)
671
+
672
+ # N + 1 token
673
+ hidden_states = torch.cat([temb, hidden_states], dim=1)
674
+
675
+ skips = []
676
+ for layer, block in enumerate(self.blocks):
677
+ skip = None if layer <= self.config.num_layers // 2 else skips.pop()
678
+
679
+ if self.training and self.gradient_checkpointing:
680
+
681
+ def create_custom_forward(module):
682
+ def custom_forward(*inputs):
683
+ return module(*inputs)
684
+
685
+ return custom_forward
686
+
687
+ ckpt_kwargs: Dict[str, Any] = (
688
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
689
+ )
690
+ hidden_states = torch.utils.checkpoint.checkpoint(
691
+ create_custom_forward(block),
692
+ hidden_states,
693
+ encoder_hidden_states,
694
+ encoder_hidden_states_2,
695
+ temb,
696
+ image_rotary_emb,
697
+ skip,
698
+ attention_kwargs,
699
+ **ckpt_kwargs,
700
+ )
701
+ else:
702
+ hidden_states = block(
703
+ hidden_states,
704
+ encoder_hidden_states=encoder_hidden_states,
705
+ encoder_hidden_states_2=encoder_hidden_states_2,
706
+ temb=temb,
707
+ image_rotary_emb=image_rotary_emb,
708
+ skip=skip,
709
+ attention_kwargs=attention_kwargs,
710
+ ) # (N, L, D)
711
+
712
+ if layer < self.config.num_layers // 2:
713
+ skips.append(hidden_states)
714
+
715
+ # final layer
716
+ hidden_states = self.norm_out(hidden_states)
717
+ hidden_states = hidden_states[:, -N:]
718
+ hidden_states = self.proj_out(hidden_states)
719
+
720
+ if USE_PEFT_BACKEND:
721
+ # remove `lora_scale` from each PEFT layer
722
+ unscale_lora_layers(self, lora_scale)
723
+
724
+ if not return_dict:
725
+ return (hidden_states,)
726
+
727
+ return Transformer1DModelOutput(sample=hidden_states)
728
+
729
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
730
+ def enable_forward_chunking(
731
+ self, chunk_size: Optional[int] = None, dim: int = 0
732
+ ) -> None:
733
+ """
734
+ Sets the attention processor to use [feed forward
735
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
736
+
737
+ Parameters:
738
+ chunk_size (`int`, *optional*):
739
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
740
+ over each tensor of dim=`dim`.
741
+ dim (`int`, *optional*, defaults to `0`):
742
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
743
+ or dim=1 (sequence length).
744
+ """
745
+ if dim not in [0, 1]:
746
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
747
+
748
+ # By default chunk size is 1
749
+ chunk_size = chunk_size or 1
750
+
751
+ def fn_recursive_feed_forward(
752
+ module: torch.nn.Module, chunk_size: int, dim: int
753
+ ):
754
+ if hasattr(module, "set_chunk_feed_forward"):
755
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
756
+
757
+ for child in module.children():
758
+ fn_recursive_feed_forward(child, chunk_size, dim)
759
+
760
+ for module in self.children():
761
+ fn_recursive_feed_forward(module, chunk_size, dim)
762
+
763
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
764
+ def disable_forward_chunking(self):
765
+ def fn_recursive_feed_forward(
766
+ module: torch.nn.Module, chunk_size: int, dim: int
767
+ ):
768
+ if hasattr(module, "set_chunk_feed_forward"):
769
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
770
+
771
+ for child in module.children():
772
+ fn_recursive_feed_forward(child, chunk_size, dim)
773
+
774
+ for module in self.children():
775
+ fn_recursive_feed_forward(module, None, 0)
triposg/pipelines/__pycache__/pipeline_triposg.cpython-310.pyc ADDED
Binary file (9.13 kB). View file
 
triposg/pipelines/__pycache__/pipeline_triposg_output.cpython-310.pyc ADDED
Binary file (734 Bytes). View file
 
triposg/pipelines/__pycache__/pipeline_utils.cpython-310.pyc ADDED
Binary file (3.65 kB). View file
 
triposg/pipelines/pipeline_triposg.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import PIL.Image
8
+ import torch
9
+ import trimesh
10
+ from diffusers.image_processor import PipelineImageInput
11
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
12
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
13
+ from diffusers.utils import logging
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from transformers import (
16
+ BitImageProcessor,
17
+ Dinov2Model,
18
+ )
19
+ from ..inference_utils import hierarchical_extract_geometry, flash_extract_geometry
20
+
21
+ from ..models.autoencoders import TripoSGVAEModel
22
+ from ..models.transformers import TripoSGDiTModel
23
+ from .pipeline_triposg_output import TripoSGPipelineOutput
24
+ from .pipeline_utils import TransformerDiffusionMixin
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+
29
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
30
+ def retrieve_timesteps(
31
+ scheduler,
32
+ num_inference_steps: Optional[int] = None,
33
+ device: Optional[Union[str, torch.device]] = None,
34
+ timesteps: Optional[List[int]] = None,
35
+ sigmas: Optional[List[float]] = None,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
40
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
41
+
42
+ Args:
43
+ scheduler (`SchedulerMixin`):
44
+ The scheduler to get timesteps from.
45
+ num_inference_steps (`int`):
46
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
47
+ must be `None`.
48
+ device (`str` or `torch.device`, *optional*):
49
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
50
+ timesteps (`List[int]`, *optional*):
51
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
52
+ `num_inference_steps` and `sigmas` must be `None`.
53
+ sigmas (`List[float]`, *optional*):
54
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
55
+ `num_inference_steps` and `timesteps` must be `None`.
56
+
57
+ Returns:
58
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
59
+ second element is the number of inference steps.
60
+ """
61
+ if timesteps is not None and sigmas is not None:
62
+ raise ValueError(
63
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
64
+ )
65
+ if timesteps is not None:
66
+ accepts_timesteps = "timesteps" in set(
67
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
68
+ )
69
+ if not accepts_timesteps:
70
+ raise ValueError(
71
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
72
+ f" timestep schedules. Please check whether you are using the correct scheduler."
73
+ )
74
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
75
+ timesteps = scheduler.timesteps
76
+ num_inference_steps = len(timesteps)
77
+ elif sigmas is not None:
78
+ accept_sigmas = "sigmas" in set(
79
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
80
+ )
81
+ if not accept_sigmas:
82
+ raise ValueError(
83
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
84
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
85
+ )
86
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
87
+ timesteps = scheduler.timesteps
88
+ num_inference_steps = len(timesteps)
89
+ else:
90
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
91
+ timesteps = scheduler.timesteps
92
+ return timesteps, num_inference_steps
93
+
94
+
95
+ class TripoSGPipeline(DiffusionPipeline, TransformerDiffusionMixin):
96
+ """
97
+ Pipeline for image-to-3D generation.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ vae: TripoSGVAEModel,
103
+ transformer: TripoSGDiTModel,
104
+ scheduler: FlowMatchEulerDiscreteScheduler,
105
+ image_encoder_dinov2: Dinov2Model,
106
+ feature_extractor_dinov2: BitImageProcessor,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.register_modules(
111
+ vae=vae,
112
+ transformer=transformer,
113
+ scheduler=scheduler,
114
+ image_encoder_dinov2=image_encoder_dinov2,
115
+ feature_extractor_dinov2=feature_extractor_dinov2,
116
+ )
117
+
118
+ @property
119
+ def guidance_scale(self):
120
+ return self._guidance_scale
121
+
122
+ @property
123
+ def do_classifier_free_guidance(self):
124
+ return self._guidance_scale > 1
125
+
126
+ @property
127
+ def num_timesteps(self):
128
+ return self._num_timesteps
129
+
130
+ @property
131
+ def attention_kwargs(self):
132
+ return self._attention_kwargs
133
+
134
+ @property
135
+ def decode_progressive(self):
136
+ return self._decode_progressive
137
+
138
+ def encode_image(self, image, device, num_images_per_prompt):
139
+ dtype = next(self.image_encoder_dinov2.parameters()).dtype
140
+
141
+ if not isinstance(image, torch.Tensor):
142
+ image = self.feature_extractor_dinov2(image, return_tensors="pt").pixel_values
143
+
144
+ image = image.to(device=device, dtype=dtype)
145
+ image_embeds = self.image_encoder_dinov2(image).last_hidden_state
146
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
147
+ uncond_image_embeds = torch.zeros_like(image_embeds)
148
+
149
+ return image_embeds, uncond_image_embeds
150
+
151
+ def prepare_latents(
152
+ self,
153
+ batch_size,
154
+ num_tokens,
155
+ num_channels_latents,
156
+ dtype,
157
+ device,
158
+ generator,
159
+ latents: Optional[torch.Tensor] = None,
160
+ ):
161
+ if latents is not None:
162
+ return latents.to(device=device, dtype=dtype)
163
+
164
+ shape = (batch_size, num_tokens, num_channels_latents)
165
+
166
+ if isinstance(generator, list) and len(generator) != batch_size:
167
+ raise ValueError(
168
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
169
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
170
+ )
171
+
172
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
173
+
174
+ return latents
175
+
176
+
177
+ @torch.no_grad()
178
+ def __call__(
179
+ self,
180
+ image: PipelineImageInput,
181
+ num_inference_steps: int = 50,
182
+ num_tokens: int = 2048,
183
+ timesteps: List[int] = None,
184
+ guidance_scale: float = 7.0,
185
+ num_shapes_per_prompt: int = 1,
186
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
187
+ latents: Optional[torch.FloatTensor] = None,
188
+ attention_kwargs: Optional[Dict[str, Any]] = None,
189
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
190
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
191
+ bounds: Union[Tuple[float], List[float], float] = (-1.005, -1.005, -1.005, 1.005, 1.005, 1.005),
192
+ dense_octree_depth: int = 8,
193
+ hierarchical_octree_depth: int = 9,
194
+ flash_octree_depth: int = 9,
195
+ use_flash_decoder: bool = True,
196
+ return_dict: bool = True,
197
+ ):
198
+ # 1. Define call parameters
199
+ self._guidance_scale = guidance_scale
200
+ self._attention_kwargs = attention_kwargs
201
+
202
+ # 2. Define call parameters
203
+ if isinstance(image, PIL.Image.Image):
204
+ batch_size = 1
205
+ elif isinstance(image, list):
206
+ batch_size = len(image)
207
+ elif isinstance(image, torch.Tensor):
208
+ batch_size = image.shape[0]
209
+ else:
210
+ raise ValueError("Invalid input type for image")
211
+
212
+ device = self._execution_device
213
+
214
+ # 3. Encode condition
215
+ image_embeds, negative_image_embeds = self.encode_image(
216
+ image, device, num_shapes_per_prompt
217
+ )
218
+
219
+ if self.do_classifier_free_guidance:
220
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
221
+
222
+ # 4. Prepare timesteps
223
+ timesteps, num_inference_steps = retrieve_timesteps(
224
+ self.scheduler, num_inference_steps, device, timesteps
225
+ )
226
+ num_warmup_steps = max(
227
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
228
+ )
229
+ self._num_timesteps = len(timesteps)
230
+
231
+ # 5. Prepare latent variables
232
+ num_channels_latents = self.transformer.config.in_channels
233
+ latents = self.prepare_latents(
234
+ batch_size * num_shapes_per_prompt,
235
+ num_tokens,
236
+ num_channels_latents,
237
+ image_embeds.dtype,
238
+ device,
239
+ generator,
240
+ latents,
241
+ )
242
+
243
+ # 6. Denoising loop
244
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
245
+ for i, t in enumerate(timesteps):
246
+
247
+ # expand the latents if we are doing classifier free guidance
248
+ latent_model_input = (
249
+ torch.cat([latents] * 2)
250
+ if self.do_classifier_free_guidance
251
+ else latents
252
+ )
253
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
254
+ timestep = t.expand(latent_model_input.shape[0])
255
+
256
+ noise_pred = self.transformer(
257
+ latent_model_input,
258
+ timestep,
259
+ encoder_hidden_states=image_embeds,
260
+ attention_kwargs=attention_kwargs,
261
+ return_dict=False,
262
+ )[0]
263
+
264
+ # perform guidance
265
+ if self.do_classifier_free_guidance:
266
+ noise_pred_uncond, noise_pred_image = noise_pred.chunk(2)
267
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
268
+ noise_pred_image - noise_pred_uncond
269
+ )
270
+
271
+ # compute the previous noisy sample x_t -> x_t-1
272
+ latents_dtype = latents.dtype
273
+ latents = self.scheduler.step(
274
+ noise_pred, t, latents, return_dict=False
275
+ )[0]
276
+
277
+ if latents.dtype != latents_dtype:
278
+ if torch.backends.mps.is_available():
279
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
280
+ latents = latents.to(latents_dtype)
281
+
282
+ if callback_on_step_end is not None:
283
+ callback_kwargs = {}
284
+ for k in callback_on_step_end_tensor_inputs:
285
+ callback_kwargs[k] = locals()[k]
286
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
287
+
288
+ latents = callback_outputs.pop("latents", latents)
289
+
290
+ # call the callback, if provided
291
+ if i == len(timesteps) - 1 or (
292
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
293
+ ):
294
+ progress_bar.update()
295
+
296
+
297
+ # 7. decoder mesh
298
+ if not use_flash_decoder:
299
+ geometric_func = lambda x: self.vae.decode(latents, sampled_points=x).sample
300
+ output = hierarchical_extract_geometry(
301
+ geometric_func,
302
+ device,
303
+ bounds=bounds,
304
+ dense_octree_depth=dense_octree_depth,
305
+ hierarchical_octree_depth=hierarchical_octree_depth,
306
+ )
307
+ else:
308
+ self.vae.set_flash_decoder()
309
+ output = flash_extract_geometry(
310
+ latents,
311
+ self.vae,
312
+ bounds=bounds,
313
+ octree_depth=flash_octree_depth,
314
+ )
315
+ meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output]
316
+
317
+ # Offload all models
318
+ self.maybe_free_model_hooks()
319
+
320
+ if not return_dict:
321
+ return (output, meshes)
322
+
323
+ return TripoSGPipelineOutput(samples=output, meshes=meshes)
324
+
triposg/pipelines/pipeline_triposg_output.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import trimesh
7
+ from diffusers.utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class TripoSGPipelineOutput(BaseOutput):
12
+ r"""
13
+ Output class for ShapeDiff pipelines.
14
+ """
15
+
16
+ samples: torch.Tensor
17
+ meshes: List[trimesh.Trimesh]
triposg/pipelines/pipeline_triposg_scribble.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import PIL.Image
8
+ import torch
9
+ import trimesh
10
+ from diffusers.image_processor import PipelineImageInput
11
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
12
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler # not sure
13
+ from diffusers.utils import logging
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from transformers import (
16
+ BitImageProcessor,
17
+ CLIPTokenizer,
18
+ CLIPTextModelWithProjection,
19
+ Dinov2Model,
20
+ )
21
+ from ..inference_utils import hierarchical_extract_geometry, flash_extract_geometry
22
+
23
+ from ..models.autoencoders import TripoSGVAEModel
24
+ from ..models.transformers import TripoSGDiTModel
25
+ from .pipeline_triposg_output import TripoSGPipelineOutput
26
+ from .pipeline_utils import TransformerDiffusionMixin
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
32
+ def retrieve_timesteps(
33
+ scheduler,
34
+ num_inference_steps: Optional[int] = None,
35
+ device: Optional[Union[str, torch.device]] = None,
36
+ timesteps: Optional[List[int]] = None,
37
+ sigmas: Optional[List[float]] = None,
38
+ **kwargs,
39
+ ):
40
+ """
41
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
42
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
43
+
44
+ Args:
45
+ scheduler (`SchedulerMixin`):
46
+ The scheduler to get timesteps from.
47
+ num_inference_steps (`int`):
48
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
49
+ must be `None`.
50
+ device (`str` or `torch.device`, *optional*):
51
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
52
+ timesteps (`List[int]`, *optional*):
53
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
54
+ `num_inference_steps` and `sigmas` must be `None`.
55
+ sigmas (`List[float]`, *optional*):
56
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
57
+ `num_inference_steps` and `timesteps` must be `None`.
58
+
59
+ Returns:
60
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
61
+ second element is the number of inference steps.
62
+ """
63
+ if timesteps is not None and sigmas is not None:
64
+ raise ValueError(
65
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
66
+ )
67
+ if timesteps is not None:
68
+ accepts_timesteps = "timesteps" in set(
69
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
70
+ )
71
+ if not accepts_timesteps:
72
+ raise ValueError(
73
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
74
+ f" timestep schedules. Please check whether you are using the correct scheduler."
75
+ )
76
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
77
+ timesteps = scheduler.timesteps
78
+ num_inference_steps = len(timesteps)
79
+ elif sigmas is not None:
80
+ accept_sigmas = "sigmas" in set(
81
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
82
+ )
83
+ if not accept_sigmas:
84
+ raise ValueError(
85
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
86
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
87
+ )
88
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
89
+ timesteps = scheduler.timesteps
90
+ num_inference_steps = len(timesteps)
91
+ else:
92
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
93
+ timesteps = scheduler.timesteps
94
+ return timesteps, num_inference_steps
95
+
96
+
97
+ class TripoSGScribblePipeline(DiffusionPipeline, TransformerDiffusionMixin):
98
+ """
99
+ Pipeline for (scribble and text)-to-3D generation.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ vae: TripoSGVAEModel,
105
+ transformer: TripoSGDiTModel,
106
+ scheduler: FlowMatchEulerDiscreteScheduler,
107
+ tokenizer: CLIPTokenizer,
108
+ text_encoder: CLIPTextModelWithProjection,
109
+ image_encoder_dinov2: Dinov2Model,
110
+ feature_extractor_dinov2: BitImageProcessor,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.register_modules(
115
+ vae=vae,
116
+ transformer=transformer,
117
+ scheduler=scheduler,
118
+ tokenizer=tokenizer,
119
+ text_encoder=text_encoder,
120
+ image_encoder_dinov2=image_encoder_dinov2,
121
+ feature_extractor_dinov2=feature_extractor_dinov2
122
+ )
123
+
124
+ @property
125
+ def guidance_scale(self):
126
+ return self._guidance_scale
127
+
128
+ @property
129
+ def do_classifier_free_guidance(self):
130
+ return self._guidance_scale > 1
131
+
132
+ @property
133
+ def num_timesteps(self):
134
+ return self._num_timesteps
135
+
136
+ @property
137
+ def attention_kwargs(self):
138
+ return self._attention_kwargs
139
+
140
+ def encode_text(self, prompt, device, num_shapes_per_prompt):
141
+ input_ids = self.tokenizer(
142
+ [prompt], max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
143
+ )['input_ids'].to(device)
144
+ text_embeds = self.text_encoder(input_ids).last_hidden_state
145
+ text_embeds = text_embeds.repeat_interleave(num_shapes_per_prompt, dim=0)
146
+ uncond_text_embeds = torch.zeros_like(text_embeds)
147
+ return text_embeds, uncond_text_embeds
148
+
149
+ def encode_image(self, image, device, num_shapes_per_prompt):
150
+ dtype = next(self.image_encoder_dinov2.parameters()).dtype
151
+
152
+ if not isinstance(image, torch.Tensor):
153
+ image = self.feature_extractor_dinov2(image, return_tensors="pt").pixel_values
154
+
155
+ image = image.to(device=device, dtype=dtype)
156
+ image_embeds = self.image_encoder_dinov2(image).last_hidden_state
157
+ image_embeds = image_embeds.repeat_interleave(num_shapes_per_prompt, dim=0)
158
+ uncond_image_embeds = torch.zeros_like(image_embeds)
159
+
160
+ return image_embeds, uncond_image_embeds
161
+
162
+ def prepare_latents(
163
+ self,
164
+ batch_size,
165
+ num_tokens,
166
+ num_channels_latents,
167
+ dtype,
168
+ device,
169
+ generator,
170
+ latents: Optional[torch.Tensor] = None,
171
+ ):
172
+ if latents is not None:
173
+ return latents.to(device=device, dtype=dtype)
174
+
175
+ shape = (batch_size, num_tokens, num_channels_latents)
176
+
177
+ if isinstance(generator, list) and len(generator) != batch_size:
178
+ raise ValueError(
179
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
180
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
181
+ )
182
+
183
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
184
+
185
+ return latents
186
+
187
+ @torch.no_grad()
188
+ def __call__(
189
+ self,
190
+ image: PipelineImageInput,
191
+ prompt: str,
192
+ num_tokens: int = 512,
193
+ num_inference_steps: int = 50,
194
+ timesteps: List[int] = None,
195
+ guidance_scale: float = 7.0,
196
+ num_shapes_per_prompt: int = 1,
197
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
198
+ latents: Optional[torch.FloatTensor] = None,
199
+ attention_kwargs: Optional[Dict[str, Any]] = None,
200
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
201
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
202
+ bounds: Union[Tuple[float], List[float], float] = (-1.005, -1.005, -1.005, 1.005, 1.005, 1.005),
203
+ dense_octree_depth: int = 8,
204
+ hierarchical_octree_depth: int = 9,
205
+ flash_octree_depth: int = 9,
206
+ use_flash_decoder: bool = True,
207
+ return_dict: bool = True,
208
+ ):
209
+ self._guidance_scale = guidance_scale
210
+ self._attention_kwargs = attention_kwargs
211
+
212
+ # 2. Define call parameters
213
+ if isinstance(image, PIL.Image.Image):
214
+ batch_size = 1
215
+ elif isinstance(image, list):
216
+ batch_size = len(image)
217
+ elif isinstance(image, torch.Tensor):
218
+ batch_size = image.shape[0]
219
+ else:
220
+ raise ValueError("Invalid input type for image")
221
+
222
+ device = self._execution_device
223
+
224
+ # 3. Encode condition
225
+ text_embeds, negative_text_embeds = self.encode_text(
226
+ prompt, device, num_shapes_per_prompt
227
+ )
228
+
229
+ image_embeds, negative_image_embeds = self.encode_image(
230
+ image, device, num_shapes_per_prompt
231
+ )
232
+
233
+ if self.do_classifier_free_guidance:
234
+ text_embeds = torch.cat([negative_text_embeds, text_embeds], dim=0)
235
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
236
+
237
+ # 4. Prepare timesteps
238
+ timesteps, num_inference_steps = retrieve_timesteps(
239
+ self.scheduler, num_inference_steps, device, timesteps
240
+ )
241
+ num_warmup_steps = max(
242
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
243
+ )
244
+ self._num_timesteps = len(timesteps)
245
+
246
+ # 5. Prepare latent variables
247
+ num_channels_latents = self.transformer.config.in_channels
248
+ latents = self.prepare_latents(
249
+ batch_size * num_shapes_per_prompt,
250
+ num_tokens,
251
+ num_channels_latents,
252
+ image_embeds.dtype,
253
+ device,
254
+ generator,
255
+ latents,
256
+ )
257
+
258
+ # 6. Denoising loop
259
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
260
+ for i, t in enumerate(timesteps):
261
+
262
+ # expand the latents if we are doing classifier free guidance
263
+ latent_model_input = (
264
+ torch.cat([latents] * 2)
265
+ if self.do_classifier_free_guidance
266
+ else latents
267
+ )
268
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
269
+ timestep = t.expand(latent_model_input.shape[0])
270
+
271
+ noise_pred = self.transformer(
272
+ latent_model_input,
273
+ timestep.to(dtype=latents.dtype),
274
+ encoder_hidden_states=text_embeds,
275
+ encoder_hidden_states_2=image_embeds,
276
+ attention_kwargs=attention_kwargs,
277
+ return_dict=False,
278
+ )[0]
279
+
280
+ # perform guidance
281
+ if self.do_classifier_free_guidance:
282
+ noise_pred_uncond, noise_pred_image = noise_pred.chunk(2)
283
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
284
+ noise_pred_image - noise_pred_uncond
285
+ )
286
+
287
+ # compute the previous noisy sample x_t -> x_t-1
288
+ latents_dtype = latents.dtype
289
+ latents = self.scheduler.step(
290
+ noise_pred, t, latents, return_dict=False
291
+ )[0]
292
+
293
+ if latents.dtype != latents_dtype:
294
+ if torch.backends.mps.is_available():
295
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
296
+ latents = latents.to(latents_dtype)
297
+
298
+ if callback_on_step_end is not None:
299
+ callback_kwargs = {}
300
+ for k in callback_on_step_end_tensor_inputs:
301
+ callback_kwargs[k] = locals()[k]
302
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
303
+
304
+ latents = callback_outputs.pop("latents", latents)
305
+
306
+ # call the callback, if provided
307
+ if i == len(timesteps) - 1 or (
308
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
309
+ ):
310
+ progress_bar.update()
311
+
312
+ # 7. decoder mesh
313
+ if not use_flash_decoder:
314
+ geometric_func = lambda x: self.vae.decode(latents, sampled_points=x).sample
315
+ output = hierarchical_extract_geometry(
316
+ geometric_func,
317
+ device,
318
+ bounds=bounds,
319
+ dense_octree_depth=dense_octree_depth,
320
+ hierarchical_octree_depth=hierarchical_octree_depth,
321
+ )
322
+ else:
323
+ self.vae.set_flash_decoder()
324
+ output = flash_extract_geometry(
325
+ latents,
326
+ self.vae,
327
+ bounds=bounds,
328
+ octree_depth=flash_octree_depth,
329
+ )
330
+ meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output]
331
+
332
+ # Offload all models
333
+ self.maybe_free_model_hooks()
334
+
335
+ if not return_dict:
336
+ return (output, meshes)
337
+
338
+ return TripoSGPipelineOutput(samples=output, meshes=meshes)
triposg/pipelines/pipeline_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.utils import logging
2
+
3
+ logger = logging.get_logger(__name__)
4
+
5
+
6
+ class TransformerDiffusionMixin:
7
+ r"""
8
+ Helper for DiffusionPipeline with vae and transformer.(mainly for DIT)
9
+ """
10
+
11
+ def enable_vae_slicing(self):
12
+ r"""
13
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
14
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
15
+ """
16
+ self.vae.enable_slicing()
17
+
18
+ def disable_vae_slicing(self):
19
+ r"""
20
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
21
+ computing decoding in one step.
22
+ """
23
+ self.vae.disable_slicing()
24
+
25
+ def enable_vae_tiling(self):
26
+ r"""
27
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
28
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
29
+ processing larger images.
30
+ """
31
+ self.vae.enable_tiling()
32
+
33
+ def disable_vae_tiling(self):
34
+ r"""
35
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
36
+ computing decoding in one step.
37
+ """
38
+ self.vae.disable_tiling()
39
+
40
+ def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
41
+ """
42
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
43
+ are fused. For cross-attention modules, key and value projection matrices are fused.
44
+
45
+ <Tip warning={true}>
46
+
47
+ This API is 🧪 experimental.
48
+
49
+ </Tip>
50
+
51
+ Args:
52
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
53
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
54
+ """
55
+ self.fusing_transformer = False
56
+ self.fusing_vae = False
57
+
58
+ if transformer:
59
+ self.fusing_transformer = True
60
+ self.transformer.fuse_qkv_projections()
61
+
62
+ if vae:
63
+ self.fusing_vae = True
64
+ self.vae.fuse_qkv_projections()
65
+
66
+ def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
67
+ """Disable QKV projection fusion if enabled.
68
+
69
+ <Tip warning={true}>
70
+
71
+ This API is 🧪 experimental.
72
+
73
+ </Tip>
74
+
75
+ Args:
76
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
77
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
78
+
79
+ """
80
+ if transformer:
81
+ if not self.fusing_transformer:
82
+ logger.warning(
83
+ "The UNet was not initially fused for QKV projections. Doing nothing."
84
+ )
85
+ else:
86
+ self.transformer.unfuse_qkv_projections()
87
+ self.fusing_transformer = False
88
+
89
+ if vae:
90
+ if not self.fusing_vae:
91
+ logger.warning(
92
+ "The VAE was not initially fused for QKV projections. Doing nothing."
93
+ )
94
+ else:
95
+ self.vae.unfuse_qkv_projections()
96
+ self.fusing_vae = False
triposg/schedulers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .scheduling_rectified_flow import (
2
+ RectifiedFlowScheduler,
3
+ compute_density_for_timestep_sampling,
4
+ compute_loss_weighting,
5
+ )
triposg/schedulers/scheduling_rectified_flow.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py.
3
+ """
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
13
+ from diffusers.utils import BaseOutput, logging
14
+ from torch.distributions import LogisticNormal
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ # TODO: may move to training_utils.py
20
+ def compute_density_for_timestep_sampling(
21
+ weighting_scheme: str,
22
+ batch_size: int,
23
+ logit_mean: float = 0.0,
24
+ logit_std: float = 1.0,
25
+ mode_scale: float = None,
26
+ ):
27
+ if weighting_scheme == "logit_normal":
28
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
29
+ u = torch.normal(
30
+ mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
31
+ )
32
+ u = torch.nn.functional.sigmoid(u)
33
+ elif weighting_scheme == "logit_normal_dist":
34
+ u = (
35
+ LogisticNormal(loc=logit_mean, scale=logit_std)
36
+ .sample((batch_size,))[:, 0]
37
+ .to("cpu")
38
+ )
39
+ elif weighting_scheme == "mode":
40
+ u = torch.rand(size=(batch_size,), device="cpu")
41
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
42
+ else:
43
+ u = torch.rand(size=(batch_size,), device="cpu")
44
+ return u
45
+
46
+
47
+ def compute_loss_weighting(weighting_scheme: str, sigmas=None):
48
+ """
49
+ Computes loss weighting scheme for SD3 training.
50
+
51
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
52
+
53
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
54
+ """
55
+ if weighting_scheme == "sigma_sqrt":
56
+ weighting = (sigmas**-2.0).float()
57
+ elif weighting_scheme == "cosmap":
58
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
59
+ weighting = 2 / (math.pi * bot)
60
+ else:
61
+ weighting = torch.ones_like(sigmas)
62
+ return weighting
63
+
64
+
65
+ @dataclass
66
+ class RectifiedFlowSchedulerOutput(BaseOutput):
67
+ """
68
+ Output class for the scheduler's `step` function output.
69
+
70
+ Args:
71
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
72
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
73
+ denoising loop.
74
+ """
75
+
76
+ prev_sample: torch.FloatTensor
77
+
78
+
79
+ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin):
80
+ """
81
+ The rectified flow scheduler is a scheduler that is used to propagate the diffusion process in the rectified flow.
82
+
83
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
84
+ methods the library implements for all schedulers such as loading and saving.
85
+
86
+ Args:
87
+ num_train_timesteps (`int`, defaults to 1000):
88
+ The number of diffusion steps to train the model.
89
+ timestep_spacing (`str`, defaults to `"linspace"`):
90
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
91
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
92
+ shift (`float`, defaults to 1.0):
93
+ The shift value for the timestep schedule.
94
+ """
95
+
96
+ _compatibles = []
97
+ order = 1
98
+
99
+ @register_to_config
100
+ def __init__(
101
+ self,
102
+ num_train_timesteps: int = 1000,
103
+ shift: float = 1.0,
104
+ use_dynamic_shifting: bool = False,
105
+ ):
106
+ # pre-compute timesteps and sigmas; no use in fact
107
+ # NOTE that shape diffusion sample timesteps randomly or in a distribution,
108
+ # instead of sampling from the pre-defined linspace
109
+ timesteps = np.array(
110
+ [
111
+ (1.0 - i / num_train_timesteps) * num_train_timesteps
112
+ for i in range(num_train_timesteps)
113
+ ]
114
+ )
115
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
116
+
117
+ sigmas = timesteps / num_train_timesteps
118
+ if not use_dynamic_shifting:
119
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
120
+ sigmas = self.time_shift(sigmas)
121
+
122
+ self.timesteps = sigmas * num_train_timesteps
123
+
124
+ self._step_index = None
125
+ self._begin_index = None
126
+
127
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
128
+
129
+ @property
130
+ def step_index(self):
131
+ """
132
+ The index counter for current timestep. It will increase 1 after each scheduler step.
133
+ """
134
+ return self._step_index
135
+
136
+ @property
137
+ def begin_index(self):
138
+ """
139
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
140
+ """
141
+ return self._begin_index
142
+
143
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
144
+ def set_begin_index(self, begin_index: int = 0):
145
+ """
146
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
147
+
148
+ Args:
149
+ begin_index (`int`):
150
+ The begin index for the scheduler.
151
+ """
152
+ self._begin_index = begin_index
153
+
154
+ def _sigma_to_t(self, sigma):
155
+ return sigma * self.config.num_train_timesteps
156
+
157
+ def _t_to_sigma(self, timestep):
158
+ return timestep / self.config.num_train_timesteps
159
+
160
+ def time_shift_dynamic(self, mu: float, sigma: float, t: torch.Tensor):
161
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
162
+
163
+ def time_shift(self, t: torch.Tensor):
164
+ return self.config.shift * t / (1 + (self.config.shift - 1) * t)
165
+
166
+ def set_timesteps(
167
+ self,
168
+ num_inference_steps: int = None,
169
+ device: Union[str, torch.device] = None,
170
+ sigmas: Optional[List[float]] = None,
171
+ mu: Optional[float] = None,
172
+ ):
173
+ """
174
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
175
+
176
+ Args:
177
+ num_inference_steps (`int`):
178
+ The number of diffusion steps used when generating samples with a pre-trained model.
179
+ device (`str` or `torch.device`, *optional*):
180
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
181
+ """
182
+
183
+ if self.config.use_dynamic_shifting and mu is None:
184
+ raise ValueError(
185
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
186
+ )
187
+
188
+ if sigmas is None:
189
+ self.num_inference_steps = num_inference_steps
190
+ timesteps = np.array(
191
+ [
192
+ (1.0 - i / num_inference_steps) * self.config.num_train_timesteps
193
+ for i in range(num_inference_steps)
194
+ ]
195
+ ) # different from the original code in SD3
196
+ sigmas = timesteps / self.config.num_train_timesteps
197
+
198
+ if self.config.use_dynamic_shifting:
199
+ sigmas = self.time_shift_dynamic(mu, 1.0, sigmas)
200
+ else:
201
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
202
+
203
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
204
+ timesteps = sigmas * self.config.num_train_timesteps
205
+
206
+ self.timesteps = timesteps.to(device=device)
207
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
208
+
209
+ self._step_index = None
210
+ self._begin_index = None
211
+
212
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
213
+ if schedule_timesteps is None:
214
+ schedule_timesteps = self.timesteps
215
+
216
+ indices = (schedule_timesteps == timestep).nonzero()
217
+
218
+ # The sigma index that is taken for the **very** first `step`
219
+ # is always the second index (or the last index if there is only 1)
220
+ # This way we can ensure we don't accidentally skip a sigma in
221
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
222
+ pos = 1 if len(indices) > 1 else 0
223
+
224
+ return indices[pos].item()
225
+
226
+ def _init_step_index(self, timestep):
227
+ if self.begin_index is None:
228
+ if isinstance(timestep, torch.Tensor):
229
+ timestep = timestep.to(self.timesteps.device)
230
+ self._step_index = self.index_for_timestep(timestep)
231
+ else:
232
+ self._step_index = self._begin_index
233
+
234
+ def step(
235
+ self,
236
+ model_output: torch.FloatTensor,
237
+ timestep: Union[float, torch.FloatTensor],
238
+ sample: torch.FloatTensor,
239
+ s_churn: float = 0.0,
240
+ s_tmin: float = 0.0,
241
+ s_tmax: float = float("inf"),
242
+ s_noise: float = 1.0,
243
+ generator: Optional[torch.Generator] = None,
244
+ return_dict: bool = True,
245
+ ) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
246
+ """
247
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
248
+ process from the learned model outputs (most often the predicted noise).
249
+
250
+ Args:
251
+ model_output (`torch.FloatTensor`):
252
+ The direct output from learned diffusion model.
253
+ timestep (`float`):
254
+ The current discrete timestep in the diffusion chain.
255
+ sample (`torch.FloatTensor`):
256
+ A current instance of a sample created by the diffusion process.
257
+ s_churn (`float`):
258
+ s_tmin (`float`):
259
+ s_tmax (`float`):
260
+ s_noise (`float`, defaults to 1.0):
261
+ Scaling factor for noise added to the sample.
262
+ generator (`torch.Generator`, *optional*):
263
+ A random number generator.
264
+ return_dict (`bool`):
265
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
266
+ tuple.
267
+
268
+ Returns:
269
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
270
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
271
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
272
+ """
273
+
274
+ if (
275
+ isinstance(timestep, int)
276
+ or isinstance(timestep, torch.IntTensor)
277
+ or isinstance(timestep, torch.LongTensor)
278
+ ):
279
+ raise ValueError(
280
+ (
281
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
282
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
283
+ " one of the `scheduler.timesteps` as a timestep."
284
+ ),
285
+ )
286
+
287
+ if self.step_index is None:
288
+ self._init_step_index(timestep)
289
+
290
+ # Upcast to avoid precision issues when computing prev_sample
291
+ sample = sample.to(torch.float32)
292
+
293
+ sigma = self.sigmas[self.step_index]
294
+ sigma_next = self.sigmas[self.step_index + 1]
295
+
296
+ # Here different directions are used for the flow matching
297
+ prev_sample = sample + (sigma - sigma_next) * model_output
298
+
299
+ # Cast sample back to model compatible dtype
300
+ prev_sample = prev_sample.to(model_output.dtype)
301
+
302
+ # upon completion increase step index by one
303
+ self._step_index += 1
304
+
305
+ if not return_dict:
306
+ return (prev_sample,)
307
+
308
+ return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
309
+
310
+ def scale_noise(
311
+ self,
312
+ original_samples: torch.Tensor,
313
+ noise: torch.Tensor,
314
+ timesteps: torch.IntTensor,
315
+ ) -> torch.Tensor:
316
+ """
317
+ Forward function for the noise scaling in the flow matching.
318
+ """
319
+ sigmas = self._t_to_sigma(timesteps.to(dtype=torch.float32))
320
+
321
+ while len(sigmas.shape) < len(original_samples.shape):
322
+ sigmas = sigmas.unsqueeze(-1)
323
+
324
+ return (1.0 - sigmas) * original_samples + sigmas * noise
325
+
326
+ def __len__(self):
327
+ return self.config.num_train_timesteps
triposg/utils/__pycache__/typing.cpython-310.pyc ADDED
Binary file (2.04 kB). View file