Spaces:
mashroo
/
Running on Zero

Zhengyi commited on
Commit
f4e8cf6
·
1 Parent(s): cdbec36
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +155 -0
  2. app.py +205 -0
  3. configs/nf7_v3_SNR_rd_size_stroke.yaml +21 -0
  4. configs/specs_objaverse_total.json +57 -0
  5. configs/stage2-v2-snr.yaml +25 -0
  6. imagedream/__init__.py +1 -0
  7. imagedream/camera_utils.py +99 -0
  8. imagedream/configs/sd_v2_base_ipmv.yaml +61 -0
  9. imagedream/configs/sd_v2_base_ipmv_ch8.yaml +61 -0
  10. imagedream/configs/sd_v2_base_ipmv_chin8.yaml +61 -0
  11. imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml +62 -0
  12. imagedream/configs/sd_v2_base_ipmv_local.yaml +62 -0
  13. imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml +62 -0
  14. imagedream/ldm/__init__.py +0 -0
  15. imagedream/ldm/interface.py +205 -0
  16. imagedream/ldm/models/__init__.py +0 -0
  17. imagedream/ldm/models/autoencoder.py +270 -0
  18. imagedream/ldm/models/diffusion/__init__.py +0 -0
  19. imagedream/ldm/models/diffusion/ddim.py +430 -0
  20. imagedream/ldm/modules/__init__.py +0 -0
  21. imagedream/ldm/modules/attention.py +456 -0
  22. imagedream/ldm/modules/diffusionmodules/__init__.py +0 -0
  23. imagedream/ldm/modules/diffusionmodules/adaptors.py +163 -0
  24. imagedream/ldm/modules/diffusionmodules/model.py +1018 -0
  25. imagedream/ldm/modules/diffusionmodules/openaimodel.py +1135 -0
  26. imagedream/ldm/modules/diffusionmodules/util.py +353 -0
  27. imagedream/ldm/modules/distributions/__init__.py +0 -0
  28. imagedream/ldm/modules/distributions/distributions.py +102 -0
  29. imagedream/ldm/modules/ema.py +86 -0
  30. imagedream/ldm/modules/encoders/__init__.py +0 -0
  31. imagedream/ldm/modules/encoders/modules.py +329 -0
  32. imagedream/ldm/util.py +226 -0
  33. imagedream/model_zoo.py +64 -0
  34. inference.py +91 -0
  35. libs/base_utils.py +84 -0
  36. libs/sample.py +380 -0
  37. mesh.py +845 -0
  38. model/__init__.py +1 -0
  39. model/archs/__init__.py +0 -0
  40. model/archs/decoders/__init__.py +1 -0
  41. model/archs/decoders/shape_texture_net.py +62 -0
  42. model/archs/mlp_head.py +40 -0
  43. model/archs/unet.py +53 -0
  44. model/crm/model.py +213 -0
  45. pipelines.py +170 -0
  46. requirements.txt +14 -0
  47. util/__init__.py +0 -0
  48. util/flexicubes.py +579 -0
  49. util/flexicubes_geometry.py +116 -0
  50. util/renderer.py +49 -0
.gitignore ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Not ready to use yet
2
+ import argparse
3
+ import numpy as np
4
+ import gradio as gr
5
+ from omegaconf import OmegaConf
6
+ import torch
7
+ from PIL import Image
8
+ import PIL
9
+ from pipelines import TwoStagePipeline
10
+ from huggingface_hub import hf_hub_download
11
+ import os
12
+ import rembg
13
+ from typing import Any
14
+ import json
15
+ import os
16
+ import json
17
+ import argparse
18
+
19
+ from model import CRM
20
+ from inference import generate3d
21
+
22
+ pipeline = None
23
+ rembg_session = rembg.new_session()
24
+
25
+
26
+ def check_input_image(input_image):
27
+ if input_image is None:
28
+ raise gr.Error("No image uploaded!")
29
+
30
+
31
+ def remove_background(
32
+ image: PIL.Image.Image,
33
+ rembg_session: Any = None,
34
+ force: bool = False,
35
+ **rembg_kwargs,
36
+ ) -> PIL.Image.Image:
37
+ do_remove = True
38
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
39
+ # explain why current do not rm bg
40
+ print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
41
+ background = Image.new("RGBA", image.size, (0, 0, 0, 0))
42
+ image = Image.alpha_composite(background, image)
43
+ do_remove = False
44
+ do_remove = do_remove or force
45
+ if do_remove:
46
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
47
+ return image
48
+
49
+ def do_resize_content(original_image: Image, scale_rate):
50
+ # resize image content wile retain the original image size
51
+ if scale_rate != 1:
52
+ # Calculate the new size after rescaling
53
+ new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
54
+ # Resize the image while maintaining the aspect ratio
55
+ resized_image = original_image.resize(new_size)
56
+ # Create a new image with the original size and black background
57
+ padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
58
+ paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
59
+ padded_image.paste(resized_image, paste_position)
60
+ return padded_image
61
+ else:
62
+ return original_image
63
+
64
+ def add_background(image, bg_color=(255, 255, 255)):
65
+ # given an RGBA image, alpha channel is used as mask to add background color
66
+ background = Image.new("RGBA", image.size, bg_color)
67
+ return Image.alpha_composite(background, image)
68
+
69
+
70
+ def preprocess_image(input_image, do_remove_background, force_remove, foreground_ratio, backgroud_color):
71
+ """
72
+ input image is a pil image in RGBA, return RGB image
73
+ """
74
+ if do_remove_background:
75
+ image = remove_background(input_image, rembg_session, force_remove)
76
+ image = do_resize_content(image, foreground_ratio)
77
+ image = add_background(image, backgroud_color)
78
+ return image.convert("RGB")
79
+
80
+
81
+ def gen_image(input_image, seed, scale, step):
82
+ global pipeline, model, args
83
+ pipeline.set_seed(seed)
84
+ rt_dict = pipeline(input_image, scale=scale, step=step)
85
+ stage1_images = rt_dict["stage1_images"]
86
+ stage2_images = rt_dict["stage2_images"]
87
+ np_imgs = np.concatenate(stage1_images, 1)
88
+ np_xyzs = np.concatenate(stage2_images, 1)
89
+
90
+ glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, args.device)
91
+ return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path, obj_path
92
+
93
+
94
+ parser = argparse.ArgumentParser()
95
+ parser.add_argument(
96
+ "--stage1_config",
97
+ type=str,
98
+ default="configs/nf7_v3_SNR_rd_size_stroke.yaml",
99
+ help="config for stage1",
100
+ )
101
+ parser.add_argument(
102
+ "--stage2_config",
103
+ type=str,
104
+ default="configs/stage2-v2-snr.yaml",
105
+ help="config for stage2",
106
+ )
107
+
108
+ parser.add_argument("--device", type=str, default="cuda")
109
+ args = parser.parse_args()
110
+
111
+ crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
112
+ specs = json.load(open("configs/specs_objaverse_total.json"))
113
+ model = CRM(specs).to(args.device)
114
+ model.load_state_dict(torch.load(crm_path, map_location = args.device), strict=False)
115
+
116
+ stage1_config = OmegaConf.load(args.stage1_config).config
117
+ stage2_config = OmegaConf.load(args.stage2_config).config
118
+ stage2_sampler_config = stage2_config.sampler
119
+ stage1_sampler_config = stage1_config.sampler
120
+
121
+ stage1_model_config = stage1_config.models
122
+ stage2_model_config = stage2_config.models
123
+
124
+ xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth")
125
+ pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth")
126
+ stage1_model_config.resume = pixel_path
127
+ stage2_model_config.resume = xyz_path
128
+
129
+ pipeline = TwoStagePipeline(
130
+ stage1_model_config,
131
+ stage2_model_config,
132
+ stage1_sampler_config,
133
+ stage2_sampler_config,
134
+ device=args.device,
135
+ dtype=torch.float16
136
+ )
137
+
138
+ with gr.Blocks() as demo:
139
+ gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model")
140
+ with gr.Row():
141
+ with gr.Column():
142
+ with gr.Row():
143
+ image_input = gr.Image(
144
+ label="Image input",
145
+ image_mode="RGBA",
146
+ sources="upload",
147
+ type="pil",
148
+ )
149
+ processed_image = gr.Image(label="Processed Image", interactive=False, type="pil", image_mode="RGB")
150
+ with gr.Row():
151
+ with gr.Column():
152
+ with gr.Row():
153
+ do_remove_background = gr.Checkbox(label="Remove Background", value=True)
154
+ force_remove = gr.Checkbox(label="Force Remove", value=False)
155
+ back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False)
156
+ foreground_ratio = gr.Slider(
157
+ label="Foreground Ratio",
158
+ minimum=0.5,
159
+ maximum=1.0,
160
+ value=1.0,
161
+ step=0.05,
162
+ )
163
+
164
+ with gr.Column():
165
+ seed = gr.Number(value=1234, label="seed", precision=0)
166
+ guidance_scale = gr.Number(value=5.5, minimum=0, maximum=20, label="guidance_scale")
167
+ step = gr.Number(value=50, minimum=1, maximum=100, label="sample steps", precision=0)
168
+ text_button = gr.Button("Generate Images")
169
+ with gr.Column():
170
+ image_output = gr.Image(interactive=False, label="Output RGB image")
171
+ xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
172
+
173
+ output_model = gr.Model3D(
174
+ label="Output GLB",
175
+ interactive=False,
176
+ )
177
+ output_obj = gr.File(interactive=False, label="Output OBJ")
178
+
179
+ inputs = [
180
+ processed_image,
181
+ seed,
182
+ guidance_scale,
183
+ step,
184
+ ]
185
+ outputs = [
186
+ image_output,
187
+ xyz_ouput,
188
+ output_model,
189
+ output_obj,
190
+ ]
191
+ gr.Examples(
192
+ examples=[os.path.join("examples", i) for i in os.listdir("examples")],
193
+ inputs=[image_input],
194
+ )
195
+
196
+ text_button.click(fn=check_input_image, inputs=[image_input]).success(
197
+ fn=preprocess_image,
198
+ inputs=[image_input, do_remove_background, force_remove, foreground_ratio, back_groud_color],
199
+ outputs=[processed_image],
200
+ ).success(
201
+ fn=gen_image,
202
+ inputs=inputs,
203
+ outputs=outputs,
204
+ )
205
+ demo.queue().launch()
configs/nf7_v3_SNR_rd_size_stroke.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config:
2
+ # others
3
+ seed: 1234
4
+ num_frames: 7
5
+ mode: pixel
6
+ offset_noise: true
7
+ # model related
8
+ models:
9
+ config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
10
+ resume: models/pixel.pth
11
+ # sampler related
12
+ sampler:
13
+ target: libs.sample.ImageDreamDiffusion
14
+ params:
15
+ mode: pixel
16
+ num_frames: 7
17
+ camera_views: [1, 2, 3, 4, 5, 0, 0]
18
+ ref_position: 6
19
+ random_background: false
20
+ offset_noise: true
21
+ resize_rate: 1.0
configs/specs_objaverse_total.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Input": {
3
+ "img_num": 16,
4
+ "class": "all",
5
+ "camera_angle_num": 8,
6
+ "tet_grid_size": 80,
7
+ "validate_num": 16,
8
+ "scale": 0.95,
9
+ "radius": 3,
10
+ "resolution": [256, 256]
11
+ },
12
+
13
+ "Pretrain": {
14
+ "mode": null,
15
+ "sdf_threshold": 0.1,
16
+ "sdf_scale": 10,
17
+ "batch_infer": false,
18
+ "lr": 1e-4,
19
+ "radius": 0.5
20
+ },
21
+
22
+ "Train": {
23
+ "mode": "rnd",
24
+ "num_epochs": 500,
25
+ "grad_acc": 1,
26
+ "warm_up": 0,
27
+ "decay": 0.000,
28
+ "learning_rate": {
29
+ "init": 1e-4,
30
+ "sdf_decay": 1,
31
+ "rgb_decay": 1
32
+ },
33
+ "batch_size": 4,
34
+ "eva_iter": 80,
35
+ "eva_all_epoch": 10,
36
+ "tex_sup_mode": "blender",
37
+ "exp_uv_mesh": false,
38
+ "doub": false,
39
+ "random_bg": false,
40
+ "shift": 0,
41
+ "aug_shift": 0,
42
+ "geo_type": "flex"
43
+ },
44
+
45
+ "ArchSpecs": {
46
+ "unet_type": "diffusers",
47
+ "use_3D_aware": false,
48
+ "fea_concat": false,
49
+ "mlp_bias": true
50
+ },
51
+
52
+ "DecoderSpecs": {
53
+ "c_dim": 32,
54
+ "plane_resolution": 256
55
+ }
56
+ }
57
+
configs/stage2-v2-snr.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config:
2
+ # others
3
+ seed: 1234
4
+ num_frames: 6
5
+ mode: pixel
6
+ offset_noise: true
7
+ gd_type: xyz
8
+ # model related
9
+ models:
10
+ config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
11
+ resume: models/xyz.pth
12
+
13
+ # eval related
14
+ sampler:
15
+ target: libs.sample.ImageDreamDiffusionStage2
16
+ params:
17
+ mode: pixel
18
+ num_frames: 6
19
+ camera_views: [1, 2, 3, 4, 5, 0]
20
+ ref_position: null
21
+ random_background: false
22
+ offset_noise: true
23
+ resize_rate: 1.0
24
+
25
+
imagedream/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_zoo import build_model
imagedream/camera_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def create_camera_to_world_matrix(elevation, azimuth):
6
+ elevation = np.radians(elevation)
7
+ azimuth = np.radians(azimuth)
8
+ # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
9
+ x = np.cos(elevation) * np.sin(azimuth)
10
+ y = np.sin(elevation)
11
+ z = np.cos(elevation) * np.cos(azimuth)
12
+
13
+ # Calculate camera position, target, and up vectors
14
+ camera_pos = np.array([x, y, z])
15
+ target = np.array([0, 0, 0])
16
+ up = np.array([0, 1, 0])
17
+
18
+ # Construct view matrix
19
+ forward = target - camera_pos
20
+ forward /= np.linalg.norm(forward)
21
+ right = np.cross(forward, up)
22
+ right /= np.linalg.norm(right)
23
+ new_up = np.cross(right, forward)
24
+ new_up /= np.linalg.norm(new_up)
25
+ cam2world = np.eye(4)
26
+ cam2world[:3, :3] = np.array([right, new_up, -forward]).T
27
+ cam2world[:3, 3] = camera_pos
28
+ return cam2world
29
+
30
+
31
+ def convert_opengl_to_blender(camera_matrix):
32
+ if isinstance(camera_matrix, np.ndarray):
33
+ # Construct transformation matrix to convert from OpenGL space to Blender space
34
+ flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
35
+ camera_matrix_blender = np.dot(flip_yz, camera_matrix)
36
+ else:
37
+ # Construct transformation matrix to convert from OpenGL space to Blender space
38
+ flip_yz = torch.tensor(
39
+ [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
40
+ )
41
+ if camera_matrix.ndim == 3:
42
+ flip_yz = flip_yz.unsqueeze(0)
43
+ camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
44
+ return camera_matrix_blender
45
+
46
+
47
+ def normalize_camera(camera_matrix):
48
+ """normalize the camera location onto a unit-sphere"""
49
+ if isinstance(camera_matrix, np.ndarray):
50
+ camera_matrix = camera_matrix.reshape(-1, 4, 4)
51
+ translation = camera_matrix[:, :3, 3]
52
+ translation = translation / (
53
+ np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8
54
+ )
55
+ camera_matrix[:, :3, 3] = translation
56
+ else:
57
+ camera_matrix = camera_matrix.reshape(-1, 4, 4)
58
+ translation = camera_matrix[:, :3, 3]
59
+ translation = translation / (
60
+ torch.norm(translation, dim=1, keepdim=True) + 1e-8
61
+ )
62
+ camera_matrix[:, :3, 3] = translation
63
+ return camera_matrix.reshape(-1, 16)
64
+
65
+
66
+ def get_camera(
67
+ num_frames,
68
+ elevation=15,
69
+ azimuth_start=0,
70
+ azimuth_span=360,
71
+ blender_coord=True,
72
+ extra_view=False,
73
+ ):
74
+ angle_gap = azimuth_span / num_frames
75
+ cameras = []
76
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
77
+ camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
78
+ if blender_coord:
79
+ camera_matrix = convert_opengl_to_blender(camera_matrix)
80
+ cameras.append(camera_matrix.flatten())
81
+
82
+ if extra_view:
83
+ dim = len(cameras[0])
84
+ cameras.append(np.zeros(dim))
85
+ return torch.tensor(np.stack(cameras, 0)).float()
86
+
87
+
88
+ def get_camera_for_index(data_index):
89
+ """
90
+ 按照当前我们的数据格式, 以000为正对我们的情况:
91
+ 000是正面, ev: 0, azimuth: 0
92
+ 001是左边, ev: 0, azimuth: -90
93
+ 002是下面, ev: -90, azimuth: 0
94
+ 003是背面, ev: 0, azimuth: 180
95
+ 004是右边, ev: 0, azimuth: 90
96
+ 005是上面, ev: 90, azimuth: 0
97
+ """
98
+ params = [(0, 0), (0, -90), (-90, 0), (0, 180), (0, 90), (90, 0)]
99
+ return get_camera(1, *params[data_index])
imagedream/configs/sd_v2_base_ipmv.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+
10
+ unet_config:
11
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
12
+ params:
13
+ image_size: 32 # unused
14
+ in_channels: 4
15
+ out_channels: 4
16
+ model_channels: 320
17
+ attention_resolutions: [ 4, 2, 1 ]
18
+ num_res_blocks: 2
19
+ channel_mult: [ 1, 2, 4, 4 ]
20
+ num_head_channels: 64 # need to fix for flash-attn
21
+ use_spatial_transformer: True
22
+ use_linear_in_transformer: True
23
+ transformer_depth: 1
24
+ context_dim: 1024
25
+ use_checkpoint: False
26
+ legacy: False
27
+ camera_dim: 16
28
+ with_ip: True
29
+ ip_dim: 16 # ip token length
30
+ ip_mode: "local_resample"
31
+
32
+ vae_config:
33
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
34
+ params:
35
+ embed_dim: 4
36
+ monitor: val/rec_loss
37
+ ddconfig:
38
+ #attn_type: "vanilla-xformers"
39
+ double_z: true
40
+ z_channels: 4
41
+ resolution: 256
42
+ in_channels: 3
43
+ out_ch: 3
44
+ ch: 128
45
+ ch_mult:
46
+ - 1
47
+ - 2
48
+ - 4
49
+ - 4
50
+ num_res_blocks: 2
51
+ attn_resolutions: []
52
+ dropout: 0.0
53
+ lossconfig:
54
+ target: torch.nn.Identity
55
+
56
+ clip_config:
57
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
58
+ params:
59
+ freeze: True
60
+ layer: "penultimate"
61
+ ip_mode: "local_resample"
imagedream/configs/sd_v2_base_ipmv_ch8.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+
10
+ unet_config:
11
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
12
+ params:
13
+ image_size: 32 # unused
14
+ in_channels: 8
15
+ out_channels: 8
16
+ model_channels: 320
17
+ attention_resolutions: [ 4, 2, 1 ]
18
+ num_res_blocks: 2
19
+ channel_mult: [ 1, 2, 4, 4 ]
20
+ num_head_channels: 64 # need to fix for flash-attn
21
+ use_spatial_transformer: True
22
+ use_linear_in_transformer: True
23
+ transformer_depth: 1
24
+ context_dim: 1024
25
+ use_checkpoint: False
26
+ legacy: False
27
+ camera_dim: 16
28
+ with_ip: True
29
+ ip_dim: 16 # ip token length
30
+ ip_mode: "local_resample"
31
+
32
+ vae_config:
33
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
34
+ params:
35
+ embed_dim: 4
36
+ monitor: val/rec_loss
37
+ ddconfig:
38
+ #attn_type: "vanilla-xformers"
39
+ double_z: true
40
+ z_channels: 4
41
+ resolution: 256
42
+ in_channels: 3
43
+ out_ch: 3
44
+ ch: 128
45
+ ch_mult:
46
+ - 1
47
+ - 2
48
+ - 4
49
+ - 4
50
+ num_res_blocks: 2
51
+ attn_resolutions: []
52
+ dropout: 0.0
53
+ lossconfig:
54
+ target: torch.nn.Identity
55
+
56
+ clip_config:
57
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
58
+ params:
59
+ freeze: True
60
+ layer: "penultimate"
61
+ ip_mode: "local_resample"
imagedream/configs/sd_v2_base_ipmv_chin8.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+
10
+ unet_config:
11
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2
12
+ params:
13
+ image_size: 32 # unused
14
+ in_channels: 8
15
+ out_channels: 4
16
+ model_channels: 320
17
+ attention_resolutions: [ 4, 2, 1 ]
18
+ num_res_blocks: 2
19
+ channel_mult: [ 1, 2, 4, 4 ]
20
+ num_head_channels: 64 # need to fix for flash-attn
21
+ use_spatial_transformer: True
22
+ use_linear_in_transformer: True
23
+ transformer_depth: 1
24
+ context_dim: 1024
25
+ use_checkpoint: False
26
+ legacy: False
27
+ camera_dim: 16
28
+ with_ip: True
29
+ ip_dim: 16 # ip token length
30
+ ip_mode: "local_resample"
31
+
32
+ vae_config:
33
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
34
+ params:
35
+ embed_dim: 4
36
+ monitor: val/rec_loss
37
+ ddconfig:
38
+ #attn_type: "vanilla-xformers"
39
+ double_z: true
40
+ z_channels: 4
41
+ resolution: 256
42
+ in_channels: 3
43
+ out_ch: 3
44
+ ch: 128
45
+ ch_mult:
46
+ - 1
47
+ - 2
48
+ - 4
49
+ - 4
50
+ num_res_blocks: 2
51
+ attn_resolutions: []
52
+ dropout: 0.0
53
+ lossconfig:
54
+ target: torch.nn.Identity
55
+
56
+ clip_config:
57
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
58
+ params:
59
+ freeze: True
60
+ layer: "penultimate"
61
+ ip_mode: "local_resample"
imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+ zero_snr: true
10
+
11
+ unet_config:
12
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2
13
+ params:
14
+ image_size: 32 # unused
15
+ in_channels: 8
16
+ out_channels: 4
17
+ model_channels: 320
18
+ attention_resolutions: [ 4, 2, 1 ]
19
+ num_res_blocks: 2
20
+ channel_mult: [ 1, 2, 4, 4 ]
21
+ num_head_channels: 64 # need to fix for flash-attn
22
+ use_spatial_transformer: True
23
+ use_linear_in_transformer: True
24
+ transformer_depth: 1
25
+ context_dim: 1024
26
+ use_checkpoint: False
27
+ legacy: False
28
+ camera_dim: 16
29
+ with_ip: True
30
+ ip_dim: 16 # ip token length
31
+ ip_mode: "local_resample"
32
+
33
+ vae_config:
34
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
35
+ params:
36
+ embed_dim: 4
37
+ monitor: val/rec_loss
38
+ ddconfig:
39
+ #attn_type: "vanilla-xformers"
40
+ double_z: true
41
+ z_channels: 4
42
+ resolution: 256
43
+ in_channels: 3
44
+ out_ch: 3
45
+ ch: 128
46
+ ch_mult:
47
+ - 1
48
+ - 2
49
+ - 4
50
+ - 4
51
+ num_res_blocks: 2
52
+ attn_resolutions: []
53
+ dropout: 0.0
54
+ lossconfig:
55
+ target: torch.nn.Identity
56
+
57
+ clip_config:
58
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
59
+ params:
60
+ freeze: True
61
+ layer: "penultimate"
62
+ ip_mode: "local_resample"
imagedream/configs/sd_v2_base_ipmv_local.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+
10
+ unet_config:
11
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
12
+ params:
13
+ image_size: 32 # unused
14
+ in_channels: 4
15
+ out_channels: 4
16
+ model_channels: 320
17
+ attention_resolutions: [ 4, 2, 1 ]
18
+ num_res_blocks: 2
19
+ channel_mult: [ 1, 2, 4, 4 ]
20
+ num_head_channels: 64 # need to fix for flash-attn
21
+ use_spatial_transformer: True
22
+ use_linear_in_transformer: True
23
+ transformer_depth: 1
24
+ context_dim: 1024
25
+ use_checkpoint: False
26
+ legacy: False
27
+ camera_dim: 16
28
+ with_ip: True
29
+ ip_dim: 16 # ip token length
30
+ ip_mode: "local_resample"
31
+ ip_weight: 1.0 # adjust for similarity to image
32
+
33
+ vae_config:
34
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
35
+ params:
36
+ embed_dim: 4
37
+ monitor: val/rec_loss
38
+ ddconfig:
39
+ #attn_type: "vanilla-xformers"
40
+ double_z: true
41
+ z_channels: 4
42
+ resolution: 256
43
+ in_channels: 3
44
+ out_ch: 3
45
+ ch: 128
46
+ ch_mult:
47
+ - 1
48
+ - 2
49
+ - 4
50
+ - 4
51
+ num_res_blocks: 2
52
+ attn_resolutions: []
53
+ dropout: 0.0
54
+ lossconfig:
55
+ target: torch.nn.Identity
56
+
57
+ clip_config:
58
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
59
+ params:
60
+ freeze: True
61
+ layer: "penultimate"
62
+ ip_mode: "local_resample"
imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+ zero_snr: true
10
+
11
+ unet_config:
12
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
13
+ params:
14
+ image_size: 32 # unused
15
+ in_channels: 4
16
+ out_channels: 4
17
+ model_channels: 320
18
+ attention_resolutions: [ 4, 2, 1 ]
19
+ num_res_blocks: 2
20
+ channel_mult: [ 1, 2, 4, 4 ]
21
+ num_head_channels: 64 # need to fix for flash-attn
22
+ use_spatial_transformer: True
23
+ use_linear_in_transformer: True
24
+ transformer_depth: 1
25
+ context_dim: 1024
26
+ use_checkpoint: False
27
+ legacy: False
28
+ camera_dim: 16
29
+ with_ip: True
30
+ ip_dim: 16 # ip token length
31
+ ip_mode: "local_resample"
32
+
33
+ vae_config:
34
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
35
+ params:
36
+ embed_dim: 4
37
+ monitor: val/rec_loss
38
+ ddconfig:
39
+ #attn_type: "vanilla-xformers"
40
+ double_z: true
41
+ z_channels: 4
42
+ resolution: 256
43
+ in_channels: 3
44
+ out_ch: 3
45
+ ch: 128
46
+ ch_mult:
47
+ - 1
48
+ - 2
49
+ - 4
50
+ - 4
51
+ num_res_blocks: 2
52
+ attn_resolutions: []
53
+ dropout: 0.0
54
+ lossconfig:
55
+ target: torch.nn.Identity
56
+
57
+ clip_config:
58
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
59
+ params:
60
+ freeze: True
61
+ layer: "penultimate"
62
+ ip_mode: "local_resample"
imagedream/ldm/__init__.py ADDED
File without changes
imagedream/ldm/interface.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .modules.diffusionmodules.util import (
9
+ make_beta_schedule,
10
+ extract_into_tensor,
11
+ enforce_zero_terminal_snr,
12
+ noise_like,
13
+ )
14
+ from .util import exists, default, instantiate_from_config
15
+ from .modules.distributions.distributions import DiagonalGaussianDistribution
16
+
17
+
18
+ class DiffusionWrapper(nn.Module):
19
+ def __init__(self, diffusion_model):
20
+ super().__init__()
21
+ self.diffusion_model = diffusion_model
22
+
23
+ def forward(self, *args, **kwargs):
24
+ return self.diffusion_model(*args, **kwargs)
25
+
26
+
27
+ class LatentDiffusionInterface(nn.Module):
28
+ """a simple interface class for LDM inference"""
29
+
30
+ def __init__(
31
+ self,
32
+ unet_config,
33
+ clip_config,
34
+ vae_config,
35
+ parameterization="eps",
36
+ scale_factor=0.18215,
37
+ beta_schedule="linear",
38
+ timesteps=1000,
39
+ linear_start=0.00085,
40
+ linear_end=0.0120,
41
+ cosine_s=8e-3,
42
+ given_betas=None,
43
+ zero_snr=False,
44
+ *args,
45
+ **kwargs,
46
+ ):
47
+ super().__init__()
48
+
49
+ unet = instantiate_from_config(unet_config)
50
+ self.model = DiffusionWrapper(unet)
51
+ self.clip_model = instantiate_from_config(clip_config)
52
+ self.vae_model = instantiate_from_config(vae_config)
53
+
54
+ self.parameterization = parameterization
55
+ self.scale_factor = scale_factor
56
+ self.register_schedule(
57
+ given_betas=given_betas,
58
+ beta_schedule=beta_schedule,
59
+ timesteps=timesteps,
60
+ linear_start=linear_start,
61
+ linear_end=linear_end,
62
+ cosine_s=cosine_s,
63
+ zero_snr=zero_snr
64
+ )
65
+
66
+ def register_schedule(
67
+ self,
68
+ given_betas=None,
69
+ beta_schedule="linear",
70
+ timesteps=1000,
71
+ linear_start=1e-4,
72
+ linear_end=2e-2,
73
+ cosine_s=8e-3,
74
+ zero_snr=False
75
+ ):
76
+ if exists(given_betas):
77
+ betas = given_betas
78
+ else:
79
+ betas = make_beta_schedule(
80
+ beta_schedule,
81
+ timesteps,
82
+ linear_start=linear_start,
83
+ linear_end=linear_end,
84
+ cosine_s=cosine_s,
85
+ )
86
+ if zero_snr:
87
+ print("--- using zero snr---")
88
+ betas = enforce_zero_terminal_snr(betas).numpy()
89
+ alphas = 1.0 - betas
90
+ alphas_cumprod = np.cumprod(alphas, axis=0)
91
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
92
+
93
+ (timesteps,) = betas.shape
94
+ self.num_timesteps = int(timesteps)
95
+ self.linear_start = linear_start
96
+ self.linear_end = linear_end
97
+ assert (
98
+ alphas_cumprod.shape[0] == self.num_timesteps
99
+ ), "alphas have to be defined for each timestep"
100
+
101
+ to_torch = partial(torch.tensor, dtype=torch.float32)
102
+
103
+ self.register_buffer("betas", to_torch(betas))
104
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
105
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
106
+
107
+ # calculations for diffusion q(x_t | x_{t-1}) and others
108
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
109
+ self.register_buffer(
110
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
111
+ )
112
+ self.register_buffer(
113
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
114
+ )
115
+ self.register_buffer(
116
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
117
+ )
118
+ self.register_buffer(
119
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
120
+ )
121
+
122
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
123
+ self.v_posterior = 0
124
+ posterior_variance = (1 - self.v_posterior) * betas * (
125
+ 1.0 - alphas_cumprod_prev
126
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
127
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
128
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
129
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
130
+ self.register_buffer(
131
+ "posterior_log_variance_clipped",
132
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
133
+ )
134
+ self.register_buffer(
135
+ "posterior_mean_coef1",
136
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
137
+ )
138
+ self.register_buffer(
139
+ "posterior_mean_coef2",
140
+ to_torch(
141
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
142
+ ),
143
+ )
144
+
145
+ def q_sample(self, x_start, t, noise=None):
146
+ noise = default(noise, lambda: torch.randn_like(x_start))
147
+ return (
148
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
149
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
150
+ * noise
151
+ )
152
+
153
+ def get_v(self, x, noise, t):
154
+ return (
155
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
156
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
157
+ )
158
+
159
+ def predict_start_from_noise(self, x_t, t, noise):
160
+ return (
161
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
162
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
163
+ * noise
164
+ )
165
+
166
+ def predict_start_from_z_and_v(self, x_t, t, v):
167
+ return (
168
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
169
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
170
+ )
171
+
172
+ def predict_eps_from_z_and_v(self, x_t, t, v):
173
+ return (
174
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
175
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
176
+ * x_t
177
+ )
178
+
179
+ def apply_model(self, x_noisy, t, cond, **kwargs):
180
+ assert isinstance(cond, dict), "cond has to be a dictionary"
181
+ return self.model(x_noisy, t, **cond, **kwargs)
182
+
183
+ def get_learned_conditioning(self, prompts: List[str]):
184
+ return self.clip_model(prompts)
185
+
186
+ def get_learned_image_conditioning(self, images):
187
+ return self.clip_model.forward_image(images)
188
+
189
+ def get_first_stage_encoding(self, encoder_posterior):
190
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
191
+ z = encoder_posterior.sample()
192
+ elif isinstance(encoder_posterior, torch.Tensor):
193
+ z = encoder_posterior
194
+ else:
195
+ raise NotImplementedError(
196
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
197
+ )
198
+ return self.scale_factor * z
199
+
200
+ def encode_first_stage(self, x):
201
+ return self.vae_model.encode(x)
202
+
203
+ def decode_first_stage(self, z):
204
+ z = 1.0 / self.scale_factor * z
205
+ return self.vae_model.decode(z)
imagedream/ldm/models/__init__.py ADDED
File without changes
imagedream/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from contextlib import contextmanager
4
+
5
+ from ..modules.diffusionmodules.model import Encoder, Decoder
6
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
7
+
8
+ from ..util import instantiate_from_config
9
+ from ..modules.ema import LitEma
10
+
11
+
12
+ class AutoencoderKL(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ ddconfig,
16
+ lossconfig,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ ema_decay=None,
24
+ learn_logvar=False,
25
+ ):
26
+ super().__init__()
27
+ self.learn_logvar = learn_logvar
28
+ self.image_key = image_key
29
+ self.encoder = Encoder(**ddconfig)
30
+ self.decoder = Decoder(**ddconfig)
31
+ self.loss = instantiate_from_config(lossconfig)
32
+ assert ddconfig["double_z"]
33
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ self.embed_dim = embed_dim
36
+ if colorize_nlabels is not None:
37
+ assert type(colorize_nlabels) == int
38
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ self.use_ema = ema_decay is not None
43
+ if self.use_ema:
44
+ self.ema_decay = ema_decay
45
+ assert 0.0 < ema_decay < 1.0
46
+ self.model_ema = LitEma(self, decay=ema_decay)
47
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48
+
49
+ if ckpt_path is not None:
50
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51
+
52
+ def init_from_ckpt(self, path, ignore_keys=list()):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ keys = list(sd.keys())
55
+ for k in keys:
56
+ for ik in ignore_keys:
57
+ if k.startswith(ik):
58
+ print("Deleting key {} from state_dict.".format(k))
59
+ del sd[k]
60
+ self.load_state_dict(sd, strict=False)
61
+ print(f"Restored from {path}")
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ if self.use_ema:
80
+ self.model_ema(self)
81
+
82
+ def encode(self, x):
83
+ h = self.encoder(x)
84
+ moments = self.quant_conv(h)
85
+ posterior = DiagonalGaussianDistribution(moments)
86
+ return posterior
87
+
88
+ def decode(self, z):
89
+ z = self.post_quant_conv(z)
90
+ dec = self.decoder(z)
91
+ return dec
92
+
93
+ def forward(self, input, sample_posterior=True):
94
+ posterior = self.encode(input)
95
+ if sample_posterior:
96
+ z = posterior.sample()
97
+ else:
98
+ z = posterior.mode()
99
+ dec = self.decode(z)
100
+ return dec, posterior
101
+
102
+ def get_input(self, batch, k):
103
+ x = batch[k]
104
+ if len(x.shape) == 3:
105
+ x = x[..., None]
106
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107
+ return x
108
+
109
+ def training_step(self, batch, batch_idx, optimizer_idx):
110
+ inputs = self.get_input(batch, self.image_key)
111
+ reconstructions, posterior = self(inputs)
112
+
113
+ if optimizer_idx == 0:
114
+ # train encoder+decoder+logvar
115
+ aeloss, log_dict_ae = self.loss(
116
+ inputs,
117
+ reconstructions,
118
+ posterior,
119
+ optimizer_idx,
120
+ self.global_step,
121
+ last_layer=self.get_last_layer(),
122
+ split="train",
123
+ )
124
+ self.log(
125
+ "aeloss",
126
+ aeloss,
127
+ prog_bar=True,
128
+ logger=True,
129
+ on_step=True,
130
+ on_epoch=True,
131
+ )
132
+ self.log_dict(
133
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
134
+ )
135
+ return aeloss
136
+
137
+ if optimizer_idx == 1:
138
+ # train the discriminator
139
+ discloss, log_dict_disc = self.loss(
140
+ inputs,
141
+ reconstructions,
142
+ posterior,
143
+ optimizer_idx,
144
+ self.global_step,
145
+ last_layer=self.get_last_layer(),
146
+ split="train",
147
+ )
148
+
149
+ self.log(
150
+ "discloss",
151
+ discloss,
152
+ prog_bar=True,
153
+ logger=True,
154
+ on_step=True,
155
+ on_epoch=True,
156
+ )
157
+ self.log_dict(
158
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
159
+ )
160
+ return discloss
161
+
162
+ def validation_step(self, batch, batch_idx):
163
+ log_dict = self._validation_step(batch, batch_idx)
164
+ with self.ema_scope():
165
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
166
+ return log_dict
167
+
168
+ def _validation_step(self, batch, batch_idx, postfix=""):
169
+ inputs = self.get_input(batch, self.image_key)
170
+ reconstructions, posterior = self(inputs)
171
+ aeloss, log_dict_ae = self.loss(
172
+ inputs,
173
+ reconstructions,
174
+ posterior,
175
+ 0,
176
+ self.global_step,
177
+ last_layer=self.get_last_layer(),
178
+ split="val" + postfix,
179
+ )
180
+
181
+ discloss, log_dict_disc = self.loss(
182
+ inputs,
183
+ reconstructions,
184
+ posterior,
185
+ 1,
186
+ self.global_step,
187
+ last_layer=self.get_last_layer(),
188
+ split="val" + postfix,
189
+ )
190
+
191
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
192
+ self.log_dict(log_dict_ae)
193
+ self.log_dict(log_dict_disc)
194
+ return self.log_dict
195
+
196
+ def configure_optimizers(self):
197
+ lr = self.learning_rate
198
+ ae_params_list = (
199
+ list(self.encoder.parameters())
200
+ + list(self.decoder.parameters())
201
+ + list(self.quant_conv.parameters())
202
+ + list(self.post_quant_conv.parameters())
203
+ )
204
+ if self.learn_logvar:
205
+ print(f"{self.__class__.__name__}: Learning logvar")
206
+ ae_params_list.append(self.loss.logvar)
207
+ opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
208
+ opt_disc = torch.optim.Adam(
209
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
210
+ )
211
+ return [opt_ae, opt_disc], []
212
+
213
+ def get_last_layer(self):
214
+ return self.decoder.conv_out.weight
215
+
216
+ @torch.no_grad()
217
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
218
+ log = dict()
219
+ x = self.get_input(batch, self.image_key)
220
+ x = x.to(self.device)
221
+ if not only_inputs:
222
+ xrec, posterior = self(x)
223
+ if x.shape[1] > 3:
224
+ # colorize with random projection
225
+ assert xrec.shape[1] > 3
226
+ x = self.to_rgb(x)
227
+ xrec = self.to_rgb(xrec)
228
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
229
+ log["reconstructions"] = xrec
230
+ if log_ema or self.use_ema:
231
+ with self.ema_scope():
232
+ xrec_ema, posterior_ema = self(x)
233
+ if x.shape[1] > 3:
234
+ # colorize with random projection
235
+ assert xrec_ema.shape[1] > 3
236
+ xrec_ema = self.to_rgb(xrec_ema)
237
+ log["samples_ema"] = self.decode(
238
+ torch.randn_like(posterior_ema.sample())
239
+ )
240
+ log["reconstructions_ema"] = xrec_ema
241
+ log["inputs"] = x
242
+ return log
243
+
244
+ def to_rgb(self, x):
245
+ assert self.image_key == "segmentation"
246
+ if not hasattr(self, "colorize"):
247
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
248
+ x = F.conv2d(x, weight=self.colorize)
249
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
250
+ return x
251
+
252
+
253
+ class IdentityFirstStage(torch.nn.Module):
254
+ def __init__(self, *args, vq_interface=False, **kwargs):
255
+ self.vq_interface = vq_interface
256
+ super().__init__()
257
+
258
+ def encode(self, x, *args, **kwargs):
259
+ return x
260
+
261
+ def decode(self, x, *args, **kwargs):
262
+ return x
263
+
264
+ def quantize(self, x, *args, **kwargs):
265
+ if self.vq_interface:
266
+ return x, None, [None, None, None]
267
+ return x
268
+
269
+ def forward(self, x, *args, **kwargs):
270
+ return x
imagedream/ldm/models/diffusion/__init__.py ADDED
File without changes
imagedream/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ...modules.diffusionmodules.util import (
9
+ make_ddim_sampling_parameters,
10
+ make_ddim_timesteps,
11
+ noise_like,
12
+ extract_into_tensor,
13
+ )
14
+
15
+
16
+ class DDIMSampler(object):
17
+ def __init__(self, model, schedule="linear", **kwargs):
18
+ super().__init__()
19
+ self.model = model
20
+ self.ddpm_num_timesteps = model.num_timesteps
21
+ self.schedule = schedule
22
+
23
+ def register_buffer(self, name, attr):
24
+ if type(attr) == torch.Tensor:
25
+ if attr.device != torch.device("cuda"):
26
+ attr = attr.to(torch.device("cuda"))
27
+ setattr(self, name, attr)
28
+
29
+ def make_schedule(
30
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
31
+ ):
32
+ self.ddim_timesteps = make_ddim_timesteps(
33
+ ddim_discr_method=ddim_discretize,
34
+ num_ddim_timesteps=ddim_num_steps,
35
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
36
+ verbose=verbose,
37
+ )
38
+ alphas_cumprod = self.model.alphas_cumprod
39
+ assert (
40
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
41
+ ), "alphas have to be defined for each timestep"
42
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
43
+
44
+ self.register_buffer("betas", to_torch(self.model.betas))
45
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
46
+ self.register_buffer(
47
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
48
+ )
49
+
50
+ # calculations for diffusion q(x_t | x_{t-1}) and others
51
+ self.register_buffer(
52
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
53
+ )
54
+ self.register_buffer(
55
+ "sqrt_one_minus_alphas_cumprod",
56
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
57
+ )
58
+ self.register_buffer(
59
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
60
+ )
61
+ self.register_buffer(
62
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
63
+ )
64
+ self.register_buffer(
65
+ "sqrt_recipm1_alphas_cumprod",
66
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
67
+ )
68
+
69
+ # ddim sampling parameters
70
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
71
+ alphacums=alphas_cumprod.cpu(),
72
+ ddim_timesteps=self.ddim_timesteps,
73
+ eta=ddim_eta,
74
+ verbose=verbose,
75
+ )
76
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
77
+ self.register_buffer("ddim_alphas", ddim_alphas)
78
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
79
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
80
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
81
+ (1 - self.alphas_cumprod_prev)
82
+ / (1 - self.alphas_cumprod)
83
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
84
+ )
85
+ self.register_buffer(
86
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
87
+ )
88
+
89
+ @torch.no_grad()
90
+ def sample(
91
+ self,
92
+ S,
93
+ batch_size,
94
+ shape,
95
+ conditioning=None,
96
+ callback=None,
97
+ normals_sequence=None,
98
+ img_callback=None,
99
+ quantize_x0=False,
100
+ eta=0.0,
101
+ mask=None,
102
+ x0=None,
103
+ temperature=1.0,
104
+ noise_dropout=0.0,
105
+ score_corrector=None,
106
+ corrector_kwargs=None,
107
+ verbose=True,
108
+ x_T=None,
109
+ log_every_t=100,
110
+ unconditional_guidance_scale=1.0,
111
+ unconditional_conditioning=None,
112
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
113
+ **kwargs,
114
+ ):
115
+ if conditioning is not None:
116
+ if isinstance(conditioning, dict):
117
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
118
+ if cbs != batch_size:
119
+ print(
120
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
121
+ )
122
+ else:
123
+ if conditioning.shape[0] != batch_size:
124
+ print(
125
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
126
+ )
127
+
128
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
129
+ # sampling
130
+ C, H, W = shape
131
+ size = (batch_size, C, H, W)
132
+
133
+ samples, intermediates = self.ddim_sampling(
134
+ conditioning,
135
+ size,
136
+ callback=callback,
137
+ img_callback=img_callback,
138
+ quantize_denoised=quantize_x0,
139
+ mask=mask,
140
+ x0=x0,
141
+ ddim_use_original_steps=False,
142
+ noise_dropout=noise_dropout,
143
+ temperature=temperature,
144
+ score_corrector=score_corrector,
145
+ corrector_kwargs=corrector_kwargs,
146
+ x_T=x_T,
147
+ log_every_t=log_every_t,
148
+ unconditional_guidance_scale=unconditional_guidance_scale,
149
+ unconditional_conditioning=unconditional_conditioning,
150
+ **kwargs,
151
+ )
152
+ return samples, intermediates
153
+
154
+ @torch.no_grad()
155
+ def ddim_sampling(
156
+ self,
157
+ cond,
158
+ shape,
159
+ x_T=None,
160
+ ddim_use_original_steps=False,
161
+ callback=None,
162
+ timesteps=None,
163
+ quantize_denoised=False,
164
+ mask=None,
165
+ x0=None,
166
+ img_callback=None,
167
+ log_every_t=100,
168
+ temperature=1.0,
169
+ noise_dropout=0.0,
170
+ score_corrector=None,
171
+ corrector_kwargs=None,
172
+ unconditional_guidance_scale=1.0,
173
+ unconditional_conditioning=None,
174
+ **kwargs,
175
+ ):
176
+ """
177
+ when inference time: all values of parameter
178
+ cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
179
+ shape: (5, 4, 32, 32)
180
+ x_T: None
181
+ ddim_use_original_steps: False
182
+ timesteps: None
183
+ callback: None
184
+ quantize_denoised: False
185
+ mask: None
186
+ image_callback: None
187
+ log_every_t: 100
188
+ temperature: 1.0
189
+ noise_dropout: 0.0
190
+ score_corrector: None
191
+ corrector_kwargs: None
192
+ unconditional_guidance_scale: 5
193
+ unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
194
+ kwargs: {}
195
+ """
196
+ device = self.model.betas.device
197
+ b = shape[0]
198
+ if x_T is None:
199
+ img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94
200
+ else:
201
+ img = x_T
202
+
203
+ if timesteps is None: # equal with set time step in hf
204
+ timesteps = (
205
+ self.ddpm_num_timesteps
206
+ if ddim_use_original_steps
207
+ else self.ddim_timesteps
208
+ )
209
+ elif timesteps is not None and not ddim_use_original_steps:
210
+ subset_end = (
211
+ int(
212
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
213
+ * self.ddim_timesteps.shape[0]
214
+ )
215
+ - 1
216
+ )
217
+ timesteps = self.ddim_timesteps[:subset_end]
218
+
219
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
220
+ time_range = ( # reversed timesteps
221
+ reversed(range(0, timesteps))
222
+ if ddim_use_original_steps
223
+ else np.flip(timesteps)
224
+ )
225
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
226
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
227
+ for i, step in enumerate(iterator):
228
+ index = total_steps - i - 1
229
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
230
+
231
+ if mask is not None:
232
+ assert x0 is not None
233
+ img_orig = self.model.q_sample(
234
+ x0, ts
235
+ ) # TODO: deterministic forward pass?
236
+ img = img_orig * mask + (1.0 - mask) * img
237
+
238
+ outs = self.p_sample_ddim(
239
+ img,
240
+ cond,
241
+ ts,
242
+ index=index,
243
+ use_original_steps=ddim_use_original_steps,
244
+ quantize_denoised=quantize_denoised,
245
+ temperature=temperature,
246
+ noise_dropout=noise_dropout,
247
+ score_corrector=score_corrector,
248
+ corrector_kwargs=corrector_kwargs,
249
+ unconditional_guidance_scale=unconditional_guidance_scale,
250
+ unconditional_conditioning=unconditional_conditioning,
251
+ **kwargs,
252
+ )
253
+ img, pred_x0 = outs
254
+ if callback:
255
+ callback(i)
256
+ if img_callback:
257
+ img_callback(pred_x0, i)
258
+
259
+ if index % log_every_t == 0 or index == total_steps - 1:
260
+ intermediates["x_inter"].append(img)
261
+ intermediates["pred_x0"].append(pred_x0)
262
+
263
+ return img, intermediates
264
+
265
+ @torch.no_grad()
266
+ def p_sample_ddim(
267
+ self,
268
+ x,
269
+ c,
270
+ t,
271
+ index,
272
+ repeat_noise=False,
273
+ use_original_steps=False,
274
+ quantize_denoised=False,
275
+ temperature=1.0,
276
+ noise_dropout=0.0,
277
+ score_corrector=None,
278
+ corrector_kwargs=None,
279
+ unconditional_guidance_scale=1.0,
280
+ unconditional_conditioning=None,
281
+ dynamic_threshold=None,
282
+ **kwargs,
283
+ ):
284
+ b, *_, device = *x.shape, x.device
285
+
286
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
287
+ model_output = self.model.apply_model(x, t, c)
288
+ else:
289
+ x_in = torch.cat([x] * 2)
290
+ t_in = torch.cat([t] * 2)
291
+ if isinstance(c, dict):
292
+ assert isinstance(unconditional_conditioning, dict)
293
+ c_in = dict()
294
+ for k in c:
295
+ if isinstance(c[k], list):
296
+ c_in[k] = [
297
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
298
+ for i in range(len(c[k]))
299
+ ]
300
+ elif isinstance(c[k], torch.Tensor):
301
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
302
+ else:
303
+ assert c[k] == unconditional_conditioning[k]
304
+ c_in[k] = c[k]
305
+ elif isinstance(c, list):
306
+ c_in = list()
307
+ assert isinstance(unconditional_conditioning, list)
308
+ for i in range(len(c)):
309
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
310
+ else:
311
+ c_in = torch.cat([unconditional_conditioning, c])
312
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
313
+ model_output = model_uncond + unconditional_guidance_scale * (
314
+ model_t - model_uncond
315
+ )
316
+
317
+
318
+ if self.model.parameterization == "v":
319
+ print("using v!")
320
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
321
+ else:
322
+ e_t = model_output
323
+
324
+ if score_corrector is not None:
325
+ assert self.model.parameterization == "eps", "not implemented"
326
+ e_t = score_corrector.modify_score(
327
+ self.model, e_t, x, t, c, **corrector_kwargs
328
+ )
329
+
330
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
331
+ alphas_prev = (
332
+ self.model.alphas_cumprod_prev
333
+ if use_original_steps
334
+ else self.ddim_alphas_prev
335
+ )
336
+ sqrt_one_minus_alphas = (
337
+ self.model.sqrt_one_minus_alphas_cumprod
338
+ if use_original_steps
339
+ else self.ddim_sqrt_one_minus_alphas
340
+ )
341
+ sigmas = (
342
+ self.model.ddim_sigmas_for_original_num_steps
343
+ if use_original_steps
344
+ else self.ddim_sigmas
345
+ )
346
+ # select parameters corresponding to the currently considered timestep
347
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
348
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
349
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
350
+ sqrt_one_minus_at = torch.full(
351
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
352
+ )
353
+
354
+ # current prediction for x_0
355
+ if self.model.parameterization != "v":
356
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
357
+ else:
358
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
359
+
360
+ if quantize_denoised:
361
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
362
+
363
+ if dynamic_threshold is not None:
364
+ raise NotImplementedError()
365
+
366
+ # direction pointing to x_t
367
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
368
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
369
+ if noise_dropout > 0.0:
370
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
371
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
372
+ return x_prev, pred_x0
373
+
374
+ @torch.no_grad()
375
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
376
+ # fast, but does not allow for exact reconstruction
377
+ # t serves as an index to gather the correct alphas
378
+ if use_original_steps:
379
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
380
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
381
+ else:
382
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
383
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
384
+
385
+ if noise is None:
386
+ noise = torch.randn_like(x0)
387
+ return (
388
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
389
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
390
+ )
391
+
392
+ @torch.no_grad()
393
+ def decode(
394
+ self,
395
+ x_latent,
396
+ cond,
397
+ t_start,
398
+ unconditional_guidance_scale=1.0,
399
+ unconditional_conditioning=None,
400
+ use_original_steps=False,
401
+ **kwargs,
402
+ ):
403
+ timesteps = (
404
+ np.arange(self.ddpm_num_timesteps)
405
+ if use_original_steps
406
+ else self.ddim_timesteps
407
+ )
408
+ timesteps = timesteps[:t_start]
409
+
410
+ time_range = np.flip(timesteps)
411
+ total_steps = timesteps.shape[0]
412
+
413
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
414
+ x_dec = x_latent
415
+ for i, step in enumerate(iterator):
416
+ index = total_steps - i - 1
417
+ ts = torch.full(
418
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
419
+ )
420
+ x_dec, _ = self.p_sample_ddim(
421
+ x_dec,
422
+ cond,
423
+ ts,
424
+ index=index,
425
+ use_original_steps=use_original_steps,
426
+ unconditional_guidance_scale=unconditional_guidance_scale,
427
+ unconditional_conditioning=unconditional_conditioning,
428
+ **kwargs,
429
+ )
430
+ return x_dec
imagedream/ldm/modules/__init__.py ADDED
File without changes
imagedream/ldm/modules/attention.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+
9
+ from .diffusionmodules.util import checkpoint
10
+
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+
16
+ XFORMERS_IS_AVAILBLE = True
17
+ except:
18
+ XFORMERS_IS_AVAILBLE = False
19
+
20
+ # CrossAttn precision handling
21
+ import os
22
+
23
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def uniq(arr):
31
+ return {el: True for el in arr}.keys()
32
+
33
+
34
+ def default(val, d):
35
+ if exists(val):
36
+ return val
37
+ return d() if isfunction(d) else d
38
+
39
+
40
+ def max_neg_value(t):
41
+ return -torch.finfo(t.dtype).max
42
+
43
+
44
+ def init_(tensor):
45
+ dim = tensor.shape[-1]
46
+ std = 1 / math.sqrt(dim)
47
+ tensor.uniform_(-std, std)
48
+ return tensor
49
+
50
+
51
+ # feedforward
52
+ class GEGLU(nn.Module):
53
+ def __init__(self, dim_in, dim_out):
54
+ super().__init__()
55
+ self.proj = nn.Linear(dim_in, dim_out * 2)
56
+
57
+ def forward(self, x):
58
+ x, gate = self.proj(x).chunk(2, dim=-1)
59
+ return x * F.gelu(gate)
60
+
61
+
62
+ class FeedForward(nn.Module):
63
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
64
+ super().__init__()
65
+ inner_dim = int(dim * mult)
66
+ dim_out = default(dim_out, dim)
67
+ project_in = (
68
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
69
+ if not glu
70
+ else GEGLU(dim, inner_dim)
71
+ )
72
+
73
+ self.net = nn.Sequential(
74
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.net(x)
79
+
80
+
81
+ def zero_module(module):
82
+ """
83
+ Zero out the parameters of a module and return it.
84
+ """
85
+ for p in module.parameters():
86
+ p.detach().zero_()
87
+ return module
88
+
89
+
90
+ def Normalize(in_channels):
91
+ return torch.nn.GroupNorm(
92
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
93
+ )
94
+
95
+
96
+ class SpatialSelfAttention(nn.Module):
97
+ def __init__(self, in_channels):
98
+ super().__init__()
99
+ self.in_channels = in_channels
100
+
101
+ self.norm = Normalize(in_channels)
102
+ self.q = torch.nn.Conv2d(
103
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
104
+ )
105
+ self.k = torch.nn.Conv2d(
106
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
107
+ )
108
+ self.v = torch.nn.Conv2d(
109
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
110
+ )
111
+ self.proj_out = torch.nn.Conv2d(
112
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
113
+ )
114
+
115
+ def forward(self, x):
116
+ h_ = x
117
+ h_ = self.norm(h_)
118
+ q = self.q(h_)
119
+ k = self.k(h_)
120
+ v = self.v(h_)
121
+
122
+ # compute attention
123
+ b, c, h, w = q.shape
124
+ q = rearrange(q, "b c h w -> b (h w) c")
125
+ k = rearrange(k, "b c h w -> b c (h w)")
126
+ w_ = torch.einsum("bij,bjk->bik", q, k)
127
+
128
+ w_ = w_ * (int(c) ** (-0.5))
129
+ w_ = torch.nn.functional.softmax(w_, dim=2)
130
+
131
+ # attend to values
132
+ v = rearrange(v, "b c h w -> b c (h w)")
133
+ w_ = rearrange(w_, "b i j -> b j i")
134
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
135
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
136
+ h_ = self.proj_out(h_)
137
+
138
+ return x + h_
139
+
140
+
141
+ class MemoryEfficientCrossAttention(nn.Module):
142
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
143
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
144
+ super().__init__()
145
+ print(
146
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
147
+ f"{heads} heads."
148
+ )
149
+ inner_dim = dim_head * heads
150
+ context_dim = default(context_dim, query_dim)
151
+
152
+ self.heads = heads
153
+ self.dim_head = dim_head
154
+
155
+ self.with_ip = kwargs.get("with_ip", False)
156
+ if self.with_ip and (context_dim is not None):
157
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
158
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
159
+ self.ip_dim= kwargs.get("ip_dim", 16)
160
+ self.ip_weight = kwargs.get("ip_weight", 1.0)
161
+
162
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
163
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
164
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
165
+
166
+ self.to_out = nn.Sequential(
167
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
168
+ )
169
+ self.attention_op: Optional[Any] = None
170
+
171
+ def forward(self, x, context=None, mask=None):
172
+ q = self.to_q(x)
173
+
174
+ has_ip = self.with_ip and (context is not None)
175
+ if has_ip:
176
+ # context dim [(b frame_num), (77 + img_token), 1024]
177
+ token_len = context.shape[1]
178
+ context_ip = context[:, -self.ip_dim:, :]
179
+ k_ip = self.to_k_ip(context_ip)
180
+ v_ip = self.to_v_ip(context_ip)
181
+ context = context[:, :(token_len - self.ip_dim), :]
182
+
183
+ context = default(context, x)
184
+ k = self.to_k(context)
185
+ v = self.to_v(context)
186
+
187
+ b, _, _ = q.shape
188
+ q, k, v = map(
189
+ lambda t: t.unsqueeze(3)
190
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
191
+ .permute(0, 2, 1, 3)
192
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
193
+ .contiguous(),
194
+ (q, k, v),
195
+ )
196
+
197
+ # actually compute the attention, what we cannot get enough of
198
+ out = xformers.ops.memory_efficient_attention(
199
+ q, k, v, attn_bias=None, op=self.attention_op
200
+ )
201
+
202
+ if has_ip:
203
+ k_ip, v_ip = map(
204
+ lambda t: t.unsqueeze(3)
205
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
206
+ .permute(0, 2, 1, 3)
207
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
208
+ .contiguous(),
209
+ (k_ip, v_ip),
210
+ )
211
+ # actually compute the attention, what we cannot get enough of
212
+ out_ip = xformers.ops.memory_efficient_attention(
213
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
214
+ )
215
+ out = out + self.ip_weight * out_ip
216
+
217
+ if exists(mask):
218
+ raise NotImplementedError
219
+ out = (
220
+ out.unsqueeze(0)
221
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
222
+ .permute(0, 2, 1, 3)
223
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
224
+ )
225
+ return self.to_out(out)
226
+
227
+
228
+ class BasicTransformerBlock(nn.Module):
229
+ def __init__(
230
+ self,
231
+ dim,
232
+ n_heads,
233
+ d_head,
234
+ dropout=0.0,
235
+ context_dim=None,
236
+ gated_ff=True,
237
+ checkpoint=True,
238
+ disable_self_attn=False,
239
+ **kwargs
240
+ ):
241
+ super().__init__()
242
+ assert XFORMERS_IS_AVAILBLE, "xformers is not available"
243
+ attn_cls = MemoryEfficientCrossAttention
244
+ self.disable_self_attn = disable_self_attn
245
+ self.attn1 = attn_cls(
246
+ query_dim=dim,
247
+ heads=n_heads,
248
+ dim_head=d_head,
249
+ dropout=dropout,
250
+ context_dim=context_dim if self.disable_self_attn else None,
251
+ ) # is a self-attention if not self.disable_self_attn
252
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
+ self.attn2 = attn_cls(
254
+ query_dim=dim,
255
+ context_dim=context_dim,
256
+ heads=n_heads,
257
+ dim_head=d_head,
258
+ dropout=dropout,
259
+ **kwargs
260
+ ) # is self-attn if context is none
261
+ self.norm1 = nn.LayerNorm(dim)
262
+ self.norm2 = nn.LayerNorm(dim)
263
+ self.norm3 = nn.LayerNorm(dim)
264
+ self.checkpoint = checkpoint
265
+
266
+ def forward(self, x, context=None):
267
+ return checkpoint(
268
+ self._forward, (x, context), self.parameters(), self.checkpoint
269
+ )
270
+
271
+ def _forward(self, x, context=None):
272
+ x = (
273
+ self.attn1(
274
+ self.norm1(x), context=context if self.disable_self_attn else None
275
+ )
276
+ + x
277
+ )
278
+ x = self.attn2(self.norm2(x), context=context) + x
279
+ x = self.ff(self.norm3(x)) + x
280
+ return x
281
+
282
+
283
+ class SpatialTransformer(nn.Module):
284
+ """
285
+ Transformer block for image-like data.
286
+ First, project the input (aka embedding)
287
+ and reshape to b, t, d.
288
+ Then apply standard transformer action.
289
+ Finally, reshape to image
290
+ NEW: use_linear for more efficiency instead of the 1x1 convs
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ in_channels,
296
+ n_heads,
297
+ d_head,
298
+ depth=1,
299
+ dropout=0.0,
300
+ context_dim=None,
301
+ disable_self_attn=False,
302
+ use_linear=False,
303
+ use_checkpoint=True,
304
+ **kwargs
305
+ ):
306
+ super().__init__()
307
+ if exists(context_dim) and not isinstance(context_dim, list):
308
+ context_dim = [context_dim]
309
+ self.in_channels = in_channels
310
+ inner_dim = n_heads * d_head
311
+ self.norm = Normalize(in_channels)
312
+ if not use_linear:
313
+ self.proj_in = nn.Conv2d(
314
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
315
+ )
316
+ else:
317
+ self.proj_in = nn.Linear(in_channels, inner_dim)
318
+
319
+ self.transformer_blocks = nn.ModuleList(
320
+ [
321
+ BasicTransformerBlock(
322
+ inner_dim,
323
+ n_heads,
324
+ d_head,
325
+ dropout=dropout,
326
+ context_dim=context_dim[d],
327
+ disable_self_attn=disable_self_attn,
328
+ checkpoint=use_checkpoint,
329
+ **kwargs
330
+ )
331
+ for d in range(depth)
332
+ ]
333
+ )
334
+ if not use_linear:
335
+ self.proj_out = zero_module(
336
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
337
+ )
338
+ else:
339
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
340
+ self.use_linear = use_linear
341
+
342
+ def forward(self, x, context=None):
343
+ # note: if no context is given, cross-attention defaults to self-attention
344
+ if not isinstance(context, list):
345
+ context = [context]
346
+ b, c, h, w = x.shape
347
+ x_in = x
348
+ x = self.norm(x)
349
+ if not self.use_linear:
350
+ x = self.proj_in(x)
351
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
352
+ if self.use_linear:
353
+ x = self.proj_in(x)
354
+ for i, block in enumerate(self.transformer_blocks):
355
+ x = block(x, context=context[i])
356
+ if self.use_linear:
357
+ x = self.proj_out(x)
358
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
359
+ if not self.use_linear:
360
+ x = self.proj_out(x)
361
+ return x + x_in
362
+
363
+
364
+ class BasicTransformerBlock3D(BasicTransformerBlock):
365
+ def forward(self, x, context=None, num_frames=1):
366
+ return checkpoint(
367
+ self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
368
+ )
369
+
370
+ def _forward(self, x, context=None, num_frames=1):
371
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
372
+ x = (
373
+ self.attn1(
374
+ self.norm1(x),
375
+ context=context if self.disable_self_attn else None
376
+ )
377
+ + x
378
+ )
379
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
380
+ x = self.attn2(self.norm2(x), context=context) + x
381
+ x = self.ff(self.norm3(x)) + x
382
+ return x
383
+
384
+
385
+ class SpatialTransformer3D(nn.Module):
386
+ """3D self-attention"""
387
+
388
+ def __init__(
389
+ self,
390
+ in_channels,
391
+ n_heads,
392
+ d_head,
393
+ depth=1,
394
+ dropout=0.0,
395
+ context_dim=None,
396
+ disable_self_attn=False,
397
+ use_linear=False,
398
+ use_checkpoint=True,
399
+ **kwargs
400
+ ):
401
+ super().__init__()
402
+ if exists(context_dim) and not isinstance(context_dim, list):
403
+ context_dim = [context_dim]
404
+ self.in_channels = in_channels
405
+ inner_dim = n_heads * d_head
406
+ self.norm = Normalize(in_channels)
407
+ if not use_linear:
408
+ self.proj_in = nn.Conv2d(
409
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
410
+ )
411
+ else:
412
+ self.proj_in = nn.Linear(in_channels, inner_dim)
413
+
414
+ self.transformer_blocks = nn.ModuleList(
415
+ [
416
+ BasicTransformerBlock3D(
417
+ inner_dim,
418
+ n_heads,
419
+ d_head,
420
+ dropout=dropout,
421
+ context_dim=context_dim[d],
422
+ disable_self_attn=disable_self_attn,
423
+ checkpoint=use_checkpoint,
424
+ **kwargs
425
+ )
426
+ for d in range(depth)
427
+ ]
428
+ )
429
+ if not use_linear:
430
+ self.proj_out = zero_module(
431
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
432
+ )
433
+ else:
434
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
435
+ self.use_linear = use_linear
436
+
437
+ def forward(self, x, context=None, num_frames=1):
438
+ # note: if no context is given, cross-attention defaults to self-attention
439
+ if not isinstance(context, list):
440
+ context = [context]
441
+ b, c, h, w = x.shape
442
+ x_in = x
443
+ x = self.norm(x)
444
+ if not self.use_linear:
445
+ x = self.proj_in(x)
446
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
447
+ if self.use_linear:
448
+ x = self.proj_in(x)
449
+ for i, block in enumerate(self.transformer_blocks):
450
+ x = block(x, context=context[i], num_frames=num_frames)
451
+ if self.use_linear:
452
+ x = self.proj_out(x)
453
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
454
+ if not self.use_linear:
455
+ x = self.proj_out(x)
456
+ return x + x_in
imagedream/ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
imagedream/ldm/modules/diffusionmodules/adaptors.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ # FFN
9
+ def FeedForward(dim, mult=4):
10
+ inner_dim = int(dim * mult)
11
+ return nn.Sequential(
12
+ nn.LayerNorm(dim),
13
+ nn.Linear(dim, inner_dim, bias=False),
14
+ nn.GELU(),
15
+ nn.Linear(inner_dim, dim, bias=False),
16
+ )
17
+
18
+
19
+ def reshape_tensor(x, heads):
20
+ bs, length, width = x.shape
21
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
22
+ x = x.view(bs, length, heads, -1)
23
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24
+ x = x.transpose(1, 2)
25
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26
+ x = x.reshape(bs, heads, length, -1)
27
+ return x
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.dim_head = dim_head
35
+ self.heads = heads
36
+ inner_dim = dim_head * heads
37
+
38
+ self.norm1 = nn.LayerNorm(dim)
39
+ self.norm2 = nn.LayerNorm(dim)
40
+
41
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
42
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
44
+
45
+
46
+ def forward(self, x, latents):
47
+ """
48
+ Args:
49
+ x (torch.Tensor): image features
50
+ shape (b, n1, D)
51
+ latent (torch.Tensor): latent features
52
+ shape (b, n2, D)
53
+ """
54
+ x = self.norm1(x)
55
+ latents = self.norm2(latents)
56
+
57
+ b, l, _ = latents.shape
58
+
59
+ q = self.to_q(latents)
60
+ kv_input = torch.cat((x, latents), dim=-2)
61
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
62
+
63
+ q = reshape_tensor(q, self.heads)
64
+ k = reshape_tensor(k, self.heads)
65
+ v = reshape_tensor(v, self.heads)
66
+
67
+ # attention
68
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
69
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
70
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
+ out = weight @ v
72
+
73
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
74
+
75
+ return self.to_out(out)
76
+
77
+
78
+ class ImageProjModel(torch.nn.Module):
79
+ """Projection Model"""
80
+ def __init__(self,
81
+ cross_attention_dim=1024,
82
+ clip_embeddings_dim=1024,
83
+ clip_extra_context_tokens=4):
84
+ super().__init__()
85
+ self.cross_attention_dim = cross_attention_dim
86
+ self.clip_extra_context_tokens = clip_extra_context_tokens
87
+
88
+ # from 1024 -> 4 * 1024
89
+ self.proj = torch.nn.Linear(
90
+ clip_embeddings_dim,
91
+ self.clip_extra_context_tokens * cross_attention_dim)
92
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
93
+
94
+ def forward(self, image_embeds):
95
+ embeds = image_embeds
96
+ clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
97
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
98
+ return clip_extra_context_tokens
99
+
100
+
101
+ class SimpleReSampler(nn.Module):
102
+ def __init__(self, embedding_dim=1280, output_dim=1024):
103
+ super().__init__()
104
+ self.proj_out = nn.Linear(embedding_dim, output_dim)
105
+ self.norm_out = nn.LayerNorm(output_dim)
106
+
107
+ def forward(self, latents):
108
+ """
109
+ latents: B 256 N
110
+ """
111
+ latents = self.proj_out(latents)
112
+ return self.norm_out(latents)
113
+
114
+
115
+ class Resampler(nn.Module):
116
+ def __init__(
117
+ self,
118
+ dim=1024,
119
+ depth=8,
120
+ dim_head=64,
121
+ heads=16,
122
+ num_queries=8,
123
+ embedding_dim=768,
124
+ output_dim=1024,
125
+ ff_mult=4,
126
+ ):
127
+ super().__init__()
128
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
129
+ self.proj_in = nn.Linear(embedding_dim, dim)
130
+ self.proj_out = nn.Linear(dim, output_dim)
131
+ self.norm_out = nn.LayerNorm(output_dim)
132
+
133
+ self.layers = nn.ModuleList([])
134
+ for _ in range(depth):
135
+ self.layers.append(
136
+ nn.ModuleList(
137
+ [
138
+ PerceiverAttention(dim=dim,
139
+ dim_head=dim_head,
140
+ heads=heads),
141
+ FeedForward(dim=dim, mult=ff_mult),
142
+ ]
143
+ )
144
+ )
145
+
146
+ def forward(self, x):
147
+ latents = self.latents.repeat(x.size(0), 1, 1)
148
+ x = self.proj_in(x)
149
+ for attn, ff in self.layers:
150
+ latents = attn(x, latents) + latents
151
+ latents = ff(latents) + latents
152
+
153
+ latents = self.proj_out(latents)
154
+ return self.norm_out(latents)
155
+
156
+
157
+ if __name__ == '__main__':
158
+ resampler = Resampler(embedding_dim=1280)
159
+ resampler = SimpleReSampler(embedding_dim=1280)
160
+ tensor = torch.rand(4, 257, 1280)
161
+ embed = resampler(tensor)
162
+ # embed = (tensor)
163
+ print(embed.shape)
imagedream/ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Optional, Any
8
+
9
+ from ..attention import MemoryEfficientCrossAttention
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+
15
+ XFORMERS_IS_AVAILBLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILBLE = False
18
+ print("No module 'xformers'. Proceeding without it.")
19
+
20
+
21
+ def get_timestep_embedding(timesteps, embedding_dim):
22
+ """
23
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
24
+ From Fairseq.
25
+ Build sinusoidal embeddings.
26
+ This matches the implementation in tensor2tensor, but differs slightly
27
+ from the description in Section 3.5 of "Attention Is All You Need".
28
+ """
29
+ assert len(timesteps.shape) == 1
30
+
31
+ half_dim = embedding_dim // 2
32
+ emb = math.log(10000) / (half_dim - 1)
33
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
34
+ emb = emb.to(device=timesteps.device)
35
+ emb = timesteps.float()[:, None] * emb[None, :]
36
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
37
+ if embedding_dim % 2 == 1: # zero pad
38
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
39
+ return emb
40
+
41
+
42
+ def nonlinearity(x):
43
+ # swish
44
+ return x * torch.sigmoid(x)
45
+
46
+
47
+ def Normalize(in_channels, num_groups=32):
48
+ return torch.nn.GroupNorm(
49
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
50
+ )
51
+
52
+
53
+ class Upsample(nn.Module):
54
+ def __init__(self, in_channels, with_conv):
55
+ super().__init__()
56
+ self.with_conv = with_conv
57
+ if self.with_conv:
58
+ self.conv = torch.nn.Conv2d(
59
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
60
+ )
61
+
62
+ def forward(self, x):
63
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
64
+ if self.with_conv:
65
+ x = self.conv(x)
66
+ return x
67
+
68
+
69
+ class Downsample(nn.Module):
70
+ def __init__(self, in_channels, with_conv):
71
+ super().__init__()
72
+ self.with_conv = with_conv
73
+ if self.with_conv:
74
+ # no asymmetric padding in torch conv, must do it ourselves
75
+ self.conv = torch.nn.Conv2d(
76
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
77
+ )
78
+
79
+ def forward(self, x):
80
+ if self.with_conv:
81
+ pad = (0, 1, 0, 1)
82
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
83
+ x = self.conv(x)
84
+ else:
85
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
86
+ return x
87
+
88
+
89
+ class ResnetBlock(nn.Module):
90
+ def __init__(
91
+ self,
92
+ *,
93
+ in_channels,
94
+ out_channels=None,
95
+ conv_shortcut=False,
96
+ dropout,
97
+ temb_channels=512,
98
+ ):
99
+ super().__init__()
100
+ self.in_channels = in_channels
101
+ out_channels = in_channels if out_channels is None else out_channels
102
+ self.out_channels = out_channels
103
+ self.use_conv_shortcut = conv_shortcut
104
+
105
+ self.norm1 = Normalize(in_channels)
106
+ self.conv1 = torch.nn.Conv2d(
107
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
108
+ )
109
+ if temb_channels > 0:
110
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
111
+ self.norm2 = Normalize(out_channels)
112
+ self.dropout = torch.nn.Dropout(dropout)
113
+ self.conv2 = torch.nn.Conv2d(
114
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
115
+ )
116
+ if self.in_channels != self.out_channels:
117
+ if self.use_conv_shortcut:
118
+ self.conv_shortcut = torch.nn.Conv2d(
119
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
120
+ )
121
+ else:
122
+ self.nin_shortcut = torch.nn.Conv2d(
123
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
124
+ )
125
+
126
+ def forward(self, x, temb):
127
+ h = x
128
+ h = self.norm1(h)
129
+ h = nonlinearity(h)
130
+ h = self.conv1(h)
131
+
132
+ if temb is not None:
133
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
134
+
135
+ h = self.norm2(h)
136
+ h = nonlinearity(h)
137
+ h = self.dropout(h)
138
+ h = self.conv2(h)
139
+
140
+ if self.in_channels != self.out_channels:
141
+ if self.use_conv_shortcut:
142
+ x = self.conv_shortcut(x)
143
+ else:
144
+ x = self.nin_shortcut(x)
145
+
146
+ return x + h
147
+
148
+
149
+ class AttnBlock(nn.Module):
150
+ def __init__(self, in_channels):
151
+ super().__init__()
152
+ self.in_channels = in_channels
153
+
154
+ self.norm = Normalize(in_channels)
155
+ self.q = torch.nn.Conv2d(
156
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
+ )
158
+ self.k = torch.nn.Conv2d(
159
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
+ )
161
+ self.v = torch.nn.Conv2d(
162
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
+ )
164
+ self.proj_out = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b, c, h, w = q.shape
177
+ q = q.reshape(b, c, h * w)
178
+ q = q.permute(0, 2, 1) # b,hw,c
179
+ k = k.reshape(b, c, h * w) # b,c,hw
180
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
181
+ w_ = w_ * (int(c) ** (-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = v.reshape(b, c, h * w)
186
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
187
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
188
+ h_ = h_.reshape(b, c, h, w)
189
+
190
+ h_ = self.proj_out(h_)
191
+
192
+ return x + h_
193
+
194
+
195
+ class MemoryEfficientAttnBlock(nn.Module):
196
+ """
197
+ Uses xformers efficient implementation,
198
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
199
+ Note: this is a single-head self-attention operation
200
+ """
201
+
202
+ #
203
+ def __init__(self, in_channels):
204
+ super().__init__()
205
+ self.in_channels = in_channels
206
+
207
+ self.norm = Normalize(in_channels)
208
+ self.q = torch.nn.Conv2d(
209
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
210
+ )
211
+ self.k = torch.nn.Conv2d(
212
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
213
+ )
214
+ self.v = torch.nn.Conv2d(
215
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
216
+ )
217
+ self.proj_out = torch.nn.Conv2d(
218
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
219
+ )
220
+ self.attention_op: Optional[Any] = None
221
+
222
+ def forward(self, x):
223
+ h_ = x
224
+ h_ = self.norm(h_)
225
+ q = self.q(h_)
226
+ k = self.k(h_)
227
+ v = self.v(h_)
228
+
229
+ # compute attention
230
+ B, C, H, W = q.shape
231
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
232
+
233
+ q, k, v = map(
234
+ lambda t: t.unsqueeze(3)
235
+ .reshape(B, t.shape[1], 1, C)
236
+ .permute(0, 2, 1, 3)
237
+ .reshape(B * 1, t.shape[1], C)
238
+ .contiguous(),
239
+ (q, k, v),
240
+ )
241
+ out = xformers.ops.memory_efficient_attention(
242
+ q, k, v, attn_bias=None, op=self.attention_op
243
+ )
244
+
245
+ out = (
246
+ out.unsqueeze(0)
247
+ .reshape(B, 1, out.shape[1], C)
248
+ .permute(0, 2, 1, 3)
249
+ .reshape(B, out.shape[1], C)
250
+ )
251
+ out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
252
+ out = self.proj_out(out)
253
+ return x + out
254
+
255
+
256
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
257
+ def forward(self, x, context=None, mask=None):
258
+ b, c, h, w = x.shape
259
+ x = rearrange(x, "b c h w -> b (h w) c")
260
+ out = super().forward(x, context=context, mask=mask)
261
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
262
+ return x + out
263
+
264
+
265
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
266
+ assert attn_type in [
267
+ "vanilla",
268
+ "vanilla-xformers",
269
+ "memory-efficient-cross-attn",
270
+ "linear",
271
+ "none",
272
+ ], f"attn_type {attn_type} unknown"
273
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
274
+ attn_type = "vanilla-xformers"
275
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
276
+ if attn_type == "vanilla":
277
+ assert attn_kwargs is None
278
+ return AttnBlock(in_channels)
279
+ elif attn_type == "vanilla-xformers":
280
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
281
+ return MemoryEfficientAttnBlock(in_channels)
282
+ elif type == "memory-efficient-cross-attn":
283
+ attn_kwargs["query_dim"] = in_channels
284
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
285
+ elif attn_type == "none":
286
+ return nn.Identity(in_channels)
287
+ else:
288
+ raise NotImplementedError()
289
+
290
+
291
+ class Model(nn.Module):
292
+ def __init__(
293
+ self,
294
+ *,
295
+ ch,
296
+ out_ch,
297
+ ch_mult=(1, 2, 4, 8),
298
+ num_res_blocks,
299
+ attn_resolutions,
300
+ dropout=0.0,
301
+ resamp_with_conv=True,
302
+ in_channels,
303
+ resolution,
304
+ use_timestep=True,
305
+ use_linear_attn=False,
306
+ attn_type="vanilla",
307
+ ):
308
+ super().__init__()
309
+ if use_linear_attn:
310
+ attn_type = "linear"
311
+ self.ch = ch
312
+ self.temb_ch = self.ch * 4
313
+ self.num_resolutions = len(ch_mult)
314
+ self.num_res_blocks = num_res_blocks
315
+ self.resolution = resolution
316
+ self.in_channels = in_channels
317
+
318
+ self.use_timestep = use_timestep
319
+ if self.use_timestep:
320
+ # timestep embedding
321
+ self.temb = nn.Module()
322
+ self.temb.dense = nn.ModuleList(
323
+ [
324
+ torch.nn.Linear(self.ch, self.temb_ch),
325
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
326
+ ]
327
+ )
328
+
329
+ # downsampling
330
+ self.conv_in = torch.nn.Conv2d(
331
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
332
+ )
333
+
334
+ curr_res = resolution
335
+ in_ch_mult = (1,) + tuple(ch_mult)
336
+ self.down = nn.ModuleList()
337
+ for i_level in range(self.num_resolutions):
338
+ block = nn.ModuleList()
339
+ attn = nn.ModuleList()
340
+ block_in = ch * in_ch_mult[i_level]
341
+ block_out = ch * ch_mult[i_level]
342
+ for i_block in range(self.num_res_blocks):
343
+ block.append(
344
+ ResnetBlock(
345
+ in_channels=block_in,
346
+ out_channels=block_out,
347
+ temb_channels=self.temb_ch,
348
+ dropout=dropout,
349
+ )
350
+ )
351
+ block_in = block_out
352
+ if curr_res in attn_resolutions:
353
+ attn.append(make_attn(block_in, attn_type=attn_type))
354
+ down = nn.Module()
355
+ down.block = block
356
+ down.attn = attn
357
+ if i_level != self.num_resolutions - 1:
358
+ down.downsample = Downsample(block_in, resamp_with_conv)
359
+ curr_res = curr_res // 2
360
+ self.down.append(down)
361
+
362
+ # middle
363
+ self.mid = nn.Module()
364
+ self.mid.block_1 = ResnetBlock(
365
+ in_channels=block_in,
366
+ out_channels=block_in,
367
+ temb_channels=self.temb_ch,
368
+ dropout=dropout,
369
+ )
370
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
371
+ self.mid.block_2 = ResnetBlock(
372
+ in_channels=block_in,
373
+ out_channels=block_in,
374
+ temb_channels=self.temb_ch,
375
+ dropout=dropout,
376
+ )
377
+
378
+ # upsampling
379
+ self.up = nn.ModuleList()
380
+ for i_level in reversed(range(self.num_resolutions)):
381
+ block = nn.ModuleList()
382
+ attn = nn.ModuleList()
383
+ block_out = ch * ch_mult[i_level]
384
+ skip_in = ch * ch_mult[i_level]
385
+ for i_block in range(self.num_res_blocks + 1):
386
+ if i_block == self.num_res_blocks:
387
+ skip_in = ch * in_ch_mult[i_level]
388
+ block.append(
389
+ ResnetBlock(
390
+ in_channels=block_in + skip_in,
391
+ out_channels=block_out,
392
+ temb_channels=self.temb_ch,
393
+ dropout=dropout,
394
+ )
395
+ )
396
+ block_in = block_out
397
+ if curr_res in attn_resolutions:
398
+ attn.append(make_attn(block_in, attn_type=attn_type))
399
+ up = nn.Module()
400
+ up.block = block
401
+ up.attn = attn
402
+ if i_level != 0:
403
+ up.upsample = Upsample(block_in, resamp_with_conv)
404
+ curr_res = curr_res * 2
405
+ self.up.insert(0, up) # prepend to get consistent order
406
+
407
+ # end
408
+ self.norm_out = Normalize(block_in)
409
+ self.conv_out = torch.nn.Conv2d(
410
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
411
+ )
412
+
413
+ def forward(self, x, t=None, context=None):
414
+ # assert x.shape[2] == x.shape[3] == self.resolution
415
+ if context is not None:
416
+ # assume aligned context, cat along channel axis
417
+ x = torch.cat((x, context), dim=1)
418
+ if self.use_timestep:
419
+ # timestep embedding
420
+ assert t is not None
421
+ temb = get_timestep_embedding(t, self.ch)
422
+ temb = self.temb.dense[0](temb)
423
+ temb = nonlinearity(temb)
424
+ temb = self.temb.dense[1](temb)
425
+ else:
426
+ temb = None
427
+
428
+ # downsampling
429
+ hs = [self.conv_in(x)]
430
+ for i_level in range(self.num_resolutions):
431
+ for i_block in range(self.num_res_blocks):
432
+ h = self.down[i_level].block[i_block](hs[-1], temb)
433
+ if len(self.down[i_level].attn) > 0:
434
+ h = self.down[i_level].attn[i_block](h)
435
+ hs.append(h)
436
+ if i_level != self.num_resolutions - 1:
437
+ hs.append(self.down[i_level].downsample(hs[-1]))
438
+
439
+ # middle
440
+ h = hs[-1]
441
+ h = self.mid.block_1(h, temb)
442
+ h = self.mid.attn_1(h)
443
+ h = self.mid.block_2(h, temb)
444
+
445
+ # upsampling
446
+ for i_level in reversed(range(self.num_resolutions)):
447
+ for i_block in range(self.num_res_blocks + 1):
448
+ h = self.up[i_level].block[i_block](
449
+ torch.cat([h, hs.pop()], dim=1), temb
450
+ )
451
+ if len(self.up[i_level].attn) > 0:
452
+ h = self.up[i_level].attn[i_block](h)
453
+ if i_level != 0:
454
+ h = self.up[i_level].upsample(h)
455
+
456
+ # end
457
+ h = self.norm_out(h)
458
+ h = nonlinearity(h)
459
+ h = self.conv_out(h)
460
+ return h
461
+
462
+ def get_last_layer(self):
463
+ return self.conv_out.weight
464
+
465
+
466
+ class Encoder(nn.Module):
467
+ def __init__(
468
+ self,
469
+ *,
470
+ ch,
471
+ out_ch,
472
+ ch_mult=(1, 2, 4, 8),
473
+ num_res_blocks,
474
+ attn_resolutions,
475
+ dropout=0.0,
476
+ resamp_with_conv=True,
477
+ in_channels,
478
+ resolution,
479
+ z_channels,
480
+ double_z=True,
481
+ use_linear_attn=False,
482
+ attn_type="vanilla",
483
+ **ignore_kwargs,
484
+ ):
485
+ super().__init__()
486
+ if use_linear_attn:
487
+ attn_type = "linear"
488
+ self.ch = ch
489
+ self.temb_ch = 0
490
+ self.num_resolutions = len(ch_mult)
491
+ self.num_res_blocks = num_res_blocks
492
+ self.resolution = resolution
493
+ self.in_channels = in_channels
494
+
495
+ # downsampling
496
+ self.conv_in = torch.nn.Conv2d(
497
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
498
+ )
499
+
500
+ curr_res = resolution
501
+ in_ch_mult = (1,) + tuple(ch_mult)
502
+ self.in_ch_mult = in_ch_mult
503
+ self.down = nn.ModuleList()
504
+ for i_level in range(self.num_resolutions):
505
+ block = nn.ModuleList()
506
+ attn = nn.ModuleList()
507
+ block_in = ch * in_ch_mult[i_level]
508
+ block_out = ch * ch_mult[i_level]
509
+ for i_block in range(self.num_res_blocks):
510
+ block.append(
511
+ ResnetBlock(
512
+ in_channels=block_in,
513
+ out_channels=block_out,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout,
516
+ )
517
+ )
518
+ block_in = block_out
519
+ if curr_res in attn_resolutions:
520
+ attn.append(make_attn(block_in, attn_type=attn_type))
521
+ down = nn.Module()
522
+ down.block = block
523
+ down.attn = attn
524
+ if i_level != self.num_resolutions - 1:
525
+ down.downsample = Downsample(block_in, resamp_with_conv)
526
+ curr_res = curr_res // 2
527
+ self.down.append(down)
528
+
529
+ # middle
530
+ self.mid = nn.Module()
531
+ self.mid.block_1 = ResnetBlock(
532
+ in_channels=block_in,
533
+ out_channels=block_in,
534
+ temb_channels=self.temb_ch,
535
+ dropout=dropout,
536
+ )
537
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
538
+ self.mid.block_2 = ResnetBlock(
539
+ in_channels=block_in,
540
+ out_channels=block_in,
541
+ temb_channels=self.temb_ch,
542
+ dropout=dropout,
543
+ )
544
+
545
+ # end
546
+ self.norm_out = Normalize(block_in)
547
+ self.conv_out = torch.nn.Conv2d(
548
+ block_in,
549
+ 2 * z_channels if double_z else z_channels,
550
+ kernel_size=3,
551
+ stride=1,
552
+ padding=1,
553
+ )
554
+
555
+ def forward(self, x):
556
+ # timestep embedding
557
+ temb = None
558
+
559
+ # downsampling
560
+ hs = [self.conv_in(x)]
561
+ for i_level in range(self.num_resolutions):
562
+ for i_block in range(self.num_res_blocks):
563
+ h = self.down[i_level].block[i_block](hs[-1], temb)
564
+ if len(self.down[i_level].attn) > 0:
565
+ h = self.down[i_level].attn[i_block](h)
566
+ hs.append(h)
567
+ if i_level != self.num_resolutions - 1:
568
+ hs.append(self.down[i_level].downsample(hs[-1]))
569
+
570
+ # middle
571
+ h = hs[-1]
572
+ h = self.mid.block_1(h, temb)
573
+ h = self.mid.attn_1(h)
574
+ h = self.mid.block_2(h, temb)
575
+
576
+ # end
577
+ h = self.norm_out(h)
578
+ h = nonlinearity(h)
579
+ h = self.conv_out(h)
580
+ return h
581
+
582
+
583
+ class Decoder(nn.Module):
584
+ def __init__(
585
+ self,
586
+ *,
587
+ ch,
588
+ out_ch,
589
+ ch_mult=(1, 2, 4, 8),
590
+ num_res_blocks,
591
+ attn_resolutions,
592
+ dropout=0.0,
593
+ resamp_with_conv=True,
594
+ in_channels,
595
+ resolution,
596
+ z_channels,
597
+ give_pre_end=False,
598
+ tanh_out=False,
599
+ use_linear_attn=False,
600
+ attn_type="vanilla",
601
+ **ignorekwargs,
602
+ ):
603
+ super().__init__()
604
+ if use_linear_attn:
605
+ attn_type = "linear"
606
+ self.ch = ch
607
+ self.temb_ch = 0
608
+ self.num_resolutions = len(ch_mult)
609
+ self.num_res_blocks = num_res_blocks
610
+ self.resolution = resolution
611
+ self.in_channels = in_channels
612
+ self.give_pre_end = give_pre_end
613
+ self.tanh_out = tanh_out
614
+
615
+ # compute in_ch_mult, block_in and curr_res at lowest res
616
+ in_ch_mult = (1,) + tuple(ch_mult)
617
+ block_in = ch * ch_mult[self.num_resolutions - 1]
618
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
619
+ self.z_shape = (1, z_channels, curr_res, curr_res)
620
+ print(
621
+ "Working with z of shape {} = {} dimensions.".format(
622
+ self.z_shape, np.prod(self.z_shape)
623
+ )
624
+ )
625
+
626
+ # z to block_in
627
+ self.conv_in = torch.nn.Conv2d(
628
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
629
+ )
630
+
631
+ # middle
632
+ self.mid = nn.Module()
633
+ self.mid.block_1 = ResnetBlock(
634
+ in_channels=block_in,
635
+ out_channels=block_in,
636
+ temb_channels=self.temb_ch,
637
+ dropout=dropout,
638
+ )
639
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
640
+ self.mid.block_2 = ResnetBlock(
641
+ in_channels=block_in,
642
+ out_channels=block_in,
643
+ temb_channels=self.temb_ch,
644
+ dropout=dropout,
645
+ )
646
+
647
+ # upsampling
648
+ self.up = nn.ModuleList()
649
+ for i_level in reversed(range(self.num_resolutions)):
650
+ block = nn.ModuleList()
651
+ attn = nn.ModuleList()
652
+ block_out = ch * ch_mult[i_level]
653
+ for i_block in range(self.num_res_blocks + 1):
654
+ block.append(
655
+ ResnetBlock(
656
+ in_channels=block_in,
657
+ out_channels=block_out,
658
+ temb_channels=self.temb_ch,
659
+ dropout=dropout,
660
+ )
661
+ )
662
+ block_in = block_out
663
+ if curr_res in attn_resolutions:
664
+ attn.append(make_attn(block_in, attn_type=attn_type))
665
+ up = nn.Module()
666
+ up.block = block
667
+ up.attn = attn
668
+ if i_level != 0:
669
+ up.upsample = Upsample(block_in, resamp_with_conv)
670
+ curr_res = curr_res * 2
671
+ self.up.insert(0, up) # prepend to get consistent order
672
+
673
+ # end
674
+ self.norm_out = Normalize(block_in)
675
+ self.conv_out = torch.nn.Conv2d(
676
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
677
+ )
678
+
679
+ def forward(self, z):
680
+ # assert z.shape[1:] == self.z_shape[1:]
681
+ self.last_z_shape = z.shape
682
+
683
+ # timestep embedding
684
+ temb = None
685
+
686
+ # z to block_in
687
+ h = self.conv_in(z)
688
+
689
+ # middle
690
+ h = self.mid.block_1(h, temb)
691
+ h = self.mid.attn_1(h)
692
+ h = self.mid.block_2(h, temb)
693
+
694
+ # upsampling
695
+ for i_level in reversed(range(self.num_resolutions)):
696
+ for i_block in range(self.num_res_blocks + 1):
697
+ h = self.up[i_level].block[i_block](h, temb)
698
+ if len(self.up[i_level].attn) > 0:
699
+ h = self.up[i_level].attn[i_block](h)
700
+ if i_level != 0:
701
+ h = self.up[i_level].upsample(h)
702
+
703
+ # end
704
+ if self.give_pre_end:
705
+ return h
706
+
707
+ h = self.norm_out(h)
708
+ h = nonlinearity(h)
709
+ h = self.conv_out(h)
710
+ if self.tanh_out:
711
+ h = torch.tanh(h)
712
+ return h
713
+
714
+
715
+ class SimpleDecoder(nn.Module):
716
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
717
+ super().__init__()
718
+ self.model = nn.ModuleList(
719
+ [
720
+ nn.Conv2d(in_channels, in_channels, 1),
721
+ ResnetBlock(
722
+ in_channels=in_channels,
723
+ out_channels=2 * in_channels,
724
+ temb_channels=0,
725
+ dropout=0.0,
726
+ ),
727
+ ResnetBlock(
728
+ in_channels=2 * in_channels,
729
+ out_channels=4 * in_channels,
730
+ temb_channels=0,
731
+ dropout=0.0,
732
+ ),
733
+ ResnetBlock(
734
+ in_channels=4 * in_channels,
735
+ out_channels=2 * in_channels,
736
+ temb_channels=0,
737
+ dropout=0.0,
738
+ ),
739
+ nn.Conv2d(2 * in_channels, in_channels, 1),
740
+ Upsample(in_channels, with_conv=True),
741
+ ]
742
+ )
743
+ # end
744
+ self.norm_out = Normalize(in_channels)
745
+ self.conv_out = torch.nn.Conv2d(
746
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
747
+ )
748
+
749
+ def forward(self, x):
750
+ for i, layer in enumerate(self.model):
751
+ if i in [1, 2, 3]:
752
+ x = layer(x, None)
753
+ else:
754
+ x = layer(x)
755
+
756
+ h = self.norm_out(x)
757
+ h = nonlinearity(h)
758
+ x = self.conv_out(h)
759
+ return x
760
+
761
+
762
+ class UpsampleDecoder(nn.Module):
763
+ def __init__(
764
+ self,
765
+ in_channels,
766
+ out_channels,
767
+ ch,
768
+ num_res_blocks,
769
+ resolution,
770
+ ch_mult=(2, 2),
771
+ dropout=0.0,
772
+ ):
773
+ super().__init__()
774
+ # upsampling
775
+ self.temb_ch = 0
776
+ self.num_resolutions = len(ch_mult)
777
+ self.num_res_blocks = num_res_blocks
778
+ block_in = in_channels
779
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
780
+ self.res_blocks = nn.ModuleList()
781
+ self.upsample_blocks = nn.ModuleList()
782
+ for i_level in range(self.num_resolutions):
783
+ res_block = []
784
+ block_out = ch * ch_mult[i_level]
785
+ for i_block in range(self.num_res_blocks + 1):
786
+ res_block.append(
787
+ ResnetBlock(
788
+ in_channels=block_in,
789
+ out_channels=block_out,
790
+ temb_channels=self.temb_ch,
791
+ dropout=dropout,
792
+ )
793
+ )
794
+ block_in = block_out
795
+ self.res_blocks.append(nn.ModuleList(res_block))
796
+ if i_level != self.num_resolutions - 1:
797
+ self.upsample_blocks.append(Upsample(block_in, True))
798
+ curr_res = curr_res * 2
799
+
800
+ # end
801
+ self.norm_out = Normalize(block_in)
802
+ self.conv_out = torch.nn.Conv2d(
803
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
804
+ )
805
+
806
+ def forward(self, x):
807
+ # upsampling
808
+ h = x
809
+ for k, i_level in enumerate(range(self.num_resolutions)):
810
+ for i_block in range(self.num_res_blocks + 1):
811
+ h = self.res_blocks[i_level][i_block](h, None)
812
+ if i_level != self.num_resolutions - 1:
813
+ h = self.upsample_blocks[k](h)
814
+ h = self.norm_out(h)
815
+ h = nonlinearity(h)
816
+ h = self.conv_out(h)
817
+ return h
818
+
819
+
820
+ class LatentRescaler(nn.Module):
821
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
822
+ super().__init__()
823
+ # residual block, interpolate, residual block
824
+ self.factor = factor
825
+ self.conv_in = nn.Conv2d(
826
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
827
+ )
828
+ self.res_block1 = nn.ModuleList(
829
+ [
830
+ ResnetBlock(
831
+ in_channels=mid_channels,
832
+ out_channels=mid_channels,
833
+ temb_channels=0,
834
+ dropout=0.0,
835
+ )
836
+ for _ in range(depth)
837
+ ]
838
+ )
839
+ self.attn = AttnBlock(mid_channels)
840
+ self.res_block2 = nn.ModuleList(
841
+ [
842
+ ResnetBlock(
843
+ in_channels=mid_channels,
844
+ out_channels=mid_channels,
845
+ temb_channels=0,
846
+ dropout=0.0,
847
+ )
848
+ for _ in range(depth)
849
+ ]
850
+ )
851
+
852
+ self.conv_out = nn.Conv2d(
853
+ mid_channels,
854
+ out_channels,
855
+ kernel_size=1,
856
+ )
857
+
858
+ def forward(self, x):
859
+ x = self.conv_in(x)
860
+ for block in self.res_block1:
861
+ x = block(x, None)
862
+ x = torch.nn.functional.interpolate(
863
+ x,
864
+ size=(
865
+ int(round(x.shape[2] * self.factor)),
866
+ int(round(x.shape[3] * self.factor)),
867
+ ),
868
+ )
869
+ x = self.attn(x)
870
+ for block in self.res_block2:
871
+ x = block(x, None)
872
+ x = self.conv_out(x)
873
+ return x
874
+
875
+
876
+ class MergedRescaleEncoder(nn.Module):
877
+ def __init__(
878
+ self,
879
+ in_channels,
880
+ ch,
881
+ resolution,
882
+ out_ch,
883
+ num_res_blocks,
884
+ attn_resolutions,
885
+ dropout=0.0,
886
+ resamp_with_conv=True,
887
+ ch_mult=(1, 2, 4, 8),
888
+ rescale_factor=1.0,
889
+ rescale_module_depth=1,
890
+ ):
891
+ super().__init__()
892
+ intermediate_chn = ch * ch_mult[-1]
893
+ self.encoder = Encoder(
894
+ in_channels=in_channels,
895
+ num_res_blocks=num_res_blocks,
896
+ ch=ch,
897
+ ch_mult=ch_mult,
898
+ z_channels=intermediate_chn,
899
+ double_z=False,
900
+ resolution=resolution,
901
+ attn_resolutions=attn_resolutions,
902
+ dropout=dropout,
903
+ resamp_with_conv=resamp_with_conv,
904
+ out_ch=None,
905
+ )
906
+ self.rescaler = LatentRescaler(
907
+ factor=rescale_factor,
908
+ in_channels=intermediate_chn,
909
+ mid_channels=intermediate_chn,
910
+ out_channels=out_ch,
911
+ depth=rescale_module_depth,
912
+ )
913
+
914
+ def forward(self, x):
915
+ x = self.encoder(x)
916
+ x = self.rescaler(x)
917
+ return x
918
+
919
+
920
+ class MergedRescaleDecoder(nn.Module):
921
+ def __init__(
922
+ self,
923
+ z_channels,
924
+ out_ch,
925
+ resolution,
926
+ num_res_blocks,
927
+ attn_resolutions,
928
+ ch,
929
+ ch_mult=(1, 2, 4, 8),
930
+ dropout=0.0,
931
+ resamp_with_conv=True,
932
+ rescale_factor=1.0,
933
+ rescale_module_depth=1,
934
+ ):
935
+ super().__init__()
936
+ tmp_chn = z_channels * ch_mult[-1]
937
+ self.decoder = Decoder(
938
+ out_ch=out_ch,
939
+ z_channels=tmp_chn,
940
+ attn_resolutions=attn_resolutions,
941
+ dropout=dropout,
942
+ resamp_with_conv=resamp_with_conv,
943
+ in_channels=None,
944
+ num_res_blocks=num_res_blocks,
945
+ ch_mult=ch_mult,
946
+ resolution=resolution,
947
+ ch=ch,
948
+ )
949
+ self.rescaler = LatentRescaler(
950
+ factor=rescale_factor,
951
+ in_channels=z_channels,
952
+ mid_channels=tmp_chn,
953
+ out_channels=tmp_chn,
954
+ depth=rescale_module_depth,
955
+ )
956
+
957
+ def forward(self, x):
958
+ x = self.rescaler(x)
959
+ x = self.decoder(x)
960
+ return x
961
+
962
+
963
+ class Upsampler(nn.Module):
964
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
965
+ super().__init__()
966
+ assert out_size >= in_size
967
+ num_blocks = int(np.log2(out_size // in_size)) + 1
968
+ factor_up = 1.0 + (out_size % in_size)
969
+ print(
970
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
971
+ )
972
+ self.rescaler = LatentRescaler(
973
+ factor=factor_up,
974
+ in_channels=in_channels,
975
+ mid_channels=2 * in_channels,
976
+ out_channels=in_channels,
977
+ )
978
+ self.decoder = Decoder(
979
+ out_ch=out_channels,
980
+ resolution=out_size,
981
+ z_channels=in_channels,
982
+ num_res_blocks=2,
983
+ attn_resolutions=[],
984
+ in_channels=None,
985
+ ch=in_channels,
986
+ ch_mult=[ch_mult for _ in range(num_blocks)],
987
+ )
988
+
989
+ def forward(self, x):
990
+ x = self.rescaler(x)
991
+ x = self.decoder(x)
992
+ return x
993
+
994
+
995
+ class Resize(nn.Module):
996
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
997
+ super().__init__()
998
+ self.with_conv = learned
999
+ self.mode = mode
1000
+ if self.with_conv:
1001
+ print(
1002
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
1003
+ )
1004
+ raise NotImplementedError()
1005
+ assert in_channels is not None
1006
+ # no asymmetric padding in torch conv, must do it ourselves
1007
+ self.conv = torch.nn.Conv2d(
1008
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
1009
+ )
1010
+
1011
+ def forward(self, x, scale_factor=1.0):
1012
+ if scale_factor == 1.0:
1013
+ return x
1014
+ else:
1015
+ x = torch.nn.functional.interpolate(
1016
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
1017
+ )
1018
+ return x
imagedream/ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,1135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from imagedream.ldm.modules.diffusionmodules.util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ convert_module_to_f16,
20
+ convert_module_to_f32
21
+ )
22
+ from imagedream.ldm.modules.attention import (
23
+ SpatialTransformer,
24
+ SpatialTransformer3D,
25
+ exists
26
+ )
27
+ from imagedream.ldm.modules.diffusionmodules.adaptors import (
28
+ Resampler,
29
+ ImageProjModel
30
+ )
31
+
32
+ ## go
33
+ class AttentionPool2d(nn.Module):
34
+ """
35
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ spacial_dim: int,
41
+ embed_dim: int,
42
+ num_heads_channels: int,
43
+ output_dim: int = None,
44
+ ):
45
+ super().__init__()
46
+ self.positional_embedding = nn.Parameter(
47
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
48
+ )
49
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
50
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
51
+ self.num_heads = embed_dim // num_heads_channels
52
+ self.attention = QKVAttention(self.num_heads)
53
+
54
+ def forward(self, x):
55
+ b, c, *_spatial = x.shape
56
+ x = x.reshape(b, c, -1) # NC(HW)
57
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
58
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
59
+ x = self.qkv_proj(x)
60
+ x = self.attention(x)
61
+ x = self.c_proj(x)
62
+ return x[:, :, 0]
63
+
64
+
65
+ class TimestepBlock(nn.Module):
66
+ """
67
+ Any module where forward() takes timestep embeddings as a second argument.
68
+ """
69
+
70
+ @abstractmethod
71
+ def forward(self, x, emb):
72
+ """
73
+ Apply the module to `x` given `emb` timestep embeddings.
74
+ """
75
+
76
+
77
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
78
+ """
79
+ A sequential module that passes timestep embeddings to the children that
80
+ support it as an extra input.
81
+ """
82
+
83
+ def forward(self, x, emb, context=None, num_frames=1):
84
+ for layer in self:
85
+ if isinstance(layer, TimestepBlock):
86
+ x = layer(x, emb)
87
+ elif isinstance(layer, SpatialTransformer3D):
88
+ x = layer(x, context, num_frames=num_frames)
89
+ elif isinstance(layer, SpatialTransformer):
90
+ x = layer(x, context)
91
+ else:
92
+ x = layer(x)
93
+ return x
94
+
95
+
96
+ class Upsample(nn.Module):
97
+ """
98
+ An upsampling layer with an optional convolution.
99
+ :param channels: channels in the inputs and outputs.
100
+ :param use_conv: a bool determining if a convolution is applied.
101
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
102
+ upsampling occurs in the inner-two dimensions.
103
+ """
104
+
105
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
106
+ super().__init__()
107
+ self.channels = channels
108
+ self.out_channels = out_channels or channels
109
+ self.use_conv = use_conv
110
+ self.dims = dims
111
+ if use_conv:
112
+ self.conv = conv_nd(
113
+ dims, self.channels, self.out_channels, 3, padding=padding
114
+ )
115
+
116
+ def forward(self, x):
117
+ assert x.shape[1] == self.channels
118
+ if self.dims == 3:
119
+ x = F.interpolate(
120
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
121
+ )
122
+ else:
123
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
124
+ if self.use_conv:
125
+ x = self.conv(x)
126
+ return x
127
+
128
+
129
+ class TransposedUpsample(nn.Module):
130
+ "Learned 2x upsampling without padding"
131
+
132
+ def __init__(self, channels, out_channels=None, ks=5):
133
+ super().__init__()
134
+ self.channels = channels
135
+ self.out_channels = out_channels or channels
136
+
137
+ self.up = nn.ConvTranspose2d(
138
+ self.channels, self.out_channels, kernel_size=ks, stride=2
139
+ )
140
+
141
+ def forward(self, x):
142
+ return self.up(x)
143
+
144
+
145
+ class Downsample(nn.Module):
146
+ """
147
+ A downsampling layer with an optional convolution.
148
+ :param channels: channels in the inputs and outputs.
149
+ :param use_conv: a bool determining if a convolution is applied.
150
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
151
+ downsampling occurs in the inner-two dimensions.
152
+ """
153
+
154
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
155
+ super().__init__()
156
+ self.channels = channels
157
+ self.out_channels = out_channels or channels
158
+ self.use_conv = use_conv
159
+ self.dims = dims
160
+ stride = 2 if dims != 3 else (1, 2, 2)
161
+ if use_conv:
162
+ self.op = conv_nd(
163
+ dims,
164
+ self.channels,
165
+ self.out_channels,
166
+ 3,
167
+ stride=stride,
168
+ padding=padding,
169
+ )
170
+ else:
171
+ assert self.channels == self.out_channels
172
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
173
+
174
+ def forward(self, x):
175
+ assert x.shape[1] == self.channels
176
+ return self.op(x)
177
+
178
+
179
+ class ResBlock(TimestepBlock):
180
+ """
181
+ A residual block that can optionally change the number of channels.
182
+ :param channels: the number of input channels.
183
+ :param emb_channels: the number of timestep embedding channels.
184
+ :param dropout: the rate of dropout.
185
+ :param out_channels: if specified, the number of out channels.
186
+ :param use_conv: if True and out_channels is specified, use a spatial
187
+ convolution instead of a smaller 1x1 convolution to change the
188
+ channels in the skip connection.
189
+ :param dims: determines if the signal is 1D, 2D, or 3D.
190
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
191
+ :param up: if True, use this block for upsampling.
192
+ :param down: if True, use this block for downsampling.
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ channels,
198
+ emb_channels,
199
+ dropout,
200
+ out_channels=None,
201
+ use_conv=False,
202
+ use_scale_shift_norm=False,
203
+ dims=2,
204
+ use_checkpoint=False,
205
+ up=False,
206
+ down=False,
207
+ ):
208
+ super().__init__()
209
+ self.channels = channels
210
+ self.emb_channels = emb_channels
211
+ self.dropout = dropout
212
+ self.out_channels = out_channels or channels
213
+ self.use_conv = use_conv
214
+ self.use_checkpoint = use_checkpoint
215
+ self.use_scale_shift_norm = use_scale_shift_norm
216
+
217
+ self.in_layers = nn.Sequential(
218
+ normalization(channels),
219
+ nn.SiLU(),
220
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
221
+ )
222
+
223
+ self.updown = up or down
224
+
225
+ if up:
226
+ self.h_upd = Upsample(channels, False, dims)
227
+ self.x_upd = Upsample(channels, False, dims)
228
+ elif down:
229
+ self.h_upd = Downsample(channels, False, dims)
230
+ self.x_upd = Downsample(channels, False, dims)
231
+ else:
232
+ self.h_upd = self.x_upd = nn.Identity()
233
+
234
+ self.emb_layers = nn.Sequential(
235
+ nn.SiLU(),
236
+ linear(
237
+ emb_channels,
238
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
239
+ ),
240
+ )
241
+ self.out_layers = nn.Sequential(
242
+ normalization(self.out_channels),
243
+ nn.SiLU(),
244
+ nn.Dropout(p=dropout),
245
+ zero_module(
246
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
247
+ ),
248
+ )
249
+
250
+ if self.out_channels == channels:
251
+ self.skip_connection = nn.Identity()
252
+ elif use_conv:
253
+ self.skip_connection = conv_nd(
254
+ dims, channels, self.out_channels, 3, padding=1
255
+ )
256
+ else:
257
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
258
+
259
+ def forward(self, x, emb):
260
+ """
261
+ Apply the block to a Tensor, conditioned on a timestep embedding.
262
+ :param x: an [N x C x ...] Tensor of features.
263
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
264
+ :return: an [N x C x ...] Tensor of outputs.
265
+ """
266
+ return checkpoint(
267
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
268
+ )
269
+
270
+ def _forward(self, x, emb):
271
+ if self.updown:
272
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
273
+ h = in_rest(x)
274
+ h = self.h_upd(h)
275
+ x = self.x_upd(x)
276
+ h = in_conv(h)
277
+ else:
278
+ h = self.in_layers(x)
279
+ emb_out = self.emb_layers(emb).type(h.dtype)
280
+ while len(emb_out.shape) < len(h.shape):
281
+ emb_out = emb_out[..., None]
282
+ if self.use_scale_shift_norm:
283
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
284
+ scale, shift = th.chunk(emb_out, 2, dim=1)
285
+ h = out_norm(h) * (1 + scale) + shift
286
+ h = out_rest(h)
287
+ else:
288
+ h = h + emb_out
289
+ h = self.out_layers(h)
290
+ return self.skip_connection(x) + h
291
+
292
+
293
+ class AttentionBlock(nn.Module):
294
+ """
295
+ An attention block that allows spatial positions to attend to each other.
296
+ Originally ported from here, but adapted to the N-d case.
297
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ channels,
303
+ num_heads=1,
304
+ num_head_channels=-1,
305
+ use_checkpoint=False,
306
+ use_new_attention_order=False,
307
+ ):
308
+ super().__init__()
309
+ self.channels = channels
310
+ if num_head_channels == -1:
311
+ self.num_heads = num_heads
312
+ else:
313
+ assert (
314
+ channels % num_head_channels == 0
315
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
316
+ self.num_heads = channels // num_head_channels
317
+ self.use_checkpoint = use_checkpoint
318
+ self.norm = normalization(channels)
319
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
320
+ if use_new_attention_order:
321
+ # split qkv before split heads
322
+ self.attention = QKVAttention(self.num_heads)
323
+ else:
324
+ # split heads before split qkv
325
+ self.attention = QKVAttentionLegacy(self.num_heads)
326
+
327
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
328
+
329
+ def forward(self, x):
330
+ return checkpoint(
331
+ self._forward, (x,), self.parameters(), True
332
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
333
+ # return pt_checkpoint(self._forward, x) # pytorch
334
+
335
+ def _forward(self, x):
336
+ b, c, *spatial = x.shape
337
+ x = x.reshape(b, c, -1)
338
+ qkv = self.qkv(self.norm(x))
339
+ h = self.attention(qkv)
340
+ h = self.proj_out(h)
341
+ return (x + h).reshape(b, c, *spatial)
342
+
343
+
344
+ def count_flops_attn(model, _x, y):
345
+ """
346
+ A counter for the `thop` package to count the operations in an
347
+ attention operation.
348
+ Meant to be used like:
349
+ macs, params = thop.profile(
350
+ model,
351
+ inputs=(inputs, timestamps),
352
+ custom_ops={QKVAttention: QKVAttention.count_flops},
353
+ )
354
+ """
355
+ b, c, *spatial = y[0].shape
356
+ num_spatial = int(np.prod(spatial))
357
+ # We perform two matmuls with the same number of ops.
358
+ # The first computes the weight matrix, the second computes
359
+ # the combination of the value vectors.
360
+ matmul_ops = 2 * b * (num_spatial**2) * c
361
+ model.total_ops += th.DoubleTensor([matmul_ops])
362
+
363
+
364
+ class QKVAttentionLegacy(nn.Module):
365
+ """
366
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
367
+ """
368
+
369
+ def __init__(self, n_heads):
370
+ super().__init__()
371
+ self.n_heads = n_heads
372
+
373
+ def forward(self, qkv):
374
+ """
375
+ Apply QKV attention.
376
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
377
+ :return: an [N x (H * C) x T] tensor after attention.
378
+ """
379
+ bs, width, length = qkv.shape
380
+ assert width % (3 * self.n_heads) == 0
381
+ ch = width // (3 * self.n_heads)
382
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
383
+ scale = 1 / math.sqrt(math.sqrt(ch))
384
+ weight = th.einsum(
385
+ "bct,bcs->bts", q * scale, k * scale
386
+ ) # More stable with f16 than dividing afterwards
387
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
388
+ a = th.einsum("bts,bcs->bct", weight, v)
389
+ return a.reshape(bs, -1, length)
390
+
391
+ @staticmethod
392
+ def count_flops(model, _x, y):
393
+ return count_flops_attn(model, _x, y)
394
+
395
+
396
+ class QKVAttention(nn.Module):
397
+ """
398
+ A module which performs QKV attention and splits in a different order.
399
+ """
400
+
401
+ def __init__(self, n_heads):
402
+ super().__init__()
403
+ self.n_heads = n_heads
404
+
405
+ def forward(self, qkv):
406
+ """
407
+ Apply QKV attention.
408
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
409
+ :return: an [N x (H * C) x T] tensor after attention.
410
+ """
411
+ bs, width, length = qkv.shape
412
+ assert width % (3 * self.n_heads) == 0
413
+ ch = width // (3 * self.n_heads)
414
+ q, k, v = qkv.chunk(3, dim=1)
415
+ scale = 1 / math.sqrt(math.sqrt(ch))
416
+ weight = th.einsum(
417
+ "bct,bcs->bts",
418
+ (q * scale).view(bs * self.n_heads, ch, length),
419
+ (k * scale).view(bs * self.n_heads, ch, length),
420
+ ) # More stable with f16 than dividing afterwards
421
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
422
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
423
+ return a.reshape(bs, -1, length)
424
+
425
+ @staticmethod
426
+ def count_flops(model, _x, y):
427
+ return count_flops_attn(model, _x, y)
428
+
429
+
430
+ class Timestep(nn.Module):
431
+ def __init__(self, dim):
432
+ super().__init__()
433
+ self.dim = dim
434
+
435
+ def forward(self, t):
436
+ return timestep_embedding(t, self.dim)
437
+
438
+
439
+ class MultiViewUNetModel(nn.Module):
440
+ """
441
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
442
+ :param in_channels: channels in the input Tensor.
443
+ :param model_channels: base channel count for the model.
444
+ :param out_channels: channels in the output Tensor.
445
+ :param num_res_blocks: number of residual blocks per downsample.
446
+ :param attention_resolutions: a collection of downsample rates at which
447
+ attention will take place. May be a set, list, or tuple.
448
+ For example, if this contains 4, then at 4x downsampling, attention
449
+ will be used.
450
+ :param dropout: the dropout probability.
451
+ :param channel_mult: channel multiplier for each level of the UNet.
452
+ :param conv_resample: if True, use learned convolutions for upsampling and
453
+ downsampling.
454
+ :param dims: determines if the signal is 1D, 2D, or 3D.
455
+ :param num_classes: if specified (as an int), then this model will be
456
+ class-conditional with `num_classes` classes.
457
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
458
+ :param num_heads: the number of attention heads in each attention layer.
459
+ :param num_heads_channels: if specified, ignore num_heads and instead use
460
+ a fixed channel width per attention head.
461
+ :param num_heads_upsample: works with num_heads to set a different number
462
+ of heads for upsampling. Deprecated.
463
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
464
+ :param resblock_updown: use residual blocks for up/downsampling.
465
+ :param use_new_attention_order: use a different attention pattern for potentially
466
+ increased efficiency.
467
+ :param camera_dim: dimensionality of camera input.
468
+ """
469
+
470
+ def __init__(
471
+ self,
472
+ image_size,
473
+ in_channels,
474
+ model_channels,
475
+ out_channels,
476
+ num_res_blocks,
477
+ attention_resolutions,
478
+ dropout=0,
479
+ channel_mult=(1, 2, 4, 8),
480
+ conv_resample=True,
481
+ dims=2,
482
+ num_classes=None,
483
+ use_checkpoint=False,
484
+ use_fp16=False,
485
+ use_bf16=False,
486
+ num_heads=-1,
487
+ num_head_channels=-1,
488
+ num_heads_upsample=-1,
489
+ use_scale_shift_norm=False,
490
+ resblock_updown=False,
491
+ use_new_attention_order=False,
492
+ use_spatial_transformer=False, # custom transformer support
493
+ transformer_depth=1, # custom transformer support
494
+ context_dim=None, # custom transformer support
495
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
496
+ legacy=True,
497
+ disable_self_attentions=None,
498
+ num_attention_blocks=None,
499
+ disable_middle_self_attn=False,
500
+ use_linear_in_transformer=False,
501
+ adm_in_channels=None,
502
+ camera_dim=None,
503
+ with_ip=False, # wether add image prompt images
504
+ ip_dim=0, # number of extra token, 4 for global 16 for local
505
+ ip_weight=1.0, # weight for image prompt context
506
+ ip_mode="local_resample", # which mode of adaptor, global or local
507
+ ):
508
+ super().__init__()
509
+ if use_spatial_transformer:
510
+ assert (
511
+ context_dim is not None
512
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
513
+
514
+ if context_dim is not None:
515
+ assert (
516
+ use_spatial_transformer
517
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
518
+ from omegaconf.listconfig import ListConfig
519
+
520
+ if type(context_dim) == ListConfig:
521
+ context_dim = list(context_dim)
522
+
523
+ if num_heads_upsample == -1:
524
+ num_heads_upsample = num_heads
525
+
526
+ if num_heads == -1:
527
+ assert (
528
+ num_head_channels != -1
529
+ ), "Either num_heads or num_head_channels has to be set"
530
+
531
+ if num_head_channels == -1:
532
+ assert (
533
+ num_heads != -1
534
+ ), "Either num_heads or num_head_channels has to be set"
535
+
536
+ self.image_size = image_size
537
+ self.in_channels = in_channels
538
+ self.model_channels = model_channels
539
+ self.out_channels = out_channels
540
+ if isinstance(num_res_blocks, int):
541
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
542
+ else:
543
+ if len(num_res_blocks) != len(channel_mult):
544
+ raise ValueError(
545
+ "provide num_res_blocks either as an int (globally constant) or "
546
+ "as a list/tuple (per-level) with the same length as channel_mult"
547
+ )
548
+ self.num_res_blocks = num_res_blocks
549
+ if disable_self_attentions is not None:
550
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
551
+ assert len(disable_self_attentions) == len(channel_mult)
552
+ if num_attention_blocks is not None:
553
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
554
+ assert all(
555
+ map(
556
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
557
+ range(len(num_attention_blocks)),
558
+ )
559
+ )
560
+ print(
561
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
562
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
563
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
564
+ f"attention will still not be set."
565
+ )
566
+
567
+ self.attention_resolutions = attention_resolutions
568
+ self.dropout = dropout
569
+ self.channel_mult = channel_mult
570
+ self.conv_resample = conv_resample
571
+ self.num_classes = num_classes
572
+ self.use_checkpoint = use_checkpoint
573
+ self.dtype = th.float16 if use_fp16 else th.float32
574
+ self.dtype = th.bfloat16 if use_bf16 else self.dtype
575
+ self.num_heads = num_heads
576
+ self.num_head_channels = num_head_channels
577
+ self.num_heads_upsample = num_heads_upsample
578
+ self.predict_codebook_ids = n_embed is not None
579
+
580
+ self.with_ip = with_ip # wether there is image prompt
581
+ self.ip_dim = ip_dim # num of extra token, 4 for global 16 for local
582
+ self.ip_weight = ip_weight
583
+ assert ip_mode in ["global", "local_resample"]
584
+ self.ip_mode = ip_mode # which mode of adaptor
585
+
586
+ time_embed_dim = model_channels * 4
587
+ self.time_embed = nn.Sequential(
588
+ linear(model_channels, time_embed_dim),
589
+ nn.SiLU(),
590
+ linear(time_embed_dim, time_embed_dim),
591
+ )
592
+
593
+ if camera_dim is not None:
594
+ time_embed_dim = model_channels * 4
595
+ self.camera_embed = nn.Sequential(
596
+ linear(camera_dim, time_embed_dim),
597
+ nn.SiLU(),
598
+ linear(time_embed_dim, time_embed_dim),
599
+ )
600
+
601
+ if self.num_classes is not None:
602
+ if isinstance(self.num_classes, int):
603
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
604
+ elif self.num_classes == "continuous":
605
+ print("setting up linear c_adm embedding layer")
606
+ self.label_emb = nn.Linear(1, time_embed_dim)
607
+ elif self.num_classes == "sequential":
608
+ assert adm_in_channels is not None
609
+ self.label_emb = nn.Sequential(
610
+ nn.Sequential(
611
+ linear(adm_in_channels, time_embed_dim),
612
+ nn.SiLU(),
613
+ linear(time_embed_dim, time_embed_dim),
614
+ )
615
+ )
616
+ else:
617
+ raise ValueError()
618
+
619
+ if self.with_ip and (context_dim is not None) and ip_dim > 0:
620
+ if self.ip_mode == "local_resample":
621
+ # ip-adapter-plus
622
+ hidden_dim = 1280
623
+ self.image_embed = Resampler(
624
+ dim=context_dim,
625
+ depth=4,
626
+ dim_head=64,
627
+ heads=12,
628
+ num_queries=ip_dim, # num token
629
+ embedding_dim=hidden_dim,
630
+ output_dim=context_dim,
631
+ ff_mult=4,
632
+ )
633
+ elif self.ip_mode == "global":
634
+ self.image_embed = ImageProjModel(
635
+ cross_attention_dim=context_dim,
636
+ clip_extra_context_tokens=ip_dim)
637
+ else:
638
+ raise ValueError(f"{self.ip_mode} is not supported")
639
+
640
+ self.input_blocks = nn.ModuleList(
641
+ [
642
+ TimestepEmbedSequential(
643
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
644
+ )
645
+ ]
646
+ )
647
+ self._feature_size = model_channels
648
+ input_block_chans = [model_channels]
649
+ ch = model_channels
650
+ ds = 1
651
+ for level, mult in enumerate(channel_mult):
652
+ for nr in range(self.num_res_blocks[level]):
653
+ layers = [
654
+ ResBlock(
655
+ ch,
656
+ time_embed_dim,
657
+ dropout,
658
+ out_channels=mult * model_channels,
659
+ dims=dims,
660
+ use_checkpoint=use_checkpoint,
661
+ use_scale_shift_norm=use_scale_shift_norm,
662
+ )
663
+ ]
664
+ ch = mult * model_channels
665
+ if ds in attention_resolutions:
666
+ if num_head_channels == -1:
667
+ dim_head = ch // num_heads
668
+ else:
669
+ num_heads = ch // num_head_channels
670
+ dim_head = num_head_channels
671
+ if legacy:
672
+ # num_heads = 1
673
+ dim_head = (
674
+ ch // num_heads
675
+ if use_spatial_transformer
676
+ else num_head_channels
677
+ )
678
+ if exists(disable_self_attentions):
679
+ disabled_sa = disable_self_attentions[level]
680
+ else:
681
+ disabled_sa = False
682
+
683
+ if (
684
+ not exists(num_attention_blocks)
685
+ or nr < num_attention_blocks[level]
686
+ ):
687
+ layers.append(
688
+ AttentionBlock(
689
+ ch,
690
+ use_checkpoint=use_checkpoint,
691
+ num_heads=num_heads,
692
+ num_head_channels=dim_head,
693
+ use_new_attention_order=use_new_attention_order,
694
+ )
695
+ if not use_spatial_transformer
696
+ else SpatialTransformer3D(
697
+ ch,
698
+ num_heads,
699
+ dim_head,
700
+ depth=transformer_depth,
701
+ context_dim=context_dim,
702
+ disable_self_attn=disabled_sa,
703
+ use_linear=use_linear_in_transformer,
704
+ use_checkpoint=use_checkpoint,
705
+ with_ip=self.with_ip,
706
+ ip_dim=self.ip_dim,
707
+ ip_weight=self.ip_weight
708
+ )
709
+ )
710
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
711
+ self._feature_size += ch
712
+ input_block_chans.append(ch)
713
+
714
+ if level != len(channel_mult) - 1:
715
+ out_ch = ch
716
+ self.input_blocks.append(
717
+ TimestepEmbedSequential(
718
+ ResBlock(
719
+ ch,
720
+ time_embed_dim,
721
+ dropout,
722
+ out_channels=out_ch,
723
+ dims=dims,
724
+ use_checkpoint=use_checkpoint,
725
+ use_scale_shift_norm=use_scale_shift_norm,
726
+ down=True,
727
+ )
728
+ if resblock_updown
729
+ else Downsample(
730
+ ch, conv_resample, dims=dims, out_channels=out_ch
731
+ )
732
+ )
733
+ )
734
+ ch = out_ch
735
+ input_block_chans.append(ch)
736
+ ds *= 2
737
+ self._feature_size += ch
738
+
739
+ if num_head_channels == -1:
740
+ dim_head = ch // num_heads
741
+ else:
742
+ num_heads = ch // num_head_channels
743
+ dim_head = num_head_channels
744
+ if legacy:
745
+ # num_heads = 1
746
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
747
+ self.middle_block = TimestepEmbedSequential(
748
+ ResBlock(
749
+ ch,
750
+ time_embed_dim,
751
+ dropout,
752
+ dims=dims,
753
+ use_checkpoint=use_checkpoint,
754
+ use_scale_shift_norm=use_scale_shift_norm,
755
+ ),
756
+ AttentionBlock(
757
+ ch,
758
+ use_checkpoint=use_checkpoint,
759
+ num_heads=num_heads,
760
+ num_head_channels=dim_head,
761
+ use_new_attention_order=use_new_attention_order,
762
+ )
763
+ if not use_spatial_transformer
764
+ else SpatialTransformer3D( # always uses a self-attn
765
+ ch,
766
+ num_heads,
767
+ dim_head,
768
+ depth=transformer_depth,
769
+ context_dim=context_dim,
770
+ disable_self_attn=disable_middle_self_attn,
771
+ use_linear=use_linear_in_transformer,
772
+ use_checkpoint=use_checkpoint,
773
+ with_ip=self.with_ip,
774
+ ip_dim=self.ip_dim,
775
+ ip_weight=self.ip_weight
776
+ ),
777
+ ResBlock(
778
+ ch,
779
+ time_embed_dim,
780
+ dropout,
781
+ dims=dims,
782
+ use_checkpoint=use_checkpoint,
783
+ use_scale_shift_norm=use_scale_shift_norm,
784
+ ),
785
+ )
786
+ self._feature_size += ch
787
+
788
+ self.output_blocks = nn.ModuleList([])
789
+ for level, mult in list(enumerate(channel_mult))[::-1]:
790
+ for i in range(self.num_res_blocks[level] + 1):
791
+ ich = input_block_chans.pop()
792
+ layers = [
793
+ ResBlock(
794
+ ch + ich,
795
+ time_embed_dim,
796
+ dropout,
797
+ out_channels=model_channels * mult,
798
+ dims=dims,
799
+ use_checkpoint=use_checkpoint,
800
+ use_scale_shift_norm=use_scale_shift_norm,
801
+ )
802
+ ]
803
+ ch = model_channels * mult
804
+ if ds in attention_resolutions:
805
+ if num_head_channels == -1:
806
+ dim_head = ch // num_heads
807
+ else:
808
+ num_heads = ch // num_head_channels
809
+ dim_head = num_head_channels
810
+ if legacy:
811
+ # num_heads = 1
812
+ dim_head = (
813
+ ch // num_heads
814
+ if use_spatial_transformer
815
+ else num_head_channels
816
+ )
817
+ if exists(disable_self_attentions):
818
+ disabled_sa = disable_self_attentions[level]
819
+ else:
820
+ disabled_sa = False
821
+
822
+ if (
823
+ not exists(num_attention_blocks)
824
+ or i < num_attention_blocks[level]
825
+ ):
826
+ layers.append(
827
+ AttentionBlock(
828
+ ch,
829
+ use_checkpoint=use_checkpoint,
830
+ num_heads=num_heads_upsample,
831
+ num_head_channels=dim_head,
832
+ use_new_attention_order=use_new_attention_order,
833
+ )
834
+ if not use_spatial_transformer
835
+ else SpatialTransformer3D(
836
+ ch,
837
+ num_heads,
838
+ dim_head,
839
+ depth=transformer_depth,
840
+ context_dim=context_dim,
841
+ disable_self_attn=disabled_sa,
842
+ use_linear=use_linear_in_transformer,
843
+ use_checkpoint=use_checkpoint,
844
+ with_ip=self.with_ip,
845
+ ip_dim=self.ip_dim,
846
+ ip_weight=self.ip_weight
847
+ )
848
+ )
849
+ if level and i == self.num_res_blocks[level]:
850
+ out_ch = ch
851
+ layers.append(
852
+ ResBlock(
853
+ ch,
854
+ time_embed_dim,
855
+ dropout,
856
+ out_channels=out_ch,
857
+ dims=dims,
858
+ use_checkpoint=use_checkpoint,
859
+ use_scale_shift_norm=use_scale_shift_norm,
860
+ up=True,
861
+ )
862
+ if resblock_updown
863
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
864
+ )
865
+ ds //= 2
866
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
867
+ self._feature_size += ch
868
+
869
+ self.out = nn.Sequential(
870
+ normalization(ch),
871
+ nn.SiLU(),
872
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
873
+ )
874
+ if self.predict_codebook_ids:
875
+ self.id_predictor = nn.Sequential(
876
+ normalization(ch),
877
+ conv_nd(dims, model_channels, n_embed, 1),
878
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
879
+ )
880
+
881
+ def convert_to_fp16(self):
882
+ """
883
+ Convert the torso of the model to float16.
884
+ """
885
+ self.input_blocks.apply(convert_module_to_f16)
886
+ self.middle_block.apply(convert_module_to_f16)
887
+ self.output_blocks.apply(convert_module_to_f16)
888
+
889
+ def convert_to_fp32(self):
890
+ """
891
+ Convert the torso of the model to float32.
892
+ """
893
+ self.input_blocks.apply(convert_module_to_f32)
894
+ self.middle_block.apply(convert_module_to_f32)
895
+ self.output_blocks.apply(convert_module_to_f32)
896
+
897
+ def forward(
898
+ self,
899
+ x,
900
+ timesteps=None,
901
+ context=None,
902
+ y=None,
903
+ camera=None,
904
+ num_frames=1,
905
+ **kwargs,
906
+ ):
907
+ """
908
+ Apply the model to an input batch.
909
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
910
+ :param timesteps: a 1-D batch of timesteps.
911
+ :param context: a dict conditioning plugged in via crossattn
912
+ :param y: an [N] Tensor of labels, if class-conditional, default None.
913
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
914
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
915
+ """
916
+ assert (
917
+ x.shape[0] % num_frames == 0
918
+ ), "[UNet] input batch size must be dividable by num_frames!"
919
+ assert (y is not None) == (
920
+ self.num_classes is not None
921
+ ), "must specify y if and only if the model is class-conditional"
922
+
923
+ hs = []
924
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
925
+ emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
926
+
927
+ if self.num_classes is not None:
928
+ assert y.shape[0] == x.shape[0]
929
+ emb = emb + self.label_emb(y)
930
+
931
+ # Add camera embeddings
932
+ if camera is not None:
933
+ assert camera.shape[0] == emb.shape[0]
934
+ # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
935
+ emb = emb + self.camera_embed(camera)
936
+ ip = kwargs.get("ip", None)
937
+ ip_img = kwargs.get("ip_img", None)
938
+
939
+ if ip_img is not None:
940
+ x[(num_frames-1)::num_frames, :, :, :] = ip_img
941
+
942
+ if ip is not None:
943
+ ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
944
+ context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
945
+
946
+ h = x.type(self.dtype)
947
+ for module in self.input_blocks:
948
+ h = module(h, emb, context, num_frames=num_frames)
949
+ hs.append(h)
950
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
951
+ for module in self.output_blocks:
952
+ h = th.cat([h, hs.pop()], dim=1)
953
+ h = module(h, emb, context, num_frames=num_frames)
954
+ h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
955
+ if self.predict_codebook_ids: # False
956
+ return self.id_predictor(h)
957
+ else:
958
+ return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
959
+
960
+
961
+
962
+
963
+ class MultiViewUNetModelStage2(MultiViewUNetModel):
964
+ """
965
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
966
+ :param in_channels: channels in the input Tensor.
967
+ :param model_channels: base channel count for the model.
968
+ :param out_channels: channels in the output Tensor.
969
+ :param num_res_blocks: number of residual blocks per downsample.
970
+ :param attention_resolutions: a collection of downsample rates at which
971
+ attention will take place. May be a set, list, or tuple.
972
+ For example, if this contains 4, then at 4x downsampling, attention
973
+ will be used.
974
+ :param dropout: the dropout probability.
975
+ :param channel_mult: channel multiplier for each level of the UNet.
976
+ :param conv_resample: if True, use learned convolutions for upsampling and
977
+ downsampling.
978
+ :param dims: determines if the signal is 1D, 2D, or 3D.
979
+ :param num_classes: if specified (as an int), then this model will be
980
+ class-conditional with `num_classes` classes.
981
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
982
+ :param num_heads: the number of attention heads in each attention layer.
983
+ :param num_heads_channels: if specified, ignore num_heads and instead use
984
+ a fixed channel width per attention head.
985
+ :param num_heads_upsample: works with num_heads to set a different number
986
+ of heads for upsampling. Deprecated.
987
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
988
+ :param resblock_updown: use residual blocks for up/downsampling.
989
+ :param use_new_attention_order: use a different attention pattern for potentially
990
+ increased efficiency.
991
+ :param camera_dim: dimensionality of camera input.
992
+ """
993
+
994
+ def __init__(
995
+ self,
996
+ image_size,
997
+ in_channels,
998
+ model_channels,
999
+ out_channels,
1000
+ num_res_blocks,
1001
+ attention_resolutions,
1002
+ dropout=0,
1003
+ channel_mult=(1, 2, 4, 8),
1004
+ conv_resample=True,
1005
+ dims=2,
1006
+ num_classes=None,
1007
+ use_checkpoint=False,
1008
+ use_fp16=False,
1009
+ use_bf16=False,
1010
+ num_heads=-1,
1011
+ num_head_channels=-1,
1012
+ num_heads_upsample=-1,
1013
+ use_scale_shift_norm=False,
1014
+ resblock_updown=False,
1015
+ use_new_attention_order=False,
1016
+ use_spatial_transformer=False, # custom transformer support
1017
+ transformer_depth=1, # custom transformer support
1018
+ context_dim=None, # custom transformer support
1019
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
1020
+ legacy=True,
1021
+ disable_self_attentions=None,
1022
+ num_attention_blocks=None,
1023
+ disable_middle_self_attn=False,
1024
+ use_linear_in_transformer=False,
1025
+ adm_in_channels=None,
1026
+ camera_dim=None,
1027
+ with_ip=False, # wether add image prompt images
1028
+ ip_dim=0, # number of extra token, 4 for global 16 for local
1029
+ ip_weight=1.0, # weight for image prompt context
1030
+ ip_mode="local_resample", # which mode of adaptor, global or local
1031
+ ):
1032
+ super().__init__(
1033
+ image_size,
1034
+ in_channels,
1035
+ model_channels,
1036
+ out_channels,
1037
+ num_res_blocks,
1038
+ attention_resolutions,
1039
+ dropout,
1040
+ channel_mult,
1041
+ conv_resample,
1042
+ dims,
1043
+ num_classes,
1044
+ use_checkpoint,
1045
+ use_fp16,
1046
+ use_bf16,
1047
+ num_heads,
1048
+ num_head_channels,
1049
+ num_heads_upsample,
1050
+ use_scale_shift_norm,
1051
+ resblock_updown,
1052
+ use_new_attention_order,
1053
+ use_spatial_transformer,
1054
+ transformer_depth,
1055
+ context_dim,
1056
+ n_embed,
1057
+ legacy,
1058
+ disable_self_attentions,
1059
+ num_attention_blocks,
1060
+ disable_middle_self_attn,
1061
+ use_linear_in_transformer,
1062
+ adm_in_channels,
1063
+ camera_dim,
1064
+ with_ip,
1065
+ ip_dim,
1066
+ ip_weight,
1067
+ ip_mode,
1068
+ )
1069
+
1070
+ def forward(
1071
+ self,
1072
+ x,
1073
+ timesteps=None,
1074
+ context=None,
1075
+ y=None,
1076
+ camera=None,
1077
+ num_frames=1,
1078
+ **kwargs,
1079
+ ):
1080
+ """
1081
+ Apply the model to an input batch.
1082
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
1083
+ :param timesteps: a 1-D batch of timesteps.
1084
+ :param context: a dict conditioning plugged in via crossattn
1085
+ :param y: an [N] Tensor of labels, if class-conditional, default None.
1086
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
1087
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
1088
+ """
1089
+ assert (
1090
+ x.shape[0] % num_frames == 0
1091
+ ), "[UNet] input batch size must be dividable by num_frames!"
1092
+ assert (y is not None) == (
1093
+ self.num_classes is not None
1094
+ ), "must specify y if and only if the model is class-conditional"
1095
+
1096
+ hs = []
1097
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
1098
+ emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
1099
+
1100
+ if self.num_classes is not None:
1101
+ assert y.shape[0] == x.shape[0]
1102
+ emb = emb + self.label_emb(y)
1103
+
1104
+ # Add camera embeddings
1105
+ if camera is not None:
1106
+ assert camera.shape[0] == emb.shape[0]
1107
+ # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
1108
+ emb = emb + self.camera_embed(camera)
1109
+ ip = kwargs.get("ip", None)
1110
+ ip_img = kwargs.get("ip_img", None)
1111
+ pixel_images = kwargs.get("pixel_images", None)
1112
+
1113
+ if ip_img is not None:
1114
+ x[(num_frames-1)::num_frames, :, :, :] = ip_img
1115
+
1116
+ x = torch.cat((x, pixel_images), dim=1)
1117
+
1118
+ if ip is not None:
1119
+ ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
1120
+ context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
1121
+
1122
+ h = x.type(self.dtype)
1123
+ for module in self.input_blocks:
1124
+ h = module(h, emb, context, num_frames=num_frames)
1125
+ hs.append(h)
1126
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
1127
+ for module in self.output_blocks:
1128
+ h = th.cat([h, hs.pop()], dim=1)
1129
+ h = module(h, emb, context, num_frames=num_frames)
1130
+ h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
1131
+ if self.predict_codebook_ids: # False
1132
+ return self.id_predictor(h)
1133
+ else:
1134
+ return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
1135
+
imagedream/ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+ import importlib
18
+
19
+
20
+ def instantiate_from_config(config):
21
+ if not "target" in config:
22
+ if config == "__is_first_stage__":
23
+ return None
24
+ elif config == "__is_unconditional__":
25
+ return None
26
+ raise KeyError("Expected key `target` to instantiate.")
27
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
28
+
29
+
30
+ def get_obj_from_str(string, reload=False):
31
+ module, cls = string.rsplit(".", 1)
32
+ if reload:
33
+ module_imp = importlib.import_module(module)
34
+ importlib.reload(module_imp)
35
+ return getattr(importlib.import_module(module, package=None), cls)
36
+
37
+
38
+ def make_beta_schedule(
39
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
40
+ ):
41
+ if schedule == "linear":
42
+ betas = (
43
+ torch.linspace(
44
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
45
+ )
46
+ ** 2
47
+ )
48
+
49
+ elif schedule == "cosine":
50
+ timesteps = (
51
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
52
+ )
53
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
54
+ alphas = torch.cos(alphas).pow(2)
55
+ alphas = alphas / alphas[0]
56
+ betas = 1 - alphas[1:] / alphas[:-1]
57
+ betas = np.clip(betas, a_min=0, a_max=0.999)
58
+
59
+ elif schedule == "sqrt_linear":
60
+ betas = torch.linspace(
61
+ linear_start, linear_end, n_timestep, dtype=torch.float64
62
+ )
63
+ elif schedule == "sqrt":
64
+ betas = (
65
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
66
+ ** 0.5
67
+ )
68
+ else:
69
+ raise ValueError(f"schedule '{schedule}' unknown.")
70
+ return betas.numpy()
71
+
72
+ def enforce_zero_terminal_snr(betas):
73
+ betas = torch.tensor(betas) if not isinstance(betas, torch.Tensor) else betas
74
+ # Convert betas to alphas_bar_sqrt
75
+ alphas =1 - betas
76
+ alphas_bar = alphas.cumprod(0)
77
+ alphas_bar_sqrt = alphas_bar.sqrt()
78
+ # Store old values.
79
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
80
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
81
+ # Shift so last timestep is zero.
82
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
83
+ # Scale so first timestep is back to old value.
84
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
85
+ # Convert alphas_bar_sqrt to betas
86
+ alphas_bar = alphas_bar_sqrt ** 2
87
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
88
+ alphas = torch.cat ([alphas_bar[0:1], alphas])
89
+ betas = 1 - alphas
90
+ return betas
91
+
92
+
93
+ def make_ddim_timesteps(
94
+ ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
95
+ ):
96
+ if ddim_discr_method == "uniform":
97
+ c = num_ddpm_timesteps // num_ddim_timesteps
98
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
99
+ elif ddim_discr_method == "quad":
100
+ ddim_timesteps = (
101
+ (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
102
+ ).astype(int)
103
+ else:
104
+ raise NotImplementedError(
105
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
106
+ )
107
+
108
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
109
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
110
+ steps_out = ddim_timesteps + 1
111
+ if verbose:
112
+ print(f"Selected timesteps for ddim sampler: {steps_out}")
113
+ return steps_out
114
+
115
+
116
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
117
+ # select alphas for computing the variance schedule
118
+ alphas = alphacums[ddim_timesteps]
119
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
120
+
121
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
122
+ sigmas = eta * np.sqrt(
123
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
124
+ )
125
+ if verbose:
126
+ print(
127
+ f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
128
+ )
129
+ print(
130
+ f"For the chosen value of eta, which is {eta}, "
131
+ f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
132
+ )
133
+ return sigmas, alphas, alphas_prev
134
+
135
+
136
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
137
+ """
138
+ Create a beta schedule that discretizes the given alpha_t_bar function,
139
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
140
+ :param num_diffusion_timesteps: the number of betas to produce.
141
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
142
+ produces the cumulative product of (1-beta) up to that
143
+ part of the diffusion process.
144
+ :param max_beta: the maximum beta to use; use values lower than 1 to
145
+ prevent singularities.
146
+ """
147
+ betas = []
148
+ for i in range(num_diffusion_timesteps):
149
+ t1 = i / num_diffusion_timesteps
150
+ t2 = (i + 1) / num_diffusion_timesteps
151
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
152
+ return np.array(betas)
153
+
154
+
155
+ def extract_into_tensor(a, t, x_shape):
156
+ b, *_ = t.shape
157
+ out = a.gather(-1, t)
158
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
159
+
160
+
161
+ def checkpoint(func, inputs, params, flag):
162
+ """
163
+ Evaluate a function without caching intermediate activations, allowing for
164
+ reduced memory at the expense of extra compute in the backward pass.
165
+ :param func: the function to evaluate.
166
+ :param inputs: the argument sequence to pass to `func`.
167
+ :param params: a sequence of parameters `func` depends on but does not
168
+ explicitly take as arguments.
169
+ :param flag: if False, disable gradient checkpointing.
170
+ """
171
+ if flag:
172
+ args = tuple(inputs) + tuple(params)
173
+ return CheckpointFunction.apply(func, len(inputs), *args)
174
+ else:
175
+ return func(*inputs)
176
+
177
+
178
+ class CheckpointFunction(torch.autograd.Function):
179
+ @staticmethod
180
+ def forward(ctx, run_function, length, *args):
181
+ ctx.run_function = run_function
182
+ ctx.input_tensors = list(args[:length])
183
+ ctx.input_params = list(args[length:])
184
+
185
+ with torch.no_grad():
186
+ output_tensors = ctx.run_function(*ctx.input_tensors)
187
+ return output_tensors
188
+
189
+ @staticmethod
190
+ def backward(ctx, *output_grads):
191
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
192
+ with torch.enable_grad():
193
+ # Fixes a bug where the first op in run_function modifies the
194
+ # Tensor storage in place, which is not allowed for detach()'d
195
+ # Tensors.
196
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
197
+ output_tensors = ctx.run_function(*shallow_copies)
198
+ input_grads = torch.autograd.grad(
199
+ output_tensors,
200
+ ctx.input_tensors + ctx.input_params,
201
+ output_grads,
202
+ allow_unused=True,
203
+ )
204
+ del ctx.input_tensors
205
+ del ctx.input_params
206
+ del output_tensors
207
+ return (None, None) + input_grads
208
+
209
+
210
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
211
+ """
212
+ Create sinusoidal timestep embeddings.
213
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
214
+ These may be fractional.
215
+ :param dim: the dimension of the output.
216
+ :param max_period: controls the minimum frequency of the embeddings.
217
+ :return: an [N x dim] Tensor of positional embeddings.
218
+ """
219
+ if not repeat_only:
220
+ half = dim // 2
221
+ freqs = torch.exp(
222
+ -math.log(max_period)
223
+ * torch.arange(start=0, end=half, dtype=torch.float32)
224
+ / half
225
+ ).to(device=timesteps.device)
226
+ args = timesteps[:, None].float() * freqs[None]
227
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
228
+ if dim % 2:
229
+ embedding = torch.cat(
230
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
231
+ )
232
+ else:
233
+ embedding = repeat(timesteps, "b -> b d", d=dim)
234
+ # import pdb; pdb.set_trace()
235
+ return embedding
236
+
237
+
238
+ def zero_module(module):
239
+ """
240
+ Zero out the parameters of a module and return it.
241
+ """
242
+ for p in module.parameters():
243
+ p.detach().zero_()
244
+ return module
245
+
246
+
247
+ def scale_module(module, scale):
248
+ """
249
+ Scale the parameters of a module and return it.
250
+ """
251
+ for p in module.parameters():
252
+ p.detach().mul_(scale)
253
+ return module
254
+
255
+
256
+ def mean_flat(tensor):
257
+ """
258
+ Take the mean over all non-batch dimensions.
259
+ """
260
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
261
+
262
+
263
+ def normalization(channels):
264
+ """
265
+ Make a standard normalization layer.
266
+ :param channels: number of input channels.
267
+ :return: an nn.Module for normalization.
268
+ """
269
+ return GroupNorm32(32, channels)
270
+
271
+
272
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
273
+ class SiLU(nn.Module):
274
+ def forward(self, x):
275
+ return x * torch.sigmoid(x)
276
+
277
+
278
+ class GroupNorm32(nn.GroupNorm):
279
+ def forward(self, x):
280
+ return super().forward(x.float()).type(x.dtype)
281
+
282
+
283
+ def conv_nd(dims, *args, **kwargs):
284
+ """
285
+ Create a 1D, 2D, or 3D convolution module.
286
+ """
287
+ if dims == 1:
288
+ return nn.Conv1d(*args, **kwargs)
289
+ elif dims == 2:
290
+ return nn.Conv2d(*args, **kwargs)
291
+ elif dims == 3:
292
+ return nn.Conv3d(*args, **kwargs)
293
+ raise ValueError(f"unsupported dimensions: {dims}")
294
+
295
+
296
+ def linear(*args, **kwargs):
297
+ """
298
+ Create a linear module.
299
+ """
300
+ return nn.Linear(*args, **kwargs)
301
+
302
+
303
+ def avg_pool_nd(dims, *args, **kwargs):
304
+ """
305
+ Create a 1D, 2D, or 3D average pooling module.
306
+ """
307
+ if dims == 1:
308
+ return nn.AvgPool1d(*args, **kwargs)
309
+ elif dims == 2:
310
+ return nn.AvgPool2d(*args, **kwargs)
311
+ elif dims == 3:
312
+ return nn.AvgPool3d(*args, **kwargs)
313
+ raise ValueError(f"unsupported dimensions: {dims}")
314
+
315
+
316
+ class HybridConditioner(nn.Module):
317
+ def __init__(self, c_concat_config, c_crossattn_config):
318
+ super().__init__()
319
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
320
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
321
+
322
+ def forward(self, c_concat, c_crossattn):
323
+ c_concat = self.concat_conditioner(c_concat)
324
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
325
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
326
+
327
+
328
+ def noise_like(shape, device, repeat=False):
329
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
330
+ shape[0], *((1,) * (len(shape) - 1))
331
+ )
332
+ noise = lambda: torch.randn(shape, device=device)
333
+ return repeat_noise() if repeat else noise()
334
+
335
+
336
+ # dummy replace
337
+ def convert_module_to_f16(l):
338
+ """
339
+ Convert primitive modules to float16.
340
+ """
341
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
342
+ l.weight.data = l.weight.data.half()
343
+ if l.bias is not None:
344
+ l.bias.data = l.bias.data.half()
345
+
346
+ def convert_module_to_f32(l):
347
+ """
348
+ Convert primitive modules to float32, undoing convert_module_to_f16().
349
+ """
350
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
351
+ l.weight.data = l.weight.data.float()
352
+ if l.bias is not None:
353
+ l.bias.data = l.bias.data.float()
imagedream/ldm/modules/distributions/__init__.py ADDED
File without changes
imagedream/ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(
34
+ device=self.parameters.device
35
+ )
36
+
37
+ def sample(self):
38
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
+ device=self.parameters.device
40
+ )
41
+ return x
42
+
43
+ def kl(self, other=None):
44
+ if self.deterministic:
45
+ return torch.Tensor([0.0])
46
+ else:
47
+ if other is None:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
+ dim=[1, 2, 3],
51
+ )
52
+ else:
53
+ return 0.5 * torch.sum(
54
+ torch.pow(self.mean - other.mean, 2) / other.var
55
+ + self.var / other.var
56
+ - 1.0
57
+ - self.logvar
58
+ + other.logvar,
59
+ dim=[1, 2, 3],
60
+ )
61
+
62
+ def nll(self, sample, dims=[1, 2, 3]):
63
+ if self.deterministic:
64
+ return torch.Tensor([0.0])
65
+ logtwopi = np.log(2.0 * np.pi)
66
+ return 0.5 * torch.sum(
67
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
+ dim=dims,
69
+ )
70
+
71
+ def mode(self):
72
+ return self.mean
73
+
74
+
75
+ def normal_kl(mean1, logvar1, mean2, logvar2):
76
+ """
77
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
+ Compute the KL divergence between two gaussians.
79
+ Shapes are automatically broadcasted, so batches can be compared to
80
+ scalars, among other use cases.
81
+ """
82
+ tensor = None
83
+ for obj in (mean1, logvar1, mean2, logvar2):
84
+ if isinstance(obj, torch.Tensor):
85
+ tensor = obj
86
+ break
87
+ assert tensor is not None, "at least one argument must be a Tensor"
88
+
89
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
90
+ # Tensors, but it does not work for torch.exp().
91
+ logvar1, logvar2 = [
92
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
+ for x in (logvar1, logvar2)
94
+ ]
95
+
96
+ return 0.5 * (
97
+ -1.0
98
+ + logvar2
99
+ - logvar1
100
+ + torch.exp(logvar1 - logvar2)
101
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
+ )
imagedream/ldm/modules/ema.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError("Decay must be between 0 and 1")
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer(
14
+ "num_updates",
15
+ torch.tensor(0, dtype=torch.int)
16
+ if use_num_upates
17
+ else torch.tensor(-1, dtype=torch.int),
18
+ )
19
+
20
+ for name, p in model.named_parameters():
21
+ if p.requires_grad:
22
+ # remove as '.'-character is not allowed in buffers
23
+ s_name = name.replace(".", "")
24
+ self.m_name2s_name.update({name: s_name})
25
+ self.register_buffer(s_name, p.clone().detach().data)
26
+
27
+ self.collected_params = []
28
+
29
+ def reset_num_updates(self):
30
+ del self.num_updates
31
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
32
+
33
+ def forward(self, model):
34
+ decay = self.decay
35
+
36
+ if self.num_updates >= 0:
37
+ self.num_updates += 1
38
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
39
+
40
+ one_minus_decay = 1.0 - decay
41
+
42
+ with torch.no_grad():
43
+ m_param = dict(model.named_parameters())
44
+ shadow_params = dict(self.named_buffers())
45
+
46
+ for key in m_param:
47
+ if m_param[key].requires_grad:
48
+ sname = self.m_name2s_name[key]
49
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
50
+ shadow_params[sname].sub_(
51
+ one_minus_decay * (shadow_params[sname] - m_param[key])
52
+ )
53
+ else:
54
+ assert not key in self.m_name2s_name
55
+
56
+ def copy_to(self, model):
57
+ m_param = dict(model.named_parameters())
58
+ shadow_params = dict(self.named_buffers())
59
+ for key in m_param:
60
+ if m_param[key].requires_grad:
61
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
62
+ else:
63
+ assert not key in self.m_name2s_name
64
+
65
+ def store(self, parameters):
66
+ """
67
+ Save the current parameters for restoring later.
68
+ Args:
69
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
+ temporarily stored.
71
+ """
72
+ self.collected_params = [param.clone() for param in parameters]
73
+
74
+ def restore(self, parameters):
75
+ """
76
+ Restore the parameters stored with the `store` method.
77
+ Useful to validate the model with EMA parameters without affecting the
78
+ original optimization process. Store the parameters before the
79
+ `copy_to` method. After validation (or model saving), use this to
80
+ restore the former parameters.
81
+ Args:
82
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
83
+ updated with the stored parameters.
84
+ """
85
+ for c_param, param in zip(self.collected_params, parameters):
86
+ param.data.copy_(c_param.data)
imagedream/ldm/modules/encoders/__init__.py ADDED
File without changes
imagedream/ldm/modules/encoders/modules.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6
+
7
+ import numpy as np
8
+ import open_clip
9
+ from PIL import Image
10
+ from ...util import default, count_params
11
+
12
+
13
+ class AbstractEncoder(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def encode(self, *args, **kwargs):
18
+ raise NotImplementedError
19
+
20
+
21
+ class IdentityEncoder(AbstractEncoder):
22
+ def encode(self, x):
23
+ return x
24
+
25
+
26
+ class ClassEmbedder(nn.Module):
27
+ def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
28
+ super().__init__()
29
+ self.key = key
30
+ self.embedding = nn.Embedding(n_classes, embed_dim)
31
+ self.n_classes = n_classes
32
+ self.ucg_rate = ucg_rate
33
+
34
+ def forward(self, batch, key=None, disable_dropout=False):
35
+ if key is None:
36
+ key = self.key
37
+ # this is for use in crossattn
38
+ c = batch[key][:, None]
39
+ if self.ucg_rate > 0.0 and not disable_dropout:
40
+ mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
41
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
42
+ c = c.long()
43
+ c = self.embedding(c)
44
+ return c
45
+
46
+ def get_unconditional_conditioning(self, bs, device="cuda"):
47
+ uc_class = (
48
+ self.n_classes - 1
49
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
50
+ uc = torch.ones((bs,), device=device) * uc_class
51
+ uc = {self.key: uc}
52
+ return uc
53
+
54
+
55
+ def disabled_train(self, mode=True):
56
+ """Overwrite model.train with this function to make sure train/eval mode
57
+ does not change anymore."""
58
+ return self
59
+
60
+
61
+ class FrozenT5Embedder(AbstractEncoder):
62
+ """Uses the T5 transformer encoder for text"""
63
+
64
+ def __init__(
65
+ self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
66
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
67
+ super().__init__()
68
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
69
+ self.transformer = T5EncoderModel.from_pretrained(version)
70
+ self.device = device
71
+ self.max_length = max_length # TODO: typical value?
72
+ if freeze:
73
+ self.freeze()
74
+
75
+ def freeze(self):
76
+ self.transformer = self.transformer.eval()
77
+ # self.train = disabled_train
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, text):
82
+ batch_encoding = self.tokenizer(
83
+ text,
84
+ truncation=True,
85
+ max_length=self.max_length,
86
+ return_length=True,
87
+ return_overflowing_tokens=False,
88
+ padding="max_length",
89
+ return_tensors="pt",
90
+ )
91
+ tokens = batch_encoding["input_ids"].to(self.device)
92
+ outputs = self.transformer(input_ids=tokens)
93
+
94
+ z = outputs.last_hidden_state
95
+ return z
96
+
97
+ def encode(self, text):
98
+ return self(text)
99
+
100
+
101
+ class FrozenCLIPEmbedder(AbstractEncoder):
102
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
103
+
104
+ LAYERS = ["last", "pooled", "hidden"]
105
+
106
+ def __init__(
107
+ self,
108
+ version="openai/clip-vit-large-patch14",
109
+ device="cuda",
110
+ max_length=77,
111
+ freeze=True,
112
+ layer="last",
113
+ layer_idx=None,
114
+ ): # clip-vit-base-patch32
115
+ super().__init__()
116
+ assert layer in self.LAYERS
117
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
118
+ self.transformer = CLIPTextModel.from_pretrained(version)
119
+ self.device = device
120
+ self.max_length = max_length
121
+ if freeze:
122
+ self.freeze()
123
+ self.layer = layer
124
+ self.layer_idx = layer_idx
125
+ if layer == "hidden":
126
+ assert layer_idx is not None
127
+ assert 0 <= abs(layer_idx) <= 12
128
+
129
+ def freeze(self):
130
+ self.transformer = self.transformer.eval()
131
+ # self.train = disabled_train
132
+ for param in self.parameters():
133
+ param.requires_grad = False
134
+
135
+ def forward(self, text):
136
+ batch_encoding = self.tokenizer(
137
+ text,
138
+ truncation=True,
139
+ max_length=self.max_length,
140
+ return_length=True,
141
+ return_overflowing_tokens=False,
142
+ padding="max_length",
143
+ return_tensors="pt",
144
+ )
145
+ tokens = batch_encoding["input_ids"].to(self.device)
146
+ outputs = self.transformer(
147
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
148
+ )
149
+ if self.layer == "last":
150
+ z = outputs.last_hidden_state
151
+ elif self.layer == "pooled":
152
+ z = outputs.pooler_output[:, None, :]
153
+ else:
154
+ z = outputs.hidden_states[self.layer_idx]
155
+ return z
156
+
157
+ def encode(self, text):
158
+ return self(text)
159
+
160
+
161
+ class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module):
162
+ """
163
+ Uses the OpenCLIP transformer encoder for text
164
+ """
165
+
166
+ LAYERS = [
167
+ # "pooled",
168
+ "last",
169
+ "penultimate",
170
+ ]
171
+
172
+ def __init__(
173
+ self,
174
+ arch="ViT-H-14",
175
+ version="laion2b_s32b_b79k",
176
+ device="cuda",
177
+ max_length=77,
178
+ freeze=True,
179
+ layer="last",
180
+ ip_mode=None
181
+ ):
182
+ """_summary_
183
+
184
+ Args:
185
+ ip_mode (str, optional): what is the image promcessing mode. Defaults to None.
186
+
187
+ """
188
+ super().__init__()
189
+ assert layer in self.LAYERS
190
+ model, _, preprocess = open_clip.create_model_and_transforms(
191
+ arch, device=torch.device("cpu"), pretrained=version
192
+ )
193
+ if ip_mode is None:
194
+ del model.visual
195
+
196
+ self.model = model
197
+ self.preprocess = preprocess
198
+ self.device = device
199
+ self.max_length = max_length
200
+ self.ip_mode = ip_mode
201
+ if freeze:
202
+ self.freeze()
203
+ self.layer = layer
204
+ if self.layer == "last":
205
+ self.layer_idx = 0
206
+ elif self.layer == "penultimate":
207
+ self.layer_idx = 1
208
+ else:
209
+ raise NotImplementedError()
210
+
211
+ def freeze(self):
212
+ self.model = self.model.eval()
213
+ for param in self.parameters():
214
+ param.requires_grad = False
215
+
216
+ def forward(self, text):
217
+ tokens = open_clip.tokenize(text)
218
+ z = self.encode_with_transformer(tokens.to(self.device))
219
+ return z
220
+
221
+ def forward_image(self, pil_image):
222
+ if isinstance(pil_image, Image.Image):
223
+ pil_image = [pil_image]
224
+ if isinstance(pil_image, torch.Tensor):
225
+ pil_image = pil_image.cpu().numpy()
226
+ if isinstance(pil_image, np.ndarray):
227
+ if pil_image.ndim == 3:
228
+ pil_image = pil_image[None, :, :, :]
229
+ pil_image = [Image.fromarray(x) for x in pil_image]
230
+
231
+ images = []
232
+ for image in pil_image:
233
+ images.append(self.preprocess(image).to(self.device))
234
+
235
+ image = torch.stack(images, 0) # to [b, 3, h, w]
236
+ if self.ip_mode == "global":
237
+ image_features = self.model.encode_image(image)
238
+ image_features /= image_features.norm(dim=-1, keepdim=True)
239
+ elif "local" in self.ip_mode:
240
+ image_features = self.encode_image_with_transformer(image)
241
+
242
+ return image_features # b, l
243
+
244
+ def encode_image_with_transformer(self, x):
245
+ visual = self.model.visual
246
+ x = visual.conv1(x) # shape = [*, width, grid, grid]
247
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
248
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
249
+
250
+ # class embeddings and positional embeddings
251
+ x = torch.cat(
252
+ [visual.class_embedding.to(x.dtype) + \
253
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
254
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
255
+ x = x + visual.positional_embedding.to(x.dtype)
256
+
257
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
258
+ # x = visual.patch_dropout(x)
259
+ x = visual.ln_pre(x)
260
+
261
+ x = x.permute(1, 0, 2) # NLD -> LND
262
+ hidden = self.image_transformer_forward(x)
263
+ x = hidden[-2].permute(1, 0, 2) # LND -> NLD
264
+ return x
265
+
266
+ def image_transformer_forward(self, x):
267
+ encoder_states = ()
268
+ trans = self.model.visual.transformer
269
+ for r in trans.resblocks:
270
+ if trans.grad_checkpointing and not torch.jit.is_scripting():
271
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
272
+ x = checkpoint(r, x, None, None, None)
273
+ else:
274
+ x = r(x, attn_mask=None)
275
+ encoder_states = encoder_states + (x, )
276
+ return encoder_states
277
+
278
+ def encode_with_transformer(self, text):
279
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
280
+ x = x + self.model.positional_embedding
281
+ x = x.permute(1, 0, 2) # NLD -> LND
282
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
283
+ x = x.permute(1, 0, 2) # LND -> NLD
284
+ x = self.model.ln_final(x)
285
+ return x
286
+
287
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
288
+ for i, r in enumerate(self.model.transformer.resblocks):
289
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
290
+ break
291
+ if (
292
+ self.model.transformer.grad_checkpointing
293
+ and not torch.jit.is_scripting()
294
+ ):
295
+ x = checkpoint(r, x, attn_mask)
296
+ else:
297
+ x = r(x, attn_mask=attn_mask)
298
+ return x
299
+
300
+ def encode(self, text):
301
+ return self(text)
302
+
303
+
304
+ class FrozenCLIPT5Encoder(AbstractEncoder):
305
+ def __init__(
306
+ self,
307
+ clip_version="openai/clip-vit-large-patch14",
308
+ t5_version="google/t5-v1_1-xl",
309
+ device="cuda",
310
+ clip_max_length=77,
311
+ t5_max_length=77,
312
+ ):
313
+ super().__init__()
314
+ self.clip_encoder = FrozenCLIPEmbedder(
315
+ clip_version, device, max_length=clip_max_length
316
+ )
317
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
318
+ print(
319
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
320
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
321
+ )
322
+
323
+ def encode(self, text):
324
+ return self(text)
325
+
326
+ def forward(self, text):
327
+ clip_z = self.clip_encoder.encode(text)
328
+ t5_z = self.t5_encoder.encode(text)
329
+ return [clip_z, t5_z]
imagedream/ldm/util.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import random
4
+ import torch
5
+ import numpy as np
6
+ from collections import abc
7
+
8
+ import multiprocessing as mp
9
+ from threading import Thread
10
+ from queue import Queue
11
+
12
+ from inspect import isfunction
13
+ from PIL import Image, ImageDraw, ImageFont
14
+
15
+
16
+ def log_txt_as_img(wh, xc, size=10):
17
+ # wh a tuple of (width, height)
18
+ # xc a list of captions to plot
19
+ b = len(xc)
20
+ txts = list()
21
+ for bi in range(b):
22
+ txt = Image.new("RGB", wh, color="white")
23
+ draw = ImageDraw.Draw(txt)
24
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
25
+ nc = int(40 * (wh[0] / 256))
26
+ lines = "\n".join(
27
+ xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
28
+ )
29
+
30
+ try:
31
+ draw.text((0, 0), lines, fill="black", font=font)
32
+ except UnicodeEncodeError:
33
+ print("Cant encode string for logging. Skipping.")
34
+
35
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
36
+ txts.append(txt)
37
+ txts = np.stack(txts)
38
+ txts = torch.tensor(txts)
39
+ return txts
40
+
41
+
42
+ def ismap(x):
43
+ if not isinstance(x, torch.Tensor):
44
+ return False
45
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
46
+
47
+
48
+ def isimage(x):
49
+ if not isinstance(x, torch.Tensor):
50
+ return False
51
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
52
+
53
+
54
+ def exists(x):
55
+ return x is not None
56
+
57
+
58
+ def default(val, d):
59
+ if exists(val):
60
+ return val
61
+ return d() if isfunction(d) else d
62
+
63
+
64
+ def mean_flat(tensor):
65
+ """
66
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
67
+ Take the mean over all non-batch dimensions.
68
+ """
69
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
70
+
71
+
72
+ def count_params(model, verbose=False):
73
+ total_params = sum(p.numel() for p in model.parameters())
74
+ if verbose:
75
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
76
+ return total_params
77
+
78
+
79
+ def instantiate_from_config(config):
80
+ if not "target" in config:
81
+ if config == "__is_first_stage__":
82
+ return None
83
+ elif config == "__is_unconditional__":
84
+ return None
85
+ raise KeyError("Expected key `target` to instantiate.")
86
+ # import pdb; pdb.set_trace()
87
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
88
+
89
+
90
+ def get_obj_from_str(string, reload=False):
91
+ module, cls = string.rsplit(".", 1)
92
+ # import pdb; pdb.set_trace()
93
+ if reload:
94
+ module_imp = importlib.import_module(module)
95
+ importlib.reload(module_imp)
96
+ return getattr(importlib.import_module(module, package=None), cls)
97
+
98
+
99
+ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
100
+ # create dummy dataset instance
101
+
102
+ # run prefetching
103
+ if idx_to_fn:
104
+ res = func(data, worker_id=idx)
105
+ else:
106
+ res = func(data)
107
+ Q.put([idx, res])
108
+ Q.put("Done")
109
+
110
+
111
+ def parallel_data_prefetch(
112
+ func: callable,
113
+ data,
114
+ n_proc,
115
+ target_data_type="ndarray",
116
+ cpu_intensive=True,
117
+ use_worker_id=False,
118
+ ):
119
+ # if target_data_type not in ["ndarray", "list"]:
120
+ # raise ValueError(
121
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
122
+ # )
123
+ if isinstance(data, np.ndarray) and target_data_type == "list":
124
+ raise ValueError("list expected but function got ndarray.")
125
+ elif isinstance(data, abc.Iterable):
126
+ if isinstance(data, dict):
127
+ print(
128
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
129
+ )
130
+ data = list(data.values())
131
+ if target_data_type == "ndarray":
132
+ data = np.asarray(data)
133
+ else:
134
+ data = list(data)
135
+ else:
136
+ raise TypeError(
137
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
138
+ )
139
+
140
+ if cpu_intensive:
141
+ Q = mp.Queue(1000)
142
+ proc = mp.Process
143
+ else:
144
+ Q = Queue(1000)
145
+ proc = Thread
146
+ # spawn processes
147
+ if target_data_type == "ndarray":
148
+ arguments = [
149
+ [func, Q, part, i, use_worker_id]
150
+ for i, part in enumerate(np.array_split(data, n_proc))
151
+ ]
152
+ else:
153
+ step = (
154
+ int(len(data) / n_proc + 1)
155
+ if len(data) % n_proc != 0
156
+ else int(len(data) / n_proc)
157
+ )
158
+ arguments = [
159
+ [func, Q, part, i, use_worker_id]
160
+ for i, part in enumerate(
161
+ [data[i : i + step] for i in range(0, len(data), step)]
162
+ )
163
+ ]
164
+ processes = []
165
+ for i in range(n_proc):
166
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
167
+ processes += [p]
168
+
169
+ # start processes
170
+ print(f"Start prefetching...")
171
+ import time
172
+
173
+ start = time.time()
174
+ gather_res = [[] for _ in range(n_proc)]
175
+ try:
176
+ for p in processes:
177
+ p.start()
178
+
179
+ k = 0
180
+ while k < n_proc:
181
+ # get result
182
+ res = Q.get()
183
+ if res == "Done":
184
+ k += 1
185
+ else:
186
+ gather_res[res[0]] = res[1]
187
+
188
+ except Exception as e:
189
+ print("Exception: ", e)
190
+ for p in processes:
191
+ p.terminate()
192
+
193
+ raise e
194
+ finally:
195
+ for p in processes:
196
+ p.join()
197
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
198
+
199
+ if target_data_type == "ndarray":
200
+ if not isinstance(gather_res[0], np.ndarray):
201
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
202
+
203
+ # order outputs
204
+ return np.concatenate(gather_res, axis=0)
205
+ elif target_data_type == "list":
206
+ out = []
207
+ for r in gather_res:
208
+ out.extend(r)
209
+ return out
210
+ else:
211
+ return gather_res
212
+
213
+ def set_seed(seed=None):
214
+ random.seed(seed)
215
+ np.random.seed(seed)
216
+ if seed is not None:
217
+ torch.manual_seed(seed)
218
+ torch.cuda.manual_seed_all(seed)
219
+
220
+ def add_random_background(image, bg_color=None):
221
+ bg_color = np.random.rand() * 255 if bg_color is None else bg_color
222
+ image = np.array(image)
223
+ rgb, alpha = image[..., :3], image[..., 3:]
224
+ alpha = alpha.astype(np.float32) / 255.0
225
+ image_new = rgb * alpha + bg_color * (1 - alpha)
226
+ return Image.fromarray(image_new.astype(np.uint8))
imagedream/model_zoo.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Utiliy functions to load pre-trained models more easily """
2
+ import os
3
+ import pkg_resources
4
+ from omegaconf import OmegaConf
5
+
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ from imagedream.ldm.util import instantiate_from_config
10
+
11
+
12
+ PRETRAINED_MODELS = {
13
+ "sd-v2.1-base-4view-ipmv": {
14
+ "config": "sd_v2_base_ipmv.yaml",
15
+ "repo_id": "Peng-Wang/ImageDream",
16
+ "filename": "sd-v2.1-base-4view-ipmv.pt",
17
+ },
18
+ "sd-v2.1-base-4view-ipmv-local": {
19
+ "config": "sd_v2_base_ipmv_local.yaml",
20
+ "repo_id": "Peng-Wang/ImageDream",
21
+ "filename": "sd-v2.1-base-4view-ipmv-local.pt",
22
+ },
23
+ }
24
+
25
+
26
+ def get_config_file(config_path):
27
+ cfg_file = pkg_resources.resource_filename(
28
+ "imagedream", os.path.join("configs", config_path)
29
+ )
30
+ if not os.path.exists(cfg_file):
31
+ raise RuntimeError(f"Config {config_path} not available!")
32
+ return cfg_file
33
+
34
+
35
+ def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None):
36
+ if (config_path is not None) and (ckpt_path is not None):
37
+ config = OmegaConf.load(config_path)
38
+ model = instantiate_from_config(config.model)
39
+ model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
40
+ return model
41
+
42
+ if not model_name in PRETRAINED_MODELS:
43
+ raise RuntimeError(
44
+ f"Model name {model_name} is not a pre-trained model. Available models are:\n- "
45
+ + "\n- ".join(PRETRAINED_MODELS.keys())
46
+ )
47
+ model_info = PRETRAINED_MODELS[model_name]
48
+
49
+ # Instiantiate the model
50
+ print(f"Loading model from config: {model_info['config']}")
51
+ config_file = get_config_file(model_info["config"])
52
+ config = OmegaConf.load(config_file)
53
+ model = instantiate_from_config(config.model)
54
+
55
+ # Load pre-trained checkpoint from huggingface
56
+ if not ckpt_path:
57
+ ckpt_path = hf_hub_download(
58
+ repo_id=model_info["repo_id"],
59
+ filename=model_info["filename"],
60
+ cache_dir=cache_dir,
61
+ )
62
+ print(f"Loading model from cache file: {ckpt_path}")
63
+ model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
64
+ return model
inference.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import time
4
+ import nvdiffrast.torch as dr
5
+ from util.utils import get_tri
6
+ import tempfile
7
+ from mesh import Mesh
8
+ import zipfile
9
+ def generate3d(model, rgb, ccm, device):
10
+
11
+ color_tri = torch.from_numpy(rgb)/255
12
+ xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
13
+ color = color_tri.permute(2,0,1)
14
+ xyz = xyz_tri.permute(2,0,1)
15
+
16
+
17
+ def get_imgs(color):
18
+ # color : [C, H, W*6]
19
+ color_list = []
20
+ color_list.append(color[:,:,256*5:256*(1+5)])
21
+ for i in range(0,5):
22
+ color_list.append(color[:,:,256*i:256*(1+i)])
23
+ return torch.stack(color_list, dim=0)# [6, C, H, W]
24
+
25
+ triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
26
+
27
+ color = get_imgs(color)
28
+ xyz = get_imgs(xyz)
29
+
30
+ color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
31
+ xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)
32
+
33
+ triplane = torch.cat([color,xyz],dim=1).to(device)
34
+ # 3D visualize
35
+ model.eval()
36
+ glctx = dr.RasterizeCudaContext()
37
+
38
+ if model.denoising == True:
39
+ tnew = 20
40
+ tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
41
+ noise_new = torch.randn_like(triplane) *0.5+0.5
42
+ triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
43
+ start_time = time.time()
44
+ with torch.no_grad():
45
+ triplane_feature2 = model.unet2(triplane,tnew)
46
+ end_time = time.time()
47
+ elapsed_time = end_time - start_time
48
+ print(f"unet takes {elapsed_time}s")
49
+ else:
50
+ triplane_feature2 = model.unet2(triplane)
51
+
52
+
53
+ with torch.no_grad():
54
+ data_config = {
55
+ 'resolution': [1024, 1024],
56
+ "triview_color": triplane_color.to(device),
57
+ }
58
+
59
+ verts, faces = model.decode(data_config, triplane_feature2)
60
+
61
+ data_config['verts'] = verts[0]
62
+ data_config['faces'] = faces
63
+
64
+
65
+ from kiui.mesh_utils import clean_mesh
66
+ verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=False, remesh_size=0.005)
67
+ data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
68
+ data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
69
+
70
+ start_time = time.time()
71
+ with torch.no_grad():
72
+ mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
73
+ model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2)
74
+
75
+ mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z")
76
+ mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
77
+ mesh.write(mesh_path_glb+".glb")
78
+
79
+ # mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb')
80
+ # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
81
+ # mesh_obj2.export(mesh_path_obj2+".obj")
82
+
83
+ with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip:
84
+ myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj')
85
+ myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png')
86
+ myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl')
87
+
88
+ end_time = time.time()
89
+ elapsed_time = end_time - start_time
90
+ print(f"uv takes {elapsed_time}s")
91
+ return mesh_path_glb+".glb", mesh_path_obj+'.zip'
libs/base_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ def instantiate_from_config(config):
9
+ if not "target" in config:
10
+ raise KeyError("Expected key `target` to instantiate.")
11
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
12
+
13
+
14
+ def get_obj_from_str(string, reload=False):
15
+ import importlib
16
+ module, cls = string.rsplit(".", 1)
17
+ if reload:
18
+ module_imp = importlib.import_module(module)
19
+ importlib.reload(module_imp)
20
+ return getattr(importlib.import_module(module, package=None), cls)
21
+
22
+
23
+ def tensor_detail(t):
24
+ assert type(t) == torch.Tensor
25
+ print(f"shape: {t.shape} mean: {t.mean():.2f}, std: {t.std():.2f}, min: {t.min():.2f}, max: {t.max():.2f}")
26
+
27
+
28
+
29
+ def drawRoundRec(draw, color, x, y, w, h, r):
30
+ drawObject = draw
31
+
32
+ '''Rounds'''
33
+ drawObject.ellipse((x, y, x + r, y + r), fill=color)
34
+ drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color)
35
+ drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color)
36
+ drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color)
37
+
38
+ '''rec.s'''
39
+ drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color)
40
+ drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color)
41
+
42
+
43
+ def do_resize_content(original_image: Image, scale_rate):
44
+ # resize image content wile retain the original image size
45
+ if scale_rate != 1:
46
+ # Calculate the new size after rescaling
47
+ new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
48
+ # Resize the image while maintaining the aspect ratio
49
+ resized_image = original_image.resize(new_size)
50
+ # Create a new image with the original size and black background
51
+ padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
52
+ paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
53
+ padded_image.paste(resized_image, paste_position)
54
+ return padded_image
55
+ else:
56
+ return original_image
57
+
58
+ def add_stroke(img, color=(255, 255, 255), stroke_radius=3):
59
+ # color in R, G, B format
60
+ if isinstance(img, Image.Image):
61
+ assert img.mode == "RGBA"
62
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2BGRA)
63
+ else:
64
+ assert img.shape[2] == 4
65
+ gray = img[:,:, 3]
66
+ ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY)
67
+ contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
68
+ res = cv2.drawContours(img, contours,-1, tuple(color)[::-1] + (255,), stroke_radius)
69
+ return Image.fromarray(cv2.cvtColor(res,cv2.COLOR_BGRA2RGBA))
70
+
71
+ def make_blob(image_size=(512, 512), sigma=0.2):
72
+ """
73
+ make 2D blob image with:
74
+ I(x, y)=1-\exp \left(-\frac{(x-H / 2)^2+(y-W / 2)^2}{2 \sigma^2 HS}\right)
75
+ """
76
+ import numpy as np
77
+ H, W = image_size
78
+ x = np.arange(0, W, 1, float)
79
+ y = np.arange(0, H, 1, float)
80
+ x, y = np.meshgrid(x, y)
81
+ x0 = W // 2
82
+ y0 = H // 2
83
+ img = 1 - np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2 * H * W))
84
+ return (img * 255).astype(np.uint8)
libs/sample.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from imagedream.camera_utils import get_camera_for_index
4
+ from imagedream.ldm.util import set_seed, add_random_background
5
+ from libs.base_utils import do_resize_content
6
+ from imagedream.ldm.models.diffusion.ddim import DDIMSampler
7
+ from torchvision import transforms as T
8
+
9
+
10
+ class ImageDreamDiffusion:
11
+ def __init__(
12
+ self,
13
+ model,
14
+ device,
15
+ dtype,
16
+ mode,
17
+ num_frames,
18
+ camera_views,
19
+ ref_position,
20
+ random_background=False,
21
+ offset_noise=False,
22
+ resize_rate=1,
23
+ image_size=256,
24
+ seed=1234,
25
+ ) -> None:
26
+ assert mode in ["pixel", "local"]
27
+ size = image_size
28
+ self.seed = seed
29
+ batch_size = max(4, num_frames)
30
+
31
+ neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
32
+ uc = model.get_learned_conditioning([neg_texts]).to(device)
33
+ sampler = DDIMSampler(model)
34
+
35
+ # pre-compute camera matrices
36
+ camera = [get_camera_for_index(i).squeeze() for i in camera_views]
37
+ camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
38
+ camera = torch.stack(camera)
39
+ camera = camera.repeat(batch_size // num_frames, 1).to(device)
40
+
41
+ self.image_transform = T.Compose(
42
+ [
43
+ T.Resize((size, size)),
44
+ T.ToTensor(),
45
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
46
+ ]
47
+ )
48
+ self.dtype = dtype
49
+ self.ref_position = ref_position
50
+ self.mode = mode
51
+ self.random_background = random_background
52
+ self.resize_rate = resize_rate
53
+ self.num_frames = num_frames
54
+ self.size = size
55
+ self.device = device
56
+ self.batch_size = batch_size
57
+ self.model = model
58
+ self.sampler = sampler
59
+ self.uc = uc
60
+ self.camera = camera
61
+ self.offset_noise = offset_noise
62
+
63
+ @staticmethod
64
+ def i2i(
65
+ model,
66
+ image_size,
67
+ prompt,
68
+ uc,
69
+ sampler,
70
+ ip=None,
71
+ step=20,
72
+ scale=5.0,
73
+ batch_size=8,
74
+ ddim_eta=0.0,
75
+ dtype=torch.float32,
76
+ device="cuda",
77
+ camera=None,
78
+ num_frames=4,
79
+ pixel_control=False,
80
+ transform=None,
81
+ offset_noise=False,
82
+ ):
83
+ """ The function supports additional image prompt.
84
+ Args:
85
+ model (_type_): the image dream model
86
+ image_size (_type_): size of diffusion output (standard 256)
87
+ prompt (_type_): text prompt for the image (prompt in type str)
88
+ uc (_type_): unconditional vector (tensor in shape [1, 77, 1024])
89
+ sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler
90
+ ip (Image, optional): the image prompt. Defaults to None.
91
+ step (int, optional): _description_. Defaults to 20.
92
+ scale (float, optional): _description_. Defaults to 7.5.
93
+ batch_size (int, optional): _description_. Defaults to 8.
94
+ ddim_eta (float, optional): _description_. Defaults to 0.0.
95
+ dtype (_type_, optional): _description_. Defaults to torch.float32.
96
+ device (str, optional): _description_. Defaults to "cuda".
97
+ camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
98
+ num_frames (int, optional): _num of frames (views) to generate
99
+ pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode
100
+ transform: Compose(
101
+ Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn)
102
+ ToTensor()
103
+ Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
104
+ )
105
+ """
106
+ ip_raw = ip
107
+ if type(prompt) != list:
108
+ prompt = [prompt]
109
+ with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
110
+ c = model.get_learned_conditioning(prompt).to(
111
+ device
112
+ ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
113
+ c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size
114
+ uc_ = {"context": uc.repeat(batch_size, 1, 1)}
115
+
116
+ if camera is not None:
117
+ c_["camera"] = uc_["camera"] = (
118
+ camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
119
+ )
120
+ c_["num_frames"] = uc_["num_frames"] = num_frames
121
+
122
+ if ip is not None:
123
+ ip_embed = model.get_learned_image_conditioning(ip).to(
124
+ device
125
+ ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
126
+ ip_ = ip_embed.repeat(batch_size, 1, 1)
127
+ c_["ip"] = ip_
128
+ uc_["ip"] = torch.zeros_like(ip_)
129
+
130
+ if pixel_control:
131
+ assert camera is not None
132
+ ip = transform(ip).to(
133
+ device
134
+ ) # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00
135
+ ip_img = model.get_first_stage_encoding(
136
+ model.encode_first_stage(ip[None, :, :, :])
137
+ ) # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55
138
+ c_["ip_img"] = ip_img
139
+ uc_["ip_img"] = torch.zeros_like(ip_img)
140
+
141
+ shape = [4, image_size // 8, image_size // 8] # [4, 32, 32]
142
+ if offset_noise:
143
+ ref = transform(ip_raw).to(device)
144
+ ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
145
+ ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
146
+ time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
147
+ x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
148
+
149
+ samples_ddim, _ = (
150
+ sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
151
+ S=step,
152
+ conditioning=c_,
153
+ batch_size=batch_size,
154
+ shape=shape,
155
+ verbose=False,
156
+ unconditional_guidance_scale=scale,
157
+ unconditional_conditioning=uc_,
158
+ eta=ddim_eta,
159
+ x_T=x_T if offset_noise else None,
160
+ )
161
+ )
162
+
163
+ x_sample = model.decode_first_stage(samples_ddim)
164
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
165
+ x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
166
+
167
+ return list(x_sample.astype(np.uint8))
168
+
169
+ def diffuse(self, t, ip, n_test=2):
170
+ set_seed(self.seed)
171
+ ip = do_resize_content(ip, self.resize_rate)
172
+ if self.random_background:
173
+ ip = add_random_background(ip)
174
+
175
+ images = []
176
+ for _ in range(n_test):
177
+ img = self.i2i(
178
+ self.model,
179
+ self.size,
180
+ t,
181
+ self.uc,
182
+ self.sampler,
183
+ ip=ip,
184
+ step=50,
185
+ scale=5,
186
+ batch_size=self.batch_size,
187
+ ddim_eta=0.0,
188
+ dtype=self.dtype,
189
+ device=self.device,
190
+ camera=self.camera,
191
+ num_frames=self.num_frames,
192
+ pixel_control=(self.mode == "pixel"),
193
+ transform=self.image_transform,
194
+ offset_noise=self.offset_noise,
195
+ )
196
+ img = np.concatenate(img, 1)
197
+ img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1)
198
+ images.append(img)
199
+ set_seed() # unset random and numpy seed
200
+ return images
201
+
202
+
203
+ class ImageDreamDiffusionStage2:
204
+ def __init__(
205
+ self,
206
+ model,
207
+ device,
208
+ dtype,
209
+ num_frames,
210
+ camera_views,
211
+ ref_position,
212
+ random_background=False,
213
+ offset_noise=False,
214
+ resize_rate=1,
215
+ mode="pixel",
216
+ image_size=256,
217
+ seed=1234,
218
+ ) -> None:
219
+ assert mode in ["pixel", "local"]
220
+
221
+ size = image_size
222
+ self.seed = seed
223
+ batch_size = max(4, num_frames)
224
+
225
+ neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
226
+ uc = model.get_learned_conditioning([neg_texts]).to(device)
227
+ sampler = DDIMSampler(model)
228
+
229
+ # pre-compute camera matrices
230
+ camera = [get_camera_for_index(i).squeeze() for i in camera_views]
231
+ if ref_position is not None:
232
+ camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
233
+ camera = torch.stack(camera)
234
+ camera = camera.repeat(batch_size // num_frames, 1).to(device)
235
+
236
+ self.image_transform = T.Compose(
237
+ [
238
+ T.Resize((size, size)),
239
+ T.ToTensor(),
240
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
241
+ ]
242
+ )
243
+
244
+ self.dtype = dtype
245
+ self.mode = mode
246
+ self.ref_position = ref_position
247
+ self.random_background = random_background
248
+ self.resize_rate = resize_rate
249
+ self.num_frames = num_frames
250
+ self.size = size
251
+ self.device = device
252
+ self.batch_size = batch_size
253
+ self.model = model
254
+ self.sampler = sampler
255
+ self.uc = uc
256
+ self.camera = camera
257
+ self.offset_noise = offset_noise
258
+
259
+ @staticmethod
260
+ def i2iStage2(
261
+ model,
262
+ image_size,
263
+ prompt,
264
+ uc,
265
+ sampler,
266
+ pixel_images,
267
+ ip=None,
268
+ step=20,
269
+ scale=5.0,
270
+ batch_size=8,
271
+ ddim_eta=0.0,
272
+ dtype=torch.float32,
273
+ device="cuda",
274
+ camera=None,
275
+ num_frames=4,
276
+ pixel_control=False,
277
+ transform=None,
278
+ offset_noise=False,
279
+ ):
280
+ ip_raw = ip
281
+ if type(prompt) != list:
282
+ prompt = [prompt]
283
+ with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
284
+ c = model.get_learned_conditioning(prompt).to(
285
+ device
286
+ ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
287
+ c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size
288
+ uc_ = {"context": uc.repeat(batch_size, 1, 1)}
289
+
290
+ if camera is not None:
291
+ c_["camera"] = uc_["camera"] = (
292
+ camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
293
+ )
294
+ c_["num_frames"] = uc_["num_frames"] = num_frames
295
+
296
+ if ip is not None:
297
+ ip_embed = model.get_learned_image_conditioning(ip).to(
298
+ device
299
+ ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
300
+ ip_ = ip_embed.repeat(batch_size, 1, 1)
301
+ c_["ip"] = ip_
302
+ uc_["ip"] = torch.zeros_like(ip_)
303
+
304
+ if pixel_control:
305
+ assert camera is not None
306
+
307
+ transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images])
308
+ latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images))
309
+
310
+ c_["pixel_images"] = latent_pixel_images
311
+ uc_["pixel_images"] = torch.zeros_like(latent_pixel_images)
312
+
313
+ shape = [4, image_size // 8, image_size // 8] # [4, 32, 32]
314
+ if offset_noise:
315
+ ref = transform(ip_raw).to(device)
316
+ ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
317
+ ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
318
+ time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
319
+ x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
320
+
321
+ samples_ddim, _ = (
322
+ sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
323
+ S=step,
324
+ conditioning=c_,
325
+ batch_size=batch_size,
326
+ shape=shape,
327
+ verbose=False,
328
+ unconditional_guidance_scale=scale,
329
+ unconditional_conditioning=uc_,
330
+ eta=ddim_eta,
331
+ x_T=x_T if offset_noise else None,
332
+ )
333
+ )
334
+ x_sample = model.decode_first_stage(samples_ddim)
335
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
336
+ x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
337
+
338
+ return list(x_sample.astype(np.uint8))
339
+
340
+ @torch.no_grad()
341
+ def diffuse(self, t, ip, pixel_images, n_test=2):
342
+ set_seed(self.seed)
343
+ ip = do_resize_content(ip, self.resize_rate)
344
+ pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images]
345
+
346
+ if self.random_background:
347
+ bg_color = np.random.rand() * 255
348
+ ip = add_random_background(ip, bg_color)
349
+ pixel_images = [add_random_background(i, bg_color) for i in pixel_images]
350
+
351
+ images = []
352
+ for _ in range(n_test):
353
+ img = self.i2iStage2(
354
+ self.model,
355
+ self.size,
356
+ t,
357
+ self.uc,
358
+ self.sampler,
359
+ pixel_images=pixel_images,
360
+ ip=ip,
361
+ step=50,
362
+ scale=5,
363
+ batch_size=self.batch_size,
364
+ ddim_eta=0.0,
365
+ dtype=self.dtype,
366
+ device=self.device,
367
+ camera=self.camera,
368
+ num_frames=self.num_frames,
369
+ pixel_control=(self.mode == "pixel"),
370
+ transform=self.image_transform,
371
+ offset_noise=self.offset_noise,
372
+ )
373
+ img = np.concatenate(img, 1)
374
+ img = np.concatenate(
375
+ (img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]),
376
+ axis=1,
377
+ )
378
+ images.append(img)
379
+ set_seed() # unset random and numpy seed
380
+ return images
mesh.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import trimesh
5
+ import numpy as np
6
+
7
+ from kiui.op import safe_normalize, dot
8
+ from kiui.typing import *
9
+
10
+ class Mesh:
11
+ """
12
+ A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
13
+
14
+ Note:
15
+ This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
16
+ """
17
+ def __init__(
18
+ self,
19
+ v: Optional[Tensor] = None,
20
+ f: Optional[Tensor] = None,
21
+ vn: Optional[Tensor] = None,
22
+ fn: Optional[Tensor] = None,
23
+ vt: Optional[Tensor] = None,
24
+ ft: Optional[Tensor] = None,
25
+ vc: Optional[Tensor] = None, # vertex color
26
+ albedo: Optional[Tensor] = None,
27
+ metallicRoughness: Optional[Tensor] = None,
28
+ device: Optional[torch.device] = None,
29
+ ):
30
+ """Init a mesh directly using all attributes.
31
+
32
+ Args:
33
+ v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
34
+ f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
35
+ vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None.
36
+ fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None.
37
+ vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None.
38
+ ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None.
39
+ vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None.
40
+ albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None.
41
+ metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None.
42
+ device (Optional[torch.device]): torch device. Defaults to None.
43
+ """
44
+ self.device = device
45
+ self.v = v
46
+ self.vn = vn
47
+ self.vt = vt
48
+ self.f = f
49
+ self.fn = fn
50
+ self.ft = ft
51
+ # will first see if there is vertex color to use
52
+ self.vc = vc
53
+ # only support a single albedo image
54
+ self.albedo = albedo
55
+ # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]
56
+ # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html
57
+ self.metallicRoughness = metallicRoughness
58
+
59
+ self.ori_center = 0
60
+ self.ori_scale = 1
61
+
62
+ @classmethod
63
+ def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
64
+ """load mesh from path.
65
+
66
+ Args:
67
+ path (str): path to mesh file, supports ply, obj, glb.
68
+ clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
69
+ resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True.
70
+ renormal (bool, optional): re-calc the vertex normals. Defaults to True.
71
+ retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
72
+ bound (float, optional): bound to resize. Defaults to 0.9.
73
+ front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'.
74
+ device (torch.device, optional): torch device. Defaults to None.
75
+
76
+ Note:
77
+ a ``device`` keyword argument can be provided to specify the torch device.
78
+ If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
79
+
80
+ Returns:
81
+ Mesh: the loaded Mesh object.
82
+ """
83
+ # obj supports face uv
84
+ if path.endswith(".obj"):
85
+ mesh = cls.load_obj(path, **kwargs)
86
+ # trimesh only supports vertex uv, but can load more formats
87
+ else:
88
+ mesh = cls.load_trimesh(path, **kwargs)
89
+
90
+ # clean
91
+ if clean:
92
+ from kiui.mesh_utils import clean_mesh
93
+ vertices = mesh.v.detach().cpu().numpy()
94
+ triangles = mesh.f.detach().cpu().numpy()
95
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
96
+ mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device)
97
+ mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device)
98
+
99
+ print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
100
+ # auto-normalize
101
+ if resize:
102
+ mesh.auto_size(bound=bound)
103
+ # auto-fix normal
104
+ if renormal or mesh.vn is None:
105
+ mesh.auto_normal()
106
+ print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
107
+ # auto-fix texcoords
108
+ if retex or (mesh.albedo is not None and mesh.vt is None):
109
+ mesh.auto_uv(cache_path=path)
110
+ print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
111
+
112
+ # rotate front dir to +z
113
+ if front_dir != "+z":
114
+ # axis switch
115
+ if "-z" in front_dir:
116
+ T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
117
+ elif "+x" in front_dir:
118
+ T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
119
+ elif "-x" in front_dir:
120
+ T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
121
+ elif "+y" in front_dir:
122
+ T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
123
+ elif "-y" in front_dir:
124
+ T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
125
+ else:
126
+ T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
127
+ # rotation (how many 90 degrees)
128
+ if '1' in front_dir:
129
+ T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
130
+ elif '2' in front_dir:
131
+ T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
132
+ elif '3' in front_dir:
133
+ T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
134
+ mesh.v @= T
135
+ mesh.vn @= T
136
+
137
+ return mesh
138
+
139
+ # load from obj file
140
+ @classmethod
141
+ def load_obj(cls, path, albedo_path=None, device=None):
142
+ """load an ``obj`` mesh.
143
+
144
+ Args:
145
+ path (str): path to mesh.
146
+ albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
147
+ device (torch.device, optional): torch device. Defaults to None.
148
+
149
+ Note:
150
+ We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
151
+ The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
152
+
153
+ Returns:
154
+ Mesh: the loaded Mesh object.
155
+ """
156
+ assert os.path.splitext(path)[-1] == ".obj"
157
+
158
+ mesh = cls()
159
+
160
+ # device
161
+ if device is None:
162
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
163
+
164
+ mesh.device = device
165
+
166
+ # load obj
167
+ with open(path, "r") as f:
168
+ lines = f.readlines()
169
+
170
+ def parse_f_v(fv):
171
+ # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
172
+ # supported forms:
173
+ # f v1 v2 v3
174
+ # f v1/vt1 v2/vt2 v3/vt3
175
+ # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
176
+ # f v1//vn1 v2//vn2 v3//vn3
177
+ xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
178
+ xs.extend([-1] * (3 - len(xs)))
179
+ return xs[0], xs[1], xs[2]
180
+
181
+ vertices, texcoords, normals = [], [], []
182
+ faces, tfaces, nfaces = [], [], []
183
+ mtl_path = None
184
+
185
+ for line in lines:
186
+ split_line = line.split()
187
+ # empty line
188
+ if len(split_line) == 0:
189
+ continue
190
+ prefix = split_line[0].lower()
191
+ # mtllib
192
+ if prefix == "mtllib":
193
+ mtl_path = split_line[1]
194
+ # usemtl
195
+ elif prefix == "usemtl":
196
+ pass # ignored
197
+ # v/vn/vt
198
+ elif prefix == "v":
199
+ vertices.append([float(v) for v in split_line[1:]])
200
+ elif prefix == "vn":
201
+ normals.append([float(v) for v in split_line[1:]])
202
+ elif prefix == "vt":
203
+ val = [float(v) for v in split_line[1:]]
204
+ texcoords.append([val[0], 1.0 - val[1]])
205
+ elif prefix == "f":
206
+ vs = split_line[1:]
207
+ nv = len(vs)
208
+ v0, t0, n0 = parse_f_v(vs[0])
209
+ for i in range(nv - 2): # triangulate (assume vertices are ordered)
210
+ v1, t1, n1 = parse_f_v(vs[i + 1])
211
+ v2, t2, n2 = parse_f_v(vs[i + 2])
212
+ faces.append([v0, v1, v2])
213
+ tfaces.append([t0, t1, t2])
214
+ nfaces.append([n0, n1, n2])
215
+
216
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
217
+ mesh.vt = (
218
+ torch.tensor(texcoords, dtype=torch.float32, device=device)
219
+ if len(texcoords) > 0
220
+ else None
221
+ )
222
+ mesh.vn = (
223
+ torch.tensor(normals, dtype=torch.float32, device=device)
224
+ if len(normals) > 0
225
+ else None
226
+ )
227
+
228
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
229
+ mesh.ft = (
230
+ torch.tensor(tfaces, dtype=torch.int32, device=device)
231
+ if len(texcoords) > 0
232
+ else None
233
+ )
234
+ mesh.fn = (
235
+ torch.tensor(nfaces, dtype=torch.int32, device=device)
236
+ if len(normals) > 0
237
+ else None
238
+ )
239
+
240
+ # see if there is vertex color
241
+ use_vertex_color = False
242
+ if mesh.v.shape[1] == 6:
243
+ use_vertex_color = True
244
+ mesh.vc = mesh.v[:, 3:]
245
+ mesh.v = mesh.v[:, :3]
246
+ print(f"[load_obj] use vertex color: {mesh.vc.shape}")
247
+
248
+ # try to load texture image
249
+ if not use_vertex_color:
250
+ # try to retrieve mtl file
251
+ mtl_path_candidates = []
252
+ if mtl_path is not None:
253
+ mtl_path_candidates.append(mtl_path)
254
+ mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
255
+ mtl_path_candidates.append(path.replace(".obj", ".mtl"))
256
+
257
+ mtl_path = None
258
+ for candidate in mtl_path_candidates:
259
+ if os.path.exists(candidate):
260
+ mtl_path = candidate
261
+ break
262
+
263
+ # if albedo_path is not provided, try retrieve it from mtl
264
+ metallic_path = None
265
+ roughness_path = None
266
+ if mtl_path is not None and albedo_path is None:
267
+ with open(mtl_path, "r") as f:
268
+ lines = f.readlines()
269
+
270
+ for line in lines:
271
+ split_line = line.split()
272
+ # empty line
273
+ if len(split_line) == 0:
274
+ continue
275
+ prefix = split_line[0]
276
+
277
+ if "map_Kd" in prefix:
278
+ # assume relative path!
279
+ albedo_path = os.path.join(os.path.dirname(path), split_line[1])
280
+ print(f"[load_obj] use texture from: {albedo_path}")
281
+ elif "map_Pm" in prefix:
282
+ metallic_path = os.path.join(os.path.dirname(path), split_line[1])
283
+ elif "map_Pr" in prefix:
284
+ roughness_path = os.path.join(os.path.dirname(path), split_line[1])
285
+
286
+ # still not found albedo_path, or the path doesn't exist
287
+ if albedo_path is None or not os.path.exists(albedo_path):
288
+ # init an empty texture
289
+ print(f"[load_obj] init empty albedo!")
290
+ # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
291
+ albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color
292
+ else:
293
+ albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
294
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
295
+ albedo = albedo.astype(np.float32) / 255
296
+ print(f"[load_obj] load texture: {albedo.shape}")
297
+
298
+ mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
299
+
300
+ # try to load metallic and roughness
301
+ if metallic_path is not None and roughness_path is not None:
302
+ print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}")
303
+ metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED)
304
+ metallic = metallic.astype(np.float32) / 255
305
+ roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED)
306
+ roughness = roughness.astype(np.float32) / 255
307
+ metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1)
308
+
309
+ mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
310
+
311
+ return mesh
312
+
313
+ @classmethod
314
+ def load_trimesh(cls, path, device=None):
315
+ """load a mesh using ``trimesh.load()``.
316
+
317
+ Can load various formats like ``glb`` and serves as a fallback.
318
+
319
+ Note:
320
+ We will try to merge all meshes if the glb contains more than one,
321
+ but **this may cause the texture to lose**, since we only support one texture image!
322
+
323
+ Args:
324
+ path (str): path to the mesh file.
325
+ device (torch.device, optional): torch device. Defaults to None.
326
+
327
+ Returns:
328
+ Mesh: the loaded Mesh object.
329
+ """
330
+ mesh = cls()
331
+
332
+ # device
333
+ if device is None:
334
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
335
+
336
+ mesh.device = device
337
+
338
+ # use trimesh to load ply/glb
339
+ _data = trimesh.load(path)
340
+ if isinstance(_data, trimesh.Scene):
341
+ if len(_data.geometry) == 1:
342
+ _mesh = list(_data.geometry.values())[0]
343
+ else:
344
+ print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.")
345
+ _concat = []
346
+ # loop the scene graph and apply transform to each mesh
347
+ scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}}
348
+ for k, v in scene_graph.items():
349
+ name = v['geometry']
350
+ if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh):
351
+ transform = v['transform']
352
+ _concat.append(_data.geometry[name].apply_transform(transform))
353
+ _mesh = trimesh.util.concatenate(_concat)
354
+ else:
355
+ _mesh = _data
356
+
357
+ if _mesh.visual.kind == 'vertex':
358
+ vertex_colors = _mesh.visual.vertex_colors
359
+ vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
360
+ mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
361
+ print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
362
+ elif _mesh.visual.kind == 'texture':
363
+ _material = _mesh.visual.material
364
+ if isinstance(_material, trimesh.visual.material.PBRMaterial):
365
+ texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
366
+ # load metallicRoughness if present
367
+ if _material.metallicRoughnessTexture is not None:
368
+ metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255
369
+ mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
370
+ elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
371
+ texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
372
+ else:
373
+ raise NotImplementedError(f"material type {type(_material)} not supported!")
374
+ mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous()
375
+ print(f"[load_trimesh] load texture: {texture.shape}")
376
+ else:
377
+ texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
378
+ mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
379
+ print(f"[load_trimesh] failed to load texture.")
380
+
381
+ vertices = _mesh.vertices
382
+
383
+ try:
384
+ texcoords = _mesh.visual.uv
385
+ texcoords[:, 1] = 1 - texcoords[:, 1]
386
+ except Exception as e:
387
+ texcoords = None
388
+
389
+ try:
390
+ normals = _mesh.vertex_normals
391
+ except Exception as e:
392
+ normals = None
393
+
394
+ # trimesh only support vertex uv...
395
+ faces = tfaces = nfaces = _mesh.faces
396
+
397
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
398
+ mesh.vt = (
399
+ torch.tensor(texcoords, dtype=torch.float32, device=device)
400
+ if texcoords is not None
401
+ else None
402
+ )
403
+ mesh.vn = (
404
+ torch.tensor(normals, dtype=torch.float32, device=device)
405
+ if normals is not None
406
+ else None
407
+ )
408
+
409
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
410
+ mesh.ft = (
411
+ torch.tensor(tfaces, dtype=torch.int32, device=device)
412
+ if texcoords is not None
413
+ else None
414
+ )
415
+ mesh.fn = (
416
+ torch.tensor(nfaces, dtype=torch.int32, device=device)
417
+ if normals is not None
418
+ else None
419
+ )
420
+
421
+ return mesh
422
+
423
+ # sample surface (using trimesh)
424
+ def sample_surface(self, count: int):
425
+ """sample points on the surface of the mesh.
426
+
427
+ Args:
428
+ count (int): number of points to sample.
429
+
430
+ Returns:
431
+ torch.Tensor: the sampled points, float [count, 3].
432
+ """
433
+ _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy())
434
+ points, face_idx = trimesh.sample.sample_surface(_mesh, count)
435
+ points = torch.from_numpy(points).float().to(self.device)
436
+ return points
437
+
438
+ # aabb
439
+ def aabb(self):
440
+ """get the axis-aligned bounding box of the mesh.
441
+
442
+ Returns:
443
+ Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
444
+ """
445
+ return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
446
+
447
+ # unit size
448
+ @torch.no_grad()
449
+ def auto_size(self, bound=0.9):
450
+ """auto resize the mesh.
451
+
452
+ Args:
453
+ bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
454
+ """
455
+ vmin, vmax = self.aabb()
456
+ self.ori_center = (vmax + vmin) / 2
457
+ self.ori_scale = 2 * bound / torch.max(vmax - vmin).item()
458
+ self.v = (self.v - self.ori_center) * self.ori_scale
459
+
460
+ def auto_normal(self):
461
+ """auto calculate the vertex normals.
462
+ """
463
+ i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
464
+ v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
465
+
466
+ face_normals = torch.cross(v1 - v0, v2 - v0)
467
+
468
+ # Splat face normals to vertices
469
+ vn = torch.zeros_like(self.v)
470
+ vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
471
+ vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
472
+ vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
473
+
474
+ # Normalize, replace zero (degenerated) normals with some default value
475
+ vn = torch.where(
476
+ dot(vn, vn) > 1e-20,
477
+ vn,
478
+ torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
479
+ )
480
+ vn = safe_normalize(vn)
481
+
482
+ self.vn = vn
483
+ self.fn = self.f
484
+
485
+ def auto_uv(self, cache_path=None, vmap=True):
486
+ """auto calculate the uv coordinates.
487
+
488
+ Args:
489
+ cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
490
+ vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
491
+ Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
492
+ """
493
+ # try to load cache
494
+ if cache_path is not None:
495
+ cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
496
+ if cache_path is not None and os.path.exists(cache_path):
497
+ data = np.load(cache_path)
498
+ vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
499
+ else:
500
+ import xatlas
501
+
502
+ v_np = self.v.detach().cpu().numpy()
503
+ f_np = self.f.detach().int().cpu().numpy()
504
+ atlas = xatlas.Atlas()
505
+ atlas.add_mesh(v_np, f_np)
506
+ chart_options = xatlas.ChartOptions()
507
+ # chart_options.max_iterations = 4
508
+ atlas.generate(chart_options=chart_options)
509
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
510
+
511
+ # save to cache
512
+ if cache_path is not None:
513
+ np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
514
+
515
+ vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
516
+ ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
517
+ self.vt = vt
518
+ self.ft = ft
519
+
520
+ if vmap:
521
+ vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
522
+ self.align_v_to_vt(vmapping)
523
+
524
+ def align_v_to_vt(self, vmapping=None):
525
+ """ remap v/f and vn/fn to vt/ft.
526
+
527
+ Args:
528
+ vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
529
+ """
530
+ if vmapping is None:
531
+ ft = self.ft.view(-1).long()
532
+ f = self.f.view(-1).long()
533
+ vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
534
+ vmapping[ft] = f # scatter, randomly choose one if index is not unique
535
+
536
+ self.v = self.v[vmapping]
537
+ self.f = self.ft
538
+
539
+ if self.vn is not None:
540
+ self.vn = self.vn[vmapping]
541
+ self.fn = self.ft
542
+
543
+ def to(self, device):
544
+ """move all tensor attributes to device.
545
+
546
+ Args:
547
+ device (torch.device): target device.
548
+
549
+ Returns:
550
+ Mesh: self.
551
+ """
552
+ self.device = device
553
+ for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]:
554
+ tensor = getattr(self, name)
555
+ if tensor is not None:
556
+ setattr(self, name, tensor.to(device))
557
+ return self
558
+
559
+ def write(self, path):
560
+ """write the mesh to a path.
561
+
562
+ Args:
563
+ path (str): path to write, supports ply, obj and glb.
564
+ """
565
+ if path.endswith(".ply"):
566
+ self.write_ply(path)
567
+ elif path.endswith(".obj"):
568
+ self.write_obj(path)
569
+ elif path.endswith(".glb") or path.endswith(".gltf"):
570
+ self.write_glb(path)
571
+ else:
572
+ raise NotImplementedError(f"format {path} not supported!")
573
+
574
+ def write_ply(self, path):
575
+ """write the mesh in ply format. Only for geometry!
576
+
577
+ Args:
578
+ path (str): path to write.
579
+ """
580
+
581
+ if self.albedo is not None:
582
+ print(f'[WARN] ply format does not support exporting texture, will ignore!')
583
+
584
+ v_np = self.v.detach().cpu().numpy()
585
+ f_np = self.f.detach().cpu().numpy()
586
+
587
+ _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
588
+ _mesh.export(path)
589
+
590
+
591
+ def write_glb(self, path):
592
+ """write the mesh in glb/gltf format.
593
+ This will create a scene with a single mesh.
594
+
595
+ Args:
596
+ path (str): path to write.
597
+ """
598
+
599
+ # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
600
+ if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
601
+ self.align_v_to_vt()
602
+
603
+ import pygltflib
604
+
605
+ f_np = self.f.detach().cpu().numpy().astype(np.uint32)
606
+ f_np_blob = f_np.flatten().tobytes()
607
+
608
+ v_np = self.v.detach().cpu().numpy().astype(np.float32)
609
+ v_np_blob = v_np.tobytes()
610
+
611
+ blob = f_np_blob + v_np_blob
612
+ byteOffset = len(blob)
613
+
614
+ # base mesh
615
+ gltf = pygltflib.GLTF2(
616
+ scene=0,
617
+ scenes=[pygltflib.Scene(nodes=[0])],
618
+ nodes=[pygltflib.Node(mesh=0)],
619
+ meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
620
+ # indices to accessors (0 is triangles)
621
+ attributes=pygltflib.Attributes(
622
+ POSITION=1,
623
+ ),
624
+ indices=0,
625
+ )])],
626
+ buffers=[
627
+ pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
628
+ ],
629
+ # buffer view (based on dtype)
630
+ bufferViews=[
631
+ # triangles; as flatten (element) array
632
+ pygltflib.BufferView(
633
+ buffer=0,
634
+ byteLength=len(f_np_blob),
635
+ target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
636
+ ),
637
+ # positions; as vec3 array
638
+ pygltflib.BufferView(
639
+ buffer=0,
640
+ byteOffset=len(f_np_blob),
641
+ byteLength=len(v_np_blob),
642
+ byteStride=12, # vec3
643
+ target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
644
+ ),
645
+ ],
646
+ accessors=[
647
+ # 0 = triangles
648
+ pygltflib.Accessor(
649
+ bufferView=0,
650
+ componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
651
+ count=f_np.size,
652
+ type=pygltflib.SCALAR,
653
+ max=[int(f_np.max())],
654
+ min=[int(f_np.min())],
655
+ ),
656
+ # 1 = positions
657
+ pygltflib.Accessor(
658
+ bufferView=1,
659
+ componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
660
+ count=len(v_np),
661
+ type=pygltflib.VEC3,
662
+ max=v_np.max(axis=0).tolist(),
663
+ min=v_np.min(axis=0).tolist(),
664
+ ),
665
+ ],
666
+ )
667
+
668
+ # append texture info
669
+ if self.vt is not None:
670
+
671
+ vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
672
+ vt_np_blob = vt_np.tobytes()
673
+
674
+ albedo = self.albedo.detach().cpu().numpy()
675
+ albedo = (albedo * 255).astype(np.uint8)
676
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
677
+ albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
678
+
679
+ # update primitive
680
+ gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
681
+ gltf.meshes[0].primitives[0].material = 0
682
+
683
+ # update materials
684
+ gltf.materials.append(pygltflib.Material(
685
+ pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
686
+ baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
687
+ metallicFactor=0.0,
688
+ roughnessFactor=1.0,
689
+ ),
690
+ alphaMode=pygltflib.OPAQUE,
691
+ alphaCutoff=None,
692
+ doubleSided=True,
693
+ ))
694
+
695
+ gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
696
+ gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
697
+ gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
698
+
699
+ # update buffers
700
+ gltf.bufferViews.append(
701
+ # index = 2, texcoords; as vec2 array
702
+ pygltflib.BufferView(
703
+ buffer=0,
704
+ byteOffset=byteOffset,
705
+ byteLength=len(vt_np_blob),
706
+ byteStride=8, # vec2
707
+ target=pygltflib.ARRAY_BUFFER,
708
+ )
709
+ )
710
+
711
+ gltf.accessors.append(
712
+ # 2 = texcoords
713
+ pygltflib.Accessor(
714
+ bufferView=2,
715
+ componentType=pygltflib.FLOAT,
716
+ count=len(vt_np),
717
+ type=pygltflib.VEC2,
718
+ max=vt_np.max(axis=0).tolist(),
719
+ min=vt_np.min(axis=0).tolist(),
720
+ )
721
+ )
722
+
723
+ blob += vt_np_blob
724
+ byteOffset += len(vt_np_blob)
725
+
726
+ gltf.bufferViews.append(
727
+ # index = 3, albedo texture; as none target
728
+ pygltflib.BufferView(
729
+ buffer=0,
730
+ byteOffset=byteOffset,
731
+ byteLength=len(albedo_blob),
732
+ )
733
+ )
734
+
735
+ blob += albedo_blob
736
+ byteOffset += len(albedo_blob)
737
+
738
+ gltf.buffers[0].byteLength = byteOffset
739
+
740
+ # append metllic roughness
741
+ if self.metallicRoughness is not None:
742
+ metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
743
+ metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
744
+ metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
745
+ metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
746
+
747
+ # update texture definition
748
+ gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
749
+ gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
750
+ gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
751
+
752
+ gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
753
+ gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
754
+ gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
755
+
756
+ # update buffers
757
+ gltf.bufferViews.append(
758
+ # index = 4, metallicRoughness texture; as none target
759
+ pygltflib.BufferView(
760
+ buffer=0,
761
+ byteOffset=byteOffset,
762
+ byteLength=len(metallicRoughness_blob),
763
+ )
764
+ )
765
+
766
+ blob += metallicRoughness_blob
767
+ byteOffset += len(metallicRoughness_blob)
768
+
769
+ gltf.buffers[0].byteLength = byteOffset
770
+
771
+
772
+ # set actual data
773
+ gltf.set_binary_blob(blob)
774
+
775
+ # glb = b"".join(gltf.save_to_bytes())
776
+ gltf.save(path)
777
+
778
+
779
+ def write_obj(self, path):
780
+ """write the mesh in obj format. Will also write the texture and mtl files.
781
+
782
+ Args:
783
+ path (str): path to write.
784
+ """
785
+
786
+ mtl_path = path.replace(".obj", ".mtl")
787
+ albedo_path = path.replace(".obj", "_albedo.png")
788
+ metallic_path = path.replace(".obj", "_metallic.png")
789
+ roughness_path = path.replace(".obj", "_roughness.png")
790
+
791
+ v_np = self.v.detach().cpu().numpy()
792
+ vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
793
+ vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
794
+ f_np = self.f.detach().cpu().numpy()
795
+ ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
796
+ fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
797
+
798
+ with open(path, "w") as fp:
799
+ fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
800
+
801
+ for v in v_np:
802
+ fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
803
+
804
+ if vt_np is not None:
805
+ for v in vt_np:
806
+ fp.write(f"vt {v[0]} {1 - v[1]} \n")
807
+
808
+ if vn_np is not None:
809
+ for v in vn_np:
810
+ fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
811
+
812
+ fp.write(f"usemtl defaultMat \n")
813
+ for i in range(len(f_np)):
814
+ fp.write(
815
+ f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
816
+ {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
817
+ {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
818
+ )
819
+
820
+ with open(mtl_path, "w") as fp:
821
+ fp.write(f"newmtl defaultMat \n")
822
+ fp.write(f"Ka 1 1 1 \n")
823
+ fp.write(f"Kd 1 1 1 \n")
824
+ fp.write(f"Ks 0 0 0 \n")
825
+ fp.write(f"Tr 1 \n")
826
+ fp.write(f"illum 1 \n")
827
+ fp.write(f"Ns 0 \n")
828
+ if self.albedo is not None:
829
+ fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
830
+ if self.metallicRoughness is not None:
831
+ # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
832
+ fp.write(f"map_Pm {os.path.basename(metallic_path)} \n")
833
+ fp.write(f"map_Pr {os.path.basename(roughness_path)} \n")
834
+
835
+ if self.albedo is not None:
836
+ albedo = self.albedo.detach().cpu().numpy()
837
+ albedo = (albedo * 255).astype(np.uint8)
838
+ cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
839
+
840
+ if self.metallicRoughness is not None:
841
+ metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
842
+ metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
843
+ cv2.imwrite(metallic_path, metallicRoughness[..., 2])
844
+ cv2.imwrite(roughness_path, metallicRoughness[..., 1])
845
+
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from model.crm.model import CRM
model/archs/__init__.py ADDED
File without changes
model/archs/decoders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
model/archs/decoders/shape_texture_net.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class TetTexNet(nn.Module):
7
+ def __init__(self, plane_reso=64, padding=0.1, fea_concat=True):
8
+ super().__init__()
9
+ # self.c_dim = c_dim
10
+ self.plane_reso = plane_reso
11
+ self.padding = padding
12
+ self.fea_concat = fea_concat
13
+
14
+ def forward(self, rolled_out_feature, query):
15
+ # rolled_out_feature: rolled-out triplane feature
16
+ # query: queried xyz coordinates (should be scaled consistently to ptr cloud)
17
+
18
+ plane_reso = self.plane_reso
19
+
20
+ triplane_feature = dict()
21
+ triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso]
22
+ triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso]
23
+ triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:]
24
+
25
+ query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy')
26
+ query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz')
27
+ query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx')
28
+
29
+ if self.fea_concat:
30
+ query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1)
31
+ else:
32
+ query_feature = query_feature_xy + query_feature_yz + query_feature_zx
33
+
34
+ output = query_feature.permute(0, 2, 1)
35
+
36
+ return output
37
+
38
+ # uses values from plane_feature and pixel locations from vgrid to interpolate feature
39
+ def sample_plane_feature(self, query, plane_feature, plane):
40
+ # CYF note:
41
+ # for pretraining, query are uniformly sampled positions w.i. [-scale, scale]
42
+ # for training, query are essentially tetrahedra grid vertices, which are
43
+ # also within [-scale, scale] in the current version!
44
+ # xy range [-scale, scale]
45
+ if plane == 'xy':
46
+ xy = query[:, :, [0, 1]]
47
+ elif plane == 'yz':
48
+ xy = query[:, :, [1, 2]]
49
+ elif plane == 'zx':
50
+ xy = query[:, :, [2, 0]]
51
+ else:
52
+ raise ValueError("Error! Invalid plane type!")
53
+
54
+ xy = xy[:, :, None].float()
55
+ # not seem necessary to rescale the grid, because from
56
+ # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html,
57
+ # it specifies sampling locations normalized by plane_feature's spatial dimension,
58
+ # which is within [-scale, scale] as specified by encoder's calling of coordinate2index()
59
+ vgrid = 1.0 * xy
60
+ sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
61
+
62
+ return sampled_feat
model/archs/mlp_head.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class SdfMlp(nn.Module):
6
+ def __init__(self, input_dim, hidden_dim=512, bias=True):
7
+ super().__init__()
8
+ self.input_dim = input_dim
9
+ self.hidden_dim = hidden_dim
10
+
11
+ self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
12
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias)
13
+ self.fc3 = nn.Linear(hidden_dim, 4, bias=bias)
14
+
15
+
16
+ def forward(self, input):
17
+ x = F.relu(self.fc1(input))
18
+ x = F.relu(self.fc2(x))
19
+ out = self.fc3(x)
20
+ return out
21
+
22
+
23
+ class RgbMlp(nn.Module):
24
+ def __init__(self, input_dim, hidden_dim=512, bias=True):
25
+ super().__init__()
26
+ self.input_dim = input_dim
27
+ self.hidden_dim = hidden_dim
28
+
29
+ self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
30
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias)
31
+ self.fc3 = nn.Linear(hidden_dim, 3, bias=bias)
32
+
33
+ def forward(self, input):
34
+ x = F.relu(self.fc1(input))
35
+ x = F.relu(self.fc2(x))
36
+ out = self.fc3(x)
37
+
38
+ return out
39
+
40
+
model/archs/unet.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Codes are from:
3
+ https://github.com/jaxony/unet-pytorch/blob/master/model.py
4
+ '''
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from diffusers import UNet2DModel
9
+ import einops
10
+ class UNetPP(nn.Module):
11
+ '''
12
+ Wrapper for UNet in diffusers
13
+ '''
14
+ def __init__(self, in_channels):
15
+ super(UNetPP, self).__init__()
16
+ self.in_channels = in_channels
17
+ self.unet = UNet2DModel(
18
+ sample_size=[256, 256*3],
19
+ in_channels=in_channels,
20
+ out_channels=32,
21
+ layers_per_block=2,
22
+ block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4),
23
+ down_block_types=(
24
+ "DownBlock2D",
25
+ "DownBlock2D",
26
+ "DownBlock2D",
27
+ "AttnDownBlock2D",
28
+ "AttnDownBlock2D",
29
+ "AttnDownBlock2D",
30
+ "DownBlock2D",
31
+ ),
32
+ up_block_types=(
33
+ "UpBlock2D",
34
+ "AttnUpBlock2D",
35
+ "AttnUpBlock2D",
36
+ "AttnUpBlock2D",
37
+ "UpBlock2D",
38
+ "UpBlock2D",
39
+ "UpBlock2D",
40
+ ),
41
+ )
42
+
43
+ self.unet.enable_xformers_memory_efficient_attention()
44
+ if in_channels > 12:
45
+ self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3]))
46
+
47
+ def forward(self, x, t=256):
48
+ learned_plane = self.learned_plane
49
+ if x.shape[1] < self.in_channels:
50
+ learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device)
51
+ x = torch.cat([x, learned_plane], dim = 1)
52
+ return self.unet(x, t).sample
53
+
model/crm/model.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ import numpy as np
6
+
7
+
8
+ from pathlib import Path
9
+ import cv2
10
+ import trimesh
11
+ import nvdiffrast.torch as dr
12
+
13
+ from model.archs.decoders.shape_texture_net import TetTexNet
14
+ from model.archs.unet import UNetPP
15
+ from util.renderer import Renderer
16
+ from model.archs.mlp_head import SdfMlp, RgbMlp
17
+ import xatlas
18
+
19
+
20
+ class Dummy:
21
+ pass
22
+
23
+ class CRM(nn.Module):
24
+ def __init__(self, specs):
25
+ super(CRM, self).__init__()
26
+
27
+ self.specs = specs
28
+ # configs
29
+ input_specs = specs["Input"]
30
+ self.input = Dummy()
31
+ self.input.scale = input_specs['scale']
32
+ self.input.resolution = input_specs['resolution']
33
+ self.tet_grid_size = input_specs['tet_grid_size']
34
+ self.camera_angle_num = input_specs['camera_angle_num']
35
+
36
+ self.arch = Dummy()
37
+ self.arch.fea_concat = specs["ArchSpecs"]["fea_concat"]
38
+ self.arch.mlp_bias = specs["ArchSpecs"]["mlp_bias"]
39
+
40
+ self.dec = Dummy()
41
+ self.dec.c_dim = specs["DecoderSpecs"]["c_dim"]
42
+ self.dec.plane_resolution = specs["DecoderSpecs"]["plane_resolution"]
43
+
44
+ self.geo_type = specs["Train"].get("geo_type", "flex") # "dmtet" or "flex"
45
+
46
+ self.unet2 = UNetPP(in_channels=self.dec.c_dim)
47
+
48
+ mlp_chnl_s = 3 if self.arch.fea_concat else 1 # 3 for queried triplane feature concatenation
49
+ self.decoder = TetTexNet(plane_reso=self.dec.plane_resolution, fea_concat=self.arch.fea_concat)
50
+
51
+ if self.geo_type == "flex":
52
+ self.weightMlp = nn.Sequential(
53
+ nn.Linear(mlp_chnl_s * 32 * 8, 512),
54
+ nn.SiLU(),
55
+ nn.Linear(512, 21))
56
+
57
+ self.sdfMlp = SdfMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias)
58
+ self.rgbMlp = RgbMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias)
59
+ self.renderer = Renderer(tet_grid_size=self.tet_grid_size, camera_angle_num=self.camera_angle_num,
60
+ scale=self.input.scale, geo_type = self.geo_type)
61
+
62
+
63
+ self.spob = True if specs['Pretrain']['mode'] is None else False # whether to add sphere
64
+ self.radius = specs['Pretrain']['radius'] # used when spob
65
+
66
+ self.denoising = True
67
+ from diffusers import DDIMScheduler
68
+ self.scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
69
+
70
+ def decode(self, data, triplane_feature2):
71
+ if self.geo_type == "flex":
72
+ tet_verts = self.renderer.flexicubes.verts.unsqueeze(0)
73
+ tet_indices = self.renderer.flexicubes.indices
74
+
75
+ dec_verts = self.decoder(triplane_feature2, tet_verts)
76
+ out = self.sdfMlp(dec_verts)
77
+
78
+ weight = None
79
+ if self.geo_type == "flex":
80
+ grid_feat = torch.index_select(input=dec_verts, index=self.renderer.flexicubes.indices.reshape(-1),dim=1)
81
+ grid_feat = grid_feat.reshape(dec_verts.shape[0], self.renderer.flexicubes.indices.shape[0], self.renderer.flexicubes.indices.shape[1] * dec_verts.shape[-1])
82
+ weight = self.weightMlp(grid_feat)
83
+ weight = weight * 0.1
84
+
85
+ pred_sdf, deformation = out[..., 0], out[..., 1:]
86
+ if self.spob:
87
+ pred_sdf = pred_sdf + self.radius - torch.sqrt((tet_verts**2).sum(-1))
88
+
89
+ _, verts, faces = self.renderer(data, pred_sdf, deformation, tet_verts, tet_indices, weight= weight)
90
+ return verts[0].unsqueeze(0), faces[0].int()
91
+
92
+ def export_mesh(self, data, out_dir, ind, device=None, tri_fea_2 = None):
93
+ verts = data['verts']
94
+ faces = data['faces']
95
+
96
+ dec_verts = self.decoder(tri_fea_2, verts.unsqueeze(0))
97
+ colors = self.rgbMlp(dec_verts).squeeze().detach().cpu().numpy()
98
+ # Expect predicted colors value range from [-1, 1]
99
+ colors = (colors * 0.5 + 0.5).clip(0, 1)
100
+
101
+ verts = verts.squeeze().cpu().numpy()
102
+ faces = faces[..., [2, 1, 0]].squeeze().cpu().numpy()
103
+
104
+ # export the final mesh
105
+ with torch.no_grad():
106
+ mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault...
107
+ mesh.export(out_dir / f'{ind}.obj')
108
+
109
+ def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
110
+
111
+ mesh_v = data['verts'].squeeze().cpu().numpy()
112
+ mesh_pos_idx = data['faces'].squeeze().cpu().numpy()
113
+
114
+ def interpolate(attr, rast, attr_idx, rast_db=None):
115
+ return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db,
116
+ diff_attrs=None if rast_db is None else 'all')
117
+
118
+ vmapping, indices, uvs = xatlas.parametrize(mesh_v, mesh_pos_idx)
119
+
120
+ mesh_v = torch.tensor(mesh_v, dtype=torch.float32, device=device)
121
+ mesh_pos_idx = torch.tensor(mesh_pos_idx, dtype=torch.int64, device=device)
122
+
123
+ # Convert to tensors
124
+ indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
125
+
126
+ uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
127
+ mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
128
+ # mesh_v_tex. ture
129
+ uv_clip = uvs[None, ...] * 2.0 - 1.0
130
+
131
+ # pad to four component coordinate
132
+ uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
133
+
134
+ # rasterize
135
+ rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), res)
136
+
137
+ # Interpolate world space position
138
+ gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
139
+ mask = rast[..., 3:4] > 0
140
+
141
+ # return uvs, mesh_tex_idx, gb_pos, mask
142
+ gb_pos_unsqz = gb_pos.view(-1, 3)
143
+ mask_unsqz = mask.view(-1)
144
+ tex_unsqz = torch.zeros_like(gb_pos_unsqz) + 1
145
+
146
+ gb_mask_pos = gb_pos_unsqz[mask_unsqz]
147
+
148
+ gb_mask_pos = gb_mask_pos[None, ]
149
+
150
+ with torch.no_grad():
151
+
152
+ dec_verts = self.decoder(tri_fea_2, gb_mask_pos)
153
+ colors = self.rgbMlp(dec_verts).squeeze()
154
+
155
+ # Expect predicted colors value range from [-1, 1]
156
+ lo, hi = (-1, 1)
157
+ colors = (colors - lo) * (255 / (hi - lo))
158
+ colors = colors.clip(0, 255)
159
+
160
+ tex_unsqz[mask_unsqz] = colors
161
+
162
+ tex = tex_unsqz.view(res + (3,))
163
+
164
+ verts = mesh_v.squeeze().cpu().numpy()
165
+ faces = mesh_pos_idx[..., [2, 1, 0]].squeeze().cpu().numpy()
166
+ # faces = mesh_pos_idx
167
+ # faces = faces.detach().cpu().numpy()
168
+ # faces = faces[..., [2, 1, 0]]
169
+ indices = indices[..., [2, 1, 0]]
170
+
171
+ # xatlas.export(f"{out_dir}/{ind}.obj", verts[vmapping], indices, uvs)
172
+ matname = f'{out_dir}.mtl'
173
+ # matname = f'{out_dir}/{ind}.mtl'
174
+ fid = open(matname, 'w')
175
+ fid.write('newmtl material_0\n')
176
+ fid.write('Kd 1 1 1\n')
177
+ fid.write('Ka 1 1 1\n')
178
+ # fid.write('Ks 0 0 0\n')
179
+ fid.write('Ks 0.4 0.4 0.4\n')
180
+ fid.write('Ns 10\n')
181
+ fid.write('illum 2\n')
182
+ fid.write(f'map_Kd {out_dir.split("/")[-1]}.png\n')
183
+ fid.close()
184
+
185
+ fid = open(f'{out_dir}.obj', 'w')
186
+ # fid = open(f'{out_dir}/{ind}.obj', 'w')
187
+ fid.write('mtllib %s.mtl\n' % out_dir.split("/")[-1])
188
+
189
+ for pidx, p in enumerate(verts):
190
+ pp = p
191
+ fid.write('v %f %f %f\n' % (pp[0], pp[2], - pp[1]))
192
+
193
+ for pidx, p in enumerate(uvs):
194
+ pp = p
195
+ fid.write('vt %f %f\n' % (pp[0], 1 - pp[1]))
196
+
197
+ fid.write('usemtl material_0\n')
198
+ for i, f in enumerate(faces):
199
+ f1 = f + 1
200
+ f2 = indices[i] + 1
201
+ fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
202
+ fid.close()
203
+
204
+ img = np.asarray(tex.data.cpu().numpy(), dtype=np.float32)
205
+ mask = np.sum(img.astype(float), axis=-1, keepdims=True)
206
+ mask = (mask <= 3.0).astype(float)
207
+ kernel = np.ones((3, 3), 'uint8')
208
+ dilate_img = cv2.dilate(img, kernel, iterations=1)
209
+ img = img * (1 - mask) + dilate_img * mask
210
+ img = img.clip(0, 255).astype(np.uint8)
211
+
212
+ cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]])
213
+ # cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]])
pipelines.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from libs.base_utils import do_resize_content
3
+ from imagedream.ldm.util import (
4
+ instantiate_from_config,
5
+ get_obj_from_str,
6
+ )
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+
12
+ class TwoStagePipeline(object):
13
+ def __init__(
14
+ self,
15
+ stage1_model_config,
16
+ stage2_model_config,
17
+ stage1_sampler_config,
18
+ stage2_sampler_config,
19
+ device="cuda",
20
+ dtype=torch.float16,
21
+ resize_rate=1,
22
+ ) -> None:
23
+ """
24
+ only for two stage generate process.
25
+ - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
26
+ - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
27
+ """
28
+ self.resize_rate = resize_rate
29
+
30
+ self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
31
+ self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False)
32
+ self.stage1_model = self.stage1_model.to(device).to(dtype)
33
+
34
+ self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model)
35
+ sd = torch.load(stage2_model_config.resume, map_location="cpu")
36
+ self.stage2_model.load_state_dict(sd, strict=False)
37
+ self.stage2_model = self.stage2_model.to(device).to(dtype)
38
+
39
+ self.stage1_model.device = device
40
+ self.stage2_model.device = device
41
+ self.device = device
42
+ self.dtype = dtype
43
+ self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)(
44
+ self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params
45
+ )
46
+ self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)(
47
+ self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params
48
+ )
49
+
50
+ def stage1_sample(
51
+ self,
52
+ pixel_img,
53
+ prompt="3D assets",
54
+ neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.",
55
+ step=50,
56
+ scale=5,
57
+ ddim_eta=0.0,
58
+ ):
59
+ if type(pixel_img) == str:
60
+ pixel_img = Image.open(pixel_img)
61
+
62
+ if isinstance(pixel_img, Image.Image):
63
+ if pixel_img.mode == "RGBA":
64
+ background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
65
+ pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
66
+ else:
67
+ pixel_img = pixel_img.convert("RGB")
68
+ else:
69
+ raise
70
+ uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device)
71
+ stage1_images = self.stage1_sampler.i2i(
72
+ self.stage1_sampler.model,
73
+ self.stage1_sampler.size,
74
+ prompt,
75
+ uc=uc,
76
+ sampler=self.stage1_sampler.sampler,
77
+ ip=pixel_img,
78
+ step=step,
79
+ scale=scale,
80
+ batch_size=self.stage1_sampler.batch_size,
81
+ ddim_eta=ddim_eta,
82
+ dtype=self.stage1_sampler.dtype,
83
+ device=self.stage1_sampler.device,
84
+ camera=self.stage1_sampler.camera,
85
+ num_frames=self.stage1_sampler.num_frames,
86
+ pixel_control=(self.stage1_sampler.mode == "pixel"),
87
+ transform=self.stage1_sampler.image_transform,
88
+ offset_noise=self.stage1_sampler.offset_noise,
89
+ )
90
+
91
+ stage1_images = [Image.fromarray(img) for img in stage1_images]
92
+ stage1_images.pop(self.stage1_sampler.ref_position)
93
+ return stage1_images
94
+
95
+ def stage2_sample(self, pixel_img, stage1_images):
96
+ if type(pixel_img) == str:
97
+ pixel_img = Image.open(pixel_img)
98
+
99
+ if isinstance(pixel_img, Image.Image):
100
+ if pixel_img.mode == "RGBA":
101
+ background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
102
+ pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
103
+ else:
104
+ pixel_img = pixel_img.convert("RGB")
105
+ else:
106
+ raise
107
+ stage2_images = self.stage2_sampler.i2iStage2(
108
+ self.stage2_sampler.model,
109
+ self.stage2_sampler.size,
110
+ "3D assets",
111
+ self.stage2_sampler.uc,
112
+ self.stage2_sampler.sampler,
113
+ pixel_images=stage1_images,
114
+ ip=pixel_img,
115
+ step=50,
116
+ scale=5,
117
+ batch_size=self.stage2_sampler.batch_size,
118
+ ddim_eta=0.0,
119
+ dtype=self.stage2_sampler.dtype,
120
+ device=self.stage2_sampler.device,
121
+ camera=self.stage2_sampler.camera,
122
+ num_frames=self.stage2_sampler.num_frames,
123
+ pixel_control=(self.stage2_sampler.mode == "pixel"),
124
+ transform=self.stage2_sampler.image_transform,
125
+ offset_noise=self.stage2_sampler.offset_noise,
126
+ )
127
+ stage2_images = [Image.fromarray(img) for img in stage2_images]
128
+ return stage2_images
129
+
130
+ def set_seed(self, seed):
131
+ self.stage1_sampler.seed = seed
132
+ self.stage2_sampler.seed = seed
133
+
134
+ def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
135
+ pixel_img = do_resize_content(pixel_img, self.resize_rate)
136
+ stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
137
+ stage2_images = self.stage2_sample(pixel_img, stage1_images)
138
+
139
+ return {
140
+ "ref_img": pixel_img,
141
+ "stage1_images": stage1_images,
142
+ "stage2_images": stage2_images,
143
+ }
144
+
145
+
146
+ if __name__ == "__main__":
147
+
148
+ stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config
149
+ stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config
150
+ stage2_sampler_config = stage2_config.sampler
151
+ stage1_sampler_config = stage1_config.sampler
152
+
153
+ stage1_model_config = stage1_config.models
154
+ stage2_model_config = stage2_config.models
155
+
156
+ pipeline = TwoStagePipeline(
157
+ stage1_model_config,
158
+ stage2_model_config,
159
+ stage1_sampler_config,
160
+ stage2_sampler_config,
161
+ )
162
+
163
+ img = Image.open("assets/astronaut.png")
164
+ rt_dict = pipeline(img)
165
+ stage1_images = rt_dict["stage1_images"]
166
+ stage2_images = rt_dict["stage2_images"]
167
+ np_imgs = np.concatenate(stage1_images, 1)
168
+ np_xyzs = np.concatenate(stage2_images, 1)
169
+ Image.fromarray(np_imgs).save("pixel_images.png")
170
+ Image.fromarray(np_xyzs).save("xyz_images.png")
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ huggingface-hub
3
+ einops==0.7.0
4
+ Pillow==10.1.0
5
+ transformers==4.27.1
6
+ open-clip-torch==2.7.0
7
+ opencv-contrib-python-headless==4.9.0.80
8
+ opencv-python-headless==4.9.0.80
9
+ xformers
10
+ omegaconf
11
+ rembg
12
+ nvdiffrast
13
+ pygltflib
14
+ kiui
util/__init__.py ADDED
File without changes
util/flexicubes.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ import torch
9
+ from util.tables import *
10
+
11
+ __all__ = [
12
+ 'FlexiCubes'
13
+ ]
14
+
15
+
16
+ class FlexiCubes:
17
+ """
18
+ This class implements the FlexiCubes method for extracting meshes from scalar fields.
19
+ It maintains a series of lookup tables and indices to support the mesh extraction process.
20
+ FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
21
+ the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
22
+ the surface representation through gradient-based optimization.
23
+
24
+ During instantiation, the class loads DMC tables from a file and transforms them into
25
+ PyTorch tensors on the specified device.
26
+
27
+ Attributes:
28
+ device (str): Specifies the computational device (default is "cuda").
29
+ dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
30
+ associated with each dual vertex in 256 Marching Cubes (MC) configurations.
31
+ num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
32
+ the 256 MC configurations.
33
+ check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
34
+ of the DMC configurations.
35
+ tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
36
+ quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
37
+ along one diagonal.
38
+ quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
39
+ two triangles along the other diagonal.
40
+ quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
41
+ during training by connecting all edges to their midpoints.
42
+ cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
43
+ eight corners in 3D space, ordered starting from the origin (0,0,0),
44
+ moving along the x-axis, then y-axis, and finally z-axis.
45
+ Used as a blueprint for generating a voxel grid.
46
+ cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
47
+ to retrieve the case id.
48
+ cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
49
+ Used to retrieve edge vertices in DMC.
50
+ edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
51
+ their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
52
+ first edge is oriented along the x-axis.
53
+ dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
54
+ across four adjacent cubes to the shared faces of these cubes. For instance,
55
+ dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
56
+ the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
57
+ This tensor is only utilized during isosurface tetrahedralization.
58
+ adj_pairs (torch.Tensor):
59
+ A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
60
+ qef_reg_scale (float):
61
+ The scaling factor applied to the regularization loss to prevent issues with singularity
62
+ when solving the QEF. This parameter is only used when a 'grad_func' is specified.
63
+ weight_scale (float):
64
+ The scale of weights in FlexiCubes. Should be between 0 and 1.
65
+ """
66
+
67
+ def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
68
+
69
+ self.device = device
70
+ self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
71
+ self.num_vd_table = torch.tensor(num_vd_table,
72
+ dtype=torch.long, device=device, requires_grad=False)
73
+ self.check_table = torch.tensor(
74
+ check_table,
75
+ dtype=torch.long, device=device, requires_grad=False)
76
+
77
+ self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
78
+ self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
79
+ self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
80
+ self.quad_split_train = torch.tensor(
81
+ [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
82
+
83
+ self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
84
+ 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
85
+ self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
86
+ self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
87
+ 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
88
+
89
+ self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
90
+ dtype=torch.long, device=device)
91
+ self.dir_faces_table = torch.tensor([
92
+ [[5, 4], [3, 2], [4, 5], [2, 3]],
93
+ [[5, 4], [1, 0], [4, 5], [0, 1]],
94
+ [[3, 2], [1, 0], [2, 3], [0, 1]]
95
+ ], dtype=torch.long, device=device)
96
+ self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
97
+ self.qef_reg_scale = qef_reg_scale
98
+ self.weight_scale = weight_scale
99
+
100
+ def construct_voxel_grid(self, res):
101
+ """
102
+ Generates a voxel grid based on the specified resolution.
103
+
104
+ Args:
105
+ res (int or list[int]): The resolution of the voxel grid. If an integer
106
+ is provided, it is used for all three dimensions. If a list or tuple
107
+ of 3 integers is provided, they define the resolution for the x,
108
+ y, and z dimensions respectively.
109
+
110
+ Returns:
111
+ (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
112
+ cube corners (index into vertices) of the constructed voxel grid.
113
+ The vertices are centered at the origin, with the length of each
114
+ dimension in the grid being one.
115
+ """
116
+ base_cube_f = torch.arange(8).to(self.device)
117
+ if isinstance(res, int):
118
+ res = (res, res, res)
119
+ voxel_grid_template = torch.ones(res, device=self.device)
120
+
121
+ res = torch.tensor([res], dtype=torch.float, device=self.device)
122
+ coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
123
+ verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
124
+ cubes = (base_cube_f.unsqueeze(0) +
125
+ torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
126
+
127
+ verts_rounded = torch.round(verts * 10**5) / (10**5)
128
+ verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
129
+ cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
130
+
131
+ return verts_unique - 0.5, cubes
132
+
133
+ def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
134
+ gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
135
+ r"""
136
+ Main function for mesh extraction from scalar field using FlexiCubes. This function converts
137
+ discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
138
+ to triangle or tetrahedral meshes using a differentiable operation as described in
139
+ `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
140
+ mesh quality and geometric fidelity by adjusting the surface representation based on gradient
141
+ optimization. The output surface is differentiable with respect to the input vertex positions,
142
+ scalar field values, and weight parameters.
143
+
144
+ If you intend to extract a surface mesh from a fixed Signed Distance Field without the
145
+ optimization of parameters, it is suggested to provide the "grad_func" which should
146
+ return the surface gradient at any given 3D position. When grad_func is provided, the process
147
+ to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
148
+ described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
149
+ Please note, this approach is non-differentiable.
150
+
151
+ For more details and example usage in optimization, refer to the
152
+ `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
153
+
154
+ Args:
155
+ x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
156
+ s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
157
+ denote that the corresponding vertex resides inside the isosurface. This affects
158
+ the directions of the extracted triangle faces and volume to be tetrahedralized.
159
+ cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
160
+ res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
161
+ is used for all three dimensions. If a list or tuple of 3 integers is provided, they
162
+ specify the resolution for the x, y, and z dimensions respectively.
163
+ beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
164
+ vertices positioning. Defaults to uniform value for all edges.
165
+ alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
166
+ vertices positioning. Defaults to uniform value for all vertices.
167
+ gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
168
+ quadrilaterals into triangles. Defaults to uniform value for all cubes.
169
+ training (bool, optional): If set to True, applies differentiable quad splitting for
170
+ training. Defaults to False.
171
+ output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
172
+ outputs a triangular mesh. Defaults to False.
173
+ grad_func (callable, optional): A function to compute the surface gradient at specified
174
+ 3D positions (input: Nx3 positions). The function should return gradients as an Nx3
175
+ tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
176
+
177
+ Returns:
178
+ (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
179
+ - Vertices for the extracted triangular/tetrahedral mesh.
180
+ - Faces for the extracted triangular/tetrahedral mesh.
181
+ - Regularizer L_dev, computed per dual vertex.
182
+
183
+ .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
184
+ https://research.nvidia.com/labs/toronto-ai/flexicubes/
185
+ .. _Manifold Dual Contouring:
186
+ https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
187
+ """
188
+
189
+ surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
190
+ if surf_cubes.sum() == 0:
191
+ return torch.zeros(
192
+ (0, 3),
193
+ device=self.device), torch.zeros(
194
+ (0, 4),
195
+ dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
196
+ (0, 3),
197
+ dtype=torch.long, device=self.device), torch.zeros(
198
+ (0),
199
+ device=self.device)
200
+ beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
201
+
202
+ case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
203
+
204
+ surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
205
+
206
+ vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
207
+ x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
208
+ vertices, faces, s_edges, edge_indices = self._triangulate(
209
+ s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
210
+ if not output_tetmesh:
211
+ return vertices, faces, L_dev
212
+ else:
213
+ vertices, tets = self._tetrahedralize(
214
+ x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
215
+ surf_cubes, training)
216
+ return vertices, tets, L_dev
217
+
218
+ def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
219
+ """
220
+ Regularizer L_dev as in Equation 8
221
+ """
222
+ dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
223
+ mean_l2 = torch.zeros_like(vd[:, 0])
224
+ mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
225
+ mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
226
+ return mad
227
+
228
+ def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
229
+ """
230
+ Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
231
+ """
232
+ n_cubes = surf_cubes.shape[0]
233
+
234
+ if beta_fx12 is not None:
235
+ beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
236
+ else:
237
+ beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
238
+
239
+ if alpha_fx8 is not None:
240
+ alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
241
+ else:
242
+ alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
243
+
244
+ if gamma_f is not None:
245
+ gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
246
+ else:
247
+ gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
248
+
249
+ return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
250
+
251
+ @torch.no_grad()
252
+ def _get_case_id(self, occ_fx8, surf_cubes, res):
253
+ """
254
+ Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
255
+ ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
256
+ supplementary material. It should be noted that this function assumes a regular grid.
257
+ """
258
+ case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
259
+
260
+ problem_config = self.check_table.to(self.device)[case_ids]
261
+ to_check = problem_config[..., 0] == 1
262
+ problem_config = problem_config[to_check]
263
+ if not isinstance(res, (list, tuple)):
264
+ res = [res, res, res]
265
+
266
+ # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
267
+ # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
268
+ # This allows efficient checking on adjacent cubes.
269
+ problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
270
+ vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
271
+ vol_idx_problem = vol_idx[surf_cubes][to_check]
272
+ problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
273
+ vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
274
+
275
+ within_range = (
276
+ vol_idx_problem_adj[..., 0] >= 0) & (
277
+ vol_idx_problem_adj[..., 0] < res[0]) & (
278
+ vol_idx_problem_adj[..., 1] >= 0) & (
279
+ vol_idx_problem_adj[..., 1] < res[1]) & (
280
+ vol_idx_problem_adj[..., 2] >= 0) & (
281
+ vol_idx_problem_adj[..., 2] < res[2])
282
+
283
+ vol_idx_problem = vol_idx_problem[within_range]
284
+ vol_idx_problem_adj = vol_idx_problem_adj[within_range]
285
+ problem_config = problem_config[within_range]
286
+ problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
287
+ vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
288
+ # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
289
+ to_invert = (problem_config_adj[..., 0] == 1)
290
+ idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
291
+ case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
292
+ return case_ids
293
+
294
+ @torch.no_grad()
295
+ def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
296
+ """
297
+ Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
298
+ can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
299
+ and marks the cube edges with this index.
300
+ """
301
+ occ_n = s_n < 0
302
+ all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
303
+ unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
304
+
305
+ unique_edges = unique_edges.long()
306
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
307
+
308
+ surf_edges_mask = mask_edges[_idx_map]
309
+ counts = counts[_idx_map]
310
+
311
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
312
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
313
+ # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
314
+ # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
315
+ idx_map = mapping[_idx_map]
316
+ surf_edges = unique_edges[mask_edges]
317
+ return surf_edges, idx_map, counts, surf_edges_mask
318
+
319
+ @torch.no_grad()
320
+ def _identify_surf_cubes(self, s_n, cube_fx8):
321
+ """
322
+ Identifies grid cubes that intersect with the underlying surface by checking if the signs at
323
+ all corners are not identical.
324
+ """
325
+ occ_n = s_n < 0
326
+ occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
327
+ _occ_sum = torch.sum(occ_fx8, -1)
328
+ surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
329
+ return surf_cubes, occ_fx8
330
+
331
+ def _linear_interp(self, edges_weight, edges_x):
332
+ """
333
+ Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
334
+ """
335
+ edge_dim = edges_weight.dim() - 2
336
+ assert edges_weight.shape[edge_dim] == 2
337
+ edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
338
+ torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
339
+ denominator = edges_weight.sum(edge_dim)
340
+ ue = (edges_x * edges_weight).sum(edge_dim) / denominator
341
+ return ue
342
+
343
+ def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
344
+ p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
345
+ norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
346
+ c_bx3 = c_bx3.reshape(-1, 3)
347
+ A = norm_bxnx3
348
+ B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
349
+
350
+ A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
351
+ B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
352
+ A = torch.cat([A, A_reg], 1)
353
+ B = torch.cat([B, B_reg], 1)
354
+ dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
355
+ return dual_verts
356
+
357
+ def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
358
+ """
359
+ Computes the location of dual vertices as described in Section 4.2
360
+ """
361
+ alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
362
+ surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
363
+ surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
364
+ zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
365
+
366
+ idx_map = idx_map.reshape(-1, 12)
367
+ num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
368
+ edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
369
+
370
+ total_num_vd = 0
371
+ vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
372
+ if grad_func is not None:
373
+ normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
374
+ vd = []
375
+ for num in torch.unique(num_vd):
376
+ cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
377
+ curr_num_vd = cur_cubes.sum() * num
378
+ curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
379
+ curr_edge_group_to_vd = torch.arange(
380
+ curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
381
+ total_num_vd += curr_num_vd
382
+ curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
383
+ cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
384
+
385
+ curr_mask = (curr_edge_group != -1)
386
+ edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
387
+ edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
388
+ edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
389
+ vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
390
+ vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
391
+
392
+ if grad_func is not None:
393
+ with torch.no_grad():
394
+ cube_e_verts_idx = idx_map[cur_cubes]
395
+ curr_edge_group[~curr_mask] = 0
396
+
397
+ verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
398
+ verts_group_idx[verts_group_idx == -1] = 0
399
+ verts_group_pos = torch.index_select(
400
+ input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
401
+ v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
402
+ curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
403
+ verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
404
+
405
+ normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
406
+ -1, num.item(), 7,
407
+ 3)
408
+ curr_mask = curr_mask.squeeze(2)
409
+ vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
410
+ verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
411
+ edge_group = torch.cat(edge_group)
412
+ edge_group_to_vd = torch.cat(edge_group_to_vd)
413
+ edge_group_to_cube = torch.cat(edge_group_to_cube)
414
+ vd_num_edges = torch.cat(vd_num_edges)
415
+ vd_gamma = torch.cat(vd_gamma)
416
+
417
+ if grad_func is not None:
418
+ vd = torch.cat(vd)
419
+ L_dev = torch.zeros([1], device=self.device)
420
+ else:
421
+ vd = torch.zeros((total_num_vd, 3), device=self.device)
422
+ beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
423
+
424
+ idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
425
+
426
+ x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
427
+ s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
428
+
429
+ zero_crossing_group = torch.index_select(
430
+ input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
431
+
432
+ alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
433
+ index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
434
+ ue_group = self._linear_interp(s_group * alpha_group, x_group)
435
+
436
+ beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
437
+ index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
438
+ beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
439
+ vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
440
+ L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
441
+
442
+ v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
443
+
444
+ vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
445
+ 12 + edge_group, src=v_idx[edge_group_to_vd])
446
+
447
+ return vd, L_dev, vd_gamma, vd_idx_map
448
+
449
+ def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
450
+ """
451
+ Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
452
+ triangles based on the gamma parameter, as described in Section 4.3.
453
+ """
454
+ with torch.no_grad():
455
+ group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
456
+ group = idx_map.reshape(-1)[group_mask]
457
+ vd_idx = vd_idx_map[group_mask]
458
+ edge_indices, indices = torch.sort(group, stable=True)
459
+ quad_vd_idx = vd_idx[indices].reshape(-1, 4)
460
+
461
+ # Ensure all face directions point towards the positive SDF to maintain consistent winding.
462
+ s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
463
+ flip_mask = s_edges[:, 0] > 0
464
+ quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
465
+ quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
466
+ if grad_func is not None:
467
+ # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
468
+ with torch.no_grad():
469
+ vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
470
+ quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
471
+ gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
472
+ gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
473
+ else:
474
+ quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
475
+ gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
476
+ 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
477
+ gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
478
+ 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
479
+ if not training:
480
+ mask = (gamma_02 > gamma_13).squeeze(1)
481
+ faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
482
+ faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
483
+ faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
484
+ faces = faces.reshape(-1, 3)
485
+ else:
486
+ vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
487
+ vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
488
+ torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
489
+ vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
490
+ torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
491
+ weight_sum = (gamma_02 + gamma_13) + 1e-8
492
+ vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
493
+ weight_sum.unsqueeze(-1)).squeeze(1)
494
+ vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
495
+ vd = torch.cat([vd, vd_center])
496
+ faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
497
+ faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
498
+ return vd, faces, s_edges, edge_indices
499
+
500
+ def _tetrahedralize(
501
+ self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
502
+ surf_cubes, training):
503
+ """
504
+ Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
505
+ """
506
+ occ_n = s_n < 0
507
+ occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
508
+ occ_sum = torch.sum(occ_fx8, -1)
509
+
510
+ inside_verts = x_nx3[occ_n]
511
+ mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
512
+ mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
513
+ """
514
+ For each grid edge connecting two grid vertices with different
515
+ signs, we first form a four-sided pyramid by connecting one
516
+ of the grid vertices with four mesh vertices that correspond
517
+ to the grid edge and then subdivide the pyramid into two tetrahedra
518
+ """
519
+ inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
520
+ s_edges < 0]]
521
+ if not training:
522
+ inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
523
+ else:
524
+ inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
525
+
526
+ tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
527
+ """
528
+ For each grid edge connecting two grid vertices with the
529
+ same sign, the tetrahedron is formed by the two grid vertices
530
+ and two vertices in consecutive adjacent cells
531
+ """
532
+ inside_cubes = (occ_sum == 8)
533
+ inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
534
+ inside_cubes_center_idx = torch.arange(
535
+ inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
536
+
537
+ surface_n_inside_cubes = surf_cubes | inside_cubes
538
+ edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
539
+ dtype=torch.long, device=x_nx3.device) * -1
540
+ surf_cubes = surf_cubes[surface_n_inside_cubes]
541
+ inside_cubes = inside_cubes[surface_n_inside_cubes]
542
+ edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
543
+ edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
544
+
545
+ all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
546
+ unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
547
+ unique_edges = unique_edges.long()
548
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
549
+ mask = mask_edges[_idx_map]
550
+ counts = counts[_idx_map]
551
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
552
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
553
+ idx_map = mapping[_idx_map]
554
+
555
+ group_mask = (counts == 4) & mask
556
+ group = idx_map.reshape(-1)[group_mask]
557
+ edge_indices, indices = torch.sort(group)
558
+ cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
559
+ device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
560
+ edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
561
+ 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
562
+ # Identify the face shared by the adjacent cells.
563
+ cube_idx_4 = cube_idx[indices].reshape(-1, 4)
564
+ edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
565
+ shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
566
+ cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
567
+ # Identify an edge of the face with different signs and
568
+ # select the mesh vertex corresponding to the identified edge.
569
+ case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
570
+ case_ids_expand[surf_cubes] = case_ids
571
+ cases = case_ids_expand[cube_idx_4x2]
572
+ quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
573
+ mask = (quad_edge == -1).sum(-1) == 0
574
+ inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
575
+ tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
576
+
577
+ tets = torch.cat([tets_surface, tets_inside])
578
+ vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
579
+ return vertices, tets
util/flexicubes_geometry.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from util.flexicubes import FlexiCubes # replace later
11
+ # from dmtet import sdf_reg_loss_batch
12
+ import torch.nn.functional as F
13
+
14
+ def get_center_boundary_index(grid_res, device):
15
+ v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
16
+ v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
17
+ center_indices = torch.nonzero(v.reshape(-1))
18
+
19
+ v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
20
+ v[:2, ...] = True
21
+ v[-2:, ...] = True
22
+ v[:, :2, ...] = True
23
+ v[:, -2:, ...] = True
24
+ v[:, :, :2] = True
25
+ v[:, :, -2:] = True
26
+ boundary_indices = torch.nonzero(v.reshape(-1))
27
+ return center_indices, boundary_indices
28
+
29
+ ###############################################################################
30
+ # Geometry interface
31
+ ###############################################################################
32
+ class FlexiCubesGeometry(object):
33
+ def __init__(
34
+ self, grid_res=64, scale=2.0, device='cuda', renderer=None,
35
+ render_type='neural_render', args=None):
36
+ super(FlexiCubesGeometry, self).__init__()
37
+ self.grid_res = grid_res
38
+ self.device = device
39
+ self.args = args
40
+ self.fc = FlexiCubes(device, weight_scale=0.5)
41
+ self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
42
+ if isinstance(scale, list):
43
+ self.verts[:, 0] = self.verts[:, 0] * scale[0]
44
+ self.verts[:, 1] = self.verts[:, 1] * scale[1]
45
+ self.verts[:, 2] = self.verts[:, 2] * scale[1]
46
+ else:
47
+ self.verts = self.verts * scale
48
+
49
+ all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
50
+ self.all_edges = torch.unique(all_edges, dim=0)
51
+
52
+ # Parameters used for fix boundary sdf
53
+ self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
54
+ self.renderer = renderer
55
+ self.render_type = render_type
56
+
57
+ def getAABB(self):
58
+ return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
59
+
60
+ def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
61
+ if indices is None:
62
+ indices = self.indices
63
+
64
+ verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
65
+ beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
66
+ gamma_f=weight_n[:, 20], training=is_training
67
+ )
68
+ return verts, faces, v_reg_loss
69
+
70
+
71
+ def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
72
+ return_value = dict()
73
+ if self.render_type == 'neural_render':
74
+ tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
75
+ mesh_v_nx3.unsqueeze(dim=0),
76
+ mesh_f_fx3.int(),
77
+ camera_mv_bx4x4,
78
+ mesh_v_nx3.unsqueeze(dim=0),
79
+ resolution=resolution,
80
+ device=self.device,
81
+ hierarchical_mask=hierarchical_mask
82
+ )
83
+
84
+ return_value['tex_pos'] = tex_pos
85
+ return_value['mask'] = mask
86
+ return_value['hard_mask'] = hard_mask
87
+ return_value['rast'] = rast
88
+ return_value['v_pos_clip'] = v_pos_clip
89
+ return_value['mask_pyramid'] = mask_pyramid
90
+ return_value['depth'] = depth
91
+ else:
92
+ raise NotImplementedError
93
+
94
+ return return_value
95
+
96
+ def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
97
+ # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
98
+ v_list = []
99
+ f_list = []
100
+ n_batch = v_deformed_bxnx3.shape[0]
101
+ all_render_output = []
102
+ for i_batch in range(n_batch):
103
+ verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
104
+ v_list.append(verts_nx3)
105
+ f_list.append(faces_fx3)
106
+ render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
107
+ all_render_output.append(render_output)
108
+
109
+ # Concatenate all render output
110
+ return_keys = all_render_output[0].keys()
111
+ return_value = dict()
112
+ for k in return_keys:
113
+ value = [v[k] for v in all_render_output]
114
+ return_value[k] = value
115
+ # We can do concatenation outside of the render
116
+ return return_value
util/renderer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import nvdiffrast.torch as dr
5
+ from util.flexicubes_geometry import FlexiCubesGeometry
6
+
7
+ class Renderer(nn.Module):
8
+ def __init__(self, tet_grid_size, camera_angle_num, scale, geo_type):
9
+ super().__init__()
10
+
11
+ self.tet_grid_size = tet_grid_size
12
+ self.camera_angle_num = camera_angle_num
13
+ self.scale = scale
14
+ self.geo_type = geo_type
15
+ self.glctx = dr.RasterizeCudaContext()
16
+
17
+ if self.geo_type == "flex":
18
+ self.flexicubes = FlexiCubesGeometry(grid_res = self.tet_grid_size)
19
+
20
+ def forward(self, data, sdf, deform, verts, tets, training=False, weight = None):
21
+
22
+ results = {}
23
+
24
+ deform = torch.tanh(deform) / self.tet_grid_size * self.scale / 0.95
25
+ if self.geo_type == "flex":
26
+ deform = deform *0.5
27
+
28
+ v_deformed = verts + deform
29
+
30
+ verts_list = []
31
+ faces_list = []
32
+ reg_list = []
33
+ n_shape = verts.shape[0]
34
+ for i in range(n_shape):
35
+ verts_i, faces_i, reg_i = self.flexicubes.get_mesh(v_deformed[i], sdf[i].squeeze(dim=-1),
36
+ with_uv=False, indices=tets, weight_n=weight[i], is_training=training)
37
+
38
+ verts_list.append(verts_i)
39
+ faces_list.append(faces_i)
40
+ reg_list.append(reg_i)
41
+ verts = verts_list
42
+ faces = faces_list
43
+
44
+ flexicubes_surface_reg = torch.cat(reg_list).mean()
45
+ flexicubes_weight_reg = (weight ** 2).mean()
46
+ results["flex_surf_loss"] = flexicubes_surface_reg
47
+ results["flex_weight_loss"] = flexicubes_weight_reg
48
+
49
+ return results, verts, faces