init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +155 -0
- app.py +205 -0
- configs/nf7_v3_SNR_rd_size_stroke.yaml +21 -0
- configs/specs_objaverse_total.json +57 -0
- configs/stage2-v2-snr.yaml +25 -0
- imagedream/__init__.py +1 -0
- imagedream/camera_utils.py +99 -0
- imagedream/configs/sd_v2_base_ipmv.yaml +61 -0
- imagedream/configs/sd_v2_base_ipmv_ch8.yaml +61 -0
- imagedream/configs/sd_v2_base_ipmv_chin8.yaml +61 -0
- imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml +62 -0
- imagedream/configs/sd_v2_base_ipmv_local.yaml +62 -0
- imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml +62 -0
- imagedream/ldm/__init__.py +0 -0
- imagedream/ldm/interface.py +205 -0
- imagedream/ldm/models/__init__.py +0 -0
- imagedream/ldm/models/autoencoder.py +270 -0
- imagedream/ldm/models/diffusion/__init__.py +0 -0
- imagedream/ldm/models/diffusion/ddim.py +430 -0
- imagedream/ldm/modules/__init__.py +0 -0
- imagedream/ldm/modules/attention.py +456 -0
- imagedream/ldm/modules/diffusionmodules/__init__.py +0 -0
- imagedream/ldm/modules/diffusionmodules/adaptors.py +163 -0
- imagedream/ldm/modules/diffusionmodules/model.py +1018 -0
- imagedream/ldm/modules/diffusionmodules/openaimodel.py +1135 -0
- imagedream/ldm/modules/diffusionmodules/util.py +353 -0
- imagedream/ldm/modules/distributions/__init__.py +0 -0
- imagedream/ldm/modules/distributions/distributions.py +102 -0
- imagedream/ldm/modules/ema.py +86 -0
- imagedream/ldm/modules/encoders/__init__.py +0 -0
- imagedream/ldm/modules/encoders/modules.py +329 -0
- imagedream/ldm/util.py +226 -0
- imagedream/model_zoo.py +64 -0
- inference.py +91 -0
- libs/base_utils.py +84 -0
- libs/sample.py +380 -0
- mesh.py +845 -0
- model/__init__.py +1 -0
- model/archs/__init__.py +0 -0
- model/archs/decoders/__init__.py +1 -0
- model/archs/decoders/shape_texture_net.py +62 -0
- model/archs/mlp_head.py +40 -0
- model/archs/unet.py +53 -0
- model/crm/model.py +213 -0
- pipelines.py +170 -0
- requirements.txt +14 -0
- util/__init__.py +0 -0
- util/flexicubes.py +579 -0
- util/flexicubes_geometry.py +116 -0
- util/renderer.py +49 -0
.gitignore
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
|
app.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Not ready to use yet
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
import PIL
|
9 |
+
from pipelines import TwoStagePipeline
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
import os
|
12 |
+
import rembg
|
13 |
+
from typing import Any
|
14 |
+
import json
|
15 |
+
import os
|
16 |
+
import json
|
17 |
+
import argparse
|
18 |
+
|
19 |
+
from model import CRM
|
20 |
+
from inference import generate3d
|
21 |
+
|
22 |
+
pipeline = None
|
23 |
+
rembg_session = rembg.new_session()
|
24 |
+
|
25 |
+
|
26 |
+
def check_input_image(input_image):
|
27 |
+
if input_image is None:
|
28 |
+
raise gr.Error("No image uploaded!")
|
29 |
+
|
30 |
+
|
31 |
+
def remove_background(
|
32 |
+
image: PIL.Image.Image,
|
33 |
+
rembg_session: Any = None,
|
34 |
+
force: bool = False,
|
35 |
+
**rembg_kwargs,
|
36 |
+
) -> PIL.Image.Image:
|
37 |
+
do_remove = True
|
38 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
39 |
+
# explain why current do not rm bg
|
40 |
+
print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
|
41 |
+
background = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
42 |
+
image = Image.alpha_composite(background, image)
|
43 |
+
do_remove = False
|
44 |
+
do_remove = do_remove or force
|
45 |
+
if do_remove:
|
46 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
47 |
+
return image
|
48 |
+
|
49 |
+
def do_resize_content(original_image: Image, scale_rate):
|
50 |
+
# resize image content wile retain the original image size
|
51 |
+
if scale_rate != 1:
|
52 |
+
# Calculate the new size after rescaling
|
53 |
+
new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
|
54 |
+
# Resize the image while maintaining the aspect ratio
|
55 |
+
resized_image = original_image.resize(new_size)
|
56 |
+
# Create a new image with the original size and black background
|
57 |
+
padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
|
58 |
+
paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
|
59 |
+
padded_image.paste(resized_image, paste_position)
|
60 |
+
return padded_image
|
61 |
+
else:
|
62 |
+
return original_image
|
63 |
+
|
64 |
+
def add_background(image, bg_color=(255, 255, 255)):
|
65 |
+
# given an RGBA image, alpha channel is used as mask to add background color
|
66 |
+
background = Image.new("RGBA", image.size, bg_color)
|
67 |
+
return Image.alpha_composite(background, image)
|
68 |
+
|
69 |
+
|
70 |
+
def preprocess_image(input_image, do_remove_background, force_remove, foreground_ratio, backgroud_color):
|
71 |
+
"""
|
72 |
+
input image is a pil image in RGBA, return RGB image
|
73 |
+
"""
|
74 |
+
if do_remove_background:
|
75 |
+
image = remove_background(input_image, rembg_session, force_remove)
|
76 |
+
image = do_resize_content(image, foreground_ratio)
|
77 |
+
image = add_background(image, backgroud_color)
|
78 |
+
return image.convert("RGB")
|
79 |
+
|
80 |
+
|
81 |
+
def gen_image(input_image, seed, scale, step):
|
82 |
+
global pipeline, model, args
|
83 |
+
pipeline.set_seed(seed)
|
84 |
+
rt_dict = pipeline(input_image, scale=scale, step=step)
|
85 |
+
stage1_images = rt_dict["stage1_images"]
|
86 |
+
stage2_images = rt_dict["stage2_images"]
|
87 |
+
np_imgs = np.concatenate(stage1_images, 1)
|
88 |
+
np_xyzs = np.concatenate(stage2_images, 1)
|
89 |
+
|
90 |
+
glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, args.device)
|
91 |
+
return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path, obj_path
|
92 |
+
|
93 |
+
|
94 |
+
parser = argparse.ArgumentParser()
|
95 |
+
parser.add_argument(
|
96 |
+
"--stage1_config",
|
97 |
+
type=str,
|
98 |
+
default="configs/nf7_v3_SNR_rd_size_stroke.yaml",
|
99 |
+
help="config for stage1",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--stage2_config",
|
103 |
+
type=str,
|
104 |
+
default="configs/stage2-v2-snr.yaml",
|
105 |
+
help="config for stage2",
|
106 |
+
)
|
107 |
+
|
108 |
+
parser.add_argument("--device", type=str, default="cuda")
|
109 |
+
args = parser.parse_args()
|
110 |
+
|
111 |
+
crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
|
112 |
+
specs = json.load(open("configs/specs_objaverse_total.json"))
|
113 |
+
model = CRM(specs).to(args.device)
|
114 |
+
model.load_state_dict(torch.load(crm_path, map_location = args.device), strict=False)
|
115 |
+
|
116 |
+
stage1_config = OmegaConf.load(args.stage1_config).config
|
117 |
+
stage2_config = OmegaConf.load(args.stage2_config).config
|
118 |
+
stage2_sampler_config = stage2_config.sampler
|
119 |
+
stage1_sampler_config = stage1_config.sampler
|
120 |
+
|
121 |
+
stage1_model_config = stage1_config.models
|
122 |
+
stage2_model_config = stage2_config.models
|
123 |
+
|
124 |
+
xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth")
|
125 |
+
pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth")
|
126 |
+
stage1_model_config.resume = pixel_path
|
127 |
+
stage2_model_config.resume = xyz_path
|
128 |
+
|
129 |
+
pipeline = TwoStagePipeline(
|
130 |
+
stage1_model_config,
|
131 |
+
stage2_model_config,
|
132 |
+
stage1_sampler_config,
|
133 |
+
stage2_sampler_config,
|
134 |
+
device=args.device,
|
135 |
+
dtype=torch.float16
|
136 |
+
)
|
137 |
+
|
138 |
+
with gr.Blocks() as demo:
|
139 |
+
gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model")
|
140 |
+
with gr.Row():
|
141 |
+
with gr.Column():
|
142 |
+
with gr.Row():
|
143 |
+
image_input = gr.Image(
|
144 |
+
label="Image input",
|
145 |
+
image_mode="RGBA",
|
146 |
+
sources="upload",
|
147 |
+
type="pil",
|
148 |
+
)
|
149 |
+
processed_image = gr.Image(label="Processed Image", interactive=False, type="pil", image_mode="RGB")
|
150 |
+
with gr.Row():
|
151 |
+
with gr.Column():
|
152 |
+
with gr.Row():
|
153 |
+
do_remove_background = gr.Checkbox(label="Remove Background", value=True)
|
154 |
+
force_remove = gr.Checkbox(label="Force Remove", value=False)
|
155 |
+
back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False)
|
156 |
+
foreground_ratio = gr.Slider(
|
157 |
+
label="Foreground Ratio",
|
158 |
+
minimum=0.5,
|
159 |
+
maximum=1.0,
|
160 |
+
value=1.0,
|
161 |
+
step=0.05,
|
162 |
+
)
|
163 |
+
|
164 |
+
with gr.Column():
|
165 |
+
seed = gr.Number(value=1234, label="seed", precision=0)
|
166 |
+
guidance_scale = gr.Number(value=5.5, minimum=0, maximum=20, label="guidance_scale")
|
167 |
+
step = gr.Number(value=50, minimum=1, maximum=100, label="sample steps", precision=0)
|
168 |
+
text_button = gr.Button("Generate Images")
|
169 |
+
with gr.Column():
|
170 |
+
image_output = gr.Image(interactive=False, label="Output RGB image")
|
171 |
+
xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
|
172 |
+
|
173 |
+
output_model = gr.Model3D(
|
174 |
+
label="Output GLB",
|
175 |
+
interactive=False,
|
176 |
+
)
|
177 |
+
output_obj = gr.File(interactive=False, label="Output OBJ")
|
178 |
+
|
179 |
+
inputs = [
|
180 |
+
processed_image,
|
181 |
+
seed,
|
182 |
+
guidance_scale,
|
183 |
+
step,
|
184 |
+
]
|
185 |
+
outputs = [
|
186 |
+
image_output,
|
187 |
+
xyz_ouput,
|
188 |
+
output_model,
|
189 |
+
output_obj,
|
190 |
+
]
|
191 |
+
gr.Examples(
|
192 |
+
examples=[os.path.join("examples", i) for i in os.listdir("examples")],
|
193 |
+
inputs=[image_input],
|
194 |
+
)
|
195 |
+
|
196 |
+
text_button.click(fn=check_input_image, inputs=[image_input]).success(
|
197 |
+
fn=preprocess_image,
|
198 |
+
inputs=[image_input, do_remove_background, force_remove, foreground_ratio, back_groud_color],
|
199 |
+
outputs=[processed_image],
|
200 |
+
).success(
|
201 |
+
fn=gen_image,
|
202 |
+
inputs=inputs,
|
203 |
+
outputs=outputs,
|
204 |
+
)
|
205 |
+
demo.queue().launch()
|
configs/nf7_v3_SNR_rd_size_stroke.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
config:
|
2 |
+
# others
|
3 |
+
seed: 1234
|
4 |
+
num_frames: 7
|
5 |
+
mode: pixel
|
6 |
+
offset_noise: true
|
7 |
+
# model related
|
8 |
+
models:
|
9 |
+
config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
|
10 |
+
resume: models/pixel.pth
|
11 |
+
# sampler related
|
12 |
+
sampler:
|
13 |
+
target: libs.sample.ImageDreamDiffusion
|
14 |
+
params:
|
15 |
+
mode: pixel
|
16 |
+
num_frames: 7
|
17 |
+
camera_views: [1, 2, 3, 4, 5, 0, 0]
|
18 |
+
ref_position: 6
|
19 |
+
random_background: false
|
20 |
+
offset_noise: true
|
21 |
+
resize_rate: 1.0
|
configs/specs_objaverse_total.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Input": {
|
3 |
+
"img_num": 16,
|
4 |
+
"class": "all",
|
5 |
+
"camera_angle_num": 8,
|
6 |
+
"tet_grid_size": 80,
|
7 |
+
"validate_num": 16,
|
8 |
+
"scale": 0.95,
|
9 |
+
"radius": 3,
|
10 |
+
"resolution": [256, 256]
|
11 |
+
},
|
12 |
+
|
13 |
+
"Pretrain": {
|
14 |
+
"mode": null,
|
15 |
+
"sdf_threshold": 0.1,
|
16 |
+
"sdf_scale": 10,
|
17 |
+
"batch_infer": false,
|
18 |
+
"lr": 1e-4,
|
19 |
+
"radius": 0.5
|
20 |
+
},
|
21 |
+
|
22 |
+
"Train": {
|
23 |
+
"mode": "rnd",
|
24 |
+
"num_epochs": 500,
|
25 |
+
"grad_acc": 1,
|
26 |
+
"warm_up": 0,
|
27 |
+
"decay": 0.000,
|
28 |
+
"learning_rate": {
|
29 |
+
"init": 1e-4,
|
30 |
+
"sdf_decay": 1,
|
31 |
+
"rgb_decay": 1
|
32 |
+
},
|
33 |
+
"batch_size": 4,
|
34 |
+
"eva_iter": 80,
|
35 |
+
"eva_all_epoch": 10,
|
36 |
+
"tex_sup_mode": "blender",
|
37 |
+
"exp_uv_mesh": false,
|
38 |
+
"doub": false,
|
39 |
+
"random_bg": false,
|
40 |
+
"shift": 0,
|
41 |
+
"aug_shift": 0,
|
42 |
+
"geo_type": "flex"
|
43 |
+
},
|
44 |
+
|
45 |
+
"ArchSpecs": {
|
46 |
+
"unet_type": "diffusers",
|
47 |
+
"use_3D_aware": false,
|
48 |
+
"fea_concat": false,
|
49 |
+
"mlp_bias": true
|
50 |
+
},
|
51 |
+
|
52 |
+
"DecoderSpecs": {
|
53 |
+
"c_dim": 32,
|
54 |
+
"plane_resolution": 256
|
55 |
+
}
|
56 |
+
}
|
57 |
+
|
configs/stage2-v2-snr.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
config:
|
2 |
+
# others
|
3 |
+
seed: 1234
|
4 |
+
num_frames: 6
|
5 |
+
mode: pixel
|
6 |
+
offset_noise: true
|
7 |
+
gd_type: xyz
|
8 |
+
# model related
|
9 |
+
models:
|
10 |
+
config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
|
11 |
+
resume: models/xyz.pth
|
12 |
+
|
13 |
+
# eval related
|
14 |
+
sampler:
|
15 |
+
target: libs.sample.ImageDreamDiffusionStage2
|
16 |
+
params:
|
17 |
+
mode: pixel
|
18 |
+
num_frames: 6
|
19 |
+
camera_views: [1, 2, 3, 4, 5, 0]
|
20 |
+
ref_position: null
|
21 |
+
random_background: false
|
22 |
+
offset_noise: true
|
23 |
+
resize_rate: 1.0
|
24 |
+
|
25 |
+
|
imagedream/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model_zoo import build_model
|
imagedream/camera_utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def create_camera_to_world_matrix(elevation, azimuth):
|
6 |
+
elevation = np.radians(elevation)
|
7 |
+
azimuth = np.radians(azimuth)
|
8 |
+
# Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
|
9 |
+
x = np.cos(elevation) * np.sin(azimuth)
|
10 |
+
y = np.sin(elevation)
|
11 |
+
z = np.cos(elevation) * np.cos(azimuth)
|
12 |
+
|
13 |
+
# Calculate camera position, target, and up vectors
|
14 |
+
camera_pos = np.array([x, y, z])
|
15 |
+
target = np.array([0, 0, 0])
|
16 |
+
up = np.array([0, 1, 0])
|
17 |
+
|
18 |
+
# Construct view matrix
|
19 |
+
forward = target - camera_pos
|
20 |
+
forward /= np.linalg.norm(forward)
|
21 |
+
right = np.cross(forward, up)
|
22 |
+
right /= np.linalg.norm(right)
|
23 |
+
new_up = np.cross(right, forward)
|
24 |
+
new_up /= np.linalg.norm(new_up)
|
25 |
+
cam2world = np.eye(4)
|
26 |
+
cam2world[:3, :3] = np.array([right, new_up, -forward]).T
|
27 |
+
cam2world[:3, 3] = camera_pos
|
28 |
+
return cam2world
|
29 |
+
|
30 |
+
|
31 |
+
def convert_opengl_to_blender(camera_matrix):
|
32 |
+
if isinstance(camera_matrix, np.ndarray):
|
33 |
+
# Construct transformation matrix to convert from OpenGL space to Blender space
|
34 |
+
flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
|
35 |
+
camera_matrix_blender = np.dot(flip_yz, camera_matrix)
|
36 |
+
else:
|
37 |
+
# Construct transformation matrix to convert from OpenGL space to Blender space
|
38 |
+
flip_yz = torch.tensor(
|
39 |
+
[[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
|
40 |
+
)
|
41 |
+
if camera_matrix.ndim == 3:
|
42 |
+
flip_yz = flip_yz.unsqueeze(0)
|
43 |
+
camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
|
44 |
+
return camera_matrix_blender
|
45 |
+
|
46 |
+
|
47 |
+
def normalize_camera(camera_matrix):
|
48 |
+
"""normalize the camera location onto a unit-sphere"""
|
49 |
+
if isinstance(camera_matrix, np.ndarray):
|
50 |
+
camera_matrix = camera_matrix.reshape(-1, 4, 4)
|
51 |
+
translation = camera_matrix[:, :3, 3]
|
52 |
+
translation = translation / (
|
53 |
+
np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8
|
54 |
+
)
|
55 |
+
camera_matrix[:, :3, 3] = translation
|
56 |
+
else:
|
57 |
+
camera_matrix = camera_matrix.reshape(-1, 4, 4)
|
58 |
+
translation = camera_matrix[:, :3, 3]
|
59 |
+
translation = translation / (
|
60 |
+
torch.norm(translation, dim=1, keepdim=True) + 1e-8
|
61 |
+
)
|
62 |
+
camera_matrix[:, :3, 3] = translation
|
63 |
+
return camera_matrix.reshape(-1, 16)
|
64 |
+
|
65 |
+
|
66 |
+
def get_camera(
|
67 |
+
num_frames,
|
68 |
+
elevation=15,
|
69 |
+
azimuth_start=0,
|
70 |
+
azimuth_span=360,
|
71 |
+
blender_coord=True,
|
72 |
+
extra_view=False,
|
73 |
+
):
|
74 |
+
angle_gap = azimuth_span / num_frames
|
75 |
+
cameras = []
|
76 |
+
for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
|
77 |
+
camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
|
78 |
+
if blender_coord:
|
79 |
+
camera_matrix = convert_opengl_to_blender(camera_matrix)
|
80 |
+
cameras.append(camera_matrix.flatten())
|
81 |
+
|
82 |
+
if extra_view:
|
83 |
+
dim = len(cameras[0])
|
84 |
+
cameras.append(np.zeros(dim))
|
85 |
+
return torch.tensor(np.stack(cameras, 0)).float()
|
86 |
+
|
87 |
+
|
88 |
+
def get_camera_for_index(data_index):
|
89 |
+
"""
|
90 |
+
按照当前我们的数据格式, 以000为正对我们的情况:
|
91 |
+
000是正面, ev: 0, azimuth: 0
|
92 |
+
001是左边, ev: 0, azimuth: -90
|
93 |
+
002是下面, ev: -90, azimuth: 0
|
94 |
+
003是背面, ev: 0, azimuth: 180
|
95 |
+
004是右边, ev: 0, azimuth: 90
|
96 |
+
005是上面, ev: 90, azimuth: 0
|
97 |
+
"""
|
98 |
+
params = [(0, 0), (0, -90), (-90, 0), (0, 180), (0, 90), (90, 0)]
|
99 |
+
return get_camera(1, *params[data_index])
|
imagedream/configs/sd_v2_base_ipmv.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: imagedream.ldm.interface.LatentDiffusionInterface
|
3 |
+
params:
|
4 |
+
linear_start: 0.00085
|
5 |
+
linear_end: 0.0120
|
6 |
+
timesteps: 1000
|
7 |
+
scale_factor: 0.18215
|
8 |
+
parameterization: "eps"
|
9 |
+
|
10 |
+
unet_config:
|
11 |
+
target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
|
12 |
+
params:
|
13 |
+
image_size: 32 # unused
|
14 |
+
in_channels: 4
|
15 |
+
out_channels: 4
|
16 |
+
model_channels: 320
|
17 |
+
attention_resolutions: [ 4, 2, 1 ]
|
18 |
+
num_res_blocks: 2
|
19 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
20 |
+
num_head_channels: 64 # need to fix for flash-attn
|
21 |
+
use_spatial_transformer: True
|
22 |
+
use_linear_in_transformer: True
|
23 |
+
transformer_depth: 1
|
24 |
+
context_dim: 1024
|
25 |
+
use_checkpoint: False
|
26 |
+
legacy: False
|
27 |
+
camera_dim: 16
|
28 |
+
with_ip: True
|
29 |
+
ip_dim: 16 # ip token length
|
30 |
+
ip_mode: "local_resample"
|
31 |
+
|
32 |
+
vae_config:
|
33 |
+
target: imagedream.ldm.models.autoencoder.AutoencoderKL
|
34 |
+
params:
|
35 |
+
embed_dim: 4
|
36 |
+
monitor: val/rec_loss
|
37 |
+
ddconfig:
|
38 |
+
#attn_type: "vanilla-xformers"
|
39 |
+
double_z: true
|
40 |
+
z_channels: 4
|
41 |
+
resolution: 256
|
42 |
+
in_channels: 3
|
43 |
+
out_ch: 3
|
44 |
+
ch: 128
|
45 |
+
ch_mult:
|
46 |
+
- 1
|
47 |
+
- 2
|
48 |
+
- 4
|
49 |
+
- 4
|
50 |
+
num_res_blocks: 2
|
51 |
+
attn_resolutions: []
|
52 |
+
dropout: 0.0
|
53 |
+
lossconfig:
|
54 |
+
target: torch.nn.Identity
|
55 |
+
|
56 |
+
clip_config:
|
57 |
+
target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
58 |
+
params:
|
59 |
+
freeze: True
|
60 |
+
layer: "penultimate"
|
61 |
+
ip_mode: "local_resample"
|
imagedream/configs/sd_v2_base_ipmv_ch8.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: imagedream.ldm.interface.LatentDiffusionInterface
|
3 |
+
params:
|
4 |
+
linear_start: 0.00085
|
5 |
+
linear_end: 0.0120
|
6 |
+
timesteps: 1000
|
7 |
+
scale_factor: 0.18215
|
8 |
+
parameterization: "eps"
|
9 |
+
|
10 |
+
unet_config:
|
11 |
+
target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
|
12 |
+
params:
|
13 |
+
image_size: 32 # unused
|
14 |
+
in_channels: 8
|
15 |
+
out_channels: 8
|
16 |
+
model_channels: 320
|
17 |
+
attention_resolutions: [ 4, 2, 1 ]
|
18 |
+
num_res_blocks: 2
|
19 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
20 |
+
num_head_channels: 64 # need to fix for flash-attn
|
21 |
+
use_spatial_transformer: True
|
22 |
+
use_linear_in_transformer: True
|
23 |
+
transformer_depth: 1
|
24 |
+
context_dim: 1024
|
25 |
+
use_checkpoint: False
|
26 |
+
legacy: False
|
27 |
+
camera_dim: 16
|
28 |
+
with_ip: True
|
29 |
+
ip_dim: 16 # ip token length
|
30 |
+
ip_mode: "local_resample"
|
31 |
+
|
32 |
+
vae_config:
|
33 |
+
target: imagedream.ldm.models.autoencoder.AutoencoderKL
|
34 |
+
params:
|
35 |
+
embed_dim: 4
|
36 |
+
monitor: val/rec_loss
|
37 |
+
ddconfig:
|
38 |
+
#attn_type: "vanilla-xformers"
|
39 |
+
double_z: true
|
40 |
+
z_channels: 4
|
41 |
+
resolution: 256
|
42 |
+
in_channels: 3
|
43 |
+
out_ch: 3
|
44 |
+
ch: 128
|
45 |
+
ch_mult:
|
46 |
+
- 1
|
47 |
+
- 2
|
48 |
+
- 4
|
49 |
+
- 4
|
50 |
+
num_res_blocks: 2
|
51 |
+
attn_resolutions: []
|
52 |
+
dropout: 0.0
|
53 |
+
lossconfig:
|
54 |
+
target: torch.nn.Identity
|
55 |
+
|
56 |
+
clip_config:
|
57 |
+
target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
58 |
+
params:
|
59 |
+
freeze: True
|
60 |
+
layer: "penultimate"
|
61 |
+
ip_mode: "local_resample"
|
imagedream/configs/sd_v2_base_ipmv_chin8.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: imagedream.ldm.interface.LatentDiffusionInterface
|
3 |
+
params:
|
4 |
+
linear_start: 0.00085
|
5 |
+
linear_end: 0.0120
|
6 |
+
timesteps: 1000
|
7 |
+
scale_factor: 0.18215
|
8 |
+
parameterization: "eps"
|
9 |
+
|
10 |
+
unet_config:
|
11 |
+
target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2
|
12 |
+
params:
|
13 |
+
image_size: 32 # unused
|
14 |
+
in_channels: 8
|
15 |
+
out_channels: 4
|
16 |
+
model_channels: 320
|
17 |
+
attention_resolutions: [ 4, 2, 1 ]
|
18 |
+
num_res_blocks: 2
|
19 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
20 |
+
num_head_channels: 64 # need to fix for flash-attn
|
21 |
+
use_spatial_transformer: True
|
22 |
+
use_linear_in_transformer: True
|
23 |
+
transformer_depth: 1
|
24 |
+
context_dim: 1024
|
25 |
+
use_checkpoint: False
|
26 |
+
legacy: False
|
27 |
+
camera_dim: 16
|
28 |
+
with_ip: True
|
29 |
+
ip_dim: 16 # ip token length
|
30 |
+
ip_mode: "local_resample"
|
31 |
+
|
32 |
+
vae_config:
|
33 |
+
target: imagedream.ldm.models.autoencoder.AutoencoderKL
|
34 |
+
params:
|
35 |
+
embed_dim: 4
|
36 |
+
monitor: val/rec_loss
|
37 |
+
ddconfig:
|
38 |
+
#attn_type: "vanilla-xformers"
|
39 |
+
double_z: true
|
40 |
+
z_channels: 4
|
41 |
+
resolution: 256
|
42 |
+
in_channels: 3
|
43 |
+
out_ch: 3
|
44 |
+
ch: 128
|
45 |
+
ch_mult:
|
46 |
+
- 1
|
47 |
+
- 2
|
48 |
+
- 4
|
49 |
+
- 4
|
50 |
+
num_res_blocks: 2
|
51 |
+
attn_resolutions: []
|
52 |
+
dropout: 0.0
|
53 |
+
lossconfig:
|
54 |
+
target: torch.nn.Identity
|
55 |
+
|
56 |
+
clip_config:
|
57 |
+
target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
58 |
+
params:
|
59 |
+
freeze: True
|
60 |
+
layer: "penultimate"
|
61 |
+
ip_mode: "local_resample"
|
imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: imagedream.ldm.interface.LatentDiffusionInterface
|
3 |
+
params:
|
4 |
+
linear_start: 0.00085
|
5 |
+
linear_end: 0.0120
|
6 |
+
timesteps: 1000
|
7 |
+
scale_factor: 0.18215
|
8 |
+
parameterization: "eps"
|
9 |
+
zero_snr: true
|
10 |
+
|
11 |
+
unet_config:
|
12 |
+
target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2
|
13 |
+
params:
|
14 |
+
image_size: 32 # unused
|
15 |
+
in_channels: 8
|
16 |
+
out_channels: 4
|
17 |
+
model_channels: 320
|
18 |
+
attention_resolutions: [ 4, 2, 1 ]
|
19 |
+
num_res_blocks: 2
|
20 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
21 |
+
num_head_channels: 64 # need to fix for flash-attn
|
22 |
+
use_spatial_transformer: True
|
23 |
+
use_linear_in_transformer: True
|
24 |
+
transformer_depth: 1
|
25 |
+
context_dim: 1024
|
26 |
+
use_checkpoint: False
|
27 |
+
legacy: False
|
28 |
+
camera_dim: 16
|
29 |
+
with_ip: True
|
30 |
+
ip_dim: 16 # ip token length
|
31 |
+
ip_mode: "local_resample"
|
32 |
+
|
33 |
+
vae_config:
|
34 |
+
target: imagedream.ldm.models.autoencoder.AutoencoderKL
|
35 |
+
params:
|
36 |
+
embed_dim: 4
|
37 |
+
monitor: val/rec_loss
|
38 |
+
ddconfig:
|
39 |
+
#attn_type: "vanilla-xformers"
|
40 |
+
double_z: true
|
41 |
+
z_channels: 4
|
42 |
+
resolution: 256
|
43 |
+
in_channels: 3
|
44 |
+
out_ch: 3
|
45 |
+
ch: 128
|
46 |
+
ch_mult:
|
47 |
+
- 1
|
48 |
+
- 2
|
49 |
+
- 4
|
50 |
+
- 4
|
51 |
+
num_res_blocks: 2
|
52 |
+
attn_resolutions: []
|
53 |
+
dropout: 0.0
|
54 |
+
lossconfig:
|
55 |
+
target: torch.nn.Identity
|
56 |
+
|
57 |
+
clip_config:
|
58 |
+
target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
59 |
+
params:
|
60 |
+
freeze: True
|
61 |
+
layer: "penultimate"
|
62 |
+
ip_mode: "local_resample"
|
imagedream/configs/sd_v2_base_ipmv_local.yaml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: imagedream.ldm.interface.LatentDiffusionInterface
|
3 |
+
params:
|
4 |
+
linear_start: 0.00085
|
5 |
+
linear_end: 0.0120
|
6 |
+
timesteps: 1000
|
7 |
+
scale_factor: 0.18215
|
8 |
+
parameterization: "eps"
|
9 |
+
|
10 |
+
unet_config:
|
11 |
+
target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
|
12 |
+
params:
|
13 |
+
image_size: 32 # unused
|
14 |
+
in_channels: 4
|
15 |
+
out_channels: 4
|
16 |
+
model_channels: 320
|
17 |
+
attention_resolutions: [ 4, 2, 1 ]
|
18 |
+
num_res_blocks: 2
|
19 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
20 |
+
num_head_channels: 64 # need to fix for flash-attn
|
21 |
+
use_spatial_transformer: True
|
22 |
+
use_linear_in_transformer: True
|
23 |
+
transformer_depth: 1
|
24 |
+
context_dim: 1024
|
25 |
+
use_checkpoint: False
|
26 |
+
legacy: False
|
27 |
+
camera_dim: 16
|
28 |
+
with_ip: True
|
29 |
+
ip_dim: 16 # ip token length
|
30 |
+
ip_mode: "local_resample"
|
31 |
+
ip_weight: 1.0 # adjust for similarity to image
|
32 |
+
|
33 |
+
vae_config:
|
34 |
+
target: imagedream.ldm.models.autoencoder.AutoencoderKL
|
35 |
+
params:
|
36 |
+
embed_dim: 4
|
37 |
+
monitor: val/rec_loss
|
38 |
+
ddconfig:
|
39 |
+
#attn_type: "vanilla-xformers"
|
40 |
+
double_z: true
|
41 |
+
z_channels: 4
|
42 |
+
resolution: 256
|
43 |
+
in_channels: 3
|
44 |
+
out_ch: 3
|
45 |
+
ch: 128
|
46 |
+
ch_mult:
|
47 |
+
- 1
|
48 |
+
- 2
|
49 |
+
- 4
|
50 |
+
- 4
|
51 |
+
num_res_blocks: 2
|
52 |
+
attn_resolutions: []
|
53 |
+
dropout: 0.0
|
54 |
+
lossconfig:
|
55 |
+
target: torch.nn.Identity
|
56 |
+
|
57 |
+
clip_config:
|
58 |
+
target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
59 |
+
params:
|
60 |
+
freeze: True
|
61 |
+
layer: "penultimate"
|
62 |
+
ip_mode: "local_resample"
|
imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: imagedream.ldm.interface.LatentDiffusionInterface
|
3 |
+
params:
|
4 |
+
linear_start: 0.00085
|
5 |
+
linear_end: 0.0120
|
6 |
+
timesteps: 1000
|
7 |
+
scale_factor: 0.18215
|
8 |
+
parameterization: "eps"
|
9 |
+
zero_snr: true
|
10 |
+
|
11 |
+
unet_config:
|
12 |
+
target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
|
13 |
+
params:
|
14 |
+
image_size: 32 # unused
|
15 |
+
in_channels: 4
|
16 |
+
out_channels: 4
|
17 |
+
model_channels: 320
|
18 |
+
attention_resolutions: [ 4, 2, 1 ]
|
19 |
+
num_res_blocks: 2
|
20 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
21 |
+
num_head_channels: 64 # need to fix for flash-attn
|
22 |
+
use_spatial_transformer: True
|
23 |
+
use_linear_in_transformer: True
|
24 |
+
transformer_depth: 1
|
25 |
+
context_dim: 1024
|
26 |
+
use_checkpoint: False
|
27 |
+
legacy: False
|
28 |
+
camera_dim: 16
|
29 |
+
with_ip: True
|
30 |
+
ip_dim: 16 # ip token length
|
31 |
+
ip_mode: "local_resample"
|
32 |
+
|
33 |
+
vae_config:
|
34 |
+
target: imagedream.ldm.models.autoencoder.AutoencoderKL
|
35 |
+
params:
|
36 |
+
embed_dim: 4
|
37 |
+
monitor: val/rec_loss
|
38 |
+
ddconfig:
|
39 |
+
#attn_type: "vanilla-xformers"
|
40 |
+
double_z: true
|
41 |
+
z_channels: 4
|
42 |
+
resolution: 256
|
43 |
+
in_channels: 3
|
44 |
+
out_ch: 3
|
45 |
+
ch: 128
|
46 |
+
ch_mult:
|
47 |
+
- 1
|
48 |
+
- 2
|
49 |
+
- 4
|
50 |
+
- 4
|
51 |
+
num_res_blocks: 2
|
52 |
+
attn_resolutions: []
|
53 |
+
dropout: 0.0
|
54 |
+
lossconfig:
|
55 |
+
target: torch.nn.Identity
|
56 |
+
|
57 |
+
clip_config:
|
58 |
+
target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
59 |
+
params:
|
60 |
+
freeze: True
|
61 |
+
layer: "penultimate"
|
62 |
+
ip_mode: "local_resample"
|
imagedream/ldm/__init__.py
ADDED
File without changes
|
imagedream/ldm/interface.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .modules.diffusionmodules.util import (
|
9 |
+
make_beta_schedule,
|
10 |
+
extract_into_tensor,
|
11 |
+
enforce_zero_terminal_snr,
|
12 |
+
noise_like,
|
13 |
+
)
|
14 |
+
from .util import exists, default, instantiate_from_config
|
15 |
+
from .modules.distributions.distributions import DiagonalGaussianDistribution
|
16 |
+
|
17 |
+
|
18 |
+
class DiffusionWrapper(nn.Module):
|
19 |
+
def __init__(self, diffusion_model):
|
20 |
+
super().__init__()
|
21 |
+
self.diffusion_model = diffusion_model
|
22 |
+
|
23 |
+
def forward(self, *args, **kwargs):
|
24 |
+
return self.diffusion_model(*args, **kwargs)
|
25 |
+
|
26 |
+
|
27 |
+
class LatentDiffusionInterface(nn.Module):
|
28 |
+
"""a simple interface class for LDM inference"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
unet_config,
|
33 |
+
clip_config,
|
34 |
+
vae_config,
|
35 |
+
parameterization="eps",
|
36 |
+
scale_factor=0.18215,
|
37 |
+
beta_schedule="linear",
|
38 |
+
timesteps=1000,
|
39 |
+
linear_start=0.00085,
|
40 |
+
linear_end=0.0120,
|
41 |
+
cosine_s=8e-3,
|
42 |
+
given_betas=None,
|
43 |
+
zero_snr=False,
|
44 |
+
*args,
|
45 |
+
**kwargs,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
unet = instantiate_from_config(unet_config)
|
50 |
+
self.model = DiffusionWrapper(unet)
|
51 |
+
self.clip_model = instantiate_from_config(clip_config)
|
52 |
+
self.vae_model = instantiate_from_config(vae_config)
|
53 |
+
|
54 |
+
self.parameterization = parameterization
|
55 |
+
self.scale_factor = scale_factor
|
56 |
+
self.register_schedule(
|
57 |
+
given_betas=given_betas,
|
58 |
+
beta_schedule=beta_schedule,
|
59 |
+
timesteps=timesteps,
|
60 |
+
linear_start=linear_start,
|
61 |
+
linear_end=linear_end,
|
62 |
+
cosine_s=cosine_s,
|
63 |
+
zero_snr=zero_snr
|
64 |
+
)
|
65 |
+
|
66 |
+
def register_schedule(
|
67 |
+
self,
|
68 |
+
given_betas=None,
|
69 |
+
beta_schedule="linear",
|
70 |
+
timesteps=1000,
|
71 |
+
linear_start=1e-4,
|
72 |
+
linear_end=2e-2,
|
73 |
+
cosine_s=8e-3,
|
74 |
+
zero_snr=False
|
75 |
+
):
|
76 |
+
if exists(given_betas):
|
77 |
+
betas = given_betas
|
78 |
+
else:
|
79 |
+
betas = make_beta_schedule(
|
80 |
+
beta_schedule,
|
81 |
+
timesteps,
|
82 |
+
linear_start=linear_start,
|
83 |
+
linear_end=linear_end,
|
84 |
+
cosine_s=cosine_s,
|
85 |
+
)
|
86 |
+
if zero_snr:
|
87 |
+
print("--- using zero snr---")
|
88 |
+
betas = enforce_zero_terminal_snr(betas).numpy()
|
89 |
+
alphas = 1.0 - betas
|
90 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
91 |
+
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
92 |
+
|
93 |
+
(timesteps,) = betas.shape
|
94 |
+
self.num_timesteps = int(timesteps)
|
95 |
+
self.linear_start = linear_start
|
96 |
+
self.linear_end = linear_end
|
97 |
+
assert (
|
98 |
+
alphas_cumprod.shape[0] == self.num_timesteps
|
99 |
+
), "alphas have to be defined for each timestep"
|
100 |
+
|
101 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
102 |
+
|
103 |
+
self.register_buffer("betas", to_torch(betas))
|
104 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
105 |
+
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
106 |
+
|
107 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
108 |
+
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
109 |
+
self.register_buffer(
|
110 |
+
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
111 |
+
)
|
112 |
+
self.register_buffer(
|
113 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
114 |
+
)
|
115 |
+
self.register_buffer(
|
116 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
117 |
+
)
|
118 |
+
self.register_buffer(
|
119 |
+
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
120 |
+
)
|
121 |
+
|
122 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
123 |
+
self.v_posterior = 0
|
124 |
+
posterior_variance = (1 - self.v_posterior) * betas * (
|
125 |
+
1.0 - alphas_cumprod_prev
|
126 |
+
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
127 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
128 |
+
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
129 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
130 |
+
self.register_buffer(
|
131 |
+
"posterior_log_variance_clipped",
|
132 |
+
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
|
133 |
+
)
|
134 |
+
self.register_buffer(
|
135 |
+
"posterior_mean_coef1",
|
136 |
+
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
|
137 |
+
)
|
138 |
+
self.register_buffer(
|
139 |
+
"posterior_mean_coef2",
|
140 |
+
to_torch(
|
141 |
+
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
|
142 |
+
),
|
143 |
+
)
|
144 |
+
|
145 |
+
def q_sample(self, x_start, t, noise=None):
|
146 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
147 |
+
return (
|
148 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
149 |
+
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
150 |
+
* noise
|
151 |
+
)
|
152 |
+
|
153 |
+
def get_v(self, x, noise, t):
|
154 |
+
return (
|
155 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
|
156 |
+
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
157 |
+
)
|
158 |
+
|
159 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
160 |
+
return (
|
161 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
162 |
+
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
163 |
+
* noise
|
164 |
+
)
|
165 |
+
|
166 |
+
def predict_start_from_z_and_v(self, x_t, t, v):
|
167 |
+
return (
|
168 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
|
169 |
+
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
170 |
+
)
|
171 |
+
|
172 |
+
def predict_eps_from_z_and_v(self, x_t, t, v):
|
173 |
+
return (
|
174 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
|
175 |
+
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
|
176 |
+
* x_t
|
177 |
+
)
|
178 |
+
|
179 |
+
def apply_model(self, x_noisy, t, cond, **kwargs):
|
180 |
+
assert isinstance(cond, dict), "cond has to be a dictionary"
|
181 |
+
return self.model(x_noisy, t, **cond, **kwargs)
|
182 |
+
|
183 |
+
def get_learned_conditioning(self, prompts: List[str]):
|
184 |
+
return self.clip_model(prompts)
|
185 |
+
|
186 |
+
def get_learned_image_conditioning(self, images):
|
187 |
+
return self.clip_model.forward_image(images)
|
188 |
+
|
189 |
+
def get_first_stage_encoding(self, encoder_posterior):
|
190 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
191 |
+
z = encoder_posterior.sample()
|
192 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
193 |
+
z = encoder_posterior
|
194 |
+
else:
|
195 |
+
raise NotImplementedError(
|
196 |
+
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
|
197 |
+
)
|
198 |
+
return self.scale_factor * z
|
199 |
+
|
200 |
+
def encode_first_stage(self, x):
|
201 |
+
return self.vae_model.encode(x)
|
202 |
+
|
203 |
+
def decode_first_stage(self, z):
|
204 |
+
z = 1.0 / self.scale_factor * z
|
205 |
+
return self.vae_model.decode(z)
|
imagedream/ldm/models/__init__.py
ADDED
File without changes
|
imagedream/ldm/models/autoencoder.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from contextlib import contextmanager
|
4 |
+
|
5 |
+
from ..modules.diffusionmodules.model import Encoder, Decoder
|
6 |
+
from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
7 |
+
|
8 |
+
from ..util import instantiate_from_config
|
9 |
+
from ..modules.ema import LitEma
|
10 |
+
|
11 |
+
|
12 |
+
class AutoencoderKL(torch.nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
ddconfig,
|
16 |
+
lossconfig,
|
17 |
+
embed_dim,
|
18 |
+
ckpt_path=None,
|
19 |
+
ignore_keys=[],
|
20 |
+
image_key="image",
|
21 |
+
colorize_nlabels=None,
|
22 |
+
monitor=None,
|
23 |
+
ema_decay=None,
|
24 |
+
learn_logvar=False,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.learn_logvar = learn_logvar
|
28 |
+
self.image_key = image_key
|
29 |
+
self.encoder = Encoder(**ddconfig)
|
30 |
+
self.decoder = Decoder(**ddconfig)
|
31 |
+
self.loss = instantiate_from_config(lossconfig)
|
32 |
+
assert ddconfig["double_z"]
|
33 |
+
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
34 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
35 |
+
self.embed_dim = embed_dim
|
36 |
+
if colorize_nlabels is not None:
|
37 |
+
assert type(colorize_nlabels) == int
|
38 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
39 |
+
if monitor is not None:
|
40 |
+
self.monitor = monitor
|
41 |
+
|
42 |
+
self.use_ema = ema_decay is not None
|
43 |
+
if self.use_ema:
|
44 |
+
self.ema_decay = ema_decay
|
45 |
+
assert 0.0 < ema_decay < 1.0
|
46 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
47 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
48 |
+
|
49 |
+
if ckpt_path is not None:
|
50 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
51 |
+
|
52 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
53 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
54 |
+
keys = list(sd.keys())
|
55 |
+
for k in keys:
|
56 |
+
for ik in ignore_keys:
|
57 |
+
if k.startswith(ik):
|
58 |
+
print("Deleting key {} from state_dict.".format(k))
|
59 |
+
del sd[k]
|
60 |
+
self.load_state_dict(sd, strict=False)
|
61 |
+
print(f"Restored from {path}")
|
62 |
+
|
63 |
+
@contextmanager
|
64 |
+
def ema_scope(self, context=None):
|
65 |
+
if self.use_ema:
|
66 |
+
self.model_ema.store(self.parameters())
|
67 |
+
self.model_ema.copy_to(self)
|
68 |
+
if context is not None:
|
69 |
+
print(f"{context}: Switched to EMA weights")
|
70 |
+
try:
|
71 |
+
yield None
|
72 |
+
finally:
|
73 |
+
if self.use_ema:
|
74 |
+
self.model_ema.restore(self.parameters())
|
75 |
+
if context is not None:
|
76 |
+
print(f"{context}: Restored training weights")
|
77 |
+
|
78 |
+
def on_train_batch_end(self, *args, **kwargs):
|
79 |
+
if self.use_ema:
|
80 |
+
self.model_ema(self)
|
81 |
+
|
82 |
+
def encode(self, x):
|
83 |
+
h = self.encoder(x)
|
84 |
+
moments = self.quant_conv(h)
|
85 |
+
posterior = DiagonalGaussianDistribution(moments)
|
86 |
+
return posterior
|
87 |
+
|
88 |
+
def decode(self, z):
|
89 |
+
z = self.post_quant_conv(z)
|
90 |
+
dec = self.decoder(z)
|
91 |
+
return dec
|
92 |
+
|
93 |
+
def forward(self, input, sample_posterior=True):
|
94 |
+
posterior = self.encode(input)
|
95 |
+
if sample_posterior:
|
96 |
+
z = posterior.sample()
|
97 |
+
else:
|
98 |
+
z = posterior.mode()
|
99 |
+
dec = self.decode(z)
|
100 |
+
return dec, posterior
|
101 |
+
|
102 |
+
def get_input(self, batch, k):
|
103 |
+
x = batch[k]
|
104 |
+
if len(x.shape) == 3:
|
105 |
+
x = x[..., None]
|
106 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
107 |
+
return x
|
108 |
+
|
109 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
110 |
+
inputs = self.get_input(batch, self.image_key)
|
111 |
+
reconstructions, posterior = self(inputs)
|
112 |
+
|
113 |
+
if optimizer_idx == 0:
|
114 |
+
# train encoder+decoder+logvar
|
115 |
+
aeloss, log_dict_ae = self.loss(
|
116 |
+
inputs,
|
117 |
+
reconstructions,
|
118 |
+
posterior,
|
119 |
+
optimizer_idx,
|
120 |
+
self.global_step,
|
121 |
+
last_layer=self.get_last_layer(),
|
122 |
+
split="train",
|
123 |
+
)
|
124 |
+
self.log(
|
125 |
+
"aeloss",
|
126 |
+
aeloss,
|
127 |
+
prog_bar=True,
|
128 |
+
logger=True,
|
129 |
+
on_step=True,
|
130 |
+
on_epoch=True,
|
131 |
+
)
|
132 |
+
self.log_dict(
|
133 |
+
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
|
134 |
+
)
|
135 |
+
return aeloss
|
136 |
+
|
137 |
+
if optimizer_idx == 1:
|
138 |
+
# train the discriminator
|
139 |
+
discloss, log_dict_disc = self.loss(
|
140 |
+
inputs,
|
141 |
+
reconstructions,
|
142 |
+
posterior,
|
143 |
+
optimizer_idx,
|
144 |
+
self.global_step,
|
145 |
+
last_layer=self.get_last_layer(),
|
146 |
+
split="train",
|
147 |
+
)
|
148 |
+
|
149 |
+
self.log(
|
150 |
+
"discloss",
|
151 |
+
discloss,
|
152 |
+
prog_bar=True,
|
153 |
+
logger=True,
|
154 |
+
on_step=True,
|
155 |
+
on_epoch=True,
|
156 |
+
)
|
157 |
+
self.log_dict(
|
158 |
+
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
|
159 |
+
)
|
160 |
+
return discloss
|
161 |
+
|
162 |
+
def validation_step(self, batch, batch_idx):
|
163 |
+
log_dict = self._validation_step(batch, batch_idx)
|
164 |
+
with self.ema_scope():
|
165 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
166 |
+
return log_dict
|
167 |
+
|
168 |
+
def _validation_step(self, batch, batch_idx, postfix=""):
|
169 |
+
inputs = self.get_input(batch, self.image_key)
|
170 |
+
reconstructions, posterior = self(inputs)
|
171 |
+
aeloss, log_dict_ae = self.loss(
|
172 |
+
inputs,
|
173 |
+
reconstructions,
|
174 |
+
posterior,
|
175 |
+
0,
|
176 |
+
self.global_step,
|
177 |
+
last_layer=self.get_last_layer(),
|
178 |
+
split="val" + postfix,
|
179 |
+
)
|
180 |
+
|
181 |
+
discloss, log_dict_disc = self.loss(
|
182 |
+
inputs,
|
183 |
+
reconstructions,
|
184 |
+
posterior,
|
185 |
+
1,
|
186 |
+
self.global_step,
|
187 |
+
last_layer=self.get_last_layer(),
|
188 |
+
split="val" + postfix,
|
189 |
+
)
|
190 |
+
|
191 |
+
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
192 |
+
self.log_dict(log_dict_ae)
|
193 |
+
self.log_dict(log_dict_disc)
|
194 |
+
return self.log_dict
|
195 |
+
|
196 |
+
def configure_optimizers(self):
|
197 |
+
lr = self.learning_rate
|
198 |
+
ae_params_list = (
|
199 |
+
list(self.encoder.parameters())
|
200 |
+
+ list(self.decoder.parameters())
|
201 |
+
+ list(self.quant_conv.parameters())
|
202 |
+
+ list(self.post_quant_conv.parameters())
|
203 |
+
)
|
204 |
+
if self.learn_logvar:
|
205 |
+
print(f"{self.__class__.__name__}: Learning logvar")
|
206 |
+
ae_params_list.append(self.loss.logvar)
|
207 |
+
opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
|
208 |
+
opt_disc = torch.optim.Adam(
|
209 |
+
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
|
210 |
+
)
|
211 |
+
return [opt_ae, opt_disc], []
|
212 |
+
|
213 |
+
def get_last_layer(self):
|
214 |
+
return self.decoder.conv_out.weight
|
215 |
+
|
216 |
+
@torch.no_grad()
|
217 |
+
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
218 |
+
log = dict()
|
219 |
+
x = self.get_input(batch, self.image_key)
|
220 |
+
x = x.to(self.device)
|
221 |
+
if not only_inputs:
|
222 |
+
xrec, posterior = self(x)
|
223 |
+
if x.shape[1] > 3:
|
224 |
+
# colorize with random projection
|
225 |
+
assert xrec.shape[1] > 3
|
226 |
+
x = self.to_rgb(x)
|
227 |
+
xrec = self.to_rgb(xrec)
|
228 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
229 |
+
log["reconstructions"] = xrec
|
230 |
+
if log_ema or self.use_ema:
|
231 |
+
with self.ema_scope():
|
232 |
+
xrec_ema, posterior_ema = self(x)
|
233 |
+
if x.shape[1] > 3:
|
234 |
+
# colorize with random projection
|
235 |
+
assert xrec_ema.shape[1] > 3
|
236 |
+
xrec_ema = self.to_rgb(xrec_ema)
|
237 |
+
log["samples_ema"] = self.decode(
|
238 |
+
torch.randn_like(posterior_ema.sample())
|
239 |
+
)
|
240 |
+
log["reconstructions_ema"] = xrec_ema
|
241 |
+
log["inputs"] = x
|
242 |
+
return log
|
243 |
+
|
244 |
+
def to_rgb(self, x):
|
245 |
+
assert self.image_key == "segmentation"
|
246 |
+
if not hasattr(self, "colorize"):
|
247 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
248 |
+
x = F.conv2d(x, weight=self.colorize)
|
249 |
+
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
250 |
+
return x
|
251 |
+
|
252 |
+
|
253 |
+
class IdentityFirstStage(torch.nn.Module):
|
254 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
255 |
+
self.vq_interface = vq_interface
|
256 |
+
super().__init__()
|
257 |
+
|
258 |
+
def encode(self, x, *args, **kwargs):
|
259 |
+
return x
|
260 |
+
|
261 |
+
def decode(self, x, *args, **kwargs):
|
262 |
+
return x
|
263 |
+
|
264 |
+
def quantize(self, x, *args, **kwargs):
|
265 |
+
if self.vq_interface:
|
266 |
+
return x, None, [None, None, None]
|
267 |
+
return x
|
268 |
+
|
269 |
+
def forward(self, x, *args, **kwargs):
|
270 |
+
return x
|
imagedream/ldm/models/diffusion/__init__.py
ADDED
File without changes
|
imagedream/ldm/models/diffusion/ddim.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SAMPLING ONLY."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
from ...modules.diffusionmodules.util import (
|
9 |
+
make_ddim_sampling_parameters,
|
10 |
+
make_ddim_timesteps,
|
11 |
+
noise_like,
|
12 |
+
extract_into_tensor,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
class DDIMSampler(object):
|
17 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
18 |
+
super().__init__()
|
19 |
+
self.model = model
|
20 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
21 |
+
self.schedule = schedule
|
22 |
+
|
23 |
+
def register_buffer(self, name, attr):
|
24 |
+
if type(attr) == torch.Tensor:
|
25 |
+
if attr.device != torch.device("cuda"):
|
26 |
+
attr = attr.to(torch.device("cuda"))
|
27 |
+
setattr(self, name, attr)
|
28 |
+
|
29 |
+
def make_schedule(
|
30 |
+
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
31 |
+
):
|
32 |
+
self.ddim_timesteps = make_ddim_timesteps(
|
33 |
+
ddim_discr_method=ddim_discretize,
|
34 |
+
num_ddim_timesteps=ddim_num_steps,
|
35 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
36 |
+
verbose=verbose,
|
37 |
+
)
|
38 |
+
alphas_cumprod = self.model.alphas_cumprod
|
39 |
+
assert (
|
40 |
+
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
41 |
+
), "alphas have to be defined for each timestep"
|
42 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
43 |
+
|
44 |
+
self.register_buffer("betas", to_torch(self.model.betas))
|
45 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
46 |
+
self.register_buffer(
|
47 |
+
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
48 |
+
)
|
49 |
+
|
50 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
51 |
+
self.register_buffer(
|
52 |
+
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
53 |
+
)
|
54 |
+
self.register_buffer(
|
55 |
+
"sqrt_one_minus_alphas_cumprod",
|
56 |
+
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
57 |
+
)
|
58 |
+
self.register_buffer(
|
59 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
60 |
+
)
|
61 |
+
self.register_buffer(
|
62 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
63 |
+
)
|
64 |
+
self.register_buffer(
|
65 |
+
"sqrt_recipm1_alphas_cumprod",
|
66 |
+
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
67 |
+
)
|
68 |
+
|
69 |
+
# ddim sampling parameters
|
70 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
71 |
+
alphacums=alphas_cumprod.cpu(),
|
72 |
+
ddim_timesteps=self.ddim_timesteps,
|
73 |
+
eta=ddim_eta,
|
74 |
+
verbose=verbose,
|
75 |
+
)
|
76 |
+
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
77 |
+
self.register_buffer("ddim_alphas", ddim_alphas)
|
78 |
+
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
79 |
+
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
80 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
81 |
+
(1 - self.alphas_cumprod_prev)
|
82 |
+
/ (1 - self.alphas_cumprod)
|
83 |
+
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
84 |
+
)
|
85 |
+
self.register_buffer(
|
86 |
+
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
87 |
+
)
|
88 |
+
|
89 |
+
@torch.no_grad()
|
90 |
+
def sample(
|
91 |
+
self,
|
92 |
+
S,
|
93 |
+
batch_size,
|
94 |
+
shape,
|
95 |
+
conditioning=None,
|
96 |
+
callback=None,
|
97 |
+
normals_sequence=None,
|
98 |
+
img_callback=None,
|
99 |
+
quantize_x0=False,
|
100 |
+
eta=0.0,
|
101 |
+
mask=None,
|
102 |
+
x0=None,
|
103 |
+
temperature=1.0,
|
104 |
+
noise_dropout=0.0,
|
105 |
+
score_corrector=None,
|
106 |
+
corrector_kwargs=None,
|
107 |
+
verbose=True,
|
108 |
+
x_T=None,
|
109 |
+
log_every_t=100,
|
110 |
+
unconditional_guidance_scale=1.0,
|
111 |
+
unconditional_conditioning=None,
|
112 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
113 |
+
**kwargs,
|
114 |
+
):
|
115 |
+
if conditioning is not None:
|
116 |
+
if isinstance(conditioning, dict):
|
117 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
118 |
+
if cbs != batch_size:
|
119 |
+
print(
|
120 |
+
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
if conditioning.shape[0] != batch_size:
|
124 |
+
print(
|
125 |
+
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
126 |
+
)
|
127 |
+
|
128 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
129 |
+
# sampling
|
130 |
+
C, H, W = shape
|
131 |
+
size = (batch_size, C, H, W)
|
132 |
+
|
133 |
+
samples, intermediates = self.ddim_sampling(
|
134 |
+
conditioning,
|
135 |
+
size,
|
136 |
+
callback=callback,
|
137 |
+
img_callback=img_callback,
|
138 |
+
quantize_denoised=quantize_x0,
|
139 |
+
mask=mask,
|
140 |
+
x0=x0,
|
141 |
+
ddim_use_original_steps=False,
|
142 |
+
noise_dropout=noise_dropout,
|
143 |
+
temperature=temperature,
|
144 |
+
score_corrector=score_corrector,
|
145 |
+
corrector_kwargs=corrector_kwargs,
|
146 |
+
x_T=x_T,
|
147 |
+
log_every_t=log_every_t,
|
148 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
149 |
+
unconditional_conditioning=unconditional_conditioning,
|
150 |
+
**kwargs,
|
151 |
+
)
|
152 |
+
return samples, intermediates
|
153 |
+
|
154 |
+
@torch.no_grad()
|
155 |
+
def ddim_sampling(
|
156 |
+
self,
|
157 |
+
cond,
|
158 |
+
shape,
|
159 |
+
x_T=None,
|
160 |
+
ddim_use_original_steps=False,
|
161 |
+
callback=None,
|
162 |
+
timesteps=None,
|
163 |
+
quantize_denoised=False,
|
164 |
+
mask=None,
|
165 |
+
x0=None,
|
166 |
+
img_callback=None,
|
167 |
+
log_every_t=100,
|
168 |
+
temperature=1.0,
|
169 |
+
noise_dropout=0.0,
|
170 |
+
score_corrector=None,
|
171 |
+
corrector_kwargs=None,
|
172 |
+
unconditional_guidance_scale=1.0,
|
173 |
+
unconditional_conditioning=None,
|
174 |
+
**kwargs,
|
175 |
+
):
|
176 |
+
"""
|
177 |
+
when inference time: all values of parameter
|
178 |
+
cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
|
179 |
+
shape: (5, 4, 32, 32)
|
180 |
+
x_T: None
|
181 |
+
ddim_use_original_steps: False
|
182 |
+
timesteps: None
|
183 |
+
callback: None
|
184 |
+
quantize_denoised: False
|
185 |
+
mask: None
|
186 |
+
image_callback: None
|
187 |
+
log_every_t: 100
|
188 |
+
temperature: 1.0
|
189 |
+
noise_dropout: 0.0
|
190 |
+
score_corrector: None
|
191 |
+
corrector_kwargs: None
|
192 |
+
unconditional_guidance_scale: 5
|
193 |
+
unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
|
194 |
+
kwargs: {}
|
195 |
+
"""
|
196 |
+
device = self.model.betas.device
|
197 |
+
b = shape[0]
|
198 |
+
if x_T is None:
|
199 |
+
img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94
|
200 |
+
else:
|
201 |
+
img = x_T
|
202 |
+
|
203 |
+
if timesteps is None: # equal with set time step in hf
|
204 |
+
timesteps = (
|
205 |
+
self.ddpm_num_timesteps
|
206 |
+
if ddim_use_original_steps
|
207 |
+
else self.ddim_timesteps
|
208 |
+
)
|
209 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
210 |
+
subset_end = (
|
211 |
+
int(
|
212 |
+
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
213 |
+
* self.ddim_timesteps.shape[0]
|
214 |
+
)
|
215 |
+
- 1
|
216 |
+
)
|
217 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
218 |
+
|
219 |
+
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
220 |
+
time_range = ( # reversed timesteps
|
221 |
+
reversed(range(0, timesteps))
|
222 |
+
if ddim_use_original_steps
|
223 |
+
else np.flip(timesteps)
|
224 |
+
)
|
225 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
226 |
+
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
227 |
+
for i, step in enumerate(iterator):
|
228 |
+
index = total_steps - i - 1
|
229 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
230 |
+
|
231 |
+
if mask is not None:
|
232 |
+
assert x0 is not None
|
233 |
+
img_orig = self.model.q_sample(
|
234 |
+
x0, ts
|
235 |
+
) # TODO: deterministic forward pass?
|
236 |
+
img = img_orig * mask + (1.0 - mask) * img
|
237 |
+
|
238 |
+
outs = self.p_sample_ddim(
|
239 |
+
img,
|
240 |
+
cond,
|
241 |
+
ts,
|
242 |
+
index=index,
|
243 |
+
use_original_steps=ddim_use_original_steps,
|
244 |
+
quantize_denoised=quantize_denoised,
|
245 |
+
temperature=temperature,
|
246 |
+
noise_dropout=noise_dropout,
|
247 |
+
score_corrector=score_corrector,
|
248 |
+
corrector_kwargs=corrector_kwargs,
|
249 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
250 |
+
unconditional_conditioning=unconditional_conditioning,
|
251 |
+
**kwargs,
|
252 |
+
)
|
253 |
+
img, pred_x0 = outs
|
254 |
+
if callback:
|
255 |
+
callback(i)
|
256 |
+
if img_callback:
|
257 |
+
img_callback(pred_x0, i)
|
258 |
+
|
259 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
260 |
+
intermediates["x_inter"].append(img)
|
261 |
+
intermediates["pred_x0"].append(pred_x0)
|
262 |
+
|
263 |
+
return img, intermediates
|
264 |
+
|
265 |
+
@torch.no_grad()
|
266 |
+
def p_sample_ddim(
|
267 |
+
self,
|
268 |
+
x,
|
269 |
+
c,
|
270 |
+
t,
|
271 |
+
index,
|
272 |
+
repeat_noise=False,
|
273 |
+
use_original_steps=False,
|
274 |
+
quantize_denoised=False,
|
275 |
+
temperature=1.0,
|
276 |
+
noise_dropout=0.0,
|
277 |
+
score_corrector=None,
|
278 |
+
corrector_kwargs=None,
|
279 |
+
unconditional_guidance_scale=1.0,
|
280 |
+
unconditional_conditioning=None,
|
281 |
+
dynamic_threshold=None,
|
282 |
+
**kwargs,
|
283 |
+
):
|
284 |
+
b, *_, device = *x.shape, x.device
|
285 |
+
|
286 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
|
287 |
+
model_output = self.model.apply_model(x, t, c)
|
288 |
+
else:
|
289 |
+
x_in = torch.cat([x] * 2)
|
290 |
+
t_in = torch.cat([t] * 2)
|
291 |
+
if isinstance(c, dict):
|
292 |
+
assert isinstance(unconditional_conditioning, dict)
|
293 |
+
c_in = dict()
|
294 |
+
for k in c:
|
295 |
+
if isinstance(c[k], list):
|
296 |
+
c_in[k] = [
|
297 |
+
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
298 |
+
for i in range(len(c[k]))
|
299 |
+
]
|
300 |
+
elif isinstance(c[k], torch.Tensor):
|
301 |
+
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
302 |
+
else:
|
303 |
+
assert c[k] == unconditional_conditioning[k]
|
304 |
+
c_in[k] = c[k]
|
305 |
+
elif isinstance(c, list):
|
306 |
+
c_in = list()
|
307 |
+
assert isinstance(unconditional_conditioning, list)
|
308 |
+
for i in range(len(c)):
|
309 |
+
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
310 |
+
else:
|
311 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
312 |
+
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
313 |
+
model_output = model_uncond + unconditional_guidance_scale * (
|
314 |
+
model_t - model_uncond
|
315 |
+
)
|
316 |
+
|
317 |
+
|
318 |
+
if self.model.parameterization == "v":
|
319 |
+
print("using v!")
|
320 |
+
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
321 |
+
else:
|
322 |
+
e_t = model_output
|
323 |
+
|
324 |
+
if score_corrector is not None:
|
325 |
+
assert self.model.parameterization == "eps", "not implemented"
|
326 |
+
e_t = score_corrector.modify_score(
|
327 |
+
self.model, e_t, x, t, c, **corrector_kwargs
|
328 |
+
)
|
329 |
+
|
330 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
331 |
+
alphas_prev = (
|
332 |
+
self.model.alphas_cumprod_prev
|
333 |
+
if use_original_steps
|
334 |
+
else self.ddim_alphas_prev
|
335 |
+
)
|
336 |
+
sqrt_one_minus_alphas = (
|
337 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
338 |
+
if use_original_steps
|
339 |
+
else self.ddim_sqrt_one_minus_alphas
|
340 |
+
)
|
341 |
+
sigmas = (
|
342 |
+
self.model.ddim_sigmas_for_original_num_steps
|
343 |
+
if use_original_steps
|
344 |
+
else self.ddim_sigmas
|
345 |
+
)
|
346 |
+
# select parameters corresponding to the currently considered timestep
|
347 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
348 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
349 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
350 |
+
sqrt_one_minus_at = torch.full(
|
351 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
352 |
+
)
|
353 |
+
|
354 |
+
# current prediction for x_0
|
355 |
+
if self.model.parameterization != "v":
|
356 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
357 |
+
else:
|
358 |
+
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
359 |
+
|
360 |
+
if quantize_denoised:
|
361 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
362 |
+
|
363 |
+
if dynamic_threshold is not None:
|
364 |
+
raise NotImplementedError()
|
365 |
+
|
366 |
+
# direction pointing to x_t
|
367 |
+
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
368 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
369 |
+
if noise_dropout > 0.0:
|
370 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
371 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
372 |
+
return x_prev, pred_x0
|
373 |
+
|
374 |
+
@torch.no_grad()
|
375 |
+
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
376 |
+
# fast, but does not allow for exact reconstruction
|
377 |
+
# t serves as an index to gather the correct alphas
|
378 |
+
if use_original_steps:
|
379 |
+
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
380 |
+
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
381 |
+
else:
|
382 |
+
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
383 |
+
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
384 |
+
|
385 |
+
if noise is None:
|
386 |
+
noise = torch.randn_like(x0)
|
387 |
+
return (
|
388 |
+
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
389 |
+
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
|
390 |
+
)
|
391 |
+
|
392 |
+
@torch.no_grad()
|
393 |
+
def decode(
|
394 |
+
self,
|
395 |
+
x_latent,
|
396 |
+
cond,
|
397 |
+
t_start,
|
398 |
+
unconditional_guidance_scale=1.0,
|
399 |
+
unconditional_conditioning=None,
|
400 |
+
use_original_steps=False,
|
401 |
+
**kwargs,
|
402 |
+
):
|
403 |
+
timesteps = (
|
404 |
+
np.arange(self.ddpm_num_timesteps)
|
405 |
+
if use_original_steps
|
406 |
+
else self.ddim_timesteps
|
407 |
+
)
|
408 |
+
timesteps = timesteps[:t_start]
|
409 |
+
|
410 |
+
time_range = np.flip(timesteps)
|
411 |
+
total_steps = timesteps.shape[0]
|
412 |
+
|
413 |
+
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
414 |
+
x_dec = x_latent
|
415 |
+
for i, step in enumerate(iterator):
|
416 |
+
index = total_steps - i - 1
|
417 |
+
ts = torch.full(
|
418 |
+
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
|
419 |
+
)
|
420 |
+
x_dec, _ = self.p_sample_ddim(
|
421 |
+
x_dec,
|
422 |
+
cond,
|
423 |
+
ts,
|
424 |
+
index=index,
|
425 |
+
use_original_steps=use_original_steps,
|
426 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
427 |
+
unconditional_conditioning=unconditional_conditioning,
|
428 |
+
**kwargs,
|
429 |
+
)
|
430 |
+
return x_dec
|
imagedream/ldm/modules/__init__.py
ADDED
File without changes
|
imagedream/ldm/modules/attention.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inspect import isfunction
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn, einsum
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from typing import Optional, Any
|
8 |
+
|
9 |
+
from .diffusionmodules.util import checkpoint
|
10 |
+
|
11 |
+
|
12 |
+
try:
|
13 |
+
import xformers
|
14 |
+
import xformers.ops
|
15 |
+
|
16 |
+
XFORMERS_IS_AVAILBLE = True
|
17 |
+
except:
|
18 |
+
XFORMERS_IS_AVAILBLE = False
|
19 |
+
|
20 |
+
# CrossAttn precision handling
|
21 |
+
import os
|
22 |
+
|
23 |
+
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
24 |
+
|
25 |
+
|
26 |
+
def exists(val):
|
27 |
+
return val is not None
|
28 |
+
|
29 |
+
|
30 |
+
def uniq(arr):
|
31 |
+
return {el: True for el in arr}.keys()
|
32 |
+
|
33 |
+
|
34 |
+
def default(val, d):
|
35 |
+
if exists(val):
|
36 |
+
return val
|
37 |
+
return d() if isfunction(d) else d
|
38 |
+
|
39 |
+
|
40 |
+
def max_neg_value(t):
|
41 |
+
return -torch.finfo(t.dtype).max
|
42 |
+
|
43 |
+
|
44 |
+
def init_(tensor):
|
45 |
+
dim = tensor.shape[-1]
|
46 |
+
std = 1 / math.sqrt(dim)
|
47 |
+
tensor.uniform_(-std, std)
|
48 |
+
return tensor
|
49 |
+
|
50 |
+
|
51 |
+
# feedforward
|
52 |
+
class GEGLU(nn.Module):
|
53 |
+
def __init__(self, dim_in, dim_out):
|
54 |
+
super().__init__()
|
55 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
59 |
+
return x * F.gelu(gate)
|
60 |
+
|
61 |
+
|
62 |
+
class FeedForward(nn.Module):
|
63 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
64 |
+
super().__init__()
|
65 |
+
inner_dim = int(dim * mult)
|
66 |
+
dim_out = default(dim_out, dim)
|
67 |
+
project_in = (
|
68 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
69 |
+
if not glu
|
70 |
+
else GEGLU(dim, inner_dim)
|
71 |
+
)
|
72 |
+
|
73 |
+
self.net = nn.Sequential(
|
74 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
75 |
+
)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
return self.net(x)
|
79 |
+
|
80 |
+
|
81 |
+
def zero_module(module):
|
82 |
+
"""
|
83 |
+
Zero out the parameters of a module and return it.
|
84 |
+
"""
|
85 |
+
for p in module.parameters():
|
86 |
+
p.detach().zero_()
|
87 |
+
return module
|
88 |
+
|
89 |
+
|
90 |
+
def Normalize(in_channels):
|
91 |
+
return torch.nn.GroupNorm(
|
92 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
class SpatialSelfAttention(nn.Module):
|
97 |
+
def __init__(self, in_channels):
|
98 |
+
super().__init__()
|
99 |
+
self.in_channels = in_channels
|
100 |
+
|
101 |
+
self.norm = Normalize(in_channels)
|
102 |
+
self.q = torch.nn.Conv2d(
|
103 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
104 |
+
)
|
105 |
+
self.k = torch.nn.Conv2d(
|
106 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
107 |
+
)
|
108 |
+
self.v = torch.nn.Conv2d(
|
109 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
110 |
+
)
|
111 |
+
self.proj_out = torch.nn.Conv2d(
|
112 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
113 |
+
)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
h_ = x
|
117 |
+
h_ = self.norm(h_)
|
118 |
+
q = self.q(h_)
|
119 |
+
k = self.k(h_)
|
120 |
+
v = self.v(h_)
|
121 |
+
|
122 |
+
# compute attention
|
123 |
+
b, c, h, w = q.shape
|
124 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
125 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
126 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
127 |
+
|
128 |
+
w_ = w_ * (int(c) ** (-0.5))
|
129 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
130 |
+
|
131 |
+
# attend to values
|
132 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
133 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
134 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
135 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
136 |
+
h_ = self.proj_out(h_)
|
137 |
+
|
138 |
+
return x + h_
|
139 |
+
|
140 |
+
|
141 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
142 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
143 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
|
144 |
+
super().__init__()
|
145 |
+
print(
|
146 |
+
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
147 |
+
f"{heads} heads."
|
148 |
+
)
|
149 |
+
inner_dim = dim_head * heads
|
150 |
+
context_dim = default(context_dim, query_dim)
|
151 |
+
|
152 |
+
self.heads = heads
|
153 |
+
self.dim_head = dim_head
|
154 |
+
|
155 |
+
self.with_ip = kwargs.get("with_ip", False)
|
156 |
+
if self.with_ip and (context_dim is not None):
|
157 |
+
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
158 |
+
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
159 |
+
self.ip_dim= kwargs.get("ip_dim", 16)
|
160 |
+
self.ip_weight = kwargs.get("ip_weight", 1.0)
|
161 |
+
|
162 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
163 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
164 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
165 |
+
|
166 |
+
self.to_out = nn.Sequential(
|
167 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
168 |
+
)
|
169 |
+
self.attention_op: Optional[Any] = None
|
170 |
+
|
171 |
+
def forward(self, x, context=None, mask=None):
|
172 |
+
q = self.to_q(x)
|
173 |
+
|
174 |
+
has_ip = self.with_ip and (context is not None)
|
175 |
+
if has_ip:
|
176 |
+
# context dim [(b frame_num), (77 + img_token), 1024]
|
177 |
+
token_len = context.shape[1]
|
178 |
+
context_ip = context[:, -self.ip_dim:, :]
|
179 |
+
k_ip = self.to_k_ip(context_ip)
|
180 |
+
v_ip = self.to_v_ip(context_ip)
|
181 |
+
context = context[:, :(token_len - self.ip_dim), :]
|
182 |
+
|
183 |
+
context = default(context, x)
|
184 |
+
k = self.to_k(context)
|
185 |
+
v = self.to_v(context)
|
186 |
+
|
187 |
+
b, _, _ = q.shape
|
188 |
+
q, k, v = map(
|
189 |
+
lambda t: t.unsqueeze(3)
|
190 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
191 |
+
.permute(0, 2, 1, 3)
|
192 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
193 |
+
.contiguous(),
|
194 |
+
(q, k, v),
|
195 |
+
)
|
196 |
+
|
197 |
+
# actually compute the attention, what we cannot get enough of
|
198 |
+
out = xformers.ops.memory_efficient_attention(
|
199 |
+
q, k, v, attn_bias=None, op=self.attention_op
|
200 |
+
)
|
201 |
+
|
202 |
+
if has_ip:
|
203 |
+
k_ip, v_ip = map(
|
204 |
+
lambda t: t.unsqueeze(3)
|
205 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
206 |
+
.permute(0, 2, 1, 3)
|
207 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
208 |
+
.contiguous(),
|
209 |
+
(k_ip, v_ip),
|
210 |
+
)
|
211 |
+
# actually compute the attention, what we cannot get enough of
|
212 |
+
out_ip = xformers.ops.memory_efficient_attention(
|
213 |
+
q, k_ip, v_ip, attn_bias=None, op=self.attention_op
|
214 |
+
)
|
215 |
+
out = out + self.ip_weight * out_ip
|
216 |
+
|
217 |
+
if exists(mask):
|
218 |
+
raise NotImplementedError
|
219 |
+
out = (
|
220 |
+
out.unsqueeze(0)
|
221 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
222 |
+
.permute(0, 2, 1, 3)
|
223 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
224 |
+
)
|
225 |
+
return self.to_out(out)
|
226 |
+
|
227 |
+
|
228 |
+
class BasicTransformerBlock(nn.Module):
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
dim,
|
232 |
+
n_heads,
|
233 |
+
d_head,
|
234 |
+
dropout=0.0,
|
235 |
+
context_dim=None,
|
236 |
+
gated_ff=True,
|
237 |
+
checkpoint=True,
|
238 |
+
disable_self_attn=False,
|
239 |
+
**kwargs
|
240 |
+
):
|
241 |
+
super().__init__()
|
242 |
+
assert XFORMERS_IS_AVAILBLE, "xformers is not available"
|
243 |
+
attn_cls = MemoryEfficientCrossAttention
|
244 |
+
self.disable_self_attn = disable_self_attn
|
245 |
+
self.attn1 = attn_cls(
|
246 |
+
query_dim=dim,
|
247 |
+
heads=n_heads,
|
248 |
+
dim_head=d_head,
|
249 |
+
dropout=dropout,
|
250 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
251 |
+
) # is a self-attention if not self.disable_self_attn
|
252 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
253 |
+
self.attn2 = attn_cls(
|
254 |
+
query_dim=dim,
|
255 |
+
context_dim=context_dim,
|
256 |
+
heads=n_heads,
|
257 |
+
dim_head=d_head,
|
258 |
+
dropout=dropout,
|
259 |
+
**kwargs
|
260 |
+
) # is self-attn if context is none
|
261 |
+
self.norm1 = nn.LayerNorm(dim)
|
262 |
+
self.norm2 = nn.LayerNorm(dim)
|
263 |
+
self.norm3 = nn.LayerNorm(dim)
|
264 |
+
self.checkpoint = checkpoint
|
265 |
+
|
266 |
+
def forward(self, x, context=None):
|
267 |
+
return checkpoint(
|
268 |
+
self._forward, (x, context), self.parameters(), self.checkpoint
|
269 |
+
)
|
270 |
+
|
271 |
+
def _forward(self, x, context=None):
|
272 |
+
x = (
|
273 |
+
self.attn1(
|
274 |
+
self.norm1(x), context=context if self.disable_self_attn else None
|
275 |
+
)
|
276 |
+
+ x
|
277 |
+
)
|
278 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
279 |
+
x = self.ff(self.norm3(x)) + x
|
280 |
+
return x
|
281 |
+
|
282 |
+
|
283 |
+
class SpatialTransformer(nn.Module):
|
284 |
+
"""
|
285 |
+
Transformer block for image-like data.
|
286 |
+
First, project the input (aka embedding)
|
287 |
+
and reshape to b, t, d.
|
288 |
+
Then apply standard transformer action.
|
289 |
+
Finally, reshape to image
|
290 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
291 |
+
"""
|
292 |
+
|
293 |
+
def __init__(
|
294 |
+
self,
|
295 |
+
in_channels,
|
296 |
+
n_heads,
|
297 |
+
d_head,
|
298 |
+
depth=1,
|
299 |
+
dropout=0.0,
|
300 |
+
context_dim=None,
|
301 |
+
disable_self_attn=False,
|
302 |
+
use_linear=False,
|
303 |
+
use_checkpoint=True,
|
304 |
+
**kwargs
|
305 |
+
):
|
306 |
+
super().__init__()
|
307 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
308 |
+
context_dim = [context_dim]
|
309 |
+
self.in_channels = in_channels
|
310 |
+
inner_dim = n_heads * d_head
|
311 |
+
self.norm = Normalize(in_channels)
|
312 |
+
if not use_linear:
|
313 |
+
self.proj_in = nn.Conv2d(
|
314 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
315 |
+
)
|
316 |
+
else:
|
317 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
318 |
+
|
319 |
+
self.transformer_blocks = nn.ModuleList(
|
320 |
+
[
|
321 |
+
BasicTransformerBlock(
|
322 |
+
inner_dim,
|
323 |
+
n_heads,
|
324 |
+
d_head,
|
325 |
+
dropout=dropout,
|
326 |
+
context_dim=context_dim[d],
|
327 |
+
disable_self_attn=disable_self_attn,
|
328 |
+
checkpoint=use_checkpoint,
|
329 |
+
**kwargs
|
330 |
+
)
|
331 |
+
for d in range(depth)
|
332 |
+
]
|
333 |
+
)
|
334 |
+
if not use_linear:
|
335 |
+
self.proj_out = zero_module(
|
336 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
340 |
+
self.use_linear = use_linear
|
341 |
+
|
342 |
+
def forward(self, x, context=None):
|
343 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
344 |
+
if not isinstance(context, list):
|
345 |
+
context = [context]
|
346 |
+
b, c, h, w = x.shape
|
347 |
+
x_in = x
|
348 |
+
x = self.norm(x)
|
349 |
+
if not self.use_linear:
|
350 |
+
x = self.proj_in(x)
|
351 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
352 |
+
if self.use_linear:
|
353 |
+
x = self.proj_in(x)
|
354 |
+
for i, block in enumerate(self.transformer_blocks):
|
355 |
+
x = block(x, context=context[i])
|
356 |
+
if self.use_linear:
|
357 |
+
x = self.proj_out(x)
|
358 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
359 |
+
if not self.use_linear:
|
360 |
+
x = self.proj_out(x)
|
361 |
+
return x + x_in
|
362 |
+
|
363 |
+
|
364 |
+
class BasicTransformerBlock3D(BasicTransformerBlock):
|
365 |
+
def forward(self, x, context=None, num_frames=1):
|
366 |
+
return checkpoint(
|
367 |
+
self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
|
368 |
+
)
|
369 |
+
|
370 |
+
def _forward(self, x, context=None, num_frames=1):
|
371 |
+
x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
|
372 |
+
x = (
|
373 |
+
self.attn1(
|
374 |
+
self.norm1(x),
|
375 |
+
context=context if self.disable_self_attn else None
|
376 |
+
)
|
377 |
+
+ x
|
378 |
+
)
|
379 |
+
x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
|
380 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
381 |
+
x = self.ff(self.norm3(x)) + x
|
382 |
+
return x
|
383 |
+
|
384 |
+
|
385 |
+
class SpatialTransformer3D(nn.Module):
|
386 |
+
"""3D self-attention"""
|
387 |
+
|
388 |
+
def __init__(
|
389 |
+
self,
|
390 |
+
in_channels,
|
391 |
+
n_heads,
|
392 |
+
d_head,
|
393 |
+
depth=1,
|
394 |
+
dropout=0.0,
|
395 |
+
context_dim=None,
|
396 |
+
disable_self_attn=False,
|
397 |
+
use_linear=False,
|
398 |
+
use_checkpoint=True,
|
399 |
+
**kwargs
|
400 |
+
):
|
401 |
+
super().__init__()
|
402 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
403 |
+
context_dim = [context_dim]
|
404 |
+
self.in_channels = in_channels
|
405 |
+
inner_dim = n_heads * d_head
|
406 |
+
self.norm = Normalize(in_channels)
|
407 |
+
if not use_linear:
|
408 |
+
self.proj_in = nn.Conv2d(
|
409 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
410 |
+
)
|
411 |
+
else:
|
412 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
413 |
+
|
414 |
+
self.transformer_blocks = nn.ModuleList(
|
415 |
+
[
|
416 |
+
BasicTransformerBlock3D(
|
417 |
+
inner_dim,
|
418 |
+
n_heads,
|
419 |
+
d_head,
|
420 |
+
dropout=dropout,
|
421 |
+
context_dim=context_dim[d],
|
422 |
+
disable_self_attn=disable_self_attn,
|
423 |
+
checkpoint=use_checkpoint,
|
424 |
+
**kwargs
|
425 |
+
)
|
426 |
+
for d in range(depth)
|
427 |
+
]
|
428 |
+
)
|
429 |
+
if not use_linear:
|
430 |
+
self.proj_out = zero_module(
|
431 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
432 |
+
)
|
433 |
+
else:
|
434 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
435 |
+
self.use_linear = use_linear
|
436 |
+
|
437 |
+
def forward(self, x, context=None, num_frames=1):
|
438 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
439 |
+
if not isinstance(context, list):
|
440 |
+
context = [context]
|
441 |
+
b, c, h, w = x.shape
|
442 |
+
x_in = x
|
443 |
+
x = self.norm(x)
|
444 |
+
if not self.use_linear:
|
445 |
+
x = self.proj_in(x)
|
446 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
447 |
+
if self.use_linear:
|
448 |
+
x = self.proj_in(x)
|
449 |
+
for i, block in enumerate(self.transformer_blocks):
|
450 |
+
x = block(x, context=context[i], num_frames=num_frames)
|
451 |
+
if self.use_linear:
|
452 |
+
x = self.proj_out(x)
|
453 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
454 |
+
if not self.use_linear:
|
455 |
+
x = self.proj_out(x)
|
456 |
+
return x + x_in
|
imagedream/ldm/modules/diffusionmodules/__init__.py
ADDED
File without changes
|
imagedream/ldm/modules/diffusionmodules/adaptors.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
# FFN
|
9 |
+
def FeedForward(dim, mult=4):
|
10 |
+
inner_dim = int(dim * mult)
|
11 |
+
return nn.Sequential(
|
12 |
+
nn.LayerNorm(dim),
|
13 |
+
nn.Linear(dim, inner_dim, bias=False),
|
14 |
+
nn.GELU(),
|
15 |
+
nn.Linear(inner_dim, dim, bias=False),
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def reshape_tensor(x, heads):
|
20 |
+
bs, length, width = x.shape
|
21 |
+
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
22 |
+
x = x.view(bs, length, heads, -1)
|
23 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
24 |
+
x = x.transpose(1, 2)
|
25 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
26 |
+
x = x.reshape(bs, heads, length, -1)
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
class PerceiverAttention(nn.Module):
|
31 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
32 |
+
super().__init__()
|
33 |
+
self.scale = dim_head**-0.5
|
34 |
+
self.dim_head = dim_head
|
35 |
+
self.heads = heads
|
36 |
+
inner_dim = dim_head * heads
|
37 |
+
|
38 |
+
self.norm1 = nn.LayerNorm(dim)
|
39 |
+
self.norm2 = nn.LayerNorm(dim)
|
40 |
+
|
41 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
42 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
43 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, x, latents):
|
47 |
+
"""
|
48 |
+
Args:
|
49 |
+
x (torch.Tensor): image features
|
50 |
+
shape (b, n1, D)
|
51 |
+
latent (torch.Tensor): latent features
|
52 |
+
shape (b, n2, D)
|
53 |
+
"""
|
54 |
+
x = self.norm1(x)
|
55 |
+
latents = self.norm2(latents)
|
56 |
+
|
57 |
+
b, l, _ = latents.shape
|
58 |
+
|
59 |
+
q = self.to_q(latents)
|
60 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
61 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
62 |
+
|
63 |
+
q = reshape_tensor(q, self.heads)
|
64 |
+
k = reshape_tensor(k, self.heads)
|
65 |
+
v = reshape_tensor(v, self.heads)
|
66 |
+
|
67 |
+
# attention
|
68 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
69 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
70 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
71 |
+
out = weight @ v
|
72 |
+
|
73 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
74 |
+
|
75 |
+
return self.to_out(out)
|
76 |
+
|
77 |
+
|
78 |
+
class ImageProjModel(torch.nn.Module):
|
79 |
+
"""Projection Model"""
|
80 |
+
def __init__(self,
|
81 |
+
cross_attention_dim=1024,
|
82 |
+
clip_embeddings_dim=1024,
|
83 |
+
clip_extra_context_tokens=4):
|
84 |
+
super().__init__()
|
85 |
+
self.cross_attention_dim = cross_attention_dim
|
86 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
87 |
+
|
88 |
+
# from 1024 -> 4 * 1024
|
89 |
+
self.proj = torch.nn.Linear(
|
90 |
+
clip_embeddings_dim,
|
91 |
+
self.clip_extra_context_tokens * cross_attention_dim)
|
92 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
93 |
+
|
94 |
+
def forward(self, image_embeds):
|
95 |
+
embeds = image_embeds
|
96 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
97 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
98 |
+
return clip_extra_context_tokens
|
99 |
+
|
100 |
+
|
101 |
+
class SimpleReSampler(nn.Module):
|
102 |
+
def __init__(self, embedding_dim=1280, output_dim=1024):
|
103 |
+
super().__init__()
|
104 |
+
self.proj_out = nn.Linear(embedding_dim, output_dim)
|
105 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
106 |
+
|
107 |
+
def forward(self, latents):
|
108 |
+
"""
|
109 |
+
latents: B 256 N
|
110 |
+
"""
|
111 |
+
latents = self.proj_out(latents)
|
112 |
+
return self.norm_out(latents)
|
113 |
+
|
114 |
+
|
115 |
+
class Resampler(nn.Module):
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
dim=1024,
|
119 |
+
depth=8,
|
120 |
+
dim_head=64,
|
121 |
+
heads=16,
|
122 |
+
num_queries=8,
|
123 |
+
embedding_dim=768,
|
124 |
+
output_dim=1024,
|
125 |
+
ff_mult=4,
|
126 |
+
):
|
127 |
+
super().__init__()
|
128 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
129 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
130 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
131 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
132 |
+
|
133 |
+
self.layers = nn.ModuleList([])
|
134 |
+
for _ in range(depth):
|
135 |
+
self.layers.append(
|
136 |
+
nn.ModuleList(
|
137 |
+
[
|
138 |
+
PerceiverAttention(dim=dim,
|
139 |
+
dim_head=dim_head,
|
140 |
+
heads=heads),
|
141 |
+
FeedForward(dim=dim, mult=ff_mult),
|
142 |
+
]
|
143 |
+
)
|
144 |
+
)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
148 |
+
x = self.proj_in(x)
|
149 |
+
for attn, ff in self.layers:
|
150 |
+
latents = attn(x, latents) + latents
|
151 |
+
latents = ff(latents) + latents
|
152 |
+
|
153 |
+
latents = self.proj_out(latents)
|
154 |
+
return self.norm_out(latents)
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == '__main__':
|
158 |
+
resampler = Resampler(embedding_dim=1280)
|
159 |
+
resampler = SimpleReSampler(embedding_dim=1280)
|
160 |
+
tensor = torch.rand(4, 257, 1280)
|
161 |
+
embed = resampler(tensor)
|
162 |
+
# embed = (tensor)
|
163 |
+
print(embed.shape)
|
imagedream/ldm/modules/diffusionmodules/model.py
ADDED
@@ -0,0 +1,1018 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_diffusion + derived encoder decoder
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
from einops import rearrange
|
7 |
+
from typing import Optional, Any
|
8 |
+
|
9 |
+
from ..attention import MemoryEfficientCrossAttention
|
10 |
+
|
11 |
+
try:
|
12 |
+
import xformers
|
13 |
+
import xformers.ops
|
14 |
+
|
15 |
+
XFORMERS_IS_AVAILBLE = True
|
16 |
+
except:
|
17 |
+
XFORMERS_IS_AVAILBLE = False
|
18 |
+
print("No module 'xformers'. Proceeding without it.")
|
19 |
+
|
20 |
+
|
21 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
22 |
+
"""
|
23 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
24 |
+
From Fairseq.
|
25 |
+
Build sinusoidal embeddings.
|
26 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
27 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
28 |
+
"""
|
29 |
+
assert len(timesteps.shape) == 1
|
30 |
+
|
31 |
+
half_dim = embedding_dim // 2
|
32 |
+
emb = math.log(10000) / (half_dim - 1)
|
33 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
34 |
+
emb = emb.to(device=timesteps.device)
|
35 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
36 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
37 |
+
if embedding_dim % 2 == 1: # zero pad
|
38 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
39 |
+
return emb
|
40 |
+
|
41 |
+
|
42 |
+
def nonlinearity(x):
|
43 |
+
# swish
|
44 |
+
return x * torch.sigmoid(x)
|
45 |
+
|
46 |
+
|
47 |
+
def Normalize(in_channels, num_groups=32):
|
48 |
+
return torch.nn.GroupNorm(
|
49 |
+
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
class Upsample(nn.Module):
|
54 |
+
def __init__(self, in_channels, with_conv):
|
55 |
+
super().__init__()
|
56 |
+
self.with_conv = with_conv
|
57 |
+
if self.with_conv:
|
58 |
+
self.conv = torch.nn.Conv2d(
|
59 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
64 |
+
if self.with_conv:
|
65 |
+
x = self.conv(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class Downsample(nn.Module):
|
70 |
+
def __init__(self, in_channels, with_conv):
|
71 |
+
super().__init__()
|
72 |
+
self.with_conv = with_conv
|
73 |
+
if self.with_conv:
|
74 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
75 |
+
self.conv = torch.nn.Conv2d(
|
76 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
if self.with_conv:
|
81 |
+
pad = (0, 1, 0, 1)
|
82 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
83 |
+
x = self.conv(x)
|
84 |
+
else:
|
85 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class ResnetBlock(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
*,
|
93 |
+
in_channels,
|
94 |
+
out_channels=None,
|
95 |
+
conv_shortcut=False,
|
96 |
+
dropout,
|
97 |
+
temb_channels=512,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
self.in_channels = in_channels
|
101 |
+
out_channels = in_channels if out_channels is None else out_channels
|
102 |
+
self.out_channels = out_channels
|
103 |
+
self.use_conv_shortcut = conv_shortcut
|
104 |
+
|
105 |
+
self.norm1 = Normalize(in_channels)
|
106 |
+
self.conv1 = torch.nn.Conv2d(
|
107 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
108 |
+
)
|
109 |
+
if temb_channels > 0:
|
110 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
111 |
+
self.norm2 = Normalize(out_channels)
|
112 |
+
self.dropout = torch.nn.Dropout(dropout)
|
113 |
+
self.conv2 = torch.nn.Conv2d(
|
114 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
115 |
+
)
|
116 |
+
if self.in_channels != self.out_channels:
|
117 |
+
if self.use_conv_shortcut:
|
118 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
119 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
123 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, x, temb):
|
127 |
+
h = x
|
128 |
+
h = self.norm1(h)
|
129 |
+
h = nonlinearity(h)
|
130 |
+
h = self.conv1(h)
|
131 |
+
|
132 |
+
if temb is not None:
|
133 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
134 |
+
|
135 |
+
h = self.norm2(h)
|
136 |
+
h = nonlinearity(h)
|
137 |
+
h = self.dropout(h)
|
138 |
+
h = self.conv2(h)
|
139 |
+
|
140 |
+
if self.in_channels != self.out_channels:
|
141 |
+
if self.use_conv_shortcut:
|
142 |
+
x = self.conv_shortcut(x)
|
143 |
+
else:
|
144 |
+
x = self.nin_shortcut(x)
|
145 |
+
|
146 |
+
return x + h
|
147 |
+
|
148 |
+
|
149 |
+
class AttnBlock(nn.Module):
|
150 |
+
def __init__(self, in_channels):
|
151 |
+
super().__init__()
|
152 |
+
self.in_channels = in_channels
|
153 |
+
|
154 |
+
self.norm = Normalize(in_channels)
|
155 |
+
self.q = torch.nn.Conv2d(
|
156 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
157 |
+
)
|
158 |
+
self.k = torch.nn.Conv2d(
|
159 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
160 |
+
)
|
161 |
+
self.v = torch.nn.Conv2d(
|
162 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
163 |
+
)
|
164 |
+
self.proj_out = torch.nn.Conv2d(
|
165 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
166 |
+
)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
h_ = x
|
170 |
+
h_ = self.norm(h_)
|
171 |
+
q = self.q(h_)
|
172 |
+
k = self.k(h_)
|
173 |
+
v = self.v(h_)
|
174 |
+
|
175 |
+
# compute attention
|
176 |
+
b, c, h, w = q.shape
|
177 |
+
q = q.reshape(b, c, h * w)
|
178 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
179 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
180 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
181 |
+
w_ = w_ * (int(c) ** (-0.5))
|
182 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
183 |
+
|
184 |
+
# attend to values
|
185 |
+
v = v.reshape(b, c, h * w)
|
186 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
187 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
188 |
+
h_ = h_.reshape(b, c, h, w)
|
189 |
+
|
190 |
+
h_ = self.proj_out(h_)
|
191 |
+
|
192 |
+
return x + h_
|
193 |
+
|
194 |
+
|
195 |
+
class MemoryEfficientAttnBlock(nn.Module):
|
196 |
+
"""
|
197 |
+
Uses xformers efficient implementation,
|
198 |
+
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
199 |
+
Note: this is a single-head self-attention operation
|
200 |
+
"""
|
201 |
+
|
202 |
+
#
|
203 |
+
def __init__(self, in_channels):
|
204 |
+
super().__init__()
|
205 |
+
self.in_channels = in_channels
|
206 |
+
|
207 |
+
self.norm = Normalize(in_channels)
|
208 |
+
self.q = torch.nn.Conv2d(
|
209 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
210 |
+
)
|
211 |
+
self.k = torch.nn.Conv2d(
|
212 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
213 |
+
)
|
214 |
+
self.v = torch.nn.Conv2d(
|
215 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
216 |
+
)
|
217 |
+
self.proj_out = torch.nn.Conv2d(
|
218 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
219 |
+
)
|
220 |
+
self.attention_op: Optional[Any] = None
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
h_ = x
|
224 |
+
h_ = self.norm(h_)
|
225 |
+
q = self.q(h_)
|
226 |
+
k = self.k(h_)
|
227 |
+
v = self.v(h_)
|
228 |
+
|
229 |
+
# compute attention
|
230 |
+
B, C, H, W = q.shape
|
231 |
+
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
|
232 |
+
|
233 |
+
q, k, v = map(
|
234 |
+
lambda t: t.unsqueeze(3)
|
235 |
+
.reshape(B, t.shape[1], 1, C)
|
236 |
+
.permute(0, 2, 1, 3)
|
237 |
+
.reshape(B * 1, t.shape[1], C)
|
238 |
+
.contiguous(),
|
239 |
+
(q, k, v),
|
240 |
+
)
|
241 |
+
out = xformers.ops.memory_efficient_attention(
|
242 |
+
q, k, v, attn_bias=None, op=self.attention_op
|
243 |
+
)
|
244 |
+
|
245 |
+
out = (
|
246 |
+
out.unsqueeze(0)
|
247 |
+
.reshape(B, 1, out.shape[1], C)
|
248 |
+
.permute(0, 2, 1, 3)
|
249 |
+
.reshape(B, out.shape[1], C)
|
250 |
+
)
|
251 |
+
out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
|
252 |
+
out = self.proj_out(out)
|
253 |
+
return x + out
|
254 |
+
|
255 |
+
|
256 |
+
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
257 |
+
def forward(self, x, context=None, mask=None):
|
258 |
+
b, c, h, w = x.shape
|
259 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
260 |
+
out = super().forward(x, context=context, mask=mask)
|
261 |
+
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
|
262 |
+
return x + out
|
263 |
+
|
264 |
+
|
265 |
+
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
266 |
+
assert attn_type in [
|
267 |
+
"vanilla",
|
268 |
+
"vanilla-xformers",
|
269 |
+
"memory-efficient-cross-attn",
|
270 |
+
"linear",
|
271 |
+
"none",
|
272 |
+
], f"attn_type {attn_type} unknown"
|
273 |
+
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
274 |
+
attn_type = "vanilla-xformers"
|
275 |
+
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
276 |
+
if attn_type == "vanilla":
|
277 |
+
assert attn_kwargs is None
|
278 |
+
return AttnBlock(in_channels)
|
279 |
+
elif attn_type == "vanilla-xformers":
|
280 |
+
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
281 |
+
return MemoryEfficientAttnBlock(in_channels)
|
282 |
+
elif type == "memory-efficient-cross-attn":
|
283 |
+
attn_kwargs["query_dim"] = in_channels
|
284 |
+
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
285 |
+
elif attn_type == "none":
|
286 |
+
return nn.Identity(in_channels)
|
287 |
+
else:
|
288 |
+
raise NotImplementedError()
|
289 |
+
|
290 |
+
|
291 |
+
class Model(nn.Module):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
*,
|
295 |
+
ch,
|
296 |
+
out_ch,
|
297 |
+
ch_mult=(1, 2, 4, 8),
|
298 |
+
num_res_blocks,
|
299 |
+
attn_resolutions,
|
300 |
+
dropout=0.0,
|
301 |
+
resamp_with_conv=True,
|
302 |
+
in_channels,
|
303 |
+
resolution,
|
304 |
+
use_timestep=True,
|
305 |
+
use_linear_attn=False,
|
306 |
+
attn_type="vanilla",
|
307 |
+
):
|
308 |
+
super().__init__()
|
309 |
+
if use_linear_attn:
|
310 |
+
attn_type = "linear"
|
311 |
+
self.ch = ch
|
312 |
+
self.temb_ch = self.ch * 4
|
313 |
+
self.num_resolutions = len(ch_mult)
|
314 |
+
self.num_res_blocks = num_res_blocks
|
315 |
+
self.resolution = resolution
|
316 |
+
self.in_channels = in_channels
|
317 |
+
|
318 |
+
self.use_timestep = use_timestep
|
319 |
+
if self.use_timestep:
|
320 |
+
# timestep embedding
|
321 |
+
self.temb = nn.Module()
|
322 |
+
self.temb.dense = nn.ModuleList(
|
323 |
+
[
|
324 |
+
torch.nn.Linear(self.ch, self.temb_ch),
|
325 |
+
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
326 |
+
]
|
327 |
+
)
|
328 |
+
|
329 |
+
# downsampling
|
330 |
+
self.conv_in = torch.nn.Conv2d(
|
331 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
332 |
+
)
|
333 |
+
|
334 |
+
curr_res = resolution
|
335 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
336 |
+
self.down = nn.ModuleList()
|
337 |
+
for i_level in range(self.num_resolutions):
|
338 |
+
block = nn.ModuleList()
|
339 |
+
attn = nn.ModuleList()
|
340 |
+
block_in = ch * in_ch_mult[i_level]
|
341 |
+
block_out = ch * ch_mult[i_level]
|
342 |
+
for i_block in range(self.num_res_blocks):
|
343 |
+
block.append(
|
344 |
+
ResnetBlock(
|
345 |
+
in_channels=block_in,
|
346 |
+
out_channels=block_out,
|
347 |
+
temb_channels=self.temb_ch,
|
348 |
+
dropout=dropout,
|
349 |
+
)
|
350 |
+
)
|
351 |
+
block_in = block_out
|
352 |
+
if curr_res in attn_resolutions:
|
353 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
354 |
+
down = nn.Module()
|
355 |
+
down.block = block
|
356 |
+
down.attn = attn
|
357 |
+
if i_level != self.num_resolutions - 1:
|
358 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
359 |
+
curr_res = curr_res // 2
|
360 |
+
self.down.append(down)
|
361 |
+
|
362 |
+
# middle
|
363 |
+
self.mid = nn.Module()
|
364 |
+
self.mid.block_1 = ResnetBlock(
|
365 |
+
in_channels=block_in,
|
366 |
+
out_channels=block_in,
|
367 |
+
temb_channels=self.temb_ch,
|
368 |
+
dropout=dropout,
|
369 |
+
)
|
370 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
371 |
+
self.mid.block_2 = ResnetBlock(
|
372 |
+
in_channels=block_in,
|
373 |
+
out_channels=block_in,
|
374 |
+
temb_channels=self.temb_ch,
|
375 |
+
dropout=dropout,
|
376 |
+
)
|
377 |
+
|
378 |
+
# upsampling
|
379 |
+
self.up = nn.ModuleList()
|
380 |
+
for i_level in reversed(range(self.num_resolutions)):
|
381 |
+
block = nn.ModuleList()
|
382 |
+
attn = nn.ModuleList()
|
383 |
+
block_out = ch * ch_mult[i_level]
|
384 |
+
skip_in = ch * ch_mult[i_level]
|
385 |
+
for i_block in range(self.num_res_blocks + 1):
|
386 |
+
if i_block == self.num_res_blocks:
|
387 |
+
skip_in = ch * in_ch_mult[i_level]
|
388 |
+
block.append(
|
389 |
+
ResnetBlock(
|
390 |
+
in_channels=block_in + skip_in,
|
391 |
+
out_channels=block_out,
|
392 |
+
temb_channels=self.temb_ch,
|
393 |
+
dropout=dropout,
|
394 |
+
)
|
395 |
+
)
|
396 |
+
block_in = block_out
|
397 |
+
if curr_res in attn_resolutions:
|
398 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
399 |
+
up = nn.Module()
|
400 |
+
up.block = block
|
401 |
+
up.attn = attn
|
402 |
+
if i_level != 0:
|
403 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
404 |
+
curr_res = curr_res * 2
|
405 |
+
self.up.insert(0, up) # prepend to get consistent order
|
406 |
+
|
407 |
+
# end
|
408 |
+
self.norm_out = Normalize(block_in)
|
409 |
+
self.conv_out = torch.nn.Conv2d(
|
410 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
411 |
+
)
|
412 |
+
|
413 |
+
def forward(self, x, t=None, context=None):
|
414 |
+
# assert x.shape[2] == x.shape[3] == self.resolution
|
415 |
+
if context is not None:
|
416 |
+
# assume aligned context, cat along channel axis
|
417 |
+
x = torch.cat((x, context), dim=1)
|
418 |
+
if self.use_timestep:
|
419 |
+
# timestep embedding
|
420 |
+
assert t is not None
|
421 |
+
temb = get_timestep_embedding(t, self.ch)
|
422 |
+
temb = self.temb.dense[0](temb)
|
423 |
+
temb = nonlinearity(temb)
|
424 |
+
temb = self.temb.dense[1](temb)
|
425 |
+
else:
|
426 |
+
temb = None
|
427 |
+
|
428 |
+
# downsampling
|
429 |
+
hs = [self.conv_in(x)]
|
430 |
+
for i_level in range(self.num_resolutions):
|
431 |
+
for i_block in range(self.num_res_blocks):
|
432 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
433 |
+
if len(self.down[i_level].attn) > 0:
|
434 |
+
h = self.down[i_level].attn[i_block](h)
|
435 |
+
hs.append(h)
|
436 |
+
if i_level != self.num_resolutions - 1:
|
437 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
438 |
+
|
439 |
+
# middle
|
440 |
+
h = hs[-1]
|
441 |
+
h = self.mid.block_1(h, temb)
|
442 |
+
h = self.mid.attn_1(h)
|
443 |
+
h = self.mid.block_2(h, temb)
|
444 |
+
|
445 |
+
# upsampling
|
446 |
+
for i_level in reversed(range(self.num_resolutions)):
|
447 |
+
for i_block in range(self.num_res_blocks + 1):
|
448 |
+
h = self.up[i_level].block[i_block](
|
449 |
+
torch.cat([h, hs.pop()], dim=1), temb
|
450 |
+
)
|
451 |
+
if len(self.up[i_level].attn) > 0:
|
452 |
+
h = self.up[i_level].attn[i_block](h)
|
453 |
+
if i_level != 0:
|
454 |
+
h = self.up[i_level].upsample(h)
|
455 |
+
|
456 |
+
# end
|
457 |
+
h = self.norm_out(h)
|
458 |
+
h = nonlinearity(h)
|
459 |
+
h = self.conv_out(h)
|
460 |
+
return h
|
461 |
+
|
462 |
+
def get_last_layer(self):
|
463 |
+
return self.conv_out.weight
|
464 |
+
|
465 |
+
|
466 |
+
class Encoder(nn.Module):
|
467 |
+
def __init__(
|
468 |
+
self,
|
469 |
+
*,
|
470 |
+
ch,
|
471 |
+
out_ch,
|
472 |
+
ch_mult=(1, 2, 4, 8),
|
473 |
+
num_res_blocks,
|
474 |
+
attn_resolutions,
|
475 |
+
dropout=0.0,
|
476 |
+
resamp_with_conv=True,
|
477 |
+
in_channels,
|
478 |
+
resolution,
|
479 |
+
z_channels,
|
480 |
+
double_z=True,
|
481 |
+
use_linear_attn=False,
|
482 |
+
attn_type="vanilla",
|
483 |
+
**ignore_kwargs,
|
484 |
+
):
|
485 |
+
super().__init__()
|
486 |
+
if use_linear_attn:
|
487 |
+
attn_type = "linear"
|
488 |
+
self.ch = ch
|
489 |
+
self.temb_ch = 0
|
490 |
+
self.num_resolutions = len(ch_mult)
|
491 |
+
self.num_res_blocks = num_res_blocks
|
492 |
+
self.resolution = resolution
|
493 |
+
self.in_channels = in_channels
|
494 |
+
|
495 |
+
# downsampling
|
496 |
+
self.conv_in = torch.nn.Conv2d(
|
497 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
498 |
+
)
|
499 |
+
|
500 |
+
curr_res = resolution
|
501 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
502 |
+
self.in_ch_mult = in_ch_mult
|
503 |
+
self.down = nn.ModuleList()
|
504 |
+
for i_level in range(self.num_resolutions):
|
505 |
+
block = nn.ModuleList()
|
506 |
+
attn = nn.ModuleList()
|
507 |
+
block_in = ch * in_ch_mult[i_level]
|
508 |
+
block_out = ch * ch_mult[i_level]
|
509 |
+
for i_block in range(self.num_res_blocks):
|
510 |
+
block.append(
|
511 |
+
ResnetBlock(
|
512 |
+
in_channels=block_in,
|
513 |
+
out_channels=block_out,
|
514 |
+
temb_channels=self.temb_ch,
|
515 |
+
dropout=dropout,
|
516 |
+
)
|
517 |
+
)
|
518 |
+
block_in = block_out
|
519 |
+
if curr_res in attn_resolutions:
|
520 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
521 |
+
down = nn.Module()
|
522 |
+
down.block = block
|
523 |
+
down.attn = attn
|
524 |
+
if i_level != self.num_resolutions - 1:
|
525 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
526 |
+
curr_res = curr_res // 2
|
527 |
+
self.down.append(down)
|
528 |
+
|
529 |
+
# middle
|
530 |
+
self.mid = nn.Module()
|
531 |
+
self.mid.block_1 = ResnetBlock(
|
532 |
+
in_channels=block_in,
|
533 |
+
out_channels=block_in,
|
534 |
+
temb_channels=self.temb_ch,
|
535 |
+
dropout=dropout,
|
536 |
+
)
|
537 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
538 |
+
self.mid.block_2 = ResnetBlock(
|
539 |
+
in_channels=block_in,
|
540 |
+
out_channels=block_in,
|
541 |
+
temb_channels=self.temb_ch,
|
542 |
+
dropout=dropout,
|
543 |
+
)
|
544 |
+
|
545 |
+
# end
|
546 |
+
self.norm_out = Normalize(block_in)
|
547 |
+
self.conv_out = torch.nn.Conv2d(
|
548 |
+
block_in,
|
549 |
+
2 * z_channels if double_z else z_channels,
|
550 |
+
kernel_size=3,
|
551 |
+
stride=1,
|
552 |
+
padding=1,
|
553 |
+
)
|
554 |
+
|
555 |
+
def forward(self, x):
|
556 |
+
# timestep embedding
|
557 |
+
temb = None
|
558 |
+
|
559 |
+
# downsampling
|
560 |
+
hs = [self.conv_in(x)]
|
561 |
+
for i_level in range(self.num_resolutions):
|
562 |
+
for i_block in range(self.num_res_blocks):
|
563 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
564 |
+
if len(self.down[i_level].attn) > 0:
|
565 |
+
h = self.down[i_level].attn[i_block](h)
|
566 |
+
hs.append(h)
|
567 |
+
if i_level != self.num_resolutions - 1:
|
568 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
569 |
+
|
570 |
+
# middle
|
571 |
+
h = hs[-1]
|
572 |
+
h = self.mid.block_1(h, temb)
|
573 |
+
h = self.mid.attn_1(h)
|
574 |
+
h = self.mid.block_2(h, temb)
|
575 |
+
|
576 |
+
# end
|
577 |
+
h = self.norm_out(h)
|
578 |
+
h = nonlinearity(h)
|
579 |
+
h = self.conv_out(h)
|
580 |
+
return h
|
581 |
+
|
582 |
+
|
583 |
+
class Decoder(nn.Module):
|
584 |
+
def __init__(
|
585 |
+
self,
|
586 |
+
*,
|
587 |
+
ch,
|
588 |
+
out_ch,
|
589 |
+
ch_mult=(1, 2, 4, 8),
|
590 |
+
num_res_blocks,
|
591 |
+
attn_resolutions,
|
592 |
+
dropout=0.0,
|
593 |
+
resamp_with_conv=True,
|
594 |
+
in_channels,
|
595 |
+
resolution,
|
596 |
+
z_channels,
|
597 |
+
give_pre_end=False,
|
598 |
+
tanh_out=False,
|
599 |
+
use_linear_attn=False,
|
600 |
+
attn_type="vanilla",
|
601 |
+
**ignorekwargs,
|
602 |
+
):
|
603 |
+
super().__init__()
|
604 |
+
if use_linear_attn:
|
605 |
+
attn_type = "linear"
|
606 |
+
self.ch = ch
|
607 |
+
self.temb_ch = 0
|
608 |
+
self.num_resolutions = len(ch_mult)
|
609 |
+
self.num_res_blocks = num_res_blocks
|
610 |
+
self.resolution = resolution
|
611 |
+
self.in_channels = in_channels
|
612 |
+
self.give_pre_end = give_pre_end
|
613 |
+
self.tanh_out = tanh_out
|
614 |
+
|
615 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
616 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
617 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
618 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
619 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
620 |
+
print(
|
621 |
+
"Working with z of shape {} = {} dimensions.".format(
|
622 |
+
self.z_shape, np.prod(self.z_shape)
|
623 |
+
)
|
624 |
+
)
|
625 |
+
|
626 |
+
# z to block_in
|
627 |
+
self.conv_in = torch.nn.Conv2d(
|
628 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
629 |
+
)
|
630 |
+
|
631 |
+
# middle
|
632 |
+
self.mid = nn.Module()
|
633 |
+
self.mid.block_1 = ResnetBlock(
|
634 |
+
in_channels=block_in,
|
635 |
+
out_channels=block_in,
|
636 |
+
temb_channels=self.temb_ch,
|
637 |
+
dropout=dropout,
|
638 |
+
)
|
639 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
640 |
+
self.mid.block_2 = ResnetBlock(
|
641 |
+
in_channels=block_in,
|
642 |
+
out_channels=block_in,
|
643 |
+
temb_channels=self.temb_ch,
|
644 |
+
dropout=dropout,
|
645 |
+
)
|
646 |
+
|
647 |
+
# upsampling
|
648 |
+
self.up = nn.ModuleList()
|
649 |
+
for i_level in reversed(range(self.num_resolutions)):
|
650 |
+
block = nn.ModuleList()
|
651 |
+
attn = nn.ModuleList()
|
652 |
+
block_out = ch * ch_mult[i_level]
|
653 |
+
for i_block in range(self.num_res_blocks + 1):
|
654 |
+
block.append(
|
655 |
+
ResnetBlock(
|
656 |
+
in_channels=block_in,
|
657 |
+
out_channels=block_out,
|
658 |
+
temb_channels=self.temb_ch,
|
659 |
+
dropout=dropout,
|
660 |
+
)
|
661 |
+
)
|
662 |
+
block_in = block_out
|
663 |
+
if curr_res in attn_resolutions:
|
664 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
665 |
+
up = nn.Module()
|
666 |
+
up.block = block
|
667 |
+
up.attn = attn
|
668 |
+
if i_level != 0:
|
669 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
670 |
+
curr_res = curr_res * 2
|
671 |
+
self.up.insert(0, up) # prepend to get consistent order
|
672 |
+
|
673 |
+
# end
|
674 |
+
self.norm_out = Normalize(block_in)
|
675 |
+
self.conv_out = torch.nn.Conv2d(
|
676 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
677 |
+
)
|
678 |
+
|
679 |
+
def forward(self, z):
|
680 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
681 |
+
self.last_z_shape = z.shape
|
682 |
+
|
683 |
+
# timestep embedding
|
684 |
+
temb = None
|
685 |
+
|
686 |
+
# z to block_in
|
687 |
+
h = self.conv_in(z)
|
688 |
+
|
689 |
+
# middle
|
690 |
+
h = self.mid.block_1(h, temb)
|
691 |
+
h = self.mid.attn_1(h)
|
692 |
+
h = self.mid.block_2(h, temb)
|
693 |
+
|
694 |
+
# upsampling
|
695 |
+
for i_level in reversed(range(self.num_resolutions)):
|
696 |
+
for i_block in range(self.num_res_blocks + 1):
|
697 |
+
h = self.up[i_level].block[i_block](h, temb)
|
698 |
+
if len(self.up[i_level].attn) > 0:
|
699 |
+
h = self.up[i_level].attn[i_block](h)
|
700 |
+
if i_level != 0:
|
701 |
+
h = self.up[i_level].upsample(h)
|
702 |
+
|
703 |
+
# end
|
704 |
+
if self.give_pre_end:
|
705 |
+
return h
|
706 |
+
|
707 |
+
h = self.norm_out(h)
|
708 |
+
h = nonlinearity(h)
|
709 |
+
h = self.conv_out(h)
|
710 |
+
if self.tanh_out:
|
711 |
+
h = torch.tanh(h)
|
712 |
+
return h
|
713 |
+
|
714 |
+
|
715 |
+
class SimpleDecoder(nn.Module):
|
716 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
717 |
+
super().__init__()
|
718 |
+
self.model = nn.ModuleList(
|
719 |
+
[
|
720 |
+
nn.Conv2d(in_channels, in_channels, 1),
|
721 |
+
ResnetBlock(
|
722 |
+
in_channels=in_channels,
|
723 |
+
out_channels=2 * in_channels,
|
724 |
+
temb_channels=0,
|
725 |
+
dropout=0.0,
|
726 |
+
),
|
727 |
+
ResnetBlock(
|
728 |
+
in_channels=2 * in_channels,
|
729 |
+
out_channels=4 * in_channels,
|
730 |
+
temb_channels=0,
|
731 |
+
dropout=0.0,
|
732 |
+
),
|
733 |
+
ResnetBlock(
|
734 |
+
in_channels=4 * in_channels,
|
735 |
+
out_channels=2 * in_channels,
|
736 |
+
temb_channels=0,
|
737 |
+
dropout=0.0,
|
738 |
+
),
|
739 |
+
nn.Conv2d(2 * in_channels, in_channels, 1),
|
740 |
+
Upsample(in_channels, with_conv=True),
|
741 |
+
]
|
742 |
+
)
|
743 |
+
# end
|
744 |
+
self.norm_out = Normalize(in_channels)
|
745 |
+
self.conv_out = torch.nn.Conv2d(
|
746 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
747 |
+
)
|
748 |
+
|
749 |
+
def forward(self, x):
|
750 |
+
for i, layer in enumerate(self.model):
|
751 |
+
if i in [1, 2, 3]:
|
752 |
+
x = layer(x, None)
|
753 |
+
else:
|
754 |
+
x = layer(x)
|
755 |
+
|
756 |
+
h = self.norm_out(x)
|
757 |
+
h = nonlinearity(h)
|
758 |
+
x = self.conv_out(h)
|
759 |
+
return x
|
760 |
+
|
761 |
+
|
762 |
+
class UpsampleDecoder(nn.Module):
|
763 |
+
def __init__(
|
764 |
+
self,
|
765 |
+
in_channels,
|
766 |
+
out_channels,
|
767 |
+
ch,
|
768 |
+
num_res_blocks,
|
769 |
+
resolution,
|
770 |
+
ch_mult=(2, 2),
|
771 |
+
dropout=0.0,
|
772 |
+
):
|
773 |
+
super().__init__()
|
774 |
+
# upsampling
|
775 |
+
self.temb_ch = 0
|
776 |
+
self.num_resolutions = len(ch_mult)
|
777 |
+
self.num_res_blocks = num_res_blocks
|
778 |
+
block_in = in_channels
|
779 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
780 |
+
self.res_blocks = nn.ModuleList()
|
781 |
+
self.upsample_blocks = nn.ModuleList()
|
782 |
+
for i_level in range(self.num_resolutions):
|
783 |
+
res_block = []
|
784 |
+
block_out = ch * ch_mult[i_level]
|
785 |
+
for i_block in range(self.num_res_blocks + 1):
|
786 |
+
res_block.append(
|
787 |
+
ResnetBlock(
|
788 |
+
in_channels=block_in,
|
789 |
+
out_channels=block_out,
|
790 |
+
temb_channels=self.temb_ch,
|
791 |
+
dropout=dropout,
|
792 |
+
)
|
793 |
+
)
|
794 |
+
block_in = block_out
|
795 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
796 |
+
if i_level != self.num_resolutions - 1:
|
797 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
798 |
+
curr_res = curr_res * 2
|
799 |
+
|
800 |
+
# end
|
801 |
+
self.norm_out = Normalize(block_in)
|
802 |
+
self.conv_out = torch.nn.Conv2d(
|
803 |
+
block_in, out_channels, kernel_size=3, stride=1, padding=1
|
804 |
+
)
|
805 |
+
|
806 |
+
def forward(self, x):
|
807 |
+
# upsampling
|
808 |
+
h = x
|
809 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
810 |
+
for i_block in range(self.num_res_blocks + 1):
|
811 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
812 |
+
if i_level != self.num_resolutions - 1:
|
813 |
+
h = self.upsample_blocks[k](h)
|
814 |
+
h = self.norm_out(h)
|
815 |
+
h = nonlinearity(h)
|
816 |
+
h = self.conv_out(h)
|
817 |
+
return h
|
818 |
+
|
819 |
+
|
820 |
+
class LatentRescaler(nn.Module):
|
821 |
+
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
822 |
+
super().__init__()
|
823 |
+
# residual block, interpolate, residual block
|
824 |
+
self.factor = factor
|
825 |
+
self.conv_in = nn.Conv2d(
|
826 |
+
in_channels, mid_channels, kernel_size=3, stride=1, padding=1
|
827 |
+
)
|
828 |
+
self.res_block1 = nn.ModuleList(
|
829 |
+
[
|
830 |
+
ResnetBlock(
|
831 |
+
in_channels=mid_channels,
|
832 |
+
out_channels=mid_channels,
|
833 |
+
temb_channels=0,
|
834 |
+
dropout=0.0,
|
835 |
+
)
|
836 |
+
for _ in range(depth)
|
837 |
+
]
|
838 |
+
)
|
839 |
+
self.attn = AttnBlock(mid_channels)
|
840 |
+
self.res_block2 = nn.ModuleList(
|
841 |
+
[
|
842 |
+
ResnetBlock(
|
843 |
+
in_channels=mid_channels,
|
844 |
+
out_channels=mid_channels,
|
845 |
+
temb_channels=0,
|
846 |
+
dropout=0.0,
|
847 |
+
)
|
848 |
+
for _ in range(depth)
|
849 |
+
]
|
850 |
+
)
|
851 |
+
|
852 |
+
self.conv_out = nn.Conv2d(
|
853 |
+
mid_channels,
|
854 |
+
out_channels,
|
855 |
+
kernel_size=1,
|
856 |
+
)
|
857 |
+
|
858 |
+
def forward(self, x):
|
859 |
+
x = self.conv_in(x)
|
860 |
+
for block in self.res_block1:
|
861 |
+
x = block(x, None)
|
862 |
+
x = torch.nn.functional.interpolate(
|
863 |
+
x,
|
864 |
+
size=(
|
865 |
+
int(round(x.shape[2] * self.factor)),
|
866 |
+
int(round(x.shape[3] * self.factor)),
|
867 |
+
),
|
868 |
+
)
|
869 |
+
x = self.attn(x)
|
870 |
+
for block in self.res_block2:
|
871 |
+
x = block(x, None)
|
872 |
+
x = self.conv_out(x)
|
873 |
+
return x
|
874 |
+
|
875 |
+
|
876 |
+
class MergedRescaleEncoder(nn.Module):
|
877 |
+
def __init__(
|
878 |
+
self,
|
879 |
+
in_channels,
|
880 |
+
ch,
|
881 |
+
resolution,
|
882 |
+
out_ch,
|
883 |
+
num_res_blocks,
|
884 |
+
attn_resolutions,
|
885 |
+
dropout=0.0,
|
886 |
+
resamp_with_conv=True,
|
887 |
+
ch_mult=(1, 2, 4, 8),
|
888 |
+
rescale_factor=1.0,
|
889 |
+
rescale_module_depth=1,
|
890 |
+
):
|
891 |
+
super().__init__()
|
892 |
+
intermediate_chn = ch * ch_mult[-1]
|
893 |
+
self.encoder = Encoder(
|
894 |
+
in_channels=in_channels,
|
895 |
+
num_res_blocks=num_res_blocks,
|
896 |
+
ch=ch,
|
897 |
+
ch_mult=ch_mult,
|
898 |
+
z_channels=intermediate_chn,
|
899 |
+
double_z=False,
|
900 |
+
resolution=resolution,
|
901 |
+
attn_resolutions=attn_resolutions,
|
902 |
+
dropout=dropout,
|
903 |
+
resamp_with_conv=resamp_with_conv,
|
904 |
+
out_ch=None,
|
905 |
+
)
|
906 |
+
self.rescaler = LatentRescaler(
|
907 |
+
factor=rescale_factor,
|
908 |
+
in_channels=intermediate_chn,
|
909 |
+
mid_channels=intermediate_chn,
|
910 |
+
out_channels=out_ch,
|
911 |
+
depth=rescale_module_depth,
|
912 |
+
)
|
913 |
+
|
914 |
+
def forward(self, x):
|
915 |
+
x = self.encoder(x)
|
916 |
+
x = self.rescaler(x)
|
917 |
+
return x
|
918 |
+
|
919 |
+
|
920 |
+
class MergedRescaleDecoder(nn.Module):
|
921 |
+
def __init__(
|
922 |
+
self,
|
923 |
+
z_channels,
|
924 |
+
out_ch,
|
925 |
+
resolution,
|
926 |
+
num_res_blocks,
|
927 |
+
attn_resolutions,
|
928 |
+
ch,
|
929 |
+
ch_mult=(1, 2, 4, 8),
|
930 |
+
dropout=0.0,
|
931 |
+
resamp_with_conv=True,
|
932 |
+
rescale_factor=1.0,
|
933 |
+
rescale_module_depth=1,
|
934 |
+
):
|
935 |
+
super().__init__()
|
936 |
+
tmp_chn = z_channels * ch_mult[-1]
|
937 |
+
self.decoder = Decoder(
|
938 |
+
out_ch=out_ch,
|
939 |
+
z_channels=tmp_chn,
|
940 |
+
attn_resolutions=attn_resolutions,
|
941 |
+
dropout=dropout,
|
942 |
+
resamp_with_conv=resamp_with_conv,
|
943 |
+
in_channels=None,
|
944 |
+
num_res_blocks=num_res_blocks,
|
945 |
+
ch_mult=ch_mult,
|
946 |
+
resolution=resolution,
|
947 |
+
ch=ch,
|
948 |
+
)
|
949 |
+
self.rescaler = LatentRescaler(
|
950 |
+
factor=rescale_factor,
|
951 |
+
in_channels=z_channels,
|
952 |
+
mid_channels=tmp_chn,
|
953 |
+
out_channels=tmp_chn,
|
954 |
+
depth=rescale_module_depth,
|
955 |
+
)
|
956 |
+
|
957 |
+
def forward(self, x):
|
958 |
+
x = self.rescaler(x)
|
959 |
+
x = self.decoder(x)
|
960 |
+
return x
|
961 |
+
|
962 |
+
|
963 |
+
class Upsampler(nn.Module):
|
964 |
+
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
965 |
+
super().__init__()
|
966 |
+
assert out_size >= in_size
|
967 |
+
num_blocks = int(np.log2(out_size // in_size)) + 1
|
968 |
+
factor_up = 1.0 + (out_size % in_size)
|
969 |
+
print(
|
970 |
+
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
|
971 |
+
)
|
972 |
+
self.rescaler = LatentRescaler(
|
973 |
+
factor=factor_up,
|
974 |
+
in_channels=in_channels,
|
975 |
+
mid_channels=2 * in_channels,
|
976 |
+
out_channels=in_channels,
|
977 |
+
)
|
978 |
+
self.decoder = Decoder(
|
979 |
+
out_ch=out_channels,
|
980 |
+
resolution=out_size,
|
981 |
+
z_channels=in_channels,
|
982 |
+
num_res_blocks=2,
|
983 |
+
attn_resolutions=[],
|
984 |
+
in_channels=None,
|
985 |
+
ch=in_channels,
|
986 |
+
ch_mult=[ch_mult for _ in range(num_blocks)],
|
987 |
+
)
|
988 |
+
|
989 |
+
def forward(self, x):
|
990 |
+
x = self.rescaler(x)
|
991 |
+
x = self.decoder(x)
|
992 |
+
return x
|
993 |
+
|
994 |
+
|
995 |
+
class Resize(nn.Module):
|
996 |
+
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
997 |
+
super().__init__()
|
998 |
+
self.with_conv = learned
|
999 |
+
self.mode = mode
|
1000 |
+
if self.with_conv:
|
1001 |
+
print(
|
1002 |
+
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
|
1003 |
+
)
|
1004 |
+
raise NotImplementedError()
|
1005 |
+
assert in_channels is not None
|
1006 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
1007 |
+
self.conv = torch.nn.Conv2d(
|
1008 |
+
in_channels, in_channels, kernel_size=4, stride=2, padding=1
|
1009 |
+
)
|
1010 |
+
|
1011 |
+
def forward(self, x, scale_factor=1.0):
|
1012 |
+
if scale_factor == 1.0:
|
1013 |
+
return x
|
1014 |
+
else:
|
1015 |
+
x = torch.nn.functional.interpolate(
|
1016 |
+
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
|
1017 |
+
)
|
1018 |
+
return x
|
imagedream/ldm/modules/diffusionmodules/openaimodel.py
ADDED
@@ -0,0 +1,1135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch as th
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from imagedream.ldm.modules.diffusionmodules.util import (
|
12 |
+
checkpoint,
|
13 |
+
conv_nd,
|
14 |
+
linear,
|
15 |
+
avg_pool_nd,
|
16 |
+
zero_module,
|
17 |
+
normalization,
|
18 |
+
timestep_embedding,
|
19 |
+
convert_module_to_f16,
|
20 |
+
convert_module_to_f32
|
21 |
+
)
|
22 |
+
from imagedream.ldm.modules.attention import (
|
23 |
+
SpatialTransformer,
|
24 |
+
SpatialTransformer3D,
|
25 |
+
exists
|
26 |
+
)
|
27 |
+
from imagedream.ldm.modules.diffusionmodules.adaptors import (
|
28 |
+
Resampler,
|
29 |
+
ImageProjModel
|
30 |
+
)
|
31 |
+
|
32 |
+
## go
|
33 |
+
class AttentionPool2d(nn.Module):
|
34 |
+
"""
|
35 |
+
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
spacial_dim: int,
|
41 |
+
embed_dim: int,
|
42 |
+
num_heads_channels: int,
|
43 |
+
output_dim: int = None,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.positional_embedding = nn.Parameter(
|
47 |
+
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
|
48 |
+
)
|
49 |
+
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
50 |
+
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
51 |
+
self.num_heads = embed_dim // num_heads_channels
|
52 |
+
self.attention = QKVAttention(self.num_heads)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
b, c, *_spatial = x.shape
|
56 |
+
x = x.reshape(b, c, -1) # NC(HW)
|
57 |
+
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
58 |
+
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
59 |
+
x = self.qkv_proj(x)
|
60 |
+
x = self.attention(x)
|
61 |
+
x = self.c_proj(x)
|
62 |
+
return x[:, :, 0]
|
63 |
+
|
64 |
+
|
65 |
+
class TimestepBlock(nn.Module):
|
66 |
+
"""
|
67 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
68 |
+
"""
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def forward(self, x, emb):
|
72 |
+
"""
|
73 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
74 |
+
"""
|
75 |
+
|
76 |
+
|
77 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
78 |
+
"""
|
79 |
+
A sequential module that passes timestep embeddings to the children that
|
80 |
+
support it as an extra input.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def forward(self, x, emb, context=None, num_frames=1):
|
84 |
+
for layer in self:
|
85 |
+
if isinstance(layer, TimestepBlock):
|
86 |
+
x = layer(x, emb)
|
87 |
+
elif isinstance(layer, SpatialTransformer3D):
|
88 |
+
x = layer(x, context, num_frames=num_frames)
|
89 |
+
elif isinstance(layer, SpatialTransformer):
|
90 |
+
x = layer(x, context)
|
91 |
+
else:
|
92 |
+
x = layer(x)
|
93 |
+
return x
|
94 |
+
|
95 |
+
|
96 |
+
class Upsample(nn.Module):
|
97 |
+
"""
|
98 |
+
An upsampling layer with an optional convolution.
|
99 |
+
:param channels: channels in the inputs and outputs.
|
100 |
+
:param use_conv: a bool determining if a convolution is applied.
|
101 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
102 |
+
upsampling occurs in the inner-two dimensions.
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
106 |
+
super().__init__()
|
107 |
+
self.channels = channels
|
108 |
+
self.out_channels = out_channels or channels
|
109 |
+
self.use_conv = use_conv
|
110 |
+
self.dims = dims
|
111 |
+
if use_conv:
|
112 |
+
self.conv = conv_nd(
|
113 |
+
dims, self.channels, self.out_channels, 3, padding=padding
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
assert x.shape[1] == self.channels
|
118 |
+
if self.dims == 3:
|
119 |
+
x = F.interpolate(
|
120 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
124 |
+
if self.use_conv:
|
125 |
+
x = self.conv(x)
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
class TransposedUpsample(nn.Module):
|
130 |
+
"Learned 2x upsampling without padding"
|
131 |
+
|
132 |
+
def __init__(self, channels, out_channels=None, ks=5):
|
133 |
+
super().__init__()
|
134 |
+
self.channels = channels
|
135 |
+
self.out_channels = out_channels or channels
|
136 |
+
|
137 |
+
self.up = nn.ConvTranspose2d(
|
138 |
+
self.channels, self.out_channels, kernel_size=ks, stride=2
|
139 |
+
)
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
return self.up(x)
|
143 |
+
|
144 |
+
|
145 |
+
class Downsample(nn.Module):
|
146 |
+
"""
|
147 |
+
A downsampling layer with an optional convolution.
|
148 |
+
:param channels: channels in the inputs and outputs.
|
149 |
+
:param use_conv: a bool determining if a convolution is applied.
|
150 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
151 |
+
downsampling occurs in the inner-two dimensions.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
155 |
+
super().__init__()
|
156 |
+
self.channels = channels
|
157 |
+
self.out_channels = out_channels or channels
|
158 |
+
self.use_conv = use_conv
|
159 |
+
self.dims = dims
|
160 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
161 |
+
if use_conv:
|
162 |
+
self.op = conv_nd(
|
163 |
+
dims,
|
164 |
+
self.channels,
|
165 |
+
self.out_channels,
|
166 |
+
3,
|
167 |
+
stride=stride,
|
168 |
+
padding=padding,
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
assert self.channels == self.out_channels
|
172 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
assert x.shape[1] == self.channels
|
176 |
+
return self.op(x)
|
177 |
+
|
178 |
+
|
179 |
+
class ResBlock(TimestepBlock):
|
180 |
+
"""
|
181 |
+
A residual block that can optionally change the number of channels.
|
182 |
+
:param channels: the number of input channels.
|
183 |
+
:param emb_channels: the number of timestep embedding channels.
|
184 |
+
:param dropout: the rate of dropout.
|
185 |
+
:param out_channels: if specified, the number of out channels.
|
186 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
187 |
+
convolution instead of a smaller 1x1 convolution to change the
|
188 |
+
channels in the skip connection.
|
189 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
190 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
191 |
+
:param up: if True, use this block for upsampling.
|
192 |
+
:param down: if True, use this block for downsampling.
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
channels,
|
198 |
+
emb_channels,
|
199 |
+
dropout,
|
200 |
+
out_channels=None,
|
201 |
+
use_conv=False,
|
202 |
+
use_scale_shift_norm=False,
|
203 |
+
dims=2,
|
204 |
+
use_checkpoint=False,
|
205 |
+
up=False,
|
206 |
+
down=False,
|
207 |
+
):
|
208 |
+
super().__init__()
|
209 |
+
self.channels = channels
|
210 |
+
self.emb_channels = emb_channels
|
211 |
+
self.dropout = dropout
|
212 |
+
self.out_channels = out_channels or channels
|
213 |
+
self.use_conv = use_conv
|
214 |
+
self.use_checkpoint = use_checkpoint
|
215 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
216 |
+
|
217 |
+
self.in_layers = nn.Sequential(
|
218 |
+
normalization(channels),
|
219 |
+
nn.SiLU(),
|
220 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
221 |
+
)
|
222 |
+
|
223 |
+
self.updown = up or down
|
224 |
+
|
225 |
+
if up:
|
226 |
+
self.h_upd = Upsample(channels, False, dims)
|
227 |
+
self.x_upd = Upsample(channels, False, dims)
|
228 |
+
elif down:
|
229 |
+
self.h_upd = Downsample(channels, False, dims)
|
230 |
+
self.x_upd = Downsample(channels, False, dims)
|
231 |
+
else:
|
232 |
+
self.h_upd = self.x_upd = nn.Identity()
|
233 |
+
|
234 |
+
self.emb_layers = nn.Sequential(
|
235 |
+
nn.SiLU(),
|
236 |
+
linear(
|
237 |
+
emb_channels,
|
238 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
239 |
+
),
|
240 |
+
)
|
241 |
+
self.out_layers = nn.Sequential(
|
242 |
+
normalization(self.out_channels),
|
243 |
+
nn.SiLU(),
|
244 |
+
nn.Dropout(p=dropout),
|
245 |
+
zero_module(
|
246 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
247 |
+
),
|
248 |
+
)
|
249 |
+
|
250 |
+
if self.out_channels == channels:
|
251 |
+
self.skip_connection = nn.Identity()
|
252 |
+
elif use_conv:
|
253 |
+
self.skip_connection = conv_nd(
|
254 |
+
dims, channels, self.out_channels, 3, padding=1
|
255 |
+
)
|
256 |
+
else:
|
257 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
258 |
+
|
259 |
+
def forward(self, x, emb):
|
260 |
+
"""
|
261 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
262 |
+
:param x: an [N x C x ...] Tensor of features.
|
263 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
264 |
+
:return: an [N x C x ...] Tensor of outputs.
|
265 |
+
"""
|
266 |
+
return checkpoint(
|
267 |
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
268 |
+
)
|
269 |
+
|
270 |
+
def _forward(self, x, emb):
|
271 |
+
if self.updown:
|
272 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
273 |
+
h = in_rest(x)
|
274 |
+
h = self.h_upd(h)
|
275 |
+
x = self.x_upd(x)
|
276 |
+
h = in_conv(h)
|
277 |
+
else:
|
278 |
+
h = self.in_layers(x)
|
279 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
280 |
+
while len(emb_out.shape) < len(h.shape):
|
281 |
+
emb_out = emb_out[..., None]
|
282 |
+
if self.use_scale_shift_norm:
|
283 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
284 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
285 |
+
h = out_norm(h) * (1 + scale) + shift
|
286 |
+
h = out_rest(h)
|
287 |
+
else:
|
288 |
+
h = h + emb_out
|
289 |
+
h = self.out_layers(h)
|
290 |
+
return self.skip_connection(x) + h
|
291 |
+
|
292 |
+
|
293 |
+
class AttentionBlock(nn.Module):
|
294 |
+
"""
|
295 |
+
An attention block that allows spatial positions to attend to each other.
|
296 |
+
Originally ported from here, but adapted to the N-d case.
|
297 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
298 |
+
"""
|
299 |
+
|
300 |
+
def __init__(
|
301 |
+
self,
|
302 |
+
channels,
|
303 |
+
num_heads=1,
|
304 |
+
num_head_channels=-1,
|
305 |
+
use_checkpoint=False,
|
306 |
+
use_new_attention_order=False,
|
307 |
+
):
|
308 |
+
super().__init__()
|
309 |
+
self.channels = channels
|
310 |
+
if num_head_channels == -1:
|
311 |
+
self.num_heads = num_heads
|
312 |
+
else:
|
313 |
+
assert (
|
314 |
+
channels % num_head_channels == 0
|
315 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
316 |
+
self.num_heads = channels // num_head_channels
|
317 |
+
self.use_checkpoint = use_checkpoint
|
318 |
+
self.norm = normalization(channels)
|
319 |
+
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
320 |
+
if use_new_attention_order:
|
321 |
+
# split qkv before split heads
|
322 |
+
self.attention = QKVAttention(self.num_heads)
|
323 |
+
else:
|
324 |
+
# split heads before split qkv
|
325 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
326 |
+
|
327 |
+
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
328 |
+
|
329 |
+
def forward(self, x):
|
330 |
+
return checkpoint(
|
331 |
+
self._forward, (x,), self.parameters(), True
|
332 |
+
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
333 |
+
# return pt_checkpoint(self._forward, x) # pytorch
|
334 |
+
|
335 |
+
def _forward(self, x):
|
336 |
+
b, c, *spatial = x.shape
|
337 |
+
x = x.reshape(b, c, -1)
|
338 |
+
qkv = self.qkv(self.norm(x))
|
339 |
+
h = self.attention(qkv)
|
340 |
+
h = self.proj_out(h)
|
341 |
+
return (x + h).reshape(b, c, *spatial)
|
342 |
+
|
343 |
+
|
344 |
+
def count_flops_attn(model, _x, y):
|
345 |
+
"""
|
346 |
+
A counter for the `thop` package to count the operations in an
|
347 |
+
attention operation.
|
348 |
+
Meant to be used like:
|
349 |
+
macs, params = thop.profile(
|
350 |
+
model,
|
351 |
+
inputs=(inputs, timestamps),
|
352 |
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
353 |
+
)
|
354 |
+
"""
|
355 |
+
b, c, *spatial = y[0].shape
|
356 |
+
num_spatial = int(np.prod(spatial))
|
357 |
+
# We perform two matmuls with the same number of ops.
|
358 |
+
# The first computes the weight matrix, the second computes
|
359 |
+
# the combination of the value vectors.
|
360 |
+
matmul_ops = 2 * b * (num_spatial**2) * c
|
361 |
+
model.total_ops += th.DoubleTensor([matmul_ops])
|
362 |
+
|
363 |
+
|
364 |
+
class QKVAttentionLegacy(nn.Module):
|
365 |
+
"""
|
366 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(self, n_heads):
|
370 |
+
super().__init__()
|
371 |
+
self.n_heads = n_heads
|
372 |
+
|
373 |
+
def forward(self, qkv):
|
374 |
+
"""
|
375 |
+
Apply QKV attention.
|
376 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
377 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
378 |
+
"""
|
379 |
+
bs, width, length = qkv.shape
|
380 |
+
assert width % (3 * self.n_heads) == 0
|
381 |
+
ch = width // (3 * self.n_heads)
|
382 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
383 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
384 |
+
weight = th.einsum(
|
385 |
+
"bct,bcs->bts", q * scale, k * scale
|
386 |
+
) # More stable with f16 than dividing afterwards
|
387 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
388 |
+
a = th.einsum("bts,bcs->bct", weight, v)
|
389 |
+
return a.reshape(bs, -1, length)
|
390 |
+
|
391 |
+
@staticmethod
|
392 |
+
def count_flops(model, _x, y):
|
393 |
+
return count_flops_attn(model, _x, y)
|
394 |
+
|
395 |
+
|
396 |
+
class QKVAttention(nn.Module):
|
397 |
+
"""
|
398 |
+
A module which performs QKV attention and splits in a different order.
|
399 |
+
"""
|
400 |
+
|
401 |
+
def __init__(self, n_heads):
|
402 |
+
super().__init__()
|
403 |
+
self.n_heads = n_heads
|
404 |
+
|
405 |
+
def forward(self, qkv):
|
406 |
+
"""
|
407 |
+
Apply QKV attention.
|
408 |
+
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
409 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
410 |
+
"""
|
411 |
+
bs, width, length = qkv.shape
|
412 |
+
assert width % (3 * self.n_heads) == 0
|
413 |
+
ch = width // (3 * self.n_heads)
|
414 |
+
q, k, v = qkv.chunk(3, dim=1)
|
415 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
416 |
+
weight = th.einsum(
|
417 |
+
"bct,bcs->bts",
|
418 |
+
(q * scale).view(bs * self.n_heads, ch, length),
|
419 |
+
(k * scale).view(bs * self.n_heads, ch, length),
|
420 |
+
) # More stable with f16 than dividing afterwards
|
421 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
422 |
+
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
423 |
+
return a.reshape(bs, -1, length)
|
424 |
+
|
425 |
+
@staticmethod
|
426 |
+
def count_flops(model, _x, y):
|
427 |
+
return count_flops_attn(model, _x, y)
|
428 |
+
|
429 |
+
|
430 |
+
class Timestep(nn.Module):
|
431 |
+
def __init__(self, dim):
|
432 |
+
super().__init__()
|
433 |
+
self.dim = dim
|
434 |
+
|
435 |
+
def forward(self, t):
|
436 |
+
return timestep_embedding(t, self.dim)
|
437 |
+
|
438 |
+
|
439 |
+
class MultiViewUNetModel(nn.Module):
|
440 |
+
"""
|
441 |
+
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
442 |
+
:param in_channels: channels in the input Tensor.
|
443 |
+
:param model_channels: base channel count for the model.
|
444 |
+
:param out_channels: channels in the output Tensor.
|
445 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
446 |
+
:param attention_resolutions: a collection of downsample rates at which
|
447 |
+
attention will take place. May be a set, list, or tuple.
|
448 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
449 |
+
will be used.
|
450 |
+
:param dropout: the dropout probability.
|
451 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
452 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
453 |
+
downsampling.
|
454 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
455 |
+
:param num_classes: if specified (as an int), then this model will be
|
456 |
+
class-conditional with `num_classes` classes.
|
457 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
458 |
+
:param num_heads: the number of attention heads in each attention layer.
|
459 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
460 |
+
a fixed channel width per attention head.
|
461 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
462 |
+
of heads for upsampling. Deprecated.
|
463 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
464 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
465 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
466 |
+
increased efficiency.
|
467 |
+
:param camera_dim: dimensionality of camera input.
|
468 |
+
"""
|
469 |
+
|
470 |
+
def __init__(
|
471 |
+
self,
|
472 |
+
image_size,
|
473 |
+
in_channels,
|
474 |
+
model_channels,
|
475 |
+
out_channels,
|
476 |
+
num_res_blocks,
|
477 |
+
attention_resolutions,
|
478 |
+
dropout=0,
|
479 |
+
channel_mult=(1, 2, 4, 8),
|
480 |
+
conv_resample=True,
|
481 |
+
dims=2,
|
482 |
+
num_classes=None,
|
483 |
+
use_checkpoint=False,
|
484 |
+
use_fp16=False,
|
485 |
+
use_bf16=False,
|
486 |
+
num_heads=-1,
|
487 |
+
num_head_channels=-1,
|
488 |
+
num_heads_upsample=-1,
|
489 |
+
use_scale_shift_norm=False,
|
490 |
+
resblock_updown=False,
|
491 |
+
use_new_attention_order=False,
|
492 |
+
use_spatial_transformer=False, # custom transformer support
|
493 |
+
transformer_depth=1, # custom transformer support
|
494 |
+
context_dim=None, # custom transformer support
|
495 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
496 |
+
legacy=True,
|
497 |
+
disable_self_attentions=None,
|
498 |
+
num_attention_blocks=None,
|
499 |
+
disable_middle_self_attn=False,
|
500 |
+
use_linear_in_transformer=False,
|
501 |
+
adm_in_channels=None,
|
502 |
+
camera_dim=None,
|
503 |
+
with_ip=False, # wether add image prompt images
|
504 |
+
ip_dim=0, # number of extra token, 4 for global 16 for local
|
505 |
+
ip_weight=1.0, # weight for image prompt context
|
506 |
+
ip_mode="local_resample", # which mode of adaptor, global or local
|
507 |
+
):
|
508 |
+
super().__init__()
|
509 |
+
if use_spatial_transformer:
|
510 |
+
assert (
|
511 |
+
context_dim is not None
|
512 |
+
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
513 |
+
|
514 |
+
if context_dim is not None:
|
515 |
+
assert (
|
516 |
+
use_spatial_transformer
|
517 |
+
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
518 |
+
from omegaconf.listconfig import ListConfig
|
519 |
+
|
520 |
+
if type(context_dim) == ListConfig:
|
521 |
+
context_dim = list(context_dim)
|
522 |
+
|
523 |
+
if num_heads_upsample == -1:
|
524 |
+
num_heads_upsample = num_heads
|
525 |
+
|
526 |
+
if num_heads == -1:
|
527 |
+
assert (
|
528 |
+
num_head_channels != -1
|
529 |
+
), "Either num_heads or num_head_channels has to be set"
|
530 |
+
|
531 |
+
if num_head_channels == -1:
|
532 |
+
assert (
|
533 |
+
num_heads != -1
|
534 |
+
), "Either num_heads or num_head_channels has to be set"
|
535 |
+
|
536 |
+
self.image_size = image_size
|
537 |
+
self.in_channels = in_channels
|
538 |
+
self.model_channels = model_channels
|
539 |
+
self.out_channels = out_channels
|
540 |
+
if isinstance(num_res_blocks, int):
|
541 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
542 |
+
else:
|
543 |
+
if len(num_res_blocks) != len(channel_mult):
|
544 |
+
raise ValueError(
|
545 |
+
"provide num_res_blocks either as an int (globally constant) or "
|
546 |
+
"as a list/tuple (per-level) with the same length as channel_mult"
|
547 |
+
)
|
548 |
+
self.num_res_blocks = num_res_blocks
|
549 |
+
if disable_self_attentions is not None:
|
550 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
551 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
552 |
+
if num_attention_blocks is not None:
|
553 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
554 |
+
assert all(
|
555 |
+
map(
|
556 |
+
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
|
557 |
+
range(len(num_attention_blocks)),
|
558 |
+
)
|
559 |
+
)
|
560 |
+
print(
|
561 |
+
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
562 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
563 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
564 |
+
f"attention will still not be set."
|
565 |
+
)
|
566 |
+
|
567 |
+
self.attention_resolutions = attention_resolutions
|
568 |
+
self.dropout = dropout
|
569 |
+
self.channel_mult = channel_mult
|
570 |
+
self.conv_resample = conv_resample
|
571 |
+
self.num_classes = num_classes
|
572 |
+
self.use_checkpoint = use_checkpoint
|
573 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
574 |
+
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
575 |
+
self.num_heads = num_heads
|
576 |
+
self.num_head_channels = num_head_channels
|
577 |
+
self.num_heads_upsample = num_heads_upsample
|
578 |
+
self.predict_codebook_ids = n_embed is not None
|
579 |
+
|
580 |
+
self.with_ip = with_ip # wether there is image prompt
|
581 |
+
self.ip_dim = ip_dim # num of extra token, 4 for global 16 for local
|
582 |
+
self.ip_weight = ip_weight
|
583 |
+
assert ip_mode in ["global", "local_resample"]
|
584 |
+
self.ip_mode = ip_mode # which mode of adaptor
|
585 |
+
|
586 |
+
time_embed_dim = model_channels * 4
|
587 |
+
self.time_embed = nn.Sequential(
|
588 |
+
linear(model_channels, time_embed_dim),
|
589 |
+
nn.SiLU(),
|
590 |
+
linear(time_embed_dim, time_embed_dim),
|
591 |
+
)
|
592 |
+
|
593 |
+
if camera_dim is not None:
|
594 |
+
time_embed_dim = model_channels * 4
|
595 |
+
self.camera_embed = nn.Sequential(
|
596 |
+
linear(camera_dim, time_embed_dim),
|
597 |
+
nn.SiLU(),
|
598 |
+
linear(time_embed_dim, time_embed_dim),
|
599 |
+
)
|
600 |
+
|
601 |
+
if self.num_classes is not None:
|
602 |
+
if isinstance(self.num_classes, int):
|
603 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
604 |
+
elif self.num_classes == "continuous":
|
605 |
+
print("setting up linear c_adm embedding layer")
|
606 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
607 |
+
elif self.num_classes == "sequential":
|
608 |
+
assert adm_in_channels is not None
|
609 |
+
self.label_emb = nn.Sequential(
|
610 |
+
nn.Sequential(
|
611 |
+
linear(adm_in_channels, time_embed_dim),
|
612 |
+
nn.SiLU(),
|
613 |
+
linear(time_embed_dim, time_embed_dim),
|
614 |
+
)
|
615 |
+
)
|
616 |
+
else:
|
617 |
+
raise ValueError()
|
618 |
+
|
619 |
+
if self.with_ip and (context_dim is not None) and ip_dim > 0:
|
620 |
+
if self.ip_mode == "local_resample":
|
621 |
+
# ip-adapter-plus
|
622 |
+
hidden_dim = 1280
|
623 |
+
self.image_embed = Resampler(
|
624 |
+
dim=context_dim,
|
625 |
+
depth=4,
|
626 |
+
dim_head=64,
|
627 |
+
heads=12,
|
628 |
+
num_queries=ip_dim, # num token
|
629 |
+
embedding_dim=hidden_dim,
|
630 |
+
output_dim=context_dim,
|
631 |
+
ff_mult=4,
|
632 |
+
)
|
633 |
+
elif self.ip_mode == "global":
|
634 |
+
self.image_embed = ImageProjModel(
|
635 |
+
cross_attention_dim=context_dim,
|
636 |
+
clip_extra_context_tokens=ip_dim)
|
637 |
+
else:
|
638 |
+
raise ValueError(f"{self.ip_mode} is not supported")
|
639 |
+
|
640 |
+
self.input_blocks = nn.ModuleList(
|
641 |
+
[
|
642 |
+
TimestepEmbedSequential(
|
643 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
644 |
+
)
|
645 |
+
]
|
646 |
+
)
|
647 |
+
self._feature_size = model_channels
|
648 |
+
input_block_chans = [model_channels]
|
649 |
+
ch = model_channels
|
650 |
+
ds = 1
|
651 |
+
for level, mult in enumerate(channel_mult):
|
652 |
+
for nr in range(self.num_res_blocks[level]):
|
653 |
+
layers = [
|
654 |
+
ResBlock(
|
655 |
+
ch,
|
656 |
+
time_embed_dim,
|
657 |
+
dropout,
|
658 |
+
out_channels=mult * model_channels,
|
659 |
+
dims=dims,
|
660 |
+
use_checkpoint=use_checkpoint,
|
661 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
662 |
+
)
|
663 |
+
]
|
664 |
+
ch = mult * model_channels
|
665 |
+
if ds in attention_resolutions:
|
666 |
+
if num_head_channels == -1:
|
667 |
+
dim_head = ch // num_heads
|
668 |
+
else:
|
669 |
+
num_heads = ch // num_head_channels
|
670 |
+
dim_head = num_head_channels
|
671 |
+
if legacy:
|
672 |
+
# num_heads = 1
|
673 |
+
dim_head = (
|
674 |
+
ch // num_heads
|
675 |
+
if use_spatial_transformer
|
676 |
+
else num_head_channels
|
677 |
+
)
|
678 |
+
if exists(disable_self_attentions):
|
679 |
+
disabled_sa = disable_self_attentions[level]
|
680 |
+
else:
|
681 |
+
disabled_sa = False
|
682 |
+
|
683 |
+
if (
|
684 |
+
not exists(num_attention_blocks)
|
685 |
+
or nr < num_attention_blocks[level]
|
686 |
+
):
|
687 |
+
layers.append(
|
688 |
+
AttentionBlock(
|
689 |
+
ch,
|
690 |
+
use_checkpoint=use_checkpoint,
|
691 |
+
num_heads=num_heads,
|
692 |
+
num_head_channels=dim_head,
|
693 |
+
use_new_attention_order=use_new_attention_order,
|
694 |
+
)
|
695 |
+
if not use_spatial_transformer
|
696 |
+
else SpatialTransformer3D(
|
697 |
+
ch,
|
698 |
+
num_heads,
|
699 |
+
dim_head,
|
700 |
+
depth=transformer_depth,
|
701 |
+
context_dim=context_dim,
|
702 |
+
disable_self_attn=disabled_sa,
|
703 |
+
use_linear=use_linear_in_transformer,
|
704 |
+
use_checkpoint=use_checkpoint,
|
705 |
+
with_ip=self.with_ip,
|
706 |
+
ip_dim=self.ip_dim,
|
707 |
+
ip_weight=self.ip_weight
|
708 |
+
)
|
709 |
+
)
|
710 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
711 |
+
self._feature_size += ch
|
712 |
+
input_block_chans.append(ch)
|
713 |
+
|
714 |
+
if level != len(channel_mult) - 1:
|
715 |
+
out_ch = ch
|
716 |
+
self.input_blocks.append(
|
717 |
+
TimestepEmbedSequential(
|
718 |
+
ResBlock(
|
719 |
+
ch,
|
720 |
+
time_embed_dim,
|
721 |
+
dropout,
|
722 |
+
out_channels=out_ch,
|
723 |
+
dims=dims,
|
724 |
+
use_checkpoint=use_checkpoint,
|
725 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
726 |
+
down=True,
|
727 |
+
)
|
728 |
+
if resblock_updown
|
729 |
+
else Downsample(
|
730 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
731 |
+
)
|
732 |
+
)
|
733 |
+
)
|
734 |
+
ch = out_ch
|
735 |
+
input_block_chans.append(ch)
|
736 |
+
ds *= 2
|
737 |
+
self._feature_size += ch
|
738 |
+
|
739 |
+
if num_head_channels == -1:
|
740 |
+
dim_head = ch // num_heads
|
741 |
+
else:
|
742 |
+
num_heads = ch // num_head_channels
|
743 |
+
dim_head = num_head_channels
|
744 |
+
if legacy:
|
745 |
+
# num_heads = 1
|
746 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
747 |
+
self.middle_block = TimestepEmbedSequential(
|
748 |
+
ResBlock(
|
749 |
+
ch,
|
750 |
+
time_embed_dim,
|
751 |
+
dropout,
|
752 |
+
dims=dims,
|
753 |
+
use_checkpoint=use_checkpoint,
|
754 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
755 |
+
),
|
756 |
+
AttentionBlock(
|
757 |
+
ch,
|
758 |
+
use_checkpoint=use_checkpoint,
|
759 |
+
num_heads=num_heads,
|
760 |
+
num_head_channels=dim_head,
|
761 |
+
use_new_attention_order=use_new_attention_order,
|
762 |
+
)
|
763 |
+
if not use_spatial_transformer
|
764 |
+
else SpatialTransformer3D( # always uses a self-attn
|
765 |
+
ch,
|
766 |
+
num_heads,
|
767 |
+
dim_head,
|
768 |
+
depth=transformer_depth,
|
769 |
+
context_dim=context_dim,
|
770 |
+
disable_self_attn=disable_middle_self_attn,
|
771 |
+
use_linear=use_linear_in_transformer,
|
772 |
+
use_checkpoint=use_checkpoint,
|
773 |
+
with_ip=self.with_ip,
|
774 |
+
ip_dim=self.ip_dim,
|
775 |
+
ip_weight=self.ip_weight
|
776 |
+
),
|
777 |
+
ResBlock(
|
778 |
+
ch,
|
779 |
+
time_embed_dim,
|
780 |
+
dropout,
|
781 |
+
dims=dims,
|
782 |
+
use_checkpoint=use_checkpoint,
|
783 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
784 |
+
),
|
785 |
+
)
|
786 |
+
self._feature_size += ch
|
787 |
+
|
788 |
+
self.output_blocks = nn.ModuleList([])
|
789 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
790 |
+
for i in range(self.num_res_blocks[level] + 1):
|
791 |
+
ich = input_block_chans.pop()
|
792 |
+
layers = [
|
793 |
+
ResBlock(
|
794 |
+
ch + ich,
|
795 |
+
time_embed_dim,
|
796 |
+
dropout,
|
797 |
+
out_channels=model_channels * mult,
|
798 |
+
dims=dims,
|
799 |
+
use_checkpoint=use_checkpoint,
|
800 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
801 |
+
)
|
802 |
+
]
|
803 |
+
ch = model_channels * mult
|
804 |
+
if ds in attention_resolutions:
|
805 |
+
if num_head_channels == -1:
|
806 |
+
dim_head = ch // num_heads
|
807 |
+
else:
|
808 |
+
num_heads = ch // num_head_channels
|
809 |
+
dim_head = num_head_channels
|
810 |
+
if legacy:
|
811 |
+
# num_heads = 1
|
812 |
+
dim_head = (
|
813 |
+
ch // num_heads
|
814 |
+
if use_spatial_transformer
|
815 |
+
else num_head_channels
|
816 |
+
)
|
817 |
+
if exists(disable_self_attentions):
|
818 |
+
disabled_sa = disable_self_attentions[level]
|
819 |
+
else:
|
820 |
+
disabled_sa = False
|
821 |
+
|
822 |
+
if (
|
823 |
+
not exists(num_attention_blocks)
|
824 |
+
or i < num_attention_blocks[level]
|
825 |
+
):
|
826 |
+
layers.append(
|
827 |
+
AttentionBlock(
|
828 |
+
ch,
|
829 |
+
use_checkpoint=use_checkpoint,
|
830 |
+
num_heads=num_heads_upsample,
|
831 |
+
num_head_channels=dim_head,
|
832 |
+
use_new_attention_order=use_new_attention_order,
|
833 |
+
)
|
834 |
+
if not use_spatial_transformer
|
835 |
+
else SpatialTransformer3D(
|
836 |
+
ch,
|
837 |
+
num_heads,
|
838 |
+
dim_head,
|
839 |
+
depth=transformer_depth,
|
840 |
+
context_dim=context_dim,
|
841 |
+
disable_self_attn=disabled_sa,
|
842 |
+
use_linear=use_linear_in_transformer,
|
843 |
+
use_checkpoint=use_checkpoint,
|
844 |
+
with_ip=self.with_ip,
|
845 |
+
ip_dim=self.ip_dim,
|
846 |
+
ip_weight=self.ip_weight
|
847 |
+
)
|
848 |
+
)
|
849 |
+
if level and i == self.num_res_blocks[level]:
|
850 |
+
out_ch = ch
|
851 |
+
layers.append(
|
852 |
+
ResBlock(
|
853 |
+
ch,
|
854 |
+
time_embed_dim,
|
855 |
+
dropout,
|
856 |
+
out_channels=out_ch,
|
857 |
+
dims=dims,
|
858 |
+
use_checkpoint=use_checkpoint,
|
859 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
860 |
+
up=True,
|
861 |
+
)
|
862 |
+
if resblock_updown
|
863 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
864 |
+
)
|
865 |
+
ds //= 2
|
866 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
867 |
+
self._feature_size += ch
|
868 |
+
|
869 |
+
self.out = nn.Sequential(
|
870 |
+
normalization(ch),
|
871 |
+
nn.SiLU(),
|
872 |
+
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
873 |
+
)
|
874 |
+
if self.predict_codebook_ids:
|
875 |
+
self.id_predictor = nn.Sequential(
|
876 |
+
normalization(ch),
|
877 |
+
conv_nd(dims, model_channels, n_embed, 1),
|
878 |
+
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
879 |
+
)
|
880 |
+
|
881 |
+
def convert_to_fp16(self):
|
882 |
+
"""
|
883 |
+
Convert the torso of the model to float16.
|
884 |
+
"""
|
885 |
+
self.input_blocks.apply(convert_module_to_f16)
|
886 |
+
self.middle_block.apply(convert_module_to_f16)
|
887 |
+
self.output_blocks.apply(convert_module_to_f16)
|
888 |
+
|
889 |
+
def convert_to_fp32(self):
|
890 |
+
"""
|
891 |
+
Convert the torso of the model to float32.
|
892 |
+
"""
|
893 |
+
self.input_blocks.apply(convert_module_to_f32)
|
894 |
+
self.middle_block.apply(convert_module_to_f32)
|
895 |
+
self.output_blocks.apply(convert_module_to_f32)
|
896 |
+
|
897 |
+
def forward(
|
898 |
+
self,
|
899 |
+
x,
|
900 |
+
timesteps=None,
|
901 |
+
context=None,
|
902 |
+
y=None,
|
903 |
+
camera=None,
|
904 |
+
num_frames=1,
|
905 |
+
**kwargs,
|
906 |
+
):
|
907 |
+
"""
|
908 |
+
Apply the model to an input batch.
|
909 |
+
:param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
|
910 |
+
:param timesteps: a 1-D batch of timesteps.
|
911 |
+
:param context: a dict conditioning plugged in via crossattn
|
912 |
+
:param y: an [N] Tensor of labels, if class-conditional, default None.
|
913 |
+
:param num_frames: a integer indicating number of frames for tensor reshaping.
|
914 |
+
:return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
|
915 |
+
"""
|
916 |
+
assert (
|
917 |
+
x.shape[0] % num_frames == 0
|
918 |
+
), "[UNet] input batch size must be dividable by num_frames!"
|
919 |
+
assert (y is not None) == (
|
920 |
+
self.num_classes is not None
|
921 |
+
), "must specify y if and only if the model is class-conditional"
|
922 |
+
|
923 |
+
hs = []
|
924 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
|
925 |
+
emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
|
926 |
+
|
927 |
+
if self.num_classes is not None:
|
928 |
+
assert y.shape[0] == x.shape[0]
|
929 |
+
emb = emb + self.label_emb(y)
|
930 |
+
|
931 |
+
# Add camera embeddings
|
932 |
+
if camera is not None:
|
933 |
+
assert camera.shape[0] == emb.shape[0]
|
934 |
+
# camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
|
935 |
+
emb = emb + self.camera_embed(camera)
|
936 |
+
ip = kwargs.get("ip", None)
|
937 |
+
ip_img = kwargs.get("ip_img", None)
|
938 |
+
|
939 |
+
if ip_img is not None:
|
940 |
+
x[(num_frames-1)::num_frames, :, :, :] = ip_img
|
941 |
+
|
942 |
+
if ip is not None:
|
943 |
+
ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
|
944 |
+
context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
|
945 |
+
|
946 |
+
h = x.type(self.dtype)
|
947 |
+
for module in self.input_blocks:
|
948 |
+
h = module(h, emb, context, num_frames=num_frames)
|
949 |
+
hs.append(h)
|
950 |
+
h = self.middle_block(h, emb, context, num_frames=num_frames)
|
951 |
+
for module in self.output_blocks:
|
952 |
+
h = th.cat([h, hs.pop()], dim=1)
|
953 |
+
h = module(h, emb, context, num_frames=num_frames)
|
954 |
+
h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
|
955 |
+
if self.predict_codebook_ids: # False
|
956 |
+
return self.id_predictor(h)
|
957 |
+
else:
|
958 |
+
return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
|
959 |
+
|
960 |
+
|
961 |
+
|
962 |
+
|
963 |
+
class MultiViewUNetModelStage2(MultiViewUNetModel):
|
964 |
+
"""
|
965 |
+
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
966 |
+
:param in_channels: channels in the input Tensor.
|
967 |
+
:param model_channels: base channel count for the model.
|
968 |
+
:param out_channels: channels in the output Tensor.
|
969 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
970 |
+
:param attention_resolutions: a collection of downsample rates at which
|
971 |
+
attention will take place. May be a set, list, or tuple.
|
972 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
973 |
+
will be used.
|
974 |
+
:param dropout: the dropout probability.
|
975 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
976 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
977 |
+
downsampling.
|
978 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
979 |
+
:param num_classes: if specified (as an int), then this model will be
|
980 |
+
class-conditional with `num_classes` classes.
|
981 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
982 |
+
:param num_heads: the number of attention heads in each attention layer.
|
983 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
984 |
+
a fixed channel width per attention head.
|
985 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
986 |
+
of heads for upsampling. Deprecated.
|
987 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
988 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
989 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
990 |
+
increased efficiency.
|
991 |
+
:param camera_dim: dimensionality of camera input.
|
992 |
+
"""
|
993 |
+
|
994 |
+
def __init__(
|
995 |
+
self,
|
996 |
+
image_size,
|
997 |
+
in_channels,
|
998 |
+
model_channels,
|
999 |
+
out_channels,
|
1000 |
+
num_res_blocks,
|
1001 |
+
attention_resolutions,
|
1002 |
+
dropout=0,
|
1003 |
+
channel_mult=(1, 2, 4, 8),
|
1004 |
+
conv_resample=True,
|
1005 |
+
dims=2,
|
1006 |
+
num_classes=None,
|
1007 |
+
use_checkpoint=False,
|
1008 |
+
use_fp16=False,
|
1009 |
+
use_bf16=False,
|
1010 |
+
num_heads=-1,
|
1011 |
+
num_head_channels=-1,
|
1012 |
+
num_heads_upsample=-1,
|
1013 |
+
use_scale_shift_norm=False,
|
1014 |
+
resblock_updown=False,
|
1015 |
+
use_new_attention_order=False,
|
1016 |
+
use_spatial_transformer=False, # custom transformer support
|
1017 |
+
transformer_depth=1, # custom transformer support
|
1018 |
+
context_dim=None, # custom transformer support
|
1019 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
1020 |
+
legacy=True,
|
1021 |
+
disable_self_attentions=None,
|
1022 |
+
num_attention_blocks=None,
|
1023 |
+
disable_middle_self_attn=False,
|
1024 |
+
use_linear_in_transformer=False,
|
1025 |
+
adm_in_channels=None,
|
1026 |
+
camera_dim=None,
|
1027 |
+
with_ip=False, # wether add image prompt images
|
1028 |
+
ip_dim=0, # number of extra token, 4 for global 16 for local
|
1029 |
+
ip_weight=1.0, # weight for image prompt context
|
1030 |
+
ip_mode="local_resample", # which mode of adaptor, global or local
|
1031 |
+
):
|
1032 |
+
super().__init__(
|
1033 |
+
image_size,
|
1034 |
+
in_channels,
|
1035 |
+
model_channels,
|
1036 |
+
out_channels,
|
1037 |
+
num_res_blocks,
|
1038 |
+
attention_resolutions,
|
1039 |
+
dropout,
|
1040 |
+
channel_mult,
|
1041 |
+
conv_resample,
|
1042 |
+
dims,
|
1043 |
+
num_classes,
|
1044 |
+
use_checkpoint,
|
1045 |
+
use_fp16,
|
1046 |
+
use_bf16,
|
1047 |
+
num_heads,
|
1048 |
+
num_head_channels,
|
1049 |
+
num_heads_upsample,
|
1050 |
+
use_scale_shift_norm,
|
1051 |
+
resblock_updown,
|
1052 |
+
use_new_attention_order,
|
1053 |
+
use_spatial_transformer,
|
1054 |
+
transformer_depth,
|
1055 |
+
context_dim,
|
1056 |
+
n_embed,
|
1057 |
+
legacy,
|
1058 |
+
disable_self_attentions,
|
1059 |
+
num_attention_blocks,
|
1060 |
+
disable_middle_self_attn,
|
1061 |
+
use_linear_in_transformer,
|
1062 |
+
adm_in_channels,
|
1063 |
+
camera_dim,
|
1064 |
+
with_ip,
|
1065 |
+
ip_dim,
|
1066 |
+
ip_weight,
|
1067 |
+
ip_mode,
|
1068 |
+
)
|
1069 |
+
|
1070 |
+
def forward(
|
1071 |
+
self,
|
1072 |
+
x,
|
1073 |
+
timesteps=None,
|
1074 |
+
context=None,
|
1075 |
+
y=None,
|
1076 |
+
camera=None,
|
1077 |
+
num_frames=1,
|
1078 |
+
**kwargs,
|
1079 |
+
):
|
1080 |
+
"""
|
1081 |
+
Apply the model to an input batch.
|
1082 |
+
:param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
|
1083 |
+
:param timesteps: a 1-D batch of timesteps.
|
1084 |
+
:param context: a dict conditioning plugged in via crossattn
|
1085 |
+
:param y: an [N] Tensor of labels, if class-conditional, default None.
|
1086 |
+
:param num_frames: a integer indicating number of frames for tensor reshaping.
|
1087 |
+
:return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
|
1088 |
+
"""
|
1089 |
+
assert (
|
1090 |
+
x.shape[0] % num_frames == 0
|
1091 |
+
), "[UNet] input batch size must be dividable by num_frames!"
|
1092 |
+
assert (y is not None) == (
|
1093 |
+
self.num_classes is not None
|
1094 |
+
), "must specify y if and only if the model is class-conditional"
|
1095 |
+
|
1096 |
+
hs = []
|
1097 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
|
1098 |
+
emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
|
1099 |
+
|
1100 |
+
if self.num_classes is not None:
|
1101 |
+
assert y.shape[0] == x.shape[0]
|
1102 |
+
emb = emb + self.label_emb(y)
|
1103 |
+
|
1104 |
+
# Add camera embeddings
|
1105 |
+
if camera is not None:
|
1106 |
+
assert camera.shape[0] == emb.shape[0]
|
1107 |
+
# camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
|
1108 |
+
emb = emb + self.camera_embed(camera)
|
1109 |
+
ip = kwargs.get("ip", None)
|
1110 |
+
ip_img = kwargs.get("ip_img", None)
|
1111 |
+
pixel_images = kwargs.get("pixel_images", None)
|
1112 |
+
|
1113 |
+
if ip_img is not None:
|
1114 |
+
x[(num_frames-1)::num_frames, :, :, :] = ip_img
|
1115 |
+
|
1116 |
+
x = torch.cat((x, pixel_images), dim=1)
|
1117 |
+
|
1118 |
+
if ip is not None:
|
1119 |
+
ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
|
1120 |
+
context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
|
1121 |
+
|
1122 |
+
h = x.type(self.dtype)
|
1123 |
+
for module in self.input_blocks:
|
1124 |
+
h = module(h, emb, context, num_frames=num_frames)
|
1125 |
+
hs.append(h)
|
1126 |
+
h = self.middle_block(h, emb, context, num_frames=num_frames)
|
1127 |
+
for module in self.output_blocks:
|
1128 |
+
h = th.cat([h, hs.pop()], dim=1)
|
1129 |
+
h = module(h, emb, context, num_frames=num_frames)
|
1130 |
+
h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
|
1131 |
+
if self.predict_codebook_ids: # False
|
1132 |
+
return self.id_predictor(h)
|
1133 |
+
else:
|
1134 |
+
return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
|
1135 |
+
|
imagedream/ldm/modules/diffusionmodules/util.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from
|
2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
+
# and
|
4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
+
# and
|
6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
+
#
|
8 |
+
# thanks!
|
9 |
+
|
10 |
+
|
11 |
+
import os
|
12 |
+
import math
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import numpy as np
|
16 |
+
from einops import repeat
|
17 |
+
import importlib
|
18 |
+
|
19 |
+
|
20 |
+
def instantiate_from_config(config):
|
21 |
+
if not "target" in config:
|
22 |
+
if config == "__is_first_stage__":
|
23 |
+
return None
|
24 |
+
elif config == "__is_unconditional__":
|
25 |
+
return None
|
26 |
+
raise KeyError("Expected key `target` to instantiate.")
|
27 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
28 |
+
|
29 |
+
|
30 |
+
def get_obj_from_str(string, reload=False):
|
31 |
+
module, cls = string.rsplit(".", 1)
|
32 |
+
if reload:
|
33 |
+
module_imp = importlib.import_module(module)
|
34 |
+
importlib.reload(module_imp)
|
35 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
36 |
+
|
37 |
+
|
38 |
+
def make_beta_schedule(
|
39 |
+
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
40 |
+
):
|
41 |
+
if schedule == "linear":
|
42 |
+
betas = (
|
43 |
+
torch.linspace(
|
44 |
+
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
45 |
+
)
|
46 |
+
** 2
|
47 |
+
)
|
48 |
+
|
49 |
+
elif schedule == "cosine":
|
50 |
+
timesteps = (
|
51 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
52 |
+
)
|
53 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
54 |
+
alphas = torch.cos(alphas).pow(2)
|
55 |
+
alphas = alphas / alphas[0]
|
56 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
57 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
58 |
+
|
59 |
+
elif schedule == "sqrt_linear":
|
60 |
+
betas = torch.linspace(
|
61 |
+
linear_start, linear_end, n_timestep, dtype=torch.float64
|
62 |
+
)
|
63 |
+
elif schedule == "sqrt":
|
64 |
+
betas = (
|
65 |
+
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
66 |
+
** 0.5
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
70 |
+
return betas.numpy()
|
71 |
+
|
72 |
+
def enforce_zero_terminal_snr(betas):
|
73 |
+
betas = torch.tensor(betas) if not isinstance(betas, torch.Tensor) else betas
|
74 |
+
# Convert betas to alphas_bar_sqrt
|
75 |
+
alphas =1 - betas
|
76 |
+
alphas_bar = alphas.cumprod(0)
|
77 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
78 |
+
# Store old values.
|
79 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
80 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
81 |
+
# Shift so last timestep is zero.
|
82 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
83 |
+
# Scale so first timestep is back to old value.
|
84 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
85 |
+
# Convert alphas_bar_sqrt to betas
|
86 |
+
alphas_bar = alphas_bar_sqrt ** 2
|
87 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
88 |
+
alphas = torch.cat ([alphas_bar[0:1], alphas])
|
89 |
+
betas = 1 - alphas
|
90 |
+
return betas
|
91 |
+
|
92 |
+
|
93 |
+
def make_ddim_timesteps(
|
94 |
+
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
|
95 |
+
):
|
96 |
+
if ddim_discr_method == "uniform":
|
97 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
98 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
99 |
+
elif ddim_discr_method == "quad":
|
100 |
+
ddim_timesteps = (
|
101 |
+
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
|
102 |
+
).astype(int)
|
103 |
+
else:
|
104 |
+
raise NotImplementedError(
|
105 |
+
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
106 |
+
)
|
107 |
+
|
108 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
109 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
110 |
+
steps_out = ddim_timesteps + 1
|
111 |
+
if verbose:
|
112 |
+
print(f"Selected timesteps for ddim sampler: {steps_out}")
|
113 |
+
return steps_out
|
114 |
+
|
115 |
+
|
116 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
117 |
+
# select alphas for computing the variance schedule
|
118 |
+
alphas = alphacums[ddim_timesteps]
|
119 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
120 |
+
|
121 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
122 |
+
sigmas = eta * np.sqrt(
|
123 |
+
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
|
124 |
+
)
|
125 |
+
if verbose:
|
126 |
+
print(
|
127 |
+
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
|
128 |
+
)
|
129 |
+
print(
|
130 |
+
f"For the chosen value of eta, which is {eta}, "
|
131 |
+
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
|
132 |
+
)
|
133 |
+
return sigmas, alphas, alphas_prev
|
134 |
+
|
135 |
+
|
136 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
137 |
+
"""
|
138 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
139 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
140 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
141 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
142 |
+
produces the cumulative product of (1-beta) up to that
|
143 |
+
part of the diffusion process.
|
144 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
145 |
+
prevent singularities.
|
146 |
+
"""
|
147 |
+
betas = []
|
148 |
+
for i in range(num_diffusion_timesteps):
|
149 |
+
t1 = i / num_diffusion_timesteps
|
150 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
151 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
152 |
+
return np.array(betas)
|
153 |
+
|
154 |
+
|
155 |
+
def extract_into_tensor(a, t, x_shape):
|
156 |
+
b, *_ = t.shape
|
157 |
+
out = a.gather(-1, t)
|
158 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
159 |
+
|
160 |
+
|
161 |
+
def checkpoint(func, inputs, params, flag):
|
162 |
+
"""
|
163 |
+
Evaluate a function without caching intermediate activations, allowing for
|
164 |
+
reduced memory at the expense of extra compute in the backward pass.
|
165 |
+
:param func: the function to evaluate.
|
166 |
+
:param inputs: the argument sequence to pass to `func`.
|
167 |
+
:param params: a sequence of parameters `func` depends on but does not
|
168 |
+
explicitly take as arguments.
|
169 |
+
:param flag: if False, disable gradient checkpointing.
|
170 |
+
"""
|
171 |
+
if flag:
|
172 |
+
args = tuple(inputs) + tuple(params)
|
173 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
174 |
+
else:
|
175 |
+
return func(*inputs)
|
176 |
+
|
177 |
+
|
178 |
+
class CheckpointFunction(torch.autograd.Function):
|
179 |
+
@staticmethod
|
180 |
+
def forward(ctx, run_function, length, *args):
|
181 |
+
ctx.run_function = run_function
|
182 |
+
ctx.input_tensors = list(args[:length])
|
183 |
+
ctx.input_params = list(args[length:])
|
184 |
+
|
185 |
+
with torch.no_grad():
|
186 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
187 |
+
return output_tensors
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def backward(ctx, *output_grads):
|
191 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
192 |
+
with torch.enable_grad():
|
193 |
+
# Fixes a bug where the first op in run_function modifies the
|
194 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
195 |
+
# Tensors.
|
196 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
197 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
198 |
+
input_grads = torch.autograd.grad(
|
199 |
+
output_tensors,
|
200 |
+
ctx.input_tensors + ctx.input_params,
|
201 |
+
output_grads,
|
202 |
+
allow_unused=True,
|
203 |
+
)
|
204 |
+
del ctx.input_tensors
|
205 |
+
del ctx.input_params
|
206 |
+
del output_tensors
|
207 |
+
return (None, None) + input_grads
|
208 |
+
|
209 |
+
|
210 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
211 |
+
"""
|
212 |
+
Create sinusoidal timestep embeddings.
|
213 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
214 |
+
These may be fractional.
|
215 |
+
:param dim: the dimension of the output.
|
216 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
217 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
218 |
+
"""
|
219 |
+
if not repeat_only:
|
220 |
+
half = dim // 2
|
221 |
+
freqs = torch.exp(
|
222 |
+
-math.log(max_period)
|
223 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
224 |
+
/ half
|
225 |
+
).to(device=timesteps.device)
|
226 |
+
args = timesteps[:, None].float() * freqs[None]
|
227 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
228 |
+
if dim % 2:
|
229 |
+
embedding = torch.cat(
|
230 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
embedding = repeat(timesteps, "b -> b d", d=dim)
|
234 |
+
# import pdb; pdb.set_trace()
|
235 |
+
return embedding
|
236 |
+
|
237 |
+
|
238 |
+
def zero_module(module):
|
239 |
+
"""
|
240 |
+
Zero out the parameters of a module and return it.
|
241 |
+
"""
|
242 |
+
for p in module.parameters():
|
243 |
+
p.detach().zero_()
|
244 |
+
return module
|
245 |
+
|
246 |
+
|
247 |
+
def scale_module(module, scale):
|
248 |
+
"""
|
249 |
+
Scale the parameters of a module and return it.
|
250 |
+
"""
|
251 |
+
for p in module.parameters():
|
252 |
+
p.detach().mul_(scale)
|
253 |
+
return module
|
254 |
+
|
255 |
+
|
256 |
+
def mean_flat(tensor):
|
257 |
+
"""
|
258 |
+
Take the mean over all non-batch dimensions.
|
259 |
+
"""
|
260 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
261 |
+
|
262 |
+
|
263 |
+
def normalization(channels):
|
264 |
+
"""
|
265 |
+
Make a standard normalization layer.
|
266 |
+
:param channels: number of input channels.
|
267 |
+
:return: an nn.Module for normalization.
|
268 |
+
"""
|
269 |
+
return GroupNorm32(32, channels)
|
270 |
+
|
271 |
+
|
272 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
273 |
+
class SiLU(nn.Module):
|
274 |
+
def forward(self, x):
|
275 |
+
return x * torch.sigmoid(x)
|
276 |
+
|
277 |
+
|
278 |
+
class GroupNorm32(nn.GroupNorm):
|
279 |
+
def forward(self, x):
|
280 |
+
return super().forward(x.float()).type(x.dtype)
|
281 |
+
|
282 |
+
|
283 |
+
def conv_nd(dims, *args, **kwargs):
|
284 |
+
"""
|
285 |
+
Create a 1D, 2D, or 3D convolution module.
|
286 |
+
"""
|
287 |
+
if dims == 1:
|
288 |
+
return nn.Conv1d(*args, **kwargs)
|
289 |
+
elif dims == 2:
|
290 |
+
return nn.Conv2d(*args, **kwargs)
|
291 |
+
elif dims == 3:
|
292 |
+
return nn.Conv3d(*args, **kwargs)
|
293 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
294 |
+
|
295 |
+
|
296 |
+
def linear(*args, **kwargs):
|
297 |
+
"""
|
298 |
+
Create a linear module.
|
299 |
+
"""
|
300 |
+
return nn.Linear(*args, **kwargs)
|
301 |
+
|
302 |
+
|
303 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
304 |
+
"""
|
305 |
+
Create a 1D, 2D, or 3D average pooling module.
|
306 |
+
"""
|
307 |
+
if dims == 1:
|
308 |
+
return nn.AvgPool1d(*args, **kwargs)
|
309 |
+
elif dims == 2:
|
310 |
+
return nn.AvgPool2d(*args, **kwargs)
|
311 |
+
elif dims == 3:
|
312 |
+
return nn.AvgPool3d(*args, **kwargs)
|
313 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
314 |
+
|
315 |
+
|
316 |
+
class HybridConditioner(nn.Module):
|
317 |
+
def __init__(self, c_concat_config, c_crossattn_config):
|
318 |
+
super().__init__()
|
319 |
+
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
320 |
+
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
321 |
+
|
322 |
+
def forward(self, c_concat, c_crossattn):
|
323 |
+
c_concat = self.concat_conditioner(c_concat)
|
324 |
+
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
325 |
+
return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
|
326 |
+
|
327 |
+
|
328 |
+
def noise_like(shape, device, repeat=False):
|
329 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
330 |
+
shape[0], *((1,) * (len(shape) - 1))
|
331 |
+
)
|
332 |
+
noise = lambda: torch.randn(shape, device=device)
|
333 |
+
return repeat_noise() if repeat else noise()
|
334 |
+
|
335 |
+
|
336 |
+
# dummy replace
|
337 |
+
def convert_module_to_f16(l):
|
338 |
+
"""
|
339 |
+
Convert primitive modules to float16.
|
340 |
+
"""
|
341 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
342 |
+
l.weight.data = l.weight.data.half()
|
343 |
+
if l.bias is not None:
|
344 |
+
l.bias.data = l.bias.data.half()
|
345 |
+
|
346 |
+
def convert_module_to_f32(l):
|
347 |
+
"""
|
348 |
+
Convert primitive modules to float32, undoing convert_module_to_f16().
|
349 |
+
"""
|
350 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
351 |
+
l.weight.data = l.weight.data.float()
|
352 |
+
if l.bias is not None:
|
353 |
+
l.bias.data = l.bias.data.float()
|
imagedream/ldm/modules/distributions/__init__.py
ADDED
File without changes
|
imagedream/ldm/modules/distributions/distributions.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class AbstractDistribution:
|
6 |
+
def sample(self):
|
7 |
+
raise NotImplementedError()
|
8 |
+
|
9 |
+
def mode(self):
|
10 |
+
raise NotImplementedError()
|
11 |
+
|
12 |
+
|
13 |
+
class DiracDistribution(AbstractDistribution):
|
14 |
+
def __init__(self, value):
|
15 |
+
self.value = value
|
16 |
+
|
17 |
+
def sample(self):
|
18 |
+
return self.value
|
19 |
+
|
20 |
+
def mode(self):
|
21 |
+
return self.value
|
22 |
+
|
23 |
+
|
24 |
+
class DiagonalGaussianDistribution(object):
|
25 |
+
def __init__(self, parameters, deterministic=False):
|
26 |
+
self.parameters = parameters
|
27 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
28 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
29 |
+
self.deterministic = deterministic
|
30 |
+
self.std = torch.exp(0.5 * self.logvar)
|
31 |
+
self.var = torch.exp(self.logvar)
|
32 |
+
if self.deterministic:
|
33 |
+
self.var = self.std = torch.zeros_like(self.mean).to(
|
34 |
+
device=self.parameters.device
|
35 |
+
)
|
36 |
+
|
37 |
+
def sample(self):
|
38 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
39 |
+
device=self.parameters.device
|
40 |
+
)
|
41 |
+
return x
|
42 |
+
|
43 |
+
def kl(self, other=None):
|
44 |
+
if self.deterministic:
|
45 |
+
return torch.Tensor([0.0])
|
46 |
+
else:
|
47 |
+
if other is None:
|
48 |
+
return 0.5 * torch.sum(
|
49 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
50 |
+
dim=[1, 2, 3],
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
return 0.5 * torch.sum(
|
54 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
55 |
+
+ self.var / other.var
|
56 |
+
- 1.0
|
57 |
+
- self.logvar
|
58 |
+
+ other.logvar,
|
59 |
+
dim=[1, 2, 3],
|
60 |
+
)
|
61 |
+
|
62 |
+
def nll(self, sample, dims=[1, 2, 3]):
|
63 |
+
if self.deterministic:
|
64 |
+
return torch.Tensor([0.0])
|
65 |
+
logtwopi = np.log(2.0 * np.pi)
|
66 |
+
return 0.5 * torch.sum(
|
67 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
68 |
+
dim=dims,
|
69 |
+
)
|
70 |
+
|
71 |
+
def mode(self):
|
72 |
+
return self.mean
|
73 |
+
|
74 |
+
|
75 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
76 |
+
"""
|
77 |
+
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
78 |
+
Compute the KL divergence between two gaussians.
|
79 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
80 |
+
scalars, among other use cases.
|
81 |
+
"""
|
82 |
+
tensor = None
|
83 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
84 |
+
if isinstance(obj, torch.Tensor):
|
85 |
+
tensor = obj
|
86 |
+
break
|
87 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
88 |
+
|
89 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
90 |
+
# Tensors, but it does not work for torch.exp().
|
91 |
+
logvar1, logvar2 = [
|
92 |
+
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
93 |
+
for x in (logvar1, logvar2)
|
94 |
+
]
|
95 |
+
|
96 |
+
return 0.5 * (
|
97 |
+
-1.0
|
98 |
+
+ logvar2
|
99 |
+
- logvar1
|
100 |
+
+ torch.exp(logvar1 - logvar2)
|
101 |
+
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
102 |
+
)
|
imagedream/ldm/modules/ema.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class LitEma(nn.Module):
|
6 |
+
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
7 |
+
super().__init__()
|
8 |
+
if decay < 0.0 or decay > 1.0:
|
9 |
+
raise ValueError("Decay must be between 0 and 1")
|
10 |
+
|
11 |
+
self.m_name2s_name = {}
|
12 |
+
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
13 |
+
self.register_buffer(
|
14 |
+
"num_updates",
|
15 |
+
torch.tensor(0, dtype=torch.int)
|
16 |
+
if use_num_upates
|
17 |
+
else torch.tensor(-1, dtype=torch.int),
|
18 |
+
)
|
19 |
+
|
20 |
+
for name, p in model.named_parameters():
|
21 |
+
if p.requires_grad:
|
22 |
+
# remove as '.'-character is not allowed in buffers
|
23 |
+
s_name = name.replace(".", "")
|
24 |
+
self.m_name2s_name.update({name: s_name})
|
25 |
+
self.register_buffer(s_name, p.clone().detach().data)
|
26 |
+
|
27 |
+
self.collected_params = []
|
28 |
+
|
29 |
+
def reset_num_updates(self):
|
30 |
+
del self.num_updates
|
31 |
+
self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
|
32 |
+
|
33 |
+
def forward(self, model):
|
34 |
+
decay = self.decay
|
35 |
+
|
36 |
+
if self.num_updates >= 0:
|
37 |
+
self.num_updates += 1
|
38 |
+
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
39 |
+
|
40 |
+
one_minus_decay = 1.0 - decay
|
41 |
+
|
42 |
+
with torch.no_grad():
|
43 |
+
m_param = dict(model.named_parameters())
|
44 |
+
shadow_params = dict(self.named_buffers())
|
45 |
+
|
46 |
+
for key in m_param:
|
47 |
+
if m_param[key].requires_grad:
|
48 |
+
sname = self.m_name2s_name[key]
|
49 |
+
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
50 |
+
shadow_params[sname].sub_(
|
51 |
+
one_minus_decay * (shadow_params[sname] - m_param[key])
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
assert not key in self.m_name2s_name
|
55 |
+
|
56 |
+
def copy_to(self, model):
|
57 |
+
m_param = dict(model.named_parameters())
|
58 |
+
shadow_params = dict(self.named_buffers())
|
59 |
+
for key in m_param:
|
60 |
+
if m_param[key].requires_grad:
|
61 |
+
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
62 |
+
else:
|
63 |
+
assert not key in self.m_name2s_name
|
64 |
+
|
65 |
+
def store(self, parameters):
|
66 |
+
"""
|
67 |
+
Save the current parameters for restoring later.
|
68 |
+
Args:
|
69 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
70 |
+
temporarily stored.
|
71 |
+
"""
|
72 |
+
self.collected_params = [param.clone() for param in parameters]
|
73 |
+
|
74 |
+
def restore(self, parameters):
|
75 |
+
"""
|
76 |
+
Restore the parameters stored with the `store` method.
|
77 |
+
Useful to validate the model with EMA parameters without affecting the
|
78 |
+
original optimization process. Store the parameters before the
|
79 |
+
`copy_to` method. After validation (or model saving), use this to
|
80 |
+
restore the former parameters.
|
81 |
+
Args:
|
82 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
83 |
+
updated with the stored parameters.
|
84 |
+
"""
|
85 |
+
for c_param, param in zip(self.collected_params, parameters):
|
86 |
+
param.data.copy_(c_param.data)
|
imagedream/ldm/modules/encoders/__init__.py
ADDED
File without changes
|
imagedream/ldm/modules/encoders/modules.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.checkpoint import checkpoint
|
4 |
+
|
5 |
+
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import open_clip
|
9 |
+
from PIL import Image
|
10 |
+
from ...util import default, count_params
|
11 |
+
|
12 |
+
|
13 |
+
class AbstractEncoder(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
def encode(self, *args, **kwargs):
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
|
21 |
+
class IdentityEncoder(AbstractEncoder):
|
22 |
+
def encode(self, x):
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
class ClassEmbedder(nn.Module):
|
27 |
+
def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
|
28 |
+
super().__init__()
|
29 |
+
self.key = key
|
30 |
+
self.embedding = nn.Embedding(n_classes, embed_dim)
|
31 |
+
self.n_classes = n_classes
|
32 |
+
self.ucg_rate = ucg_rate
|
33 |
+
|
34 |
+
def forward(self, batch, key=None, disable_dropout=False):
|
35 |
+
if key is None:
|
36 |
+
key = self.key
|
37 |
+
# this is for use in crossattn
|
38 |
+
c = batch[key][:, None]
|
39 |
+
if self.ucg_rate > 0.0 and not disable_dropout:
|
40 |
+
mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
41 |
+
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
42 |
+
c = c.long()
|
43 |
+
c = self.embedding(c)
|
44 |
+
return c
|
45 |
+
|
46 |
+
def get_unconditional_conditioning(self, bs, device="cuda"):
|
47 |
+
uc_class = (
|
48 |
+
self.n_classes - 1
|
49 |
+
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
50 |
+
uc = torch.ones((bs,), device=device) * uc_class
|
51 |
+
uc = {self.key: uc}
|
52 |
+
return uc
|
53 |
+
|
54 |
+
|
55 |
+
def disabled_train(self, mode=True):
|
56 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
57 |
+
does not change anymore."""
|
58 |
+
return self
|
59 |
+
|
60 |
+
|
61 |
+
class FrozenT5Embedder(AbstractEncoder):
|
62 |
+
"""Uses the T5 transformer encoder for text"""
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
|
66 |
+
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
67 |
+
super().__init__()
|
68 |
+
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
69 |
+
self.transformer = T5EncoderModel.from_pretrained(version)
|
70 |
+
self.device = device
|
71 |
+
self.max_length = max_length # TODO: typical value?
|
72 |
+
if freeze:
|
73 |
+
self.freeze()
|
74 |
+
|
75 |
+
def freeze(self):
|
76 |
+
self.transformer = self.transformer.eval()
|
77 |
+
# self.train = disabled_train
|
78 |
+
for param in self.parameters():
|
79 |
+
param.requires_grad = False
|
80 |
+
|
81 |
+
def forward(self, text):
|
82 |
+
batch_encoding = self.tokenizer(
|
83 |
+
text,
|
84 |
+
truncation=True,
|
85 |
+
max_length=self.max_length,
|
86 |
+
return_length=True,
|
87 |
+
return_overflowing_tokens=False,
|
88 |
+
padding="max_length",
|
89 |
+
return_tensors="pt",
|
90 |
+
)
|
91 |
+
tokens = batch_encoding["input_ids"].to(self.device)
|
92 |
+
outputs = self.transformer(input_ids=tokens)
|
93 |
+
|
94 |
+
z = outputs.last_hidden_state
|
95 |
+
return z
|
96 |
+
|
97 |
+
def encode(self, text):
|
98 |
+
return self(text)
|
99 |
+
|
100 |
+
|
101 |
+
class FrozenCLIPEmbedder(AbstractEncoder):
|
102 |
+
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
103 |
+
|
104 |
+
LAYERS = ["last", "pooled", "hidden"]
|
105 |
+
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
version="openai/clip-vit-large-patch14",
|
109 |
+
device="cuda",
|
110 |
+
max_length=77,
|
111 |
+
freeze=True,
|
112 |
+
layer="last",
|
113 |
+
layer_idx=None,
|
114 |
+
): # clip-vit-base-patch32
|
115 |
+
super().__init__()
|
116 |
+
assert layer in self.LAYERS
|
117 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
118 |
+
self.transformer = CLIPTextModel.from_pretrained(version)
|
119 |
+
self.device = device
|
120 |
+
self.max_length = max_length
|
121 |
+
if freeze:
|
122 |
+
self.freeze()
|
123 |
+
self.layer = layer
|
124 |
+
self.layer_idx = layer_idx
|
125 |
+
if layer == "hidden":
|
126 |
+
assert layer_idx is not None
|
127 |
+
assert 0 <= abs(layer_idx) <= 12
|
128 |
+
|
129 |
+
def freeze(self):
|
130 |
+
self.transformer = self.transformer.eval()
|
131 |
+
# self.train = disabled_train
|
132 |
+
for param in self.parameters():
|
133 |
+
param.requires_grad = False
|
134 |
+
|
135 |
+
def forward(self, text):
|
136 |
+
batch_encoding = self.tokenizer(
|
137 |
+
text,
|
138 |
+
truncation=True,
|
139 |
+
max_length=self.max_length,
|
140 |
+
return_length=True,
|
141 |
+
return_overflowing_tokens=False,
|
142 |
+
padding="max_length",
|
143 |
+
return_tensors="pt",
|
144 |
+
)
|
145 |
+
tokens = batch_encoding["input_ids"].to(self.device)
|
146 |
+
outputs = self.transformer(
|
147 |
+
input_ids=tokens, output_hidden_states=self.layer == "hidden"
|
148 |
+
)
|
149 |
+
if self.layer == "last":
|
150 |
+
z = outputs.last_hidden_state
|
151 |
+
elif self.layer == "pooled":
|
152 |
+
z = outputs.pooler_output[:, None, :]
|
153 |
+
else:
|
154 |
+
z = outputs.hidden_states[self.layer_idx]
|
155 |
+
return z
|
156 |
+
|
157 |
+
def encode(self, text):
|
158 |
+
return self(text)
|
159 |
+
|
160 |
+
|
161 |
+
class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module):
|
162 |
+
"""
|
163 |
+
Uses the OpenCLIP transformer encoder for text
|
164 |
+
"""
|
165 |
+
|
166 |
+
LAYERS = [
|
167 |
+
# "pooled",
|
168 |
+
"last",
|
169 |
+
"penultimate",
|
170 |
+
]
|
171 |
+
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
arch="ViT-H-14",
|
175 |
+
version="laion2b_s32b_b79k",
|
176 |
+
device="cuda",
|
177 |
+
max_length=77,
|
178 |
+
freeze=True,
|
179 |
+
layer="last",
|
180 |
+
ip_mode=None
|
181 |
+
):
|
182 |
+
"""_summary_
|
183 |
+
|
184 |
+
Args:
|
185 |
+
ip_mode (str, optional): what is the image promcessing mode. Defaults to None.
|
186 |
+
|
187 |
+
"""
|
188 |
+
super().__init__()
|
189 |
+
assert layer in self.LAYERS
|
190 |
+
model, _, preprocess = open_clip.create_model_and_transforms(
|
191 |
+
arch, device=torch.device("cpu"), pretrained=version
|
192 |
+
)
|
193 |
+
if ip_mode is None:
|
194 |
+
del model.visual
|
195 |
+
|
196 |
+
self.model = model
|
197 |
+
self.preprocess = preprocess
|
198 |
+
self.device = device
|
199 |
+
self.max_length = max_length
|
200 |
+
self.ip_mode = ip_mode
|
201 |
+
if freeze:
|
202 |
+
self.freeze()
|
203 |
+
self.layer = layer
|
204 |
+
if self.layer == "last":
|
205 |
+
self.layer_idx = 0
|
206 |
+
elif self.layer == "penultimate":
|
207 |
+
self.layer_idx = 1
|
208 |
+
else:
|
209 |
+
raise NotImplementedError()
|
210 |
+
|
211 |
+
def freeze(self):
|
212 |
+
self.model = self.model.eval()
|
213 |
+
for param in self.parameters():
|
214 |
+
param.requires_grad = False
|
215 |
+
|
216 |
+
def forward(self, text):
|
217 |
+
tokens = open_clip.tokenize(text)
|
218 |
+
z = self.encode_with_transformer(tokens.to(self.device))
|
219 |
+
return z
|
220 |
+
|
221 |
+
def forward_image(self, pil_image):
|
222 |
+
if isinstance(pil_image, Image.Image):
|
223 |
+
pil_image = [pil_image]
|
224 |
+
if isinstance(pil_image, torch.Tensor):
|
225 |
+
pil_image = pil_image.cpu().numpy()
|
226 |
+
if isinstance(pil_image, np.ndarray):
|
227 |
+
if pil_image.ndim == 3:
|
228 |
+
pil_image = pil_image[None, :, :, :]
|
229 |
+
pil_image = [Image.fromarray(x) for x in pil_image]
|
230 |
+
|
231 |
+
images = []
|
232 |
+
for image in pil_image:
|
233 |
+
images.append(self.preprocess(image).to(self.device))
|
234 |
+
|
235 |
+
image = torch.stack(images, 0) # to [b, 3, h, w]
|
236 |
+
if self.ip_mode == "global":
|
237 |
+
image_features = self.model.encode_image(image)
|
238 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
239 |
+
elif "local" in self.ip_mode:
|
240 |
+
image_features = self.encode_image_with_transformer(image)
|
241 |
+
|
242 |
+
return image_features # b, l
|
243 |
+
|
244 |
+
def encode_image_with_transformer(self, x):
|
245 |
+
visual = self.model.visual
|
246 |
+
x = visual.conv1(x) # shape = [*, width, grid, grid]
|
247 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
248 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
249 |
+
|
250 |
+
# class embeddings and positional embeddings
|
251 |
+
x = torch.cat(
|
252 |
+
[visual.class_embedding.to(x.dtype) + \
|
253 |
+
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
254 |
+
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
255 |
+
x = x + visual.positional_embedding.to(x.dtype)
|
256 |
+
|
257 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
258 |
+
# x = visual.patch_dropout(x)
|
259 |
+
x = visual.ln_pre(x)
|
260 |
+
|
261 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
262 |
+
hidden = self.image_transformer_forward(x)
|
263 |
+
x = hidden[-2].permute(1, 0, 2) # LND -> NLD
|
264 |
+
return x
|
265 |
+
|
266 |
+
def image_transformer_forward(self, x):
|
267 |
+
encoder_states = ()
|
268 |
+
trans = self.model.visual.transformer
|
269 |
+
for r in trans.resblocks:
|
270 |
+
if trans.grad_checkpointing and not torch.jit.is_scripting():
|
271 |
+
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
272 |
+
x = checkpoint(r, x, None, None, None)
|
273 |
+
else:
|
274 |
+
x = r(x, attn_mask=None)
|
275 |
+
encoder_states = encoder_states + (x, )
|
276 |
+
return encoder_states
|
277 |
+
|
278 |
+
def encode_with_transformer(self, text):
|
279 |
+
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
280 |
+
x = x + self.model.positional_embedding
|
281 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
282 |
+
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
283 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
284 |
+
x = self.model.ln_final(x)
|
285 |
+
return x
|
286 |
+
|
287 |
+
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
288 |
+
for i, r in enumerate(self.model.transformer.resblocks):
|
289 |
+
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
290 |
+
break
|
291 |
+
if (
|
292 |
+
self.model.transformer.grad_checkpointing
|
293 |
+
and not torch.jit.is_scripting()
|
294 |
+
):
|
295 |
+
x = checkpoint(r, x, attn_mask)
|
296 |
+
else:
|
297 |
+
x = r(x, attn_mask=attn_mask)
|
298 |
+
return x
|
299 |
+
|
300 |
+
def encode(self, text):
|
301 |
+
return self(text)
|
302 |
+
|
303 |
+
|
304 |
+
class FrozenCLIPT5Encoder(AbstractEncoder):
|
305 |
+
def __init__(
|
306 |
+
self,
|
307 |
+
clip_version="openai/clip-vit-large-patch14",
|
308 |
+
t5_version="google/t5-v1_1-xl",
|
309 |
+
device="cuda",
|
310 |
+
clip_max_length=77,
|
311 |
+
t5_max_length=77,
|
312 |
+
):
|
313 |
+
super().__init__()
|
314 |
+
self.clip_encoder = FrozenCLIPEmbedder(
|
315 |
+
clip_version, device, max_length=clip_max_length
|
316 |
+
)
|
317 |
+
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
318 |
+
print(
|
319 |
+
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
320 |
+
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
|
321 |
+
)
|
322 |
+
|
323 |
+
def encode(self, text):
|
324 |
+
return self(text)
|
325 |
+
|
326 |
+
def forward(self, text):
|
327 |
+
clip_z = self.clip_encoder.encode(text)
|
328 |
+
t5_z = self.t5_encoder.encode(text)
|
329 |
+
return [clip_z, t5_z]
|
imagedream/ldm/util.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from collections import abc
|
7 |
+
|
8 |
+
import multiprocessing as mp
|
9 |
+
from threading import Thread
|
10 |
+
from queue import Queue
|
11 |
+
|
12 |
+
from inspect import isfunction
|
13 |
+
from PIL import Image, ImageDraw, ImageFont
|
14 |
+
|
15 |
+
|
16 |
+
def log_txt_as_img(wh, xc, size=10):
|
17 |
+
# wh a tuple of (width, height)
|
18 |
+
# xc a list of captions to plot
|
19 |
+
b = len(xc)
|
20 |
+
txts = list()
|
21 |
+
for bi in range(b):
|
22 |
+
txt = Image.new("RGB", wh, color="white")
|
23 |
+
draw = ImageDraw.Draw(txt)
|
24 |
+
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
|
25 |
+
nc = int(40 * (wh[0] / 256))
|
26 |
+
lines = "\n".join(
|
27 |
+
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
|
28 |
+
)
|
29 |
+
|
30 |
+
try:
|
31 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
32 |
+
except UnicodeEncodeError:
|
33 |
+
print("Cant encode string for logging. Skipping.")
|
34 |
+
|
35 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
36 |
+
txts.append(txt)
|
37 |
+
txts = np.stack(txts)
|
38 |
+
txts = torch.tensor(txts)
|
39 |
+
return txts
|
40 |
+
|
41 |
+
|
42 |
+
def ismap(x):
|
43 |
+
if not isinstance(x, torch.Tensor):
|
44 |
+
return False
|
45 |
+
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
46 |
+
|
47 |
+
|
48 |
+
def isimage(x):
|
49 |
+
if not isinstance(x, torch.Tensor):
|
50 |
+
return False
|
51 |
+
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
52 |
+
|
53 |
+
|
54 |
+
def exists(x):
|
55 |
+
return x is not None
|
56 |
+
|
57 |
+
|
58 |
+
def default(val, d):
|
59 |
+
if exists(val):
|
60 |
+
return val
|
61 |
+
return d() if isfunction(d) else d
|
62 |
+
|
63 |
+
|
64 |
+
def mean_flat(tensor):
|
65 |
+
"""
|
66 |
+
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
67 |
+
Take the mean over all non-batch dimensions.
|
68 |
+
"""
|
69 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
70 |
+
|
71 |
+
|
72 |
+
def count_params(model, verbose=False):
|
73 |
+
total_params = sum(p.numel() for p in model.parameters())
|
74 |
+
if verbose:
|
75 |
+
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
76 |
+
return total_params
|
77 |
+
|
78 |
+
|
79 |
+
def instantiate_from_config(config):
|
80 |
+
if not "target" in config:
|
81 |
+
if config == "__is_first_stage__":
|
82 |
+
return None
|
83 |
+
elif config == "__is_unconditional__":
|
84 |
+
return None
|
85 |
+
raise KeyError("Expected key `target` to instantiate.")
|
86 |
+
# import pdb; pdb.set_trace()
|
87 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
88 |
+
|
89 |
+
|
90 |
+
def get_obj_from_str(string, reload=False):
|
91 |
+
module, cls = string.rsplit(".", 1)
|
92 |
+
# import pdb; pdb.set_trace()
|
93 |
+
if reload:
|
94 |
+
module_imp = importlib.import_module(module)
|
95 |
+
importlib.reload(module_imp)
|
96 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
97 |
+
|
98 |
+
|
99 |
+
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
100 |
+
# create dummy dataset instance
|
101 |
+
|
102 |
+
# run prefetching
|
103 |
+
if idx_to_fn:
|
104 |
+
res = func(data, worker_id=idx)
|
105 |
+
else:
|
106 |
+
res = func(data)
|
107 |
+
Q.put([idx, res])
|
108 |
+
Q.put("Done")
|
109 |
+
|
110 |
+
|
111 |
+
def parallel_data_prefetch(
|
112 |
+
func: callable,
|
113 |
+
data,
|
114 |
+
n_proc,
|
115 |
+
target_data_type="ndarray",
|
116 |
+
cpu_intensive=True,
|
117 |
+
use_worker_id=False,
|
118 |
+
):
|
119 |
+
# if target_data_type not in ["ndarray", "list"]:
|
120 |
+
# raise ValueError(
|
121 |
+
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
122 |
+
# )
|
123 |
+
if isinstance(data, np.ndarray) and target_data_type == "list":
|
124 |
+
raise ValueError("list expected but function got ndarray.")
|
125 |
+
elif isinstance(data, abc.Iterable):
|
126 |
+
if isinstance(data, dict):
|
127 |
+
print(
|
128 |
+
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
129 |
+
)
|
130 |
+
data = list(data.values())
|
131 |
+
if target_data_type == "ndarray":
|
132 |
+
data = np.asarray(data)
|
133 |
+
else:
|
134 |
+
data = list(data)
|
135 |
+
else:
|
136 |
+
raise TypeError(
|
137 |
+
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
138 |
+
)
|
139 |
+
|
140 |
+
if cpu_intensive:
|
141 |
+
Q = mp.Queue(1000)
|
142 |
+
proc = mp.Process
|
143 |
+
else:
|
144 |
+
Q = Queue(1000)
|
145 |
+
proc = Thread
|
146 |
+
# spawn processes
|
147 |
+
if target_data_type == "ndarray":
|
148 |
+
arguments = [
|
149 |
+
[func, Q, part, i, use_worker_id]
|
150 |
+
for i, part in enumerate(np.array_split(data, n_proc))
|
151 |
+
]
|
152 |
+
else:
|
153 |
+
step = (
|
154 |
+
int(len(data) / n_proc + 1)
|
155 |
+
if len(data) % n_proc != 0
|
156 |
+
else int(len(data) / n_proc)
|
157 |
+
)
|
158 |
+
arguments = [
|
159 |
+
[func, Q, part, i, use_worker_id]
|
160 |
+
for i, part in enumerate(
|
161 |
+
[data[i : i + step] for i in range(0, len(data), step)]
|
162 |
+
)
|
163 |
+
]
|
164 |
+
processes = []
|
165 |
+
for i in range(n_proc):
|
166 |
+
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
|
167 |
+
processes += [p]
|
168 |
+
|
169 |
+
# start processes
|
170 |
+
print(f"Start prefetching...")
|
171 |
+
import time
|
172 |
+
|
173 |
+
start = time.time()
|
174 |
+
gather_res = [[] for _ in range(n_proc)]
|
175 |
+
try:
|
176 |
+
for p in processes:
|
177 |
+
p.start()
|
178 |
+
|
179 |
+
k = 0
|
180 |
+
while k < n_proc:
|
181 |
+
# get result
|
182 |
+
res = Q.get()
|
183 |
+
if res == "Done":
|
184 |
+
k += 1
|
185 |
+
else:
|
186 |
+
gather_res[res[0]] = res[1]
|
187 |
+
|
188 |
+
except Exception as e:
|
189 |
+
print("Exception: ", e)
|
190 |
+
for p in processes:
|
191 |
+
p.terminate()
|
192 |
+
|
193 |
+
raise e
|
194 |
+
finally:
|
195 |
+
for p in processes:
|
196 |
+
p.join()
|
197 |
+
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
198 |
+
|
199 |
+
if target_data_type == "ndarray":
|
200 |
+
if not isinstance(gather_res[0], np.ndarray):
|
201 |
+
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
202 |
+
|
203 |
+
# order outputs
|
204 |
+
return np.concatenate(gather_res, axis=0)
|
205 |
+
elif target_data_type == "list":
|
206 |
+
out = []
|
207 |
+
for r in gather_res:
|
208 |
+
out.extend(r)
|
209 |
+
return out
|
210 |
+
else:
|
211 |
+
return gather_res
|
212 |
+
|
213 |
+
def set_seed(seed=None):
|
214 |
+
random.seed(seed)
|
215 |
+
np.random.seed(seed)
|
216 |
+
if seed is not None:
|
217 |
+
torch.manual_seed(seed)
|
218 |
+
torch.cuda.manual_seed_all(seed)
|
219 |
+
|
220 |
+
def add_random_background(image, bg_color=None):
|
221 |
+
bg_color = np.random.rand() * 255 if bg_color is None else bg_color
|
222 |
+
image = np.array(image)
|
223 |
+
rgb, alpha = image[..., :3], image[..., 3:]
|
224 |
+
alpha = alpha.astype(np.float32) / 255.0
|
225 |
+
image_new = rgb * alpha + bg_color * (1 - alpha)
|
226 |
+
return Image.fromarray(image_new.astype(np.uint8))
|
imagedream/model_zoo.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Utiliy functions to load pre-trained models more easily """
|
2 |
+
import os
|
3 |
+
import pkg_resources
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
|
9 |
+
from imagedream.ldm.util import instantiate_from_config
|
10 |
+
|
11 |
+
|
12 |
+
PRETRAINED_MODELS = {
|
13 |
+
"sd-v2.1-base-4view-ipmv": {
|
14 |
+
"config": "sd_v2_base_ipmv.yaml",
|
15 |
+
"repo_id": "Peng-Wang/ImageDream",
|
16 |
+
"filename": "sd-v2.1-base-4view-ipmv.pt",
|
17 |
+
},
|
18 |
+
"sd-v2.1-base-4view-ipmv-local": {
|
19 |
+
"config": "sd_v2_base_ipmv_local.yaml",
|
20 |
+
"repo_id": "Peng-Wang/ImageDream",
|
21 |
+
"filename": "sd-v2.1-base-4view-ipmv-local.pt",
|
22 |
+
},
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def get_config_file(config_path):
|
27 |
+
cfg_file = pkg_resources.resource_filename(
|
28 |
+
"imagedream", os.path.join("configs", config_path)
|
29 |
+
)
|
30 |
+
if not os.path.exists(cfg_file):
|
31 |
+
raise RuntimeError(f"Config {config_path} not available!")
|
32 |
+
return cfg_file
|
33 |
+
|
34 |
+
|
35 |
+
def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None):
|
36 |
+
if (config_path is not None) and (ckpt_path is not None):
|
37 |
+
config = OmegaConf.load(config_path)
|
38 |
+
model = instantiate_from_config(config.model)
|
39 |
+
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
|
40 |
+
return model
|
41 |
+
|
42 |
+
if not model_name in PRETRAINED_MODELS:
|
43 |
+
raise RuntimeError(
|
44 |
+
f"Model name {model_name} is not a pre-trained model. Available models are:\n- "
|
45 |
+
+ "\n- ".join(PRETRAINED_MODELS.keys())
|
46 |
+
)
|
47 |
+
model_info = PRETRAINED_MODELS[model_name]
|
48 |
+
|
49 |
+
# Instiantiate the model
|
50 |
+
print(f"Loading model from config: {model_info['config']}")
|
51 |
+
config_file = get_config_file(model_info["config"])
|
52 |
+
config = OmegaConf.load(config_file)
|
53 |
+
model = instantiate_from_config(config.model)
|
54 |
+
|
55 |
+
# Load pre-trained checkpoint from huggingface
|
56 |
+
if not ckpt_path:
|
57 |
+
ckpt_path = hf_hub_download(
|
58 |
+
repo_id=model_info["repo_id"],
|
59 |
+
filename=model_info["filename"],
|
60 |
+
cache_dir=cache_dir,
|
61 |
+
)
|
62 |
+
print(f"Loading model from cache file: {ckpt_path}")
|
63 |
+
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
|
64 |
+
return model
|
inference.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import time
|
4 |
+
import nvdiffrast.torch as dr
|
5 |
+
from util.utils import get_tri
|
6 |
+
import tempfile
|
7 |
+
from mesh import Mesh
|
8 |
+
import zipfile
|
9 |
+
def generate3d(model, rgb, ccm, device):
|
10 |
+
|
11 |
+
color_tri = torch.from_numpy(rgb)/255
|
12 |
+
xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
|
13 |
+
color = color_tri.permute(2,0,1)
|
14 |
+
xyz = xyz_tri.permute(2,0,1)
|
15 |
+
|
16 |
+
|
17 |
+
def get_imgs(color):
|
18 |
+
# color : [C, H, W*6]
|
19 |
+
color_list = []
|
20 |
+
color_list.append(color[:,:,256*5:256*(1+5)])
|
21 |
+
for i in range(0,5):
|
22 |
+
color_list.append(color[:,:,256*i:256*(1+i)])
|
23 |
+
return torch.stack(color_list, dim=0)# [6, C, H, W]
|
24 |
+
|
25 |
+
triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
|
26 |
+
|
27 |
+
color = get_imgs(color)
|
28 |
+
xyz = get_imgs(xyz)
|
29 |
+
|
30 |
+
color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
|
31 |
+
xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)
|
32 |
+
|
33 |
+
triplane = torch.cat([color,xyz],dim=1).to(device)
|
34 |
+
# 3D visualize
|
35 |
+
model.eval()
|
36 |
+
glctx = dr.RasterizeCudaContext()
|
37 |
+
|
38 |
+
if model.denoising == True:
|
39 |
+
tnew = 20
|
40 |
+
tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
|
41 |
+
noise_new = torch.randn_like(triplane) *0.5+0.5
|
42 |
+
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
|
43 |
+
start_time = time.time()
|
44 |
+
with torch.no_grad():
|
45 |
+
triplane_feature2 = model.unet2(triplane,tnew)
|
46 |
+
end_time = time.time()
|
47 |
+
elapsed_time = end_time - start_time
|
48 |
+
print(f"unet takes {elapsed_time}s")
|
49 |
+
else:
|
50 |
+
triplane_feature2 = model.unet2(triplane)
|
51 |
+
|
52 |
+
|
53 |
+
with torch.no_grad():
|
54 |
+
data_config = {
|
55 |
+
'resolution': [1024, 1024],
|
56 |
+
"triview_color": triplane_color.to(device),
|
57 |
+
}
|
58 |
+
|
59 |
+
verts, faces = model.decode(data_config, triplane_feature2)
|
60 |
+
|
61 |
+
data_config['verts'] = verts[0]
|
62 |
+
data_config['faces'] = faces
|
63 |
+
|
64 |
+
|
65 |
+
from kiui.mesh_utils import clean_mesh
|
66 |
+
verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=False, remesh_size=0.005)
|
67 |
+
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
|
68 |
+
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
|
69 |
+
|
70 |
+
start_time = time.time()
|
71 |
+
with torch.no_grad():
|
72 |
+
mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
|
73 |
+
model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2)
|
74 |
+
|
75 |
+
mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z")
|
76 |
+
mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
|
77 |
+
mesh.write(mesh_path_glb+".glb")
|
78 |
+
|
79 |
+
# mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb')
|
80 |
+
# mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
|
81 |
+
# mesh_obj2.export(mesh_path_obj2+".obj")
|
82 |
+
|
83 |
+
with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip:
|
84 |
+
myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj')
|
85 |
+
myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png')
|
86 |
+
myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl')
|
87 |
+
|
88 |
+
end_time = time.time()
|
89 |
+
elapsed_time = end_time - start_time
|
90 |
+
print(f"uv takes {elapsed_time}s")
|
91 |
+
return mesh_path_glb+".glb", mesh_path_obj+'.zip'
|
libs/base_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def instantiate_from_config(config):
|
9 |
+
if not "target" in config:
|
10 |
+
raise KeyError("Expected key `target` to instantiate.")
|
11 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
12 |
+
|
13 |
+
|
14 |
+
def get_obj_from_str(string, reload=False):
|
15 |
+
import importlib
|
16 |
+
module, cls = string.rsplit(".", 1)
|
17 |
+
if reload:
|
18 |
+
module_imp = importlib.import_module(module)
|
19 |
+
importlib.reload(module_imp)
|
20 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
21 |
+
|
22 |
+
|
23 |
+
def tensor_detail(t):
|
24 |
+
assert type(t) == torch.Tensor
|
25 |
+
print(f"shape: {t.shape} mean: {t.mean():.2f}, std: {t.std():.2f}, min: {t.min():.2f}, max: {t.max():.2f}")
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def drawRoundRec(draw, color, x, y, w, h, r):
|
30 |
+
drawObject = draw
|
31 |
+
|
32 |
+
'''Rounds'''
|
33 |
+
drawObject.ellipse((x, y, x + r, y + r), fill=color)
|
34 |
+
drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color)
|
35 |
+
drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color)
|
36 |
+
drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color)
|
37 |
+
|
38 |
+
'''rec.s'''
|
39 |
+
drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color)
|
40 |
+
drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color)
|
41 |
+
|
42 |
+
|
43 |
+
def do_resize_content(original_image: Image, scale_rate):
|
44 |
+
# resize image content wile retain the original image size
|
45 |
+
if scale_rate != 1:
|
46 |
+
# Calculate the new size after rescaling
|
47 |
+
new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
|
48 |
+
# Resize the image while maintaining the aspect ratio
|
49 |
+
resized_image = original_image.resize(new_size)
|
50 |
+
# Create a new image with the original size and black background
|
51 |
+
padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
|
52 |
+
paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
|
53 |
+
padded_image.paste(resized_image, paste_position)
|
54 |
+
return padded_image
|
55 |
+
else:
|
56 |
+
return original_image
|
57 |
+
|
58 |
+
def add_stroke(img, color=(255, 255, 255), stroke_radius=3):
|
59 |
+
# color in R, G, B format
|
60 |
+
if isinstance(img, Image.Image):
|
61 |
+
assert img.mode == "RGBA"
|
62 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2BGRA)
|
63 |
+
else:
|
64 |
+
assert img.shape[2] == 4
|
65 |
+
gray = img[:,:, 3]
|
66 |
+
ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY)
|
67 |
+
contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
68 |
+
res = cv2.drawContours(img, contours,-1, tuple(color)[::-1] + (255,), stroke_radius)
|
69 |
+
return Image.fromarray(cv2.cvtColor(res,cv2.COLOR_BGRA2RGBA))
|
70 |
+
|
71 |
+
def make_blob(image_size=(512, 512), sigma=0.2):
|
72 |
+
"""
|
73 |
+
make 2D blob image with:
|
74 |
+
I(x, y)=1-\exp \left(-\frac{(x-H / 2)^2+(y-W / 2)^2}{2 \sigma^2 HS}\right)
|
75 |
+
"""
|
76 |
+
import numpy as np
|
77 |
+
H, W = image_size
|
78 |
+
x = np.arange(0, W, 1, float)
|
79 |
+
y = np.arange(0, H, 1, float)
|
80 |
+
x, y = np.meshgrid(x, y)
|
81 |
+
x0 = W // 2
|
82 |
+
y0 = H // 2
|
83 |
+
img = 1 - np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2 * H * W))
|
84 |
+
return (img * 255).astype(np.uint8)
|
libs/sample.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from imagedream.camera_utils import get_camera_for_index
|
4 |
+
from imagedream.ldm.util import set_seed, add_random_background
|
5 |
+
from libs.base_utils import do_resize_content
|
6 |
+
from imagedream.ldm.models.diffusion.ddim import DDIMSampler
|
7 |
+
from torchvision import transforms as T
|
8 |
+
|
9 |
+
|
10 |
+
class ImageDreamDiffusion:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
model,
|
14 |
+
device,
|
15 |
+
dtype,
|
16 |
+
mode,
|
17 |
+
num_frames,
|
18 |
+
camera_views,
|
19 |
+
ref_position,
|
20 |
+
random_background=False,
|
21 |
+
offset_noise=False,
|
22 |
+
resize_rate=1,
|
23 |
+
image_size=256,
|
24 |
+
seed=1234,
|
25 |
+
) -> None:
|
26 |
+
assert mode in ["pixel", "local"]
|
27 |
+
size = image_size
|
28 |
+
self.seed = seed
|
29 |
+
batch_size = max(4, num_frames)
|
30 |
+
|
31 |
+
neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
|
32 |
+
uc = model.get_learned_conditioning([neg_texts]).to(device)
|
33 |
+
sampler = DDIMSampler(model)
|
34 |
+
|
35 |
+
# pre-compute camera matrices
|
36 |
+
camera = [get_camera_for_index(i).squeeze() for i in camera_views]
|
37 |
+
camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
|
38 |
+
camera = torch.stack(camera)
|
39 |
+
camera = camera.repeat(batch_size // num_frames, 1).to(device)
|
40 |
+
|
41 |
+
self.image_transform = T.Compose(
|
42 |
+
[
|
43 |
+
T.Resize((size, size)),
|
44 |
+
T.ToTensor(),
|
45 |
+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
46 |
+
]
|
47 |
+
)
|
48 |
+
self.dtype = dtype
|
49 |
+
self.ref_position = ref_position
|
50 |
+
self.mode = mode
|
51 |
+
self.random_background = random_background
|
52 |
+
self.resize_rate = resize_rate
|
53 |
+
self.num_frames = num_frames
|
54 |
+
self.size = size
|
55 |
+
self.device = device
|
56 |
+
self.batch_size = batch_size
|
57 |
+
self.model = model
|
58 |
+
self.sampler = sampler
|
59 |
+
self.uc = uc
|
60 |
+
self.camera = camera
|
61 |
+
self.offset_noise = offset_noise
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def i2i(
|
65 |
+
model,
|
66 |
+
image_size,
|
67 |
+
prompt,
|
68 |
+
uc,
|
69 |
+
sampler,
|
70 |
+
ip=None,
|
71 |
+
step=20,
|
72 |
+
scale=5.0,
|
73 |
+
batch_size=8,
|
74 |
+
ddim_eta=0.0,
|
75 |
+
dtype=torch.float32,
|
76 |
+
device="cuda",
|
77 |
+
camera=None,
|
78 |
+
num_frames=4,
|
79 |
+
pixel_control=False,
|
80 |
+
transform=None,
|
81 |
+
offset_noise=False,
|
82 |
+
):
|
83 |
+
""" The function supports additional image prompt.
|
84 |
+
Args:
|
85 |
+
model (_type_): the image dream model
|
86 |
+
image_size (_type_): size of diffusion output (standard 256)
|
87 |
+
prompt (_type_): text prompt for the image (prompt in type str)
|
88 |
+
uc (_type_): unconditional vector (tensor in shape [1, 77, 1024])
|
89 |
+
sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler
|
90 |
+
ip (Image, optional): the image prompt. Defaults to None.
|
91 |
+
step (int, optional): _description_. Defaults to 20.
|
92 |
+
scale (float, optional): _description_. Defaults to 7.5.
|
93 |
+
batch_size (int, optional): _description_. Defaults to 8.
|
94 |
+
ddim_eta (float, optional): _description_. Defaults to 0.0.
|
95 |
+
dtype (_type_, optional): _description_. Defaults to torch.float32.
|
96 |
+
device (str, optional): _description_. Defaults to "cuda".
|
97 |
+
camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
|
98 |
+
num_frames (int, optional): _num of frames (views) to generate
|
99 |
+
pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode
|
100 |
+
transform: Compose(
|
101 |
+
Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn)
|
102 |
+
ToTensor()
|
103 |
+
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
104 |
+
)
|
105 |
+
"""
|
106 |
+
ip_raw = ip
|
107 |
+
if type(prompt) != list:
|
108 |
+
prompt = [prompt]
|
109 |
+
with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
|
110 |
+
c = model.get_learned_conditioning(prompt).to(
|
111 |
+
device
|
112 |
+
) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
|
113 |
+
c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size
|
114 |
+
uc_ = {"context": uc.repeat(batch_size, 1, 1)}
|
115 |
+
|
116 |
+
if camera is not None:
|
117 |
+
c_["camera"] = uc_["camera"] = (
|
118 |
+
camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
|
119 |
+
)
|
120 |
+
c_["num_frames"] = uc_["num_frames"] = num_frames
|
121 |
+
|
122 |
+
if ip is not None:
|
123 |
+
ip_embed = model.get_learned_image_conditioning(ip).to(
|
124 |
+
device
|
125 |
+
) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
|
126 |
+
ip_ = ip_embed.repeat(batch_size, 1, 1)
|
127 |
+
c_["ip"] = ip_
|
128 |
+
uc_["ip"] = torch.zeros_like(ip_)
|
129 |
+
|
130 |
+
if pixel_control:
|
131 |
+
assert camera is not None
|
132 |
+
ip = transform(ip).to(
|
133 |
+
device
|
134 |
+
) # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00
|
135 |
+
ip_img = model.get_first_stage_encoding(
|
136 |
+
model.encode_first_stage(ip[None, :, :, :])
|
137 |
+
) # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55
|
138 |
+
c_["ip_img"] = ip_img
|
139 |
+
uc_["ip_img"] = torch.zeros_like(ip_img)
|
140 |
+
|
141 |
+
shape = [4, image_size // 8, image_size // 8] # [4, 32, 32]
|
142 |
+
if offset_noise:
|
143 |
+
ref = transform(ip_raw).to(device)
|
144 |
+
ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
|
145 |
+
ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
|
146 |
+
time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
|
147 |
+
x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
|
148 |
+
|
149 |
+
samples_ddim, _ = (
|
150 |
+
sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
|
151 |
+
S=step,
|
152 |
+
conditioning=c_,
|
153 |
+
batch_size=batch_size,
|
154 |
+
shape=shape,
|
155 |
+
verbose=False,
|
156 |
+
unconditional_guidance_scale=scale,
|
157 |
+
unconditional_conditioning=uc_,
|
158 |
+
eta=ddim_eta,
|
159 |
+
x_T=x_T if offset_noise else None,
|
160 |
+
)
|
161 |
+
)
|
162 |
+
|
163 |
+
x_sample = model.decode_first_stage(samples_ddim)
|
164 |
+
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
165 |
+
x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
|
166 |
+
|
167 |
+
return list(x_sample.astype(np.uint8))
|
168 |
+
|
169 |
+
def diffuse(self, t, ip, n_test=2):
|
170 |
+
set_seed(self.seed)
|
171 |
+
ip = do_resize_content(ip, self.resize_rate)
|
172 |
+
if self.random_background:
|
173 |
+
ip = add_random_background(ip)
|
174 |
+
|
175 |
+
images = []
|
176 |
+
for _ in range(n_test):
|
177 |
+
img = self.i2i(
|
178 |
+
self.model,
|
179 |
+
self.size,
|
180 |
+
t,
|
181 |
+
self.uc,
|
182 |
+
self.sampler,
|
183 |
+
ip=ip,
|
184 |
+
step=50,
|
185 |
+
scale=5,
|
186 |
+
batch_size=self.batch_size,
|
187 |
+
ddim_eta=0.0,
|
188 |
+
dtype=self.dtype,
|
189 |
+
device=self.device,
|
190 |
+
camera=self.camera,
|
191 |
+
num_frames=self.num_frames,
|
192 |
+
pixel_control=(self.mode == "pixel"),
|
193 |
+
transform=self.image_transform,
|
194 |
+
offset_noise=self.offset_noise,
|
195 |
+
)
|
196 |
+
img = np.concatenate(img, 1)
|
197 |
+
img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1)
|
198 |
+
images.append(img)
|
199 |
+
set_seed() # unset random and numpy seed
|
200 |
+
return images
|
201 |
+
|
202 |
+
|
203 |
+
class ImageDreamDiffusionStage2:
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
model,
|
207 |
+
device,
|
208 |
+
dtype,
|
209 |
+
num_frames,
|
210 |
+
camera_views,
|
211 |
+
ref_position,
|
212 |
+
random_background=False,
|
213 |
+
offset_noise=False,
|
214 |
+
resize_rate=1,
|
215 |
+
mode="pixel",
|
216 |
+
image_size=256,
|
217 |
+
seed=1234,
|
218 |
+
) -> None:
|
219 |
+
assert mode in ["pixel", "local"]
|
220 |
+
|
221 |
+
size = image_size
|
222 |
+
self.seed = seed
|
223 |
+
batch_size = max(4, num_frames)
|
224 |
+
|
225 |
+
neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
|
226 |
+
uc = model.get_learned_conditioning([neg_texts]).to(device)
|
227 |
+
sampler = DDIMSampler(model)
|
228 |
+
|
229 |
+
# pre-compute camera matrices
|
230 |
+
camera = [get_camera_for_index(i).squeeze() for i in camera_views]
|
231 |
+
if ref_position is not None:
|
232 |
+
camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
|
233 |
+
camera = torch.stack(camera)
|
234 |
+
camera = camera.repeat(batch_size // num_frames, 1).to(device)
|
235 |
+
|
236 |
+
self.image_transform = T.Compose(
|
237 |
+
[
|
238 |
+
T.Resize((size, size)),
|
239 |
+
T.ToTensor(),
|
240 |
+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
241 |
+
]
|
242 |
+
)
|
243 |
+
|
244 |
+
self.dtype = dtype
|
245 |
+
self.mode = mode
|
246 |
+
self.ref_position = ref_position
|
247 |
+
self.random_background = random_background
|
248 |
+
self.resize_rate = resize_rate
|
249 |
+
self.num_frames = num_frames
|
250 |
+
self.size = size
|
251 |
+
self.device = device
|
252 |
+
self.batch_size = batch_size
|
253 |
+
self.model = model
|
254 |
+
self.sampler = sampler
|
255 |
+
self.uc = uc
|
256 |
+
self.camera = camera
|
257 |
+
self.offset_noise = offset_noise
|
258 |
+
|
259 |
+
@staticmethod
|
260 |
+
def i2iStage2(
|
261 |
+
model,
|
262 |
+
image_size,
|
263 |
+
prompt,
|
264 |
+
uc,
|
265 |
+
sampler,
|
266 |
+
pixel_images,
|
267 |
+
ip=None,
|
268 |
+
step=20,
|
269 |
+
scale=5.0,
|
270 |
+
batch_size=8,
|
271 |
+
ddim_eta=0.0,
|
272 |
+
dtype=torch.float32,
|
273 |
+
device="cuda",
|
274 |
+
camera=None,
|
275 |
+
num_frames=4,
|
276 |
+
pixel_control=False,
|
277 |
+
transform=None,
|
278 |
+
offset_noise=False,
|
279 |
+
):
|
280 |
+
ip_raw = ip
|
281 |
+
if type(prompt) != list:
|
282 |
+
prompt = [prompt]
|
283 |
+
with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
|
284 |
+
c = model.get_learned_conditioning(prompt).to(
|
285 |
+
device
|
286 |
+
) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
|
287 |
+
c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size
|
288 |
+
uc_ = {"context": uc.repeat(batch_size, 1, 1)}
|
289 |
+
|
290 |
+
if camera is not None:
|
291 |
+
c_["camera"] = uc_["camera"] = (
|
292 |
+
camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
|
293 |
+
)
|
294 |
+
c_["num_frames"] = uc_["num_frames"] = num_frames
|
295 |
+
|
296 |
+
if ip is not None:
|
297 |
+
ip_embed = model.get_learned_image_conditioning(ip).to(
|
298 |
+
device
|
299 |
+
) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
|
300 |
+
ip_ = ip_embed.repeat(batch_size, 1, 1)
|
301 |
+
c_["ip"] = ip_
|
302 |
+
uc_["ip"] = torch.zeros_like(ip_)
|
303 |
+
|
304 |
+
if pixel_control:
|
305 |
+
assert camera is not None
|
306 |
+
|
307 |
+
transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images])
|
308 |
+
latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images))
|
309 |
+
|
310 |
+
c_["pixel_images"] = latent_pixel_images
|
311 |
+
uc_["pixel_images"] = torch.zeros_like(latent_pixel_images)
|
312 |
+
|
313 |
+
shape = [4, image_size // 8, image_size // 8] # [4, 32, 32]
|
314 |
+
if offset_noise:
|
315 |
+
ref = transform(ip_raw).to(device)
|
316 |
+
ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
|
317 |
+
ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
|
318 |
+
time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
|
319 |
+
x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
|
320 |
+
|
321 |
+
samples_ddim, _ = (
|
322 |
+
sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
|
323 |
+
S=step,
|
324 |
+
conditioning=c_,
|
325 |
+
batch_size=batch_size,
|
326 |
+
shape=shape,
|
327 |
+
verbose=False,
|
328 |
+
unconditional_guidance_scale=scale,
|
329 |
+
unconditional_conditioning=uc_,
|
330 |
+
eta=ddim_eta,
|
331 |
+
x_T=x_T if offset_noise else None,
|
332 |
+
)
|
333 |
+
)
|
334 |
+
x_sample = model.decode_first_stage(samples_ddim)
|
335 |
+
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
336 |
+
x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
|
337 |
+
|
338 |
+
return list(x_sample.astype(np.uint8))
|
339 |
+
|
340 |
+
@torch.no_grad()
|
341 |
+
def diffuse(self, t, ip, pixel_images, n_test=2):
|
342 |
+
set_seed(self.seed)
|
343 |
+
ip = do_resize_content(ip, self.resize_rate)
|
344 |
+
pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images]
|
345 |
+
|
346 |
+
if self.random_background:
|
347 |
+
bg_color = np.random.rand() * 255
|
348 |
+
ip = add_random_background(ip, bg_color)
|
349 |
+
pixel_images = [add_random_background(i, bg_color) for i in pixel_images]
|
350 |
+
|
351 |
+
images = []
|
352 |
+
for _ in range(n_test):
|
353 |
+
img = self.i2iStage2(
|
354 |
+
self.model,
|
355 |
+
self.size,
|
356 |
+
t,
|
357 |
+
self.uc,
|
358 |
+
self.sampler,
|
359 |
+
pixel_images=pixel_images,
|
360 |
+
ip=ip,
|
361 |
+
step=50,
|
362 |
+
scale=5,
|
363 |
+
batch_size=self.batch_size,
|
364 |
+
ddim_eta=0.0,
|
365 |
+
dtype=self.dtype,
|
366 |
+
device=self.device,
|
367 |
+
camera=self.camera,
|
368 |
+
num_frames=self.num_frames,
|
369 |
+
pixel_control=(self.mode == "pixel"),
|
370 |
+
transform=self.image_transform,
|
371 |
+
offset_noise=self.offset_noise,
|
372 |
+
)
|
373 |
+
img = np.concatenate(img, 1)
|
374 |
+
img = np.concatenate(
|
375 |
+
(img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]),
|
376 |
+
axis=1,
|
377 |
+
)
|
378 |
+
images.append(img)
|
379 |
+
set_seed() # unset random and numpy seed
|
380 |
+
return images
|
mesh.py
ADDED
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import trimesh
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from kiui.op import safe_normalize, dot
|
8 |
+
from kiui.typing import *
|
9 |
+
|
10 |
+
class Mesh:
|
11 |
+
"""
|
12 |
+
A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
|
13 |
+
|
14 |
+
Note:
|
15 |
+
This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
|
16 |
+
"""
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
v: Optional[Tensor] = None,
|
20 |
+
f: Optional[Tensor] = None,
|
21 |
+
vn: Optional[Tensor] = None,
|
22 |
+
fn: Optional[Tensor] = None,
|
23 |
+
vt: Optional[Tensor] = None,
|
24 |
+
ft: Optional[Tensor] = None,
|
25 |
+
vc: Optional[Tensor] = None, # vertex color
|
26 |
+
albedo: Optional[Tensor] = None,
|
27 |
+
metallicRoughness: Optional[Tensor] = None,
|
28 |
+
device: Optional[torch.device] = None,
|
29 |
+
):
|
30 |
+
"""Init a mesh directly using all attributes.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
|
34 |
+
f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
|
35 |
+
vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None.
|
36 |
+
fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None.
|
37 |
+
vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None.
|
38 |
+
ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None.
|
39 |
+
vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None.
|
40 |
+
albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None.
|
41 |
+
metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None.
|
42 |
+
device (Optional[torch.device]): torch device. Defaults to None.
|
43 |
+
"""
|
44 |
+
self.device = device
|
45 |
+
self.v = v
|
46 |
+
self.vn = vn
|
47 |
+
self.vt = vt
|
48 |
+
self.f = f
|
49 |
+
self.fn = fn
|
50 |
+
self.ft = ft
|
51 |
+
# will first see if there is vertex color to use
|
52 |
+
self.vc = vc
|
53 |
+
# only support a single albedo image
|
54 |
+
self.albedo = albedo
|
55 |
+
# pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]
|
56 |
+
# ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html
|
57 |
+
self.metallicRoughness = metallicRoughness
|
58 |
+
|
59 |
+
self.ori_center = 0
|
60 |
+
self.ori_scale = 1
|
61 |
+
|
62 |
+
@classmethod
|
63 |
+
def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
|
64 |
+
"""load mesh from path.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
path (str): path to mesh file, supports ply, obj, glb.
|
68 |
+
clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
|
69 |
+
resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True.
|
70 |
+
renormal (bool, optional): re-calc the vertex normals. Defaults to True.
|
71 |
+
retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
|
72 |
+
bound (float, optional): bound to resize. Defaults to 0.9.
|
73 |
+
front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'.
|
74 |
+
device (torch.device, optional): torch device. Defaults to None.
|
75 |
+
|
76 |
+
Note:
|
77 |
+
a ``device`` keyword argument can be provided to specify the torch device.
|
78 |
+
If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Mesh: the loaded Mesh object.
|
82 |
+
"""
|
83 |
+
# obj supports face uv
|
84 |
+
if path.endswith(".obj"):
|
85 |
+
mesh = cls.load_obj(path, **kwargs)
|
86 |
+
# trimesh only supports vertex uv, but can load more formats
|
87 |
+
else:
|
88 |
+
mesh = cls.load_trimesh(path, **kwargs)
|
89 |
+
|
90 |
+
# clean
|
91 |
+
if clean:
|
92 |
+
from kiui.mesh_utils import clean_mesh
|
93 |
+
vertices = mesh.v.detach().cpu().numpy()
|
94 |
+
triangles = mesh.f.detach().cpu().numpy()
|
95 |
+
vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
|
96 |
+
mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device)
|
97 |
+
mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device)
|
98 |
+
|
99 |
+
print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
|
100 |
+
# auto-normalize
|
101 |
+
if resize:
|
102 |
+
mesh.auto_size(bound=bound)
|
103 |
+
# auto-fix normal
|
104 |
+
if renormal or mesh.vn is None:
|
105 |
+
mesh.auto_normal()
|
106 |
+
print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
|
107 |
+
# auto-fix texcoords
|
108 |
+
if retex or (mesh.albedo is not None and mesh.vt is None):
|
109 |
+
mesh.auto_uv(cache_path=path)
|
110 |
+
print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
|
111 |
+
|
112 |
+
# rotate front dir to +z
|
113 |
+
if front_dir != "+z":
|
114 |
+
# axis switch
|
115 |
+
if "-z" in front_dir:
|
116 |
+
T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
|
117 |
+
elif "+x" in front_dir:
|
118 |
+
T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
|
119 |
+
elif "-x" in front_dir:
|
120 |
+
T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
|
121 |
+
elif "+y" in front_dir:
|
122 |
+
T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
|
123 |
+
elif "-y" in front_dir:
|
124 |
+
T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
|
125 |
+
else:
|
126 |
+
T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
127 |
+
# rotation (how many 90 degrees)
|
128 |
+
if '1' in front_dir:
|
129 |
+
T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
130 |
+
elif '2' in front_dir:
|
131 |
+
T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
132 |
+
elif '3' in front_dir:
|
133 |
+
T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
134 |
+
mesh.v @= T
|
135 |
+
mesh.vn @= T
|
136 |
+
|
137 |
+
return mesh
|
138 |
+
|
139 |
+
# load from obj file
|
140 |
+
@classmethod
|
141 |
+
def load_obj(cls, path, albedo_path=None, device=None):
|
142 |
+
"""load an ``obj`` mesh.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
path (str): path to mesh.
|
146 |
+
albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
|
147 |
+
device (torch.device, optional): torch device. Defaults to None.
|
148 |
+
|
149 |
+
Note:
|
150 |
+
We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
|
151 |
+
The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
Mesh: the loaded Mesh object.
|
155 |
+
"""
|
156 |
+
assert os.path.splitext(path)[-1] == ".obj"
|
157 |
+
|
158 |
+
mesh = cls()
|
159 |
+
|
160 |
+
# device
|
161 |
+
if device is None:
|
162 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
163 |
+
|
164 |
+
mesh.device = device
|
165 |
+
|
166 |
+
# load obj
|
167 |
+
with open(path, "r") as f:
|
168 |
+
lines = f.readlines()
|
169 |
+
|
170 |
+
def parse_f_v(fv):
|
171 |
+
# pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
|
172 |
+
# supported forms:
|
173 |
+
# f v1 v2 v3
|
174 |
+
# f v1/vt1 v2/vt2 v3/vt3
|
175 |
+
# f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
|
176 |
+
# f v1//vn1 v2//vn2 v3//vn3
|
177 |
+
xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
|
178 |
+
xs.extend([-1] * (3 - len(xs)))
|
179 |
+
return xs[0], xs[1], xs[2]
|
180 |
+
|
181 |
+
vertices, texcoords, normals = [], [], []
|
182 |
+
faces, tfaces, nfaces = [], [], []
|
183 |
+
mtl_path = None
|
184 |
+
|
185 |
+
for line in lines:
|
186 |
+
split_line = line.split()
|
187 |
+
# empty line
|
188 |
+
if len(split_line) == 0:
|
189 |
+
continue
|
190 |
+
prefix = split_line[0].lower()
|
191 |
+
# mtllib
|
192 |
+
if prefix == "mtllib":
|
193 |
+
mtl_path = split_line[1]
|
194 |
+
# usemtl
|
195 |
+
elif prefix == "usemtl":
|
196 |
+
pass # ignored
|
197 |
+
# v/vn/vt
|
198 |
+
elif prefix == "v":
|
199 |
+
vertices.append([float(v) for v in split_line[1:]])
|
200 |
+
elif prefix == "vn":
|
201 |
+
normals.append([float(v) for v in split_line[1:]])
|
202 |
+
elif prefix == "vt":
|
203 |
+
val = [float(v) for v in split_line[1:]]
|
204 |
+
texcoords.append([val[0], 1.0 - val[1]])
|
205 |
+
elif prefix == "f":
|
206 |
+
vs = split_line[1:]
|
207 |
+
nv = len(vs)
|
208 |
+
v0, t0, n0 = parse_f_v(vs[0])
|
209 |
+
for i in range(nv - 2): # triangulate (assume vertices are ordered)
|
210 |
+
v1, t1, n1 = parse_f_v(vs[i + 1])
|
211 |
+
v2, t2, n2 = parse_f_v(vs[i + 2])
|
212 |
+
faces.append([v0, v1, v2])
|
213 |
+
tfaces.append([t0, t1, t2])
|
214 |
+
nfaces.append([n0, n1, n2])
|
215 |
+
|
216 |
+
mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
|
217 |
+
mesh.vt = (
|
218 |
+
torch.tensor(texcoords, dtype=torch.float32, device=device)
|
219 |
+
if len(texcoords) > 0
|
220 |
+
else None
|
221 |
+
)
|
222 |
+
mesh.vn = (
|
223 |
+
torch.tensor(normals, dtype=torch.float32, device=device)
|
224 |
+
if len(normals) > 0
|
225 |
+
else None
|
226 |
+
)
|
227 |
+
|
228 |
+
mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
|
229 |
+
mesh.ft = (
|
230 |
+
torch.tensor(tfaces, dtype=torch.int32, device=device)
|
231 |
+
if len(texcoords) > 0
|
232 |
+
else None
|
233 |
+
)
|
234 |
+
mesh.fn = (
|
235 |
+
torch.tensor(nfaces, dtype=torch.int32, device=device)
|
236 |
+
if len(normals) > 0
|
237 |
+
else None
|
238 |
+
)
|
239 |
+
|
240 |
+
# see if there is vertex color
|
241 |
+
use_vertex_color = False
|
242 |
+
if mesh.v.shape[1] == 6:
|
243 |
+
use_vertex_color = True
|
244 |
+
mesh.vc = mesh.v[:, 3:]
|
245 |
+
mesh.v = mesh.v[:, :3]
|
246 |
+
print(f"[load_obj] use vertex color: {mesh.vc.shape}")
|
247 |
+
|
248 |
+
# try to load texture image
|
249 |
+
if not use_vertex_color:
|
250 |
+
# try to retrieve mtl file
|
251 |
+
mtl_path_candidates = []
|
252 |
+
if mtl_path is not None:
|
253 |
+
mtl_path_candidates.append(mtl_path)
|
254 |
+
mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
|
255 |
+
mtl_path_candidates.append(path.replace(".obj", ".mtl"))
|
256 |
+
|
257 |
+
mtl_path = None
|
258 |
+
for candidate in mtl_path_candidates:
|
259 |
+
if os.path.exists(candidate):
|
260 |
+
mtl_path = candidate
|
261 |
+
break
|
262 |
+
|
263 |
+
# if albedo_path is not provided, try retrieve it from mtl
|
264 |
+
metallic_path = None
|
265 |
+
roughness_path = None
|
266 |
+
if mtl_path is not None and albedo_path is None:
|
267 |
+
with open(mtl_path, "r") as f:
|
268 |
+
lines = f.readlines()
|
269 |
+
|
270 |
+
for line in lines:
|
271 |
+
split_line = line.split()
|
272 |
+
# empty line
|
273 |
+
if len(split_line) == 0:
|
274 |
+
continue
|
275 |
+
prefix = split_line[0]
|
276 |
+
|
277 |
+
if "map_Kd" in prefix:
|
278 |
+
# assume relative path!
|
279 |
+
albedo_path = os.path.join(os.path.dirname(path), split_line[1])
|
280 |
+
print(f"[load_obj] use texture from: {albedo_path}")
|
281 |
+
elif "map_Pm" in prefix:
|
282 |
+
metallic_path = os.path.join(os.path.dirname(path), split_line[1])
|
283 |
+
elif "map_Pr" in prefix:
|
284 |
+
roughness_path = os.path.join(os.path.dirname(path), split_line[1])
|
285 |
+
|
286 |
+
# still not found albedo_path, or the path doesn't exist
|
287 |
+
if albedo_path is None or not os.path.exists(albedo_path):
|
288 |
+
# init an empty texture
|
289 |
+
print(f"[load_obj] init empty albedo!")
|
290 |
+
# albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
|
291 |
+
albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color
|
292 |
+
else:
|
293 |
+
albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
|
294 |
+
albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
|
295 |
+
albedo = albedo.astype(np.float32) / 255
|
296 |
+
print(f"[load_obj] load texture: {albedo.shape}")
|
297 |
+
|
298 |
+
mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
|
299 |
+
|
300 |
+
# try to load metallic and roughness
|
301 |
+
if metallic_path is not None and roughness_path is not None:
|
302 |
+
print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}")
|
303 |
+
metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED)
|
304 |
+
metallic = metallic.astype(np.float32) / 255
|
305 |
+
roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED)
|
306 |
+
roughness = roughness.astype(np.float32) / 255
|
307 |
+
metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1)
|
308 |
+
|
309 |
+
mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
|
310 |
+
|
311 |
+
return mesh
|
312 |
+
|
313 |
+
@classmethod
|
314 |
+
def load_trimesh(cls, path, device=None):
|
315 |
+
"""load a mesh using ``trimesh.load()``.
|
316 |
+
|
317 |
+
Can load various formats like ``glb`` and serves as a fallback.
|
318 |
+
|
319 |
+
Note:
|
320 |
+
We will try to merge all meshes if the glb contains more than one,
|
321 |
+
but **this may cause the texture to lose**, since we only support one texture image!
|
322 |
+
|
323 |
+
Args:
|
324 |
+
path (str): path to the mesh file.
|
325 |
+
device (torch.device, optional): torch device. Defaults to None.
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
Mesh: the loaded Mesh object.
|
329 |
+
"""
|
330 |
+
mesh = cls()
|
331 |
+
|
332 |
+
# device
|
333 |
+
if device is None:
|
334 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
335 |
+
|
336 |
+
mesh.device = device
|
337 |
+
|
338 |
+
# use trimesh to load ply/glb
|
339 |
+
_data = trimesh.load(path)
|
340 |
+
if isinstance(_data, trimesh.Scene):
|
341 |
+
if len(_data.geometry) == 1:
|
342 |
+
_mesh = list(_data.geometry.values())[0]
|
343 |
+
else:
|
344 |
+
print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.")
|
345 |
+
_concat = []
|
346 |
+
# loop the scene graph and apply transform to each mesh
|
347 |
+
scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}}
|
348 |
+
for k, v in scene_graph.items():
|
349 |
+
name = v['geometry']
|
350 |
+
if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh):
|
351 |
+
transform = v['transform']
|
352 |
+
_concat.append(_data.geometry[name].apply_transform(transform))
|
353 |
+
_mesh = trimesh.util.concatenate(_concat)
|
354 |
+
else:
|
355 |
+
_mesh = _data
|
356 |
+
|
357 |
+
if _mesh.visual.kind == 'vertex':
|
358 |
+
vertex_colors = _mesh.visual.vertex_colors
|
359 |
+
vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
|
360 |
+
mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
|
361 |
+
print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
|
362 |
+
elif _mesh.visual.kind == 'texture':
|
363 |
+
_material = _mesh.visual.material
|
364 |
+
if isinstance(_material, trimesh.visual.material.PBRMaterial):
|
365 |
+
texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
|
366 |
+
# load metallicRoughness if present
|
367 |
+
if _material.metallicRoughnessTexture is not None:
|
368 |
+
metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255
|
369 |
+
mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
|
370 |
+
elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
|
371 |
+
texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
|
372 |
+
else:
|
373 |
+
raise NotImplementedError(f"material type {type(_material)} not supported!")
|
374 |
+
mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous()
|
375 |
+
print(f"[load_trimesh] load texture: {texture.shape}")
|
376 |
+
else:
|
377 |
+
texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
|
378 |
+
mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
|
379 |
+
print(f"[load_trimesh] failed to load texture.")
|
380 |
+
|
381 |
+
vertices = _mesh.vertices
|
382 |
+
|
383 |
+
try:
|
384 |
+
texcoords = _mesh.visual.uv
|
385 |
+
texcoords[:, 1] = 1 - texcoords[:, 1]
|
386 |
+
except Exception as e:
|
387 |
+
texcoords = None
|
388 |
+
|
389 |
+
try:
|
390 |
+
normals = _mesh.vertex_normals
|
391 |
+
except Exception as e:
|
392 |
+
normals = None
|
393 |
+
|
394 |
+
# trimesh only support vertex uv...
|
395 |
+
faces = tfaces = nfaces = _mesh.faces
|
396 |
+
|
397 |
+
mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
|
398 |
+
mesh.vt = (
|
399 |
+
torch.tensor(texcoords, dtype=torch.float32, device=device)
|
400 |
+
if texcoords is not None
|
401 |
+
else None
|
402 |
+
)
|
403 |
+
mesh.vn = (
|
404 |
+
torch.tensor(normals, dtype=torch.float32, device=device)
|
405 |
+
if normals is not None
|
406 |
+
else None
|
407 |
+
)
|
408 |
+
|
409 |
+
mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
|
410 |
+
mesh.ft = (
|
411 |
+
torch.tensor(tfaces, dtype=torch.int32, device=device)
|
412 |
+
if texcoords is not None
|
413 |
+
else None
|
414 |
+
)
|
415 |
+
mesh.fn = (
|
416 |
+
torch.tensor(nfaces, dtype=torch.int32, device=device)
|
417 |
+
if normals is not None
|
418 |
+
else None
|
419 |
+
)
|
420 |
+
|
421 |
+
return mesh
|
422 |
+
|
423 |
+
# sample surface (using trimesh)
|
424 |
+
def sample_surface(self, count: int):
|
425 |
+
"""sample points on the surface of the mesh.
|
426 |
+
|
427 |
+
Args:
|
428 |
+
count (int): number of points to sample.
|
429 |
+
|
430 |
+
Returns:
|
431 |
+
torch.Tensor: the sampled points, float [count, 3].
|
432 |
+
"""
|
433 |
+
_mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy())
|
434 |
+
points, face_idx = trimesh.sample.sample_surface(_mesh, count)
|
435 |
+
points = torch.from_numpy(points).float().to(self.device)
|
436 |
+
return points
|
437 |
+
|
438 |
+
# aabb
|
439 |
+
def aabb(self):
|
440 |
+
"""get the axis-aligned bounding box of the mesh.
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
|
444 |
+
"""
|
445 |
+
return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
|
446 |
+
|
447 |
+
# unit size
|
448 |
+
@torch.no_grad()
|
449 |
+
def auto_size(self, bound=0.9):
|
450 |
+
"""auto resize the mesh.
|
451 |
+
|
452 |
+
Args:
|
453 |
+
bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
|
454 |
+
"""
|
455 |
+
vmin, vmax = self.aabb()
|
456 |
+
self.ori_center = (vmax + vmin) / 2
|
457 |
+
self.ori_scale = 2 * bound / torch.max(vmax - vmin).item()
|
458 |
+
self.v = (self.v - self.ori_center) * self.ori_scale
|
459 |
+
|
460 |
+
def auto_normal(self):
|
461 |
+
"""auto calculate the vertex normals.
|
462 |
+
"""
|
463 |
+
i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
|
464 |
+
v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
|
465 |
+
|
466 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
467 |
+
|
468 |
+
# Splat face normals to vertices
|
469 |
+
vn = torch.zeros_like(self.v)
|
470 |
+
vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
471 |
+
vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
472 |
+
vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
473 |
+
|
474 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
475 |
+
vn = torch.where(
|
476 |
+
dot(vn, vn) > 1e-20,
|
477 |
+
vn,
|
478 |
+
torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
|
479 |
+
)
|
480 |
+
vn = safe_normalize(vn)
|
481 |
+
|
482 |
+
self.vn = vn
|
483 |
+
self.fn = self.f
|
484 |
+
|
485 |
+
def auto_uv(self, cache_path=None, vmap=True):
|
486 |
+
"""auto calculate the uv coordinates.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
|
490 |
+
vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
|
491 |
+
Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
|
492 |
+
"""
|
493 |
+
# try to load cache
|
494 |
+
if cache_path is not None:
|
495 |
+
cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
|
496 |
+
if cache_path is not None and os.path.exists(cache_path):
|
497 |
+
data = np.load(cache_path)
|
498 |
+
vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
|
499 |
+
else:
|
500 |
+
import xatlas
|
501 |
+
|
502 |
+
v_np = self.v.detach().cpu().numpy()
|
503 |
+
f_np = self.f.detach().int().cpu().numpy()
|
504 |
+
atlas = xatlas.Atlas()
|
505 |
+
atlas.add_mesh(v_np, f_np)
|
506 |
+
chart_options = xatlas.ChartOptions()
|
507 |
+
# chart_options.max_iterations = 4
|
508 |
+
atlas.generate(chart_options=chart_options)
|
509 |
+
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
|
510 |
+
|
511 |
+
# save to cache
|
512 |
+
if cache_path is not None:
|
513 |
+
np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
|
514 |
+
|
515 |
+
vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
|
516 |
+
ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
|
517 |
+
self.vt = vt
|
518 |
+
self.ft = ft
|
519 |
+
|
520 |
+
if vmap:
|
521 |
+
vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
|
522 |
+
self.align_v_to_vt(vmapping)
|
523 |
+
|
524 |
+
def align_v_to_vt(self, vmapping=None):
|
525 |
+
""" remap v/f and vn/fn to vt/ft.
|
526 |
+
|
527 |
+
Args:
|
528 |
+
vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
|
529 |
+
"""
|
530 |
+
if vmapping is None:
|
531 |
+
ft = self.ft.view(-1).long()
|
532 |
+
f = self.f.view(-1).long()
|
533 |
+
vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
|
534 |
+
vmapping[ft] = f # scatter, randomly choose one if index is not unique
|
535 |
+
|
536 |
+
self.v = self.v[vmapping]
|
537 |
+
self.f = self.ft
|
538 |
+
|
539 |
+
if self.vn is not None:
|
540 |
+
self.vn = self.vn[vmapping]
|
541 |
+
self.fn = self.ft
|
542 |
+
|
543 |
+
def to(self, device):
|
544 |
+
"""move all tensor attributes to device.
|
545 |
+
|
546 |
+
Args:
|
547 |
+
device (torch.device): target device.
|
548 |
+
|
549 |
+
Returns:
|
550 |
+
Mesh: self.
|
551 |
+
"""
|
552 |
+
self.device = device
|
553 |
+
for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]:
|
554 |
+
tensor = getattr(self, name)
|
555 |
+
if tensor is not None:
|
556 |
+
setattr(self, name, tensor.to(device))
|
557 |
+
return self
|
558 |
+
|
559 |
+
def write(self, path):
|
560 |
+
"""write the mesh to a path.
|
561 |
+
|
562 |
+
Args:
|
563 |
+
path (str): path to write, supports ply, obj and glb.
|
564 |
+
"""
|
565 |
+
if path.endswith(".ply"):
|
566 |
+
self.write_ply(path)
|
567 |
+
elif path.endswith(".obj"):
|
568 |
+
self.write_obj(path)
|
569 |
+
elif path.endswith(".glb") or path.endswith(".gltf"):
|
570 |
+
self.write_glb(path)
|
571 |
+
else:
|
572 |
+
raise NotImplementedError(f"format {path} not supported!")
|
573 |
+
|
574 |
+
def write_ply(self, path):
|
575 |
+
"""write the mesh in ply format. Only for geometry!
|
576 |
+
|
577 |
+
Args:
|
578 |
+
path (str): path to write.
|
579 |
+
"""
|
580 |
+
|
581 |
+
if self.albedo is not None:
|
582 |
+
print(f'[WARN] ply format does not support exporting texture, will ignore!')
|
583 |
+
|
584 |
+
v_np = self.v.detach().cpu().numpy()
|
585 |
+
f_np = self.f.detach().cpu().numpy()
|
586 |
+
|
587 |
+
_mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
|
588 |
+
_mesh.export(path)
|
589 |
+
|
590 |
+
|
591 |
+
def write_glb(self, path):
|
592 |
+
"""write the mesh in glb/gltf format.
|
593 |
+
This will create a scene with a single mesh.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
path (str): path to write.
|
597 |
+
"""
|
598 |
+
|
599 |
+
# assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
|
600 |
+
if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
|
601 |
+
self.align_v_to_vt()
|
602 |
+
|
603 |
+
import pygltflib
|
604 |
+
|
605 |
+
f_np = self.f.detach().cpu().numpy().astype(np.uint32)
|
606 |
+
f_np_blob = f_np.flatten().tobytes()
|
607 |
+
|
608 |
+
v_np = self.v.detach().cpu().numpy().astype(np.float32)
|
609 |
+
v_np_blob = v_np.tobytes()
|
610 |
+
|
611 |
+
blob = f_np_blob + v_np_blob
|
612 |
+
byteOffset = len(blob)
|
613 |
+
|
614 |
+
# base mesh
|
615 |
+
gltf = pygltflib.GLTF2(
|
616 |
+
scene=0,
|
617 |
+
scenes=[pygltflib.Scene(nodes=[0])],
|
618 |
+
nodes=[pygltflib.Node(mesh=0)],
|
619 |
+
meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
|
620 |
+
# indices to accessors (0 is triangles)
|
621 |
+
attributes=pygltflib.Attributes(
|
622 |
+
POSITION=1,
|
623 |
+
),
|
624 |
+
indices=0,
|
625 |
+
)])],
|
626 |
+
buffers=[
|
627 |
+
pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
|
628 |
+
],
|
629 |
+
# buffer view (based on dtype)
|
630 |
+
bufferViews=[
|
631 |
+
# triangles; as flatten (element) array
|
632 |
+
pygltflib.BufferView(
|
633 |
+
buffer=0,
|
634 |
+
byteLength=len(f_np_blob),
|
635 |
+
target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
|
636 |
+
),
|
637 |
+
# positions; as vec3 array
|
638 |
+
pygltflib.BufferView(
|
639 |
+
buffer=0,
|
640 |
+
byteOffset=len(f_np_blob),
|
641 |
+
byteLength=len(v_np_blob),
|
642 |
+
byteStride=12, # vec3
|
643 |
+
target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
|
644 |
+
),
|
645 |
+
],
|
646 |
+
accessors=[
|
647 |
+
# 0 = triangles
|
648 |
+
pygltflib.Accessor(
|
649 |
+
bufferView=0,
|
650 |
+
componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
|
651 |
+
count=f_np.size,
|
652 |
+
type=pygltflib.SCALAR,
|
653 |
+
max=[int(f_np.max())],
|
654 |
+
min=[int(f_np.min())],
|
655 |
+
),
|
656 |
+
# 1 = positions
|
657 |
+
pygltflib.Accessor(
|
658 |
+
bufferView=1,
|
659 |
+
componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
|
660 |
+
count=len(v_np),
|
661 |
+
type=pygltflib.VEC3,
|
662 |
+
max=v_np.max(axis=0).tolist(),
|
663 |
+
min=v_np.min(axis=0).tolist(),
|
664 |
+
),
|
665 |
+
],
|
666 |
+
)
|
667 |
+
|
668 |
+
# append texture info
|
669 |
+
if self.vt is not None:
|
670 |
+
|
671 |
+
vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
|
672 |
+
vt_np_blob = vt_np.tobytes()
|
673 |
+
|
674 |
+
albedo = self.albedo.detach().cpu().numpy()
|
675 |
+
albedo = (albedo * 255).astype(np.uint8)
|
676 |
+
albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
|
677 |
+
albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
|
678 |
+
|
679 |
+
# update primitive
|
680 |
+
gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
|
681 |
+
gltf.meshes[0].primitives[0].material = 0
|
682 |
+
|
683 |
+
# update materials
|
684 |
+
gltf.materials.append(pygltflib.Material(
|
685 |
+
pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
|
686 |
+
baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
|
687 |
+
metallicFactor=0.0,
|
688 |
+
roughnessFactor=1.0,
|
689 |
+
),
|
690 |
+
alphaMode=pygltflib.OPAQUE,
|
691 |
+
alphaCutoff=None,
|
692 |
+
doubleSided=True,
|
693 |
+
))
|
694 |
+
|
695 |
+
gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
|
696 |
+
gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
|
697 |
+
gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
|
698 |
+
|
699 |
+
# update buffers
|
700 |
+
gltf.bufferViews.append(
|
701 |
+
# index = 2, texcoords; as vec2 array
|
702 |
+
pygltflib.BufferView(
|
703 |
+
buffer=0,
|
704 |
+
byteOffset=byteOffset,
|
705 |
+
byteLength=len(vt_np_blob),
|
706 |
+
byteStride=8, # vec2
|
707 |
+
target=pygltflib.ARRAY_BUFFER,
|
708 |
+
)
|
709 |
+
)
|
710 |
+
|
711 |
+
gltf.accessors.append(
|
712 |
+
# 2 = texcoords
|
713 |
+
pygltflib.Accessor(
|
714 |
+
bufferView=2,
|
715 |
+
componentType=pygltflib.FLOAT,
|
716 |
+
count=len(vt_np),
|
717 |
+
type=pygltflib.VEC2,
|
718 |
+
max=vt_np.max(axis=0).tolist(),
|
719 |
+
min=vt_np.min(axis=0).tolist(),
|
720 |
+
)
|
721 |
+
)
|
722 |
+
|
723 |
+
blob += vt_np_blob
|
724 |
+
byteOffset += len(vt_np_blob)
|
725 |
+
|
726 |
+
gltf.bufferViews.append(
|
727 |
+
# index = 3, albedo texture; as none target
|
728 |
+
pygltflib.BufferView(
|
729 |
+
buffer=0,
|
730 |
+
byteOffset=byteOffset,
|
731 |
+
byteLength=len(albedo_blob),
|
732 |
+
)
|
733 |
+
)
|
734 |
+
|
735 |
+
blob += albedo_blob
|
736 |
+
byteOffset += len(albedo_blob)
|
737 |
+
|
738 |
+
gltf.buffers[0].byteLength = byteOffset
|
739 |
+
|
740 |
+
# append metllic roughness
|
741 |
+
if self.metallicRoughness is not None:
|
742 |
+
metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
|
743 |
+
metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
|
744 |
+
metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
|
745 |
+
metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
|
746 |
+
|
747 |
+
# update texture definition
|
748 |
+
gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
|
749 |
+
gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
|
750 |
+
gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
|
751 |
+
|
752 |
+
gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
|
753 |
+
gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
|
754 |
+
gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
|
755 |
+
|
756 |
+
# update buffers
|
757 |
+
gltf.bufferViews.append(
|
758 |
+
# index = 4, metallicRoughness texture; as none target
|
759 |
+
pygltflib.BufferView(
|
760 |
+
buffer=0,
|
761 |
+
byteOffset=byteOffset,
|
762 |
+
byteLength=len(metallicRoughness_blob),
|
763 |
+
)
|
764 |
+
)
|
765 |
+
|
766 |
+
blob += metallicRoughness_blob
|
767 |
+
byteOffset += len(metallicRoughness_blob)
|
768 |
+
|
769 |
+
gltf.buffers[0].byteLength = byteOffset
|
770 |
+
|
771 |
+
|
772 |
+
# set actual data
|
773 |
+
gltf.set_binary_blob(blob)
|
774 |
+
|
775 |
+
# glb = b"".join(gltf.save_to_bytes())
|
776 |
+
gltf.save(path)
|
777 |
+
|
778 |
+
|
779 |
+
def write_obj(self, path):
|
780 |
+
"""write the mesh in obj format. Will also write the texture and mtl files.
|
781 |
+
|
782 |
+
Args:
|
783 |
+
path (str): path to write.
|
784 |
+
"""
|
785 |
+
|
786 |
+
mtl_path = path.replace(".obj", ".mtl")
|
787 |
+
albedo_path = path.replace(".obj", "_albedo.png")
|
788 |
+
metallic_path = path.replace(".obj", "_metallic.png")
|
789 |
+
roughness_path = path.replace(".obj", "_roughness.png")
|
790 |
+
|
791 |
+
v_np = self.v.detach().cpu().numpy()
|
792 |
+
vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
|
793 |
+
vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
|
794 |
+
f_np = self.f.detach().cpu().numpy()
|
795 |
+
ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
|
796 |
+
fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
|
797 |
+
|
798 |
+
with open(path, "w") as fp:
|
799 |
+
fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
|
800 |
+
|
801 |
+
for v in v_np:
|
802 |
+
fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
|
803 |
+
|
804 |
+
if vt_np is not None:
|
805 |
+
for v in vt_np:
|
806 |
+
fp.write(f"vt {v[0]} {1 - v[1]} \n")
|
807 |
+
|
808 |
+
if vn_np is not None:
|
809 |
+
for v in vn_np:
|
810 |
+
fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
|
811 |
+
|
812 |
+
fp.write(f"usemtl defaultMat \n")
|
813 |
+
for i in range(len(f_np)):
|
814 |
+
fp.write(
|
815 |
+
f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
|
816 |
+
{f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
|
817 |
+
{f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
|
818 |
+
)
|
819 |
+
|
820 |
+
with open(mtl_path, "w") as fp:
|
821 |
+
fp.write(f"newmtl defaultMat \n")
|
822 |
+
fp.write(f"Ka 1 1 1 \n")
|
823 |
+
fp.write(f"Kd 1 1 1 \n")
|
824 |
+
fp.write(f"Ks 0 0 0 \n")
|
825 |
+
fp.write(f"Tr 1 \n")
|
826 |
+
fp.write(f"illum 1 \n")
|
827 |
+
fp.write(f"Ns 0 \n")
|
828 |
+
if self.albedo is not None:
|
829 |
+
fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
|
830 |
+
if self.metallicRoughness is not None:
|
831 |
+
# ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
|
832 |
+
fp.write(f"map_Pm {os.path.basename(metallic_path)} \n")
|
833 |
+
fp.write(f"map_Pr {os.path.basename(roughness_path)} \n")
|
834 |
+
|
835 |
+
if self.albedo is not None:
|
836 |
+
albedo = self.albedo.detach().cpu().numpy()
|
837 |
+
albedo = (albedo * 255).astype(np.uint8)
|
838 |
+
cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
|
839 |
+
|
840 |
+
if self.metallicRoughness is not None:
|
841 |
+
metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
|
842 |
+
metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
|
843 |
+
cv2.imwrite(metallic_path, metallicRoughness[..., 2])
|
844 |
+
cv2.imwrite(roughness_path, metallicRoughness[..., 1])
|
845 |
+
|
model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from model.crm.model import CRM
|
model/archs/__init__.py
ADDED
File without changes
|
model/archs/decoders/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
model/archs/decoders/shape_texture_net.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class TetTexNet(nn.Module):
|
7 |
+
def __init__(self, plane_reso=64, padding=0.1, fea_concat=True):
|
8 |
+
super().__init__()
|
9 |
+
# self.c_dim = c_dim
|
10 |
+
self.plane_reso = plane_reso
|
11 |
+
self.padding = padding
|
12 |
+
self.fea_concat = fea_concat
|
13 |
+
|
14 |
+
def forward(self, rolled_out_feature, query):
|
15 |
+
# rolled_out_feature: rolled-out triplane feature
|
16 |
+
# query: queried xyz coordinates (should be scaled consistently to ptr cloud)
|
17 |
+
|
18 |
+
plane_reso = self.plane_reso
|
19 |
+
|
20 |
+
triplane_feature = dict()
|
21 |
+
triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso]
|
22 |
+
triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso]
|
23 |
+
triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:]
|
24 |
+
|
25 |
+
query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy')
|
26 |
+
query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz')
|
27 |
+
query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx')
|
28 |
+
|
29 |
+
if self.fea_concat:
|
30 |
+
query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1)
|
31 |
+
else:
|
32 |
+
query_feature = query_feature_xy + query_feature_yz + query_feature_zx
|
33 |
+
|
34 |
+
output = query_feature.permute(0, 2, 1)
|
35 |
+
|
36 |
+
return output
|
37 |
+
|
38 |
+
# uses values from plane_feature and pixel locations from vgrid to interpolate feature
|
39 |
+
def sample_plane_feature(self, query, plane_feature, plane):
|
40 |
+
# CYF note:
|
41 |
+
# for pretraining, query are uniformly sampled positions w.i. [-scale, scale]
|
42 |
+
# for training, query are essentially tetrahedra grid vertices, which are
|
43 |
+
# also within [-scale, scale] in the current version!
|
44 |
+
# xy range [-scale, scale]
|
45 |
+
if plane == 'xy':
|
46 |
+
xy = query[:, :, [0, 1]]
|
47 |
+
elif plane == 'yz':
|
48 |
+
xy = query[:, :, [1, 2]]
|
49 |
+
elif plane == 'zx':
|
50 |
+
xy = query[:, :, [2, 0]]
|
51 |
+
else:
|
52 |
+
raise ValueError("Error! Invalid plane type!")
|
53 |
+
|
54 |
+
xy = xy[:, :, None].float()
|
55 |
+
# not seem necessary to rescale the grid, because from
|
56 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html,
|
57 |
+
# it specifies sampling locations normalized by plane_feature's spatial dimension,
|
58 |
+
# which is within [-scale, scale] as specified by encoder's calling of coordinate2index()
|
59 |
+
vgrid = 1.0 * xy
|
60 |
+
sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
|
61 |
+
|
62 |
+
return sampled_feat
|
model/archs/mlp_head.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class SdfMlp(nn.Module):
|
6 |
+
def __init__(self, input_dim, hidden_dim=512, bias=True):
|
7 |
+
super().__init__()
|
8 |
+
self.input_dim = input_dim
|
9 |
+
self.hidden_dim = hidden_dim
|
10 |
+
|
11 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
|
12 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
13 |
+
self.fc3 = nn.Linear(hidden_dim, 4, bias=bias)
|
14 |
+
|
15 |
+
|
16 |
+
def forward(self, input):
|
17 |
+
x = F.relu(self.fc1(input))
|
18 |
+
x = F.relu(self.fc2(x))
|
19 |
+
out = self.fc3(x)
|
20 |
+
return out
|
21 |
+
|
22 |
+
|
23 |
+
class RgbMlp(nn.Module):
|
24 |
+
def __init__(self, input_dim, hidden_dim=512, bias=True):
|
25 |
+
super().__init__()
|
26 |
+
self.input_dim = input_dim
|
27 |
+
self.hidden_dim = hidden_dim
|
28 |
+
|
29 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
|
30 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
31 |
+
self.fc3 = nn.Linear(hidden_dim, 3, bias=bias)
|
32 |
+
|
33 |
+
def forward(self, input):
|
34 |
+
x = F.relu(self.fc1(input))
|
35 |
+
x = F.relu(self.fc2(x))
|
36 |
+
out = self.fc3(x)
|
37 |
+
|
38 |
+
return out
|
39 |
+
|
40 |
+
|
model/archs/unet.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Codes are from:
|
3 |
+
https://github.com/jaxony/unet-pytorch/blob/master/model.py
|
4 |
+
'''
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from diffusers import UNet2DModel
|
9 |
+
import einops
|
10 |
+
class UNetPP(nn.Module):
|
11 |
+
'''
|
12 |
+
Wrapper for UNet in diffusers
|
13 |
+
'''
|
14 |
+
def __init__(self, in_channels):
|
15 |
+
super(UNetPP, self).__init__()
|
16 |
+
self.in_channels = in_channels
|
17 |
+
self.unet = UNet2DModel(
|
18 |
+
sample_size=[256, 256*3],
|
19 |
+
in_channels=in_channels,
|
20 |
+
out_channels=32,
|
21 |
+
layers_per_block=2,
|
22 |
+
block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4),
|
23 |
+
down_block_types=(
|
24 |
+
"DownBlock2D",
|
25 |
+
"DownBlock2D",
|
26 |
+
"DownBlock2D",
|
27 |
+
"AttnDownBlock2D",
|
28 |
+
"AttnDownBlock2D",
|
29 |
+
"AttnDownBlock2D",
|
30 |
+
"DownBlock2D",
|
31 |
+
),
|
32 |
+
up_block_types=(
|
33 |
+
"UpBlock2D",
|
34 |
+
"AttnUpBlock2D",
|
35 |
+
"AttnUpBlock2D",
|
36 |
+
"AttnUpBlock2D",
|
37 |
+
"UpBlock2D",
|
38 |
+
"UpBlock2D",
|
39 |
+
"UpBlock2D",
|
40 |
+
),
|
41 |
+
)
|
42 |
+
|
43 |
+
self.unet.enable_xformers_memory_efficient_attention()
|
44 |
+
if in_channels > 12:
|
45 |
+
self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3]))
|
46 |
+
|
47 |
+
def forward(self, x, t=256):
|
48 |
+
learned_plane = self.learned_plane
|
49 |
+
if x.shape[1] < self.in_channels:
|
50 |
+
learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device)
|
51 |
+
x = torch.cat([x, learned_plane], dim = 1)
|
52 |
+
return self.unet(x, t).sample
|
53 |
+
|
model/crm/model.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
import cv2
|
10 |
+
import trimesh
|
11 |
+
import nvdiffrast.torch as dr
|
12 |
+
|
13 |
+
from model.archs.decoders.shape_texture_net import TetTexNet
|
14 |
+
from model.archs.unet import UNetPP
|
15 |
+
from util.renderer import Renderer
|
16 |
+
from model.archs.mlp_head import SdfMlp, RgbMlp
|
17 |
+
import xatlas
|
18 |
+
|
19 |
+
|
20 |
+
class Dummy:
|
21 |
+
pass
|
22 |
+
|
23 |
+
class CRM(nn.Module):
|
24 |
+
def __init__(self, specs):
|
25 |
+
super(CRM, self).__init__()
|
26 |
+
|
27 |
+
self.specs = specs
|
28 |
+
# configs
|
29 |
+
input_specs = specs["Input"]
|
30 |
+
self.input = Dummy()
|
31 |
+
self.input.scale = input_specs['scale']
|
32 |
+
self.input.resolution = input_specs['resolution']
|
33 |
+
self.tet_grid_size = input_specs['tet_grid_size']
|
34 |
+
self.camera_angle_num = input_specs['camera_angle_num']
|
35 |
+
|
36 |
+
self.arch = Dummy()
|
37 |
+
self.arch.fea_concat = specs["ArchSpecs"]["fea_concat"]
|
38 |
+
self.arch.mlp_bias = specs["ArchSpecs"]["mlp_bias"]
|
39 |
+
|
40 |
+
self.dec = Dummy()
|
41 |
+
self.dec.c_dim = specs["DecoderSpecs"]["c_dim"]
|
42 |
+
self.dec.plane_resolution = specs["DecoderSpecs"]["plane_resolution"]
|
43 |
+
|
44 |
+
self.geo_type = specs["Train"].get("geo_type", "flex") # "dmtet" or "flex"
|
45 |
+
|
46 |
+
self.unet2 = UNetPP(in_channels=self.dec.c_dim)
|
47 |
+
|
48 |
+
mlp_chnl_s = 3 if self.arch.fea_concat else 1 # 3 for queried triplane feature concatenation
|
49 |
+
self.decoder = TetTexNet(plane_reso=self.dec.plane_resolution, fea_concat=self.arch.fea_concat)
|
50 |
+
|
51 |
+
if self.geo_type == "flex":
|
52 |
+
self.weightMlp = nn.Sequential(
|
53 |
+
nn.Linear(mlp_chnl_s * 32 * 8, 512),
|
54 |
+
nn.SiLU(),
|
55 |
+
nn.Linear(512, 21))
|
56 |
+
|
57 |
+
self.sdfMlp = SdfMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias)
|
58 |
+
self.rgbMlp = RgbMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias)
|
59 |
+
self.renderer = Renderer(tet_grid_size=self.tet_grid_size, camera_angle_num=self.camera_angle_num,
|
60 |
+
scale=self.input.scale, geo_type = self.geo_type)
|
61 |
+
|
62 |
+
|
63 |
+
self.spob = True if specs['Pretrain']['mode'] is None else False # whether to add sphere
|
64 |
+
self.radius = specs['Pretrain']['radius'] # used when spob
|
65 |
+
|
66 |
+
self.denoising = True
|
67 |
+
from diffusers import DDIMScheduler
|
68 |
+
self.scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
|
69 |
+
|
70 |
+
def decode(self, data, triplane_feature2):
|
71 |
+
if self.geo_type == "flex":
|
72 |
+
tet_verts = self.renderer.flexicubes.verts.unsqueeze(0)
|
73 |
+
tet_indices = self.renderer.flexicubes.indices
|
74 |
+
|
75 |
+
dec_verts = self.decoder(triplane_feature2, tet_verts)
|
76 |
+
out = self.sdfMlp(dec_verts)
|
77 |
+
|
78 |
+
weight = None
|
79 |
+
if self.geo_type == "flex":
|
80 |
+
grid_feat = torch.index_select(input=dec_verts, index=self.renderer.flexicubes.indices.reshape(-1),dim=1)
|
81 |
+
grid_feat = grid_feat.reshape(dec_verts.shape[0], self.renderer.flexicubes.indices.shape[0], self.renderer.flexicubes.indices.shape[1] * dec_verts.shape[-1])
|
82 |
+
weight = self.weightMlp(grid_feat)
|
83 |
+
weight = weight * 0.1
|
84 |
+
|
85 |
+
pred_sdf, deformation = out[..., 0], out[..., 1:]
|
86 |
+
if self.spob:
|
87 |
+
pred_sdf = pred_sdf + self.radius - torch.sqrt((tet_verts**2).sum(-1))
|
88 |
+
|
89 |
+
_, verts, faces = self.renderer(data, pred_sdf, deformation, tet_verts, tet_indices, weight= weight)
|
90 |
+
return verts[0].unsqueeze(0), faces[0].int()
|
91 |
+
|
92 |
+
def export_mesh(self, data, out_dir, ind, device=None, tri_fea_2 = None):
|
93 |
+
verts = data['verts']
|
94 |
+
faces = data['faces']
|
95 |
+
|
96 |
+
dec_verts = self.decoder(tri_fea_2, verts.unsqueeze(0))
|
97 |
+
colors = self.rgbMlp(dec_verts).squeeze().detach().cpu().numpy()
|
98 |
+
# Expect predicted colors value range from [-1, 1]
|
99 |
+
colors = (colors * 0.5 + 0.5).clip(0, 1)
|
100 |
+
|
101 |
+
verts = verts.squeeze().cpu().numpy()
|
102 |
+
faces = faces[..., [2, 1, 0]].squeeze().cpu().numpy()
|
103 |
+
|
104 |
+
# export the final mesh
|
105 |
+
with torch.no_grad():
|
106 |
+
mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault...
|
107 |
+
mesh.export(out_dir / f'{ind}.obj')
|
108 |
+
|
109 |
+
def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
|
110 |
+
|
111 |
+
mesh_v = data['verts'].squeeze().cpu().numpy()
|
112 |
+
mesh_pos_idx = data['faces'].squeeze().cpu().numpy()
|
113 |
+
|
114 |
+
def interpolate(attr, rast, attr_idx, rast_db=None):
|
115 |
+
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db,
|
116 |
+
diff_attrs=None if rast_db is None else 'all')
|
117 |
+
|
118 |
+
vmapping, indices, uvs = xatlas.parametrize(mesh_v, mesh_pos_idx)
|
119 |
+
|
120 |
+
mesh_v = torch.tensor(mesh_v, dtype=torch.float32, device=device)
|
121 |
+
mesh_pos_idx = torch.tensor(mesh_pos_idx, dtype=torch.int64, device=device)
|
122 |
+
|
123 |
+
# Convert to tensors
|
124 |
+
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
125 |
+
|
126 |
+
uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
|
127 |
+
mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
|
128 |
+
# mesh_v_tex. ture
|
129 |
+
uv_clip = uvs[None, ...] * 2.0 - 1.0
|
130 |
+
|
131 |
+
# pad to four component coordinate
|
132 |
+
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
|
133 |
+
|
134 |
+
# rasterize
|
135 |
+
rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), res)
|
136 |
+
|
137 |
+
# Interpolate world space position
|
138 |
+
gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
|
139 |
+
mask = rast[..., 3:4] > 0
|
140 |
+
|
141 |
+
# return uvs, mesh_tex_idx, gb_pos, mask
|
142 |
+
gb_pos_unsqz = gb_pos.view(-1, 3)
|
143 |
+
mask_unsqz = mask.view(-1)
|
144 |
+
tex_unsqz = torch.zeros_like(gb_pos_unsqz) + 1
|
145 |
+
|
146 |
+
gb_mask_pos = gb_pos_unsqz[mask_unsqz]
|
147 |
+
|
148 |
+
gb_mask_pos = gb_mask_pos[None, ]
|
149 |
+
|
150 |
+
with torch.no_grad():
|
151 |
+
|
152 |
+
dec_verts = self.decoder(tri_fea_2, gb_mask_pos)
|
153 |
+
colors = self.rgbMlp(dec_verts).squeeze()
|
154 |
+
|
155 |
+
# Expect predicted colors value range from [-1, 1]
|
156 |
+
lo, hi = (-1, 1)
|
157 |
+
colors = (colors - lo) * (255 / (hi - lo))
|
158 |
+
colors = colors.clip(0, 255)
|
159 |
+
|
160 |
+
tex_unsqz[mask_unsqz] = colors
|
161 |
+
|
162 |
+
tex = tex_unsqz.view(res + (3,))
|
163 |
+
|
164 |
+
verts = mesh_v.squeeze().cpu().numpy()
|
165 |
+
faces = mesh_pos_idx[..., [2, 1, 0]].squeeze().cpu().numpy()
|
166 |
+
# faces = mesh_pos_idx
|
167 |
+
# faces = faces.detach().cpu().numpy()
|
168 |
+
# faces = faces[..., [2, 1, 0]]
|
169 |
+
indices = indices[..., [2, 1, 0]]
|
170 |
+
|
171 |
+
# xatlas.export(f"{out_dir}/{ind}.obj", verts[vmapping], indices, uvs)
|
172 |
+
matname = f'{out_dir}.mtl'
|
173 |
+
# matname = f'{out_dir}/{ind}.mtl'
|
174 |
+
fid = open(matname, 'w')
|
175 |
+
fid.write('newmtl material_0\n')
|
176 |
+
fid.write('Kd 1 1 1\n')
|
177 |
+
fid.write('Ka 1 1 1\n')
|
178 |
+
# fid.write('Ks 0 0 0\n')
|
179 |
+
fid.write('Ks 0.4 0.4 0.4\n')
|
180 |
+
fid.write('Ns 10\n')
|
181 |
+
fid.write('illum 2\n')
|
182 |
+
fid.write(f'map_Kd {out_dir.split("/")[-1]}.png\n')
|
183 |
+
fid.close()
|
184 |
+
|
185 |
+
fid = open(f'{out_dir}.obj', 'w')
|
186 |
+
# fid = open(f'{out_dir}/{ind}.obj', 'w')
|
187 |
+
fid.write('mtllib %s.mtl\n' % out_dir.split("/")[-1])
|
188 |
+
|
189 |
+
for pidx, p in enumerate(verts):
|
190 |
+
pp = p
|
191 |
+
fid.write('v %f %f %f\n' % (pp[0], pp[2], - pp[1]))
|
192 |
+
|
193 |
+
for pidx, p in enumerate(uvs):
|
194 |
+
pp = p
|
195 |
+
fid.write('vt %f %f\n' % (pp[0], 1 - pp[1]))
|
196 |
+
|
197 |
+
fid.write('usemtl material_0\n')
|
198 |
+
for i, f in enumerate(faces):
|
199 |
+
f1 = f + 1
|
200 |
+
f2 = indices[i] + 1
|
201 |
+
fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
|
202 |
+
fid.close()
|
203 |
+
|
204 |
+
img = np.asarray(tex.data.cpu().numpy(), dtype=np.float32)
|
205 |
+
mask = np.sum(img.astype(float), axis=-1, keepdims=True)
|
206 |
+
mask = (mask <= 3.0).astype(float)
|
207 |
+
kernel = np.ones((3, 3), 'uint8')
|
208 |
+
dilate_img = cv2.dilate(img, kernel, iterations=1)
|
209 |
+
img = img * (1 - mask) + dilate_img * mask
|
210 |
+
img = img.clip(0, 255).astype(np.uint8)
|
211 |
+
|
212 |
+
cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]])
|
213 |
+
# cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]])
|
pipelines.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from libs.base_utils import do_resize_content
|
3 |
+
from imagedream.ldm.util import (
|
4 |
+
instantiate_from_config,
|
5 |
+
get_obj_from_str,
|
6 |
+
)
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
class TwoStagePipeline(object):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
stage1_model_config,
|
16 |
+
stage2_model_config,
|
17 |
+
stage1_sampler_config,
|
18 |
+
stage2_sampler_config,
|
19 |
+
device="cuda",
|
20 |
+
dtype=torch.float16,
|
21 |
+
resize_rate=1,
|
22 |
+
) -> None:
|
23 |
+
"""
|
24 |
+
only for two stage generate process.
|
25 |
+
- the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
|
26 |
+
- the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
|
27 |
+
"""
|
28 |
+
self.resize_rate = resize_rate
|
29 |
+
|
30 |
+
self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
|
31 |
+
self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False)
|
32 |
+
self.stage1_model = self.stage1_model.to(device).to(dtype)
|
33 |
+
|
34 |
+
self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model)
|
35 |
+
sd = torch.load(stage2_model_config.resume, map_location="cpu")
|
36 |
+
self.stage2_model.load_state_dict(sd, strict=False)
|
37 |
+
self.stage2_model = self.stage2_model.to(device).to(dtype)
|
38 |
+
|
39 |
+
self.stage1_model.device = device
|
40 |
+
self.stage2_model.device = device
|
41 |
+
self.device = device
|
42 |
+
self.dtype = dtype
|
43 |
+
self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)(
|
44 |
+
self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params
|
45 |
+
)
|
46 |
+
self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)(
|
47 |
+
self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params
|
48 |
+
)
|
49 |
+
|
50 |
+
def stage1_sample(
|
51 |
+
self,
|
52 |
+
pixel_img,
|
53 |
+
prompt="3D assets",
|
54 |
+
neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.",
|
55 |
+
step=50,
|
56 |
+
scale=5,
|
57 |
+
ddim_eta=0.0,
|
58 |
+
):
|
59 |
+
if type(pixel_img) == str:
|
60 |
+
pixel_img = Image.open(pixel_img)
|
61 |
+
|
62 |
+
if isinstance(pixel_img, Image.Image):
|
63 |
+
if pixel_img.mode == "RGBA":
|
64 |
+
background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
|
65 |
+
pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
|
66 |
+
else:
|
67 |
+
pixel_img = pixel_img.convert("RGB")
|
68 |
+
else:
|
69 |
+
raise
|
70 |
+
uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device)
|
71 |
+
stage1_images = self.stage1_sampler.i2i(
|
72 |
+
self.stage1_sampler.model,
|
73 |
+
self.stage1_sampler.size,
|
74 |
+
prompt,
|
75 |
+
uc=uc,
|
76 |
+
sampler=self.stage1_sampler.sampler,
|
77 |
+
ip=pixel_img,
|
78 |
+
step=step,
|
79 |
+
scale=scale,
|
80 |
+
batch_size=self.stage1_sampler.batch_size,
|
81 |
+
ddim_eta=ddim_eta,
|
82 |
+
dtype=self.stage1_sampler.dtype,
|
83 |
+
device=self.stage1_sampler.device,
|
84 |
+
camera=self.stage1_sampler.camera,
|
85 |
+
num_frames=self.stage1_sampler.num_frames,
|
86 |
+
pixel_control=(self.stage1_sampler.mode == "pixel"),
|
87 |
+
transform=self.stage1_sampler.image_transform,
|
88 |
+
offset_noise=self.stage1_sampler.offset_noise,
|
89 |
+
)
|
90 |
+
|
91 |
+
stage1_images = [Image.fromarray(img) for img in stage1_images]
|
92 |
+
stage1_images.pop(self.stage1_sampler.ref_position)
|
93 |
+
return stage1_images
|
94 |
+
|
95 |
+
def stage2_sample(self, pixel_img, stage1_images):
|
96 |
+
if type(pixel_img) == str:
|
97 |
+
pixel_img = Image.open(pixel_img)
|
98 |
+
|
99 |
+
if isinstance(pixel_img, Image.Image):
|
100 |
+
if pixel_img.mode == "RGBA":
|
101 |
+
background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
|
102 |
+
pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
|
103 |
+
else:
|
104 |
+
pixel_img = pixel_img.convert("RGB")
|
105 |
+
else:
|
106 |
+
raise
|
107 |
+
stage2_images = self.stage2_sampler.i2iStage2(
|
108 |
+
self.stage2_sampler.model,
|
109 |
+
self.stage2_sampler.size,
|
110 |
+
"3D assets",
|
111 |
+
self.stage2_sampler.uc,
|
112 |
+
self.stage2_sampler.sampler,
|
113 |
+
pixel_images=stage1_images,
|
114 |
+
ip=pixel_img,
|
115 |
+
step=50,
|
116 |
+
scale=5,
|
117 |
+
batch_size=self.stage2_sampler.batch_size,
|
118 |
+
ddim_eta=0.0,
|
119 |
+
dtype=self.stage2_sampler.dtype,
|
120 |
+
device=self.stage2_sampler.device,
|
121 |
+
camera=self.stage2_sampler.camera,
|
122 |
+
num_frames=self.stage2_sampler.num_frames,
|
123 |
+
pixel_control=(self.stage2_sampler.mode == "pixel"),
|
124 |
+
transform=self.stage2_sampler.image_transform,
|
125 |
+
offset_noise=self.stage2_sampler.offset_noise,
|
126 |
+
)
|
127 |
+
stage2_images = [Image.fromarray(img) for img in stage2_images]
|
128 |
+
return stage2_images
|
129 |
+
|
130 |
+
def set_seed(self, seed):
|
131 |
+
self.stage1_sampler.seed = seed
|
132 |
+
self.stage2_sampler.seed = seed
|
133 |
+
|
134 |
+
def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
|
135 |
+
pixel_img = do_resize_content(pixel_img, self.resize_rate)
|
136 |
+
stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
|
137 |
+
stage2_images = self.stage2_sample(pixel_img, stage1_images)
|
138 |
+
|
139 |
+
return {
|
140 |
+
"ref_img": pixel_img,
|
141 |
+
"stage1_images": stage1_images,
|
142 |
+
"stage2_images": stage2_images,
|
143 |
+
}
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
|
148 |
+
stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config
|
149 |
+
stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config
|
150 |
+
stage2_sampler_config = stage2_config.sampler
|
151 |
+
stage1_sampler_config = stage1_config.sampler
|
152 |
+
|
153 |
+
stage1_model_config = stage1_config.models
|
154 |
+
stage2_model_config = stage2_config.models
|
155 |
+
|
156 |
+
pipeline = TwoStagePipeline(
|
157 |
+
stage1_model_config,
|
158 |
+
stage2_model_config,
|
159 |
+
stage1_sampler_config,
|
160 |
+
stage2_sampler_config,
|
161 |
+
)
|
162 |
+
|
163 |
+
img = Image.open("assets/astronaut.png")
|
164 |
+
rt_dict = pipeline(img)
|
165 |
+
stage1_images = rt_dict["stage1_images"]
|
166 |
+
stage2_images = rt_dict["stage2_images"]
|
167 |
+
np_imgs = np.concatenate(stage1_images, 1)
|
168 |
+
np_xyzs = np.concatenate(stage2_images, 1)
|
169 |
+
Image.fromarray(np_imgs).save("pixel_images.png")
|
170 |
+
Image.fromarray(np_xyzs).save("xyz_images.png")
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
huggingface-hub
|
3 |
+
einops==0.7.0
|
4 |
+
Pillow==10.1.0
|
5 |
+
transformers==4.27.1
|
6 |
+
open-clip-torch==2.7.0
|
7 |
+
opencv-contrib-python-headless==4.9.0.80
|
8 |
+
opencv-python-headless==4.9.0.80
|
9 |
+
xformers
|
10 |
+
omegaconf
|
11 |
+
rembg
|
12 |
+
nvdiffrast
|
13 |
+
pygltflib
|
14 |
+
kiui
|
util/__init__.py
ADDED
File without changes
|
util/flexicubes.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
import torch
|
9 |
+
from util.tables import *
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
'FlexiCubes'
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
class FlexiCubes:
|
17 |
+
"""
|
18 |
+
This class implements the FlexiCubes method for extracting meshes from scalar fields.
|
19 |
+
It maintains a series of lookup tables and indices to support the mesh extraction process.
|
20 |
+
FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
|
21 |
+
the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
|
22 |
+
the surface representation through gradient-based optimization.
|
23 |
+
|
24 |
+
During instantiation, the class loads DMC tables from a file and transforms them into
|
25 |
+
PyTorch tensors on the specified device.
|
26 |
+
|
27 |
+
Attributes:
|
28 |
+
device (str): Specifies the computational device (default is "cuda").
|
29 |
+
dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
|
30 |
+
associated with each dual vertex in 256 Marching Cubes (MC) configurations.
|
31 |
+
num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
|
32 |
+
the 256 MC configurations.
|
33 |
+
check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
|
34 |
+
of the DMC configurations.
|
35 |
+
tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
|
36 |
+
quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
|
37 |
+
along one diagonal.
|
38 |
+
quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
|
39 |
+
two triangles along the other diagonal.
|
40 |
+
quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
|
41 |
+
during training by connecting all edges to their midpoints.
|
42 |
+
cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
|
43 |
+
eight corners in 3D space, ordered starting from the origin (0,0,0),
|
44 |
+
moving along the x-axis, then y-axis, and finally z-axis.
|
45 |
+
Used as a blueprint for generating a voxel grid.
|
46 |
+
cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
|
47 |
+
to retrieve the case id.
|
48 |
+
cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
|
49 |
+
Used to retrieve edge vertices in DMC.
|
50 |
+
edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
|
51 |
+
their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
|
52 |
+
first edge is oriented along the x-axis.
|
53 |
+
dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
|
54 |
+
across four adjacent cubes to the shared faces of these cubes. For instance,
|
55 |
+
dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
|
56 |
+
the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
|
57 |
+
This tensor is only utilized during isosurface tetrahedralization.
|
58 |
+
adj_pairs (torch.Tensor):
|
59 |
+
A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
|
60 |
+
qef_reg_scale (float):
|
61 |
+
The scaling factor applied to the regularization loss to prevent issues with singularity
|
62 |
+
when solving the QEF. This parameter is only used when a 'grad_func' is specified.
|
63 |
+
weight_scale (float):
|
64 |
+
The scale of weights in FlexiCubes. Should be between 0 and 1.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
|
68 |
+
|
69 |
+
self.device = device
|
70 |
+
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
|
71 |
+
self.num_vd_table = torch.tensor(num_vd_table,
|
72 |
+
dtype=torch.long, device=device, requires_grad=False)
|
73 |
+
self.check_table = torch.tensor(
|
74 |
+
check_table,
|
75 |
+
dtype=torch.long, device=device, requires_grad=False)
|
76 |
+
|
77 |
+
self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
|
78 |
+
self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
79 |
+
self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
80 |
+
self.quad_split_train = torch.tensor(
|
81 |
+
[0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
|
82 |
+
|
83 |
+
self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
|
84 |
+
1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
|
85 |
+
self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
|
86 |
+
self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
|
87 |
+
2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
|
88 |
+
|
89 |
+
self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
|
90 |
+
dtype=torch.long, device=device)
|
91 |
+
self.dir_faces_table = torch.tensor([
|
92 |
+
[[5, 4], [3, 2], [4, 5], [2, 3]],
|
93 |
+
[[5, 4], [1, 0], [4, 5], [0, 1]],
|
94 |
+
[[3, 2], [1, 0], [2, 3], [0, 1]]
|
95 |
+
], dtype=torch.long, device=device)
|
96 |
+
self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
|
97 |
+
self.qef_reg_scale = qef_reg_scale
|
98 |
+
self.weight_scale = weight_scale
|
99 |
+
|
100 |
+
def construct_voxel_grid(self, res):
|
101 |
+
"""
|
102 |
+
Generates a voxel grid based on the specified resolution.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
res (int or list[int]): The resolution of the voxel grid. If an integer
|
106 |
+
is provided, it is used for all three dimensions. If a list or tuple
|
107 |
+
of 3 integers is provided, they define the resolution for the x,
|
108 |
+
y, and z dimensions respectively.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
(torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
|
112 |
+
cube corners (index into vertices) of the constructed voxel grid.
|
113 |
+
The vertices are centered at the origin, with the length of each
|
114 |
+
dimension in the grid being one.
|
115 |
+
"""
|
116 |
+
base_cube_f = torch.arange(8).to(self.device)
|
117 |
+
if isinstance(res, int):
|
118 |
+
res = (res, res, res)
|
119 |
+
voxel_grid_template = torch.ones(res, device=self.device)
|
120 |
+
|
121 |
+
res = torch.tensor([res], dtype=torch.float, device=self.device)
|
122 |
+
coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
|
123 |
+
verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
|
124 |
+
cubes = (base_cube_f.unsqueeze(0) +
|
125 |
+
torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
|
126 |
+
|
127 |
+
verts_rounded = torch.round(verts * 10**5) / (10**5)
|
128 |
+
verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
|
129 |
+
cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
|
130 |
+
|
131 |
+
return verts_unique - 0.5, cubes
|
132 |
+
|
133 |
+
def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
|
134 |
+
gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
|
135 |
+
r"""
|
136 |
+
Main function for mesh extraction from scalar field using FlexiCubes. This function converts
|
137 |
+
discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
|
138 |
+
to triangle or tetrahedral meshes using a differentiable operation as described in
|
139 |
+
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
|
140 |
+
mesh quality and geometric fidelity by adjusting the surface representation based on gradient
|
141 |
+
optimization. The output surface is differentiable with respect to the input vertex positions,
|
142 |
+
scalar field values, and weight parameters.
|
143 |
+
|
144 |
+
If you intend to extract a surface mesh from a fixed Signed Distance Field without the
|
145 |
+
optimization of parameters, it is suggested to provide the "grad_func" which should
|
146 |
+
return the surface gradient at any given 3D position. When grad_func is provided, the process
|
147 |
+
to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
|
148 |
+
described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
|
149 |
+
Please note, this approach is non-differentiable.
|
150 |
+
|
151 |
+
For more details and example usage in optimization, refer to the
|
152 |
+
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
|
156 |
+
s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
|
157 |
+
denote that the corresponding vertex resides inside the isosurface. This affects
|
158 |
+
the directions of the extracted triangle faces and volume to be tetrahedralized.
|
159 |
+
cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
|
160 |
+
res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
|
161 |
+
is used for all three dimensions. If a list or tuple of 3 integers is provided, they
|
162 |
+
specify the resolution for the x, y, and z dimensions respectively.
|
163 |
+
beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
|
164 |
+
vertices positioning. Defaults to uniform value for all edges.
|
165 |
+
alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
|
166 |
+
vertices positioning. Defaults to uniform value for all vertices.
|
167 |
+
gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
|
168 |
+
quadrilaterals into triangles. Defaults to uniform value for all cubes.
|
169 |
+
training (bool, optional): If set to True, applies differentiable quad splitting for
|
170 |
+
training. Defaults to False.
|
171 |
+
output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
|
172 |
+
outputs a triangular mesh. Defaults to False.
|
173 |
+
grad_func (callable, optional): A function to compute the surface gradient at specified
|
174 |
+
3D positions (input: Nx3 positions). The function should return gradients as an Nx3
|
175 |
+
tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
(torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
|
179 |
+
- Vertices for the extracted triangular/tetrahedral mesh.
|
180 |
+
- Faces for the extracted triangular/tetrahedral mesh.
|
181 |
+
- Regularizer L_dev, computed per dual vertex.
|
182 |
+
|
183 |
+
.. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
|
184 |
+
https://research.nvidia.com/labs/toronto-ai/flexicubes/
|
185 |
+
.. _Manifold Dual Contouring:
|
186 |
+
https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
|
187 |
+
"""
|
188 |
+
|
189 |
+
surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
|
190 |
+
if surf_cubes.sum() == 0:
|
191 |
+
return torch.zeros(
|
192 |
+
(0, 3),
|
193 |
+
device=self.device), torch.zeros(
|
194 |
+
(0, 4),
|
195 |
+
dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
|
196 |
+
(0, 3),
|
197 |
+
dtype=torch.long, device=self.device), torch.zeros(
|
198 |
+
(0),
|
199 |
+
device=self.device)
|
200 |
+
beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
|
201 |
+
|
202 |
+
case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
|
203 |
+
|
204 |
+
surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
|
205 |
+
|
206 |
+
vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
|
207 |
+
x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
|
208 |
+
vertices, faces, s_edges, edge_indices = self._triangulate(
|
209 |
+
s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
|
210 |
+
if not output_tetmesh:
|
211 |
+
return vertices, faces, L_dev
|
212 |
+
else:
|
213 |
+
vertices, tets = self._tetrahedralize(
|
214 |
+
x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
|
215 |
+
surf_cubes, training)
|
216 |
+
return vertices, tets, L_dev
|
217 |
+
|
218 |
+
def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
|
219 |
+
"""
|
220 |
+
Regularizer L_dev as in Equation 8
|
221 |
+
"""
|
222 |
+
dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
|
223 |
+
mean_l2 = torch.zeros_like(vd[:, 0])
|
224 |
+
mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
|
225 |
+
mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
|
226 |
+
return mad
|
227 |
+
|
228 |
+
def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
|
229 |
+
"""
|
230 |
+
Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
|
231 |
+
"""
|
232 |
+
n_cubes = surf_cubes.shape[0]
|
233 |
+
|
234 |
+
if beta_fx12 is not None:
|
235 |
+
beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
|
236 |
+
else:
|
237 |
+
beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
|
238 |
+
|
239 |
+
if alpha_fx8 is not None:
|
240 |
+
alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
|
241 |
+
else:
|
242 |
+
alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
|
243 |
+
|
244 |
+
if gamma_f is not None:
|
245 |
+
gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
|
246 |
+
else:
|
247 |
+
gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
|
248 |
+
|
249 |
+
return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
|
250 |
+
|
251 |
+
@torch.no_grad()
|
252 |
+
def _get_case_id(self, occ_fx8, surf_cubes, res):
|
253 |
+
"""
|
254 |
+
Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
|
255 |
+
ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
|
256 |
+
supplementary material. It should be noted that this function assumes a regular grid.
|
257 |
+
"""
|
258 |
+
case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
|
259 |
+
|
260 |
+
problem_config = self.check_table.to(self.device)[case_ids]
|
261 |
+
to_check = problem_config[..., 0] == 1
|
262 |
+
problem_config = problem_config[to_check]
|
263 |
+
if not isinstance(res, (list, tuple)):
|
264 |
+
res = [res, res, res]
|
265 |
+
|
266 |
+
# The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
|
267 |
+
# 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
|
268 |
+
# This allows efficient checking on adjacent cubes.
|
269 |
+
problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
|
270 |
+
vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
|
271 |
+
vol_idx_problem = vol_idx[surf_cubes][to_check]
|
272 |
+
problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
|
273 |
+
vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
|
274 |
+
|
275 |
+
within_range = (
|
276 |
+
vol_idx_problem_adj[..., 0] >= 0) & (
|
277 |
+
vol_idx_problem_adj[..., 0] < res[0]) & (
|
278 |
+
vol_idx_problem_adj[..., 1] >= 0) & (
|
279 |
+
vol_idx_problem_adj[..., 1] < res[1]) & (
|
280 |
+
vol_idx_problem_adj[..., 2] >= 0) & (
|
281 |
+
vol_idx_problem_adj[..., 2] < res[2])
|
282 |
+
|
283 |
+
vol_idx_problem = vol_idx_problem[within_range]
|
284 |
+
vol_idx_problem_adj = vol_idx_problem_adj[within_range]
|
285 |
+
problem_config = problem_config[within_range]
|
286 |
+
problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
|
287 |
+
vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
|
288 |
+
# If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
|
289 |
+
to_invert = (problem_config_adj[..., 0] == 1)
|
290 |
+
idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
|
291 |
+
case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
|
292 |
+
return case_ids
|
293 |
+
|
294 |
+
@torch.no_grad()
|
295 |
+
def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
|
296 |
+
"""
|
297 |
+
Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
|
298 |
+
can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
|
299 |
+
and marks the cube edges with this index.
|
300 |
+
"""
|
301 |
+
occ_n = s_n < 0
|
302 |
+
all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
|
303 |
+
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
304 |
+
|
305 |
+
unique_edges = unique_edges.long()
|
306 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
307 |
+
|
308 |
+
surf_edges_mask = mask_edges[_idx_map]
|
309 |
+
counts = counts[_idx_map]
|
310 |
+
|
311 |
+
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
|
312 |
+
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
|
313 |
+
# Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
|
314 |
+
# for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
|
315 |
+
idx_map = mapping[_idx_map]
|
316 |
+
surf_edges = unique_edges[mask_edges]
|
317 |
+
return surf_edges, idx_map, counts, surf_edges_mask
|
318 |
+
|
319 |
+
@torch.no_grad()
|
320 |
+
def _identify_surf_cubes(self, s_n, cube_fx8):
|
321 |
+
"""
|
322 |
+
Identifies grid cubes that intersect with the underlying surface by checking if the signs at
|
323 |
+
all corners are not identical.
|
324 |
+
"""
|
325 |
+
occ_n = s_n < 0
|
326 |
+
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
|
327 |
+
_occ_sum = torch.sum(occ_fx8, -1)
|
328 |
+
surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
|
329 |
+
return surf_cubes, occ_fx8
|
330 |
+
|
331 |
+
def _linear_interp(self, edges_weight, edges_x):
|
332 |
+
"""
|
333 |
+
Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
|
334 |
+
"""
|
335 |
+
edge_dim = edges_weight.dim() - 2
|
336 |
+
assert edges_weight.shape[edge_dim] == 2
|
337 |
+
edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
|
338 |
+
torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
|
339 |
+
denominator = edges_weight.sum(edge_dim)
|
340 |
+
ue = (edges_x * edges_weight).sum(edge_dim) / denominator
|
341 |
+
return ue
|
342 |
+
|
343 |
+
def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
|
344 |
+
p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
|
345 |
+
norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
|
346 |
+
c_bx3 = c_bx3.reshape(-1, 3)
|
347 |
+
A = norm_bxnx3
|
348 |
+
B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
|
349 |
+
|
350 |
+
A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
|
351 |
+
B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
|
352 |
+
A = torch.cat([A, A_reg], 1)
|
353 |
+
B = torch.cat([B, B_reg], 1)
|
354 |
+
dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
|
355 |
+
return dual_verts
|
356 |
+
|
357 |
+
def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
|
358 |
+
"""
|
359 |
+
Computes the location of dual vertices as described in Section 4.2
|
360 |
+
"""
|
361 |
+
alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
|
362 |
+
surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
|
363 |
+
surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
|
364 |
+
zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
|
365 |
+
|
366 |
+
idx_map = idx_map.reshape(-1, 12)
|
367 |
+
num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
|
368 |
+
edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
|
369 |
+
|
370 |
+
total_num_vd = 0
|
371 |
+
vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
|
372 |
+
if grad_func is not None:
|
373 |
+
normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
|
374 |
+
vd = []
|
375 |
+
for num in torch.unique(num_vd):
|
376 |
+
cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
|
377 |
+
curr_num_vd = cur_cubes.sum() * num
|
378 |
+
curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
|
379 |
+
curr_edge_group_to_vd = torch.arange(
|
380 |
+
curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
|
381 |
+
total_num_vd += curr_num_vd
|
382 |
+
curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
|
383 |
+
cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
|
384 |
+
|
385 |
+
curr_mask = (curr_edge_group != -1)
|
386 |
+
edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
|
387 |
+
edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
|
388 |
+
edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
|
389 |
+
vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
|
390 |
+
vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
|
391 |
+
|
392 |
+
if grad_func is not None:
|
393 |
+
with torch.no_grad():
|
394 |
+
cube_e_verts_idx = idx_map[cur_cubes]
|
395 |
+
curr_edge_group[~curr_mask] = 0
|
396 |
+
|
397 |
+
verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
|
398 |
+
verts_group_idx[verts_group_idx == -1] = 0
|
399 |
+
verts_group_pos = torch.index_select(
|
400 |
+
input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
|
401 |
+
v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
|
402 |
+
curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
|
403 |
+
verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
|
404 |
+
|
405 |
+
normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
|
406 |
+
-1, num.item(), 7,
|
407 |
+
3)
|
408 |
+
curr_mask = curr_mask.squeeze(2)
|
409 |
+
vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
|
410 |
+
verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
|
411 |
+
edge_group = torch.cat(edge_group)
|
412 |
+
edge_group_to_vd = torch.cat(edge_group_to_vd)
|
413 |
+
edge_group_to_cube = torch.cat(edge_group_to_cube)
|
414 |
+
vd_num_edges = torch.cat(vd_num_edges)
|
415 |
+
vd_gamma = torch.cat(vd_gamma)
|
416 |
+
|
417 |
+
if grad_func is not None:
|
418 |
+
vd = torch.cat(vd)
|
419 |
+
L_dev = torch.zeros([1], device=self.device)
|
420 |
+
else:
|
421 |
+
vd = torch.zeros((total_num_vd, 3), device=self.device)
|
422 |
+
beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
|
423 |
+
|
424 |
+
idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
|
425 |
+
|
426 |
+
x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
|
427 |
+
s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
|
428 |
+
|
429 |
+
zero_crossing_group = torch.index_select(
|
430 |
+
input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
|
431 |
+
|
432 |
+
alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
|
433 |
+
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
|
434 |
+
ue_group = self._linear_interp(s_group * alpha_group, x_group)
|
435 |
+
|
436 |
+
beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
|
437 |
+
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
|
438 |
+
beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
|
439 |
+
vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
|
440 |
+
L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
|
441 |
+
|
442 |
+
v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
|
443 |
+
|
444 |
+
vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
|
445 |
+
12 + edge_group, src=v_idx[edge_group_to_vd])
|
446 |
+
|
447 |
+
return vd, L_dev, vd_gamma, vd_idx_map
|
448 |
+
|
449 |
+
def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
|
450 |
+
"""
|
451 |
+
Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
|
452 |
+
triangles based on the gamma parameter, as described in Section 4.3.
|
453 |
+
"""
|
454 |
+
with torch.no_grad():
|
455 |
+
group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
|
456 |
+
group = idx_map.reshape(-1)[group_mask]
|
457 |
+
vd_idx = vd_idx_map[group_mask]
|
458 |
+
edge_indices, indices = torch.sort(group, stable=True)
|
459 |
+
quad_vd_idx = vd_idx[indices].reshape(-1, 4)
|
460 |
+
|
461 |
+
# Ensure all face directions point towards the positive SDF to maintain consistent winding.
|
462 |
+
s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
|
463 |
+
flip_mask = s_edges[:, 0] > 0
|
464 |
+
quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
|
465 |
+
quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
|
466 |
+
if grad_func is not None:
|
467 |
+
# when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
|
468 |
+
with torch.no_grad():
|
469 |
+
vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
|
470 |
+
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
471 |
+
gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
|
472 |
+
gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
|
473 |
+
else:
|
474 |
+
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
|
475 |
+
gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
|
476 |
+
0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
|
477 |
+
gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
|
478 |
+
1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
|
479 |
+
if not training:
|
480 |
+
mask = (gamma_02 > gamma_13).squeeze(1)
|
481 |
+
faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
|
482 |
+
faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
|
483 |
+
faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
|
484 |
+
faces = faces.reshape(-1, 3)
|
485 |
+
else:
|
486 |
+
vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
487 |
+
vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
|
488 |
+
torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
|
489 |
+
vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
|
490 |
+
torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
|
491 |
+
weight_sum = (gamma_02 + gamma_13) + 1e-8
|
492 |
+
vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
|
493 |
+
weight_sum.unsqueeze(-1)).squeeze(1)
|
494 |
+
vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
|
495 |
+
vd = torch.cat([vd, vd_center])
|
496 |
+
faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
|
497 |
+
faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
|
498 |
+
return vd, faces, s_edges, edge_indices
|
499 |
+
|
500 |
+
def _tetrahedralize(
|
501 |
+
self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
|
502 |
+
surf_cubes, training):
|
503 |
+
"""
|
504 |
+
Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
|
505 |
+
"""
|
506 |
+
occ_n = s_n < 0
|
507 |
+
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
|
508 |
+
occ_sum = torch.sum(occ_fx8, -1)
|
509 |
+
|
510 |
+
inside_verts = x_nx3[occ_n]
|
511 |
+
mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
|
512 |
+
mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
|
513 |
+
"""
|
514 |
+
For each grid edge connecting two grid vertices with different
|
515 |
+
signs, we first form a four-sided pyramid by connecting one
|
516 |
+
of the grid vertices with four mesh vertices that correspond
|
517 |
+
to the grid edge and then subdivide the pyramid into two tetrahedra
|
518 |
+
"""
|
519 |
+
inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
|
520 |
+
s_edges < 0]]
|
521 |
+
if not training:
|
522 |
+
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
|
523 |
+
else:
|
524 |
+
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
|
525 |
+
|
526 |
+
tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
|
527 |
+
"""
|
528 |
+
For each grid edge connecting two grid vertices with the
|
529 |
+
same sign, the tetrahedron is formed by the two grid vertices
|
530 |
+
and two vertices in consecutive adjacent cells
|
531 |
+
"""
|
532 |
+
inside_cubes = (occ_sum == 8)
|
533 |
+
inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
|
534 |
+
inside_cubes_center_idx = torch.arange(
|
535 |
+
inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
|
536 |
+
|
537 |
+
surface_n_inside_cubes = surf_cubes | inside_cubes
|
538 |
+
edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
|
539 |
+
dtype=torch.long, device=x_nx3.device) * -1
|
540 |
+
surf_cubes = surf_cubes[surface_n_inside_cubes]
|
541 |
+
inside_cubes = inside_cubes[surface_n_inside_cubes]
|
542 |
+
edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
|
543 |
+
edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
|
544 |
+
|
545 |
+
all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
|
546 |
+
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
547 |
+
unique_edges = unique_edges.long()
|
548 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
|
549 |
+
mask = mask_edges[_idx_map]
|
550 |
+
counts = counts[_idx_map]
|
551 |
+
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
|
552 |
+
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
|
553 |
+
idx_map = mapping[_idx_map]
|
554 |
+
|
555 |
+
group_mask = (counts == 4) & mask
|
556 |
+
group = idx_map.reshape(-1)[group_mask]
|
557 |
+
edge_indices, indices = torch.sort(group)
|
558 |
+
cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
|
559 |
+
device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
|
560 |
+
edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
|
561 |
+
0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
|
562 |
+
# Identify the face shared by the adjacent cells.
|
563 |
+
cube_idx_4 = cube_idx[indices].reshape(-1, 4)
|
564 |
+
edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
|
565 |
+
shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
|
566 |
+
cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
|
567 |
+
# Identify an edge of the face with different signs and
|
568 |
+
# select the mesh vertex corresponding to the identified edge.
|
569 |
+
case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
|
570 |
+
case_ids_expand[surf_cubes] = case_ids
|
571 |
+
cases = case_ids_expand[cube_idx_4x2]
|
572 |
+
quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
|
573 |
+
mask = (quad_edge == -1).sum(-1) == 0
|
574 |
+
inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
|
575 |
+
tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
|
576 |
+
|
577 |
+
tets = torch.cat([tets_surface, tets_inside])
|
578 |
+
vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
|
579 |
+
return vertices, tets
|
util/flexicubes_geometry.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from util.flexicubes import FlexiCubes # replace later
|
11 |
+
# from dmtet import sdf_reg_loss_batch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
def get_center_boundary_index(grid_res, device):
|
15 |
+
v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
|
16 |
+
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
|
17 |
+
center_indices = torch.nonzero(v.reshape(-1))
|
18 |
+
|
19 |
+
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
|
20 |
+
v[:2, ...] = True
|
21 |
+
v[-2:, ...] = True
|
22 |
+
v[:, :2, ...] = True
|
23 |
+
v[:, -2:, ...] = True
|
24 |
+
v[:, :, :2] = True
|
25 |
+
v[:, :, -2:] = True
|
26 |
+
boundary_indices = torch.nonzero(v.reshape(-1))
|
27 |
+
return center_indices, boundary_indices
|
28 |
+
|
29 |
+
###############################################################################
|
30 |
+
# Geometry interface
|
31 |
+
###############################################################################
|
32 |
+
class FlexiCubesGeometry(object):
|
33 |
+
def __init__(
|
34 |
+
self, grid_res=64, scale=2.0, device='cuda', renderer=None,
|
35 |
+
render_type='neural_render', args=None):
|
36 |
+
super(FlexiCubesGeometry, self).__init__()
|
37 |
+
self.grid_res = grid_res
|
38 |
+
self.device = device
|
39 |
+
self.args = args
|
40 |
+
self.fc = FlexiCubes(device, weight_scale=0.5)
|
41 |
+
self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
|
42 |
+
if isinstance(scale, list):
|
43 |
+
self.verts[:, 0] = self.verts[:, 0] * scale[0]
|
44 |
+
self.verts[:, 1] = self.verts[:, 1] * scale[1]
|
45 |
+
self.verts[:, 2] = self.verts[:, 2] * scale[1]
|
46 |
+
else:
|
47 |
+
self.verts = self.verts * scale
|
48 |
+
|
49 |
+
all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
|
50 |
+
self.all_edges = torch.unique(all_edges, dim=0)
|
51 |
+
|
52 |
+
# Parameters used for fix boundary sdf
|
53 |
+
self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
|
54 |
+
self.renderer = renderer
|
55 |
+
self.render_type = render_type
|
56 |
+
|
57 |
+
def getAABB(self):
|
58 |
+
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
59 |
+
|
60 |
+
def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
|
61 |
+
if indices is None:
|
62 |
+
indices = self.indices
|
63 |
+
|
64 |
+
verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
|
65 |
+
beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
|
66 |
+
gamma_f=weight_n[:, 20], training=is_training
|
67 |
+
)
|
68 |
+
return verts, faces, v_reg_loss
|
69 |
+
|
70 |
+
|
71 |
+
def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
|
72 |
+
return_value = dict()
|
73 |
+
if self.render_type == 'neural_render':
|
74 |
+
tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
|
75 |
+
mesh_v_nx3.unsqueeze(dim=0),
|
76 |
+
mesh_f_fx3.int(),
|
77 |
+
camera_mv_bx4x4,
|
78 |
+
mesh_v_nx3.unsqueeze(dim=0),
|
79 |
+
resolution=resolution,
|
80 |
+
device=self.device,
|
81 |
+
hierarchical_mask=hierarchical_mask
|
82 |
+
)
|
83 |
+
|
84 |
+
return_value['tex_pos'] = tex_pos
|
85 |
+
return_value['mask'] = mask
|
86 |
+
return_value['hard_mask'] = hard_mask
|
87 |
+
return_value['rast'] = rast
|
88 |
+
return_value['v_pos_clip'] = v_pos_clip
|
89 |
+
return_value['mask_pyramid'] = mask_pyramid
|
90 |
+
return_value['depth'] = depth
|
91 |
+
else:
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
return return_value
|
95 |
+
|
96 |
+
def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
|
97 |
+
# Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
|
98 |
+
v_list = []
|
99 |
+
f_list = []
|
100 |
+
n_batch = v_deformed_bxnx3.shape[0]
|
101 |
+
all_render_output = []
|
102 |
+
for i_batch in range(n_batch):
|
103 |
+
verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
|
104 |
+
v_list.append(verts_nx3)
|
105 |
+
f_list.append(faces_fx3)
|
106 |
+
render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
|
107 |
+
all_render_output.append(render_output)
|
108 |
+
|
109 |
+
# Concatenate all render output
|
110 |
+
return_keys = all_render_output[0].keys()
|
111 |
+
return_value = dict()
|
112 |
+
for k in return_keys:
|
113 |
+
value = [v[k] for v in all_render_output]
|
114 |
+
return_value[k] = value
|
115 |
+
# We can do concatenation outside of the render
|
116 |
+
return return_value
|
util/renderer.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import nvdiffrast.torch as dr
|
5 |
+
from util.flexicubes_geometry import FlexiCubesGeometry
|
6 |
+
|
7 |
+
class Renderer(nn.Module):
|
8 |
+
def __init__(self, tet_grid_size, camera_angle_num, scale, geo_type):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.tet_grid_size = tet_grid_size
|
12 |
+
self.camera_angle_num = camera_angle_num
|
13 |
+
self.scale = scale
|
14 |
+
self.geo_type = geo_type
|
15 |
+
self.glctx = dr.RasterizeCudaContext()
|
16 |
+
|
17 |
+
if self.geo_type == "flex":
|
18 |
+
self.flexicubes = FlexiCubesGeometry(grid_res = self.tet_grid_size)
|
19 |
+
|
20 |
+
def forward(self, data, sdf, deform, verts, tets, training=False, weight = None):
|
21 |
+
|
22 |
+
results = {}
|
23 |
+
|
24 |
+
deform = torch.tanh(deform) / self.tet_grid_size * self.scale / 0.95
|
25 |
+
if self.geo_type == "flex":
|
26 |
+
deform = deform *0.5
|
27 |
+
|
28 |
+
v_deformed = verts + deform
|
29 |
+
|
30 |
+
verts_list = []
|
31 |
+
faces_list = []
|
32 |
+
reg_list = []
|
33 |
+
n_shape = verts.shape[0]
|
34 |
+
for i in range(n_shape):
|
35 |
+
verts_i, faces_i, reg_i = self.flexicubes.get_mesh(v_deformed[i], sdf[i].squeeze(dim=-1),
|
36 |
+
with_uv=False, indices=tets, weight_n=weight[i], is_training=training)
|
37 |
+
|
38 |
+
verts_list.append(verts_i)
|
39 |
+
faces_list.append(faces_i)
|
40 |
+
reg_list.append(reg_i)
|
41 |
+
verts = verts_list
|
42 |
+
faces = faces_list
|
43 |
+
|
44 |
+
flexicubes_surface_reg = torch.cat(reg_list).mean()
|
45 |
+
flexicubes_weight_reg = (weight ** 2).mean()
|
46 |
+
results["flex_surf_loss"] = flexicubes_surface_reg
|
47 |
+
results["flex_weight_loss"] = flexicubes_weight_reg
|
48 |
+
|
49 |
+
return results, verts, faces
|