diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..34cac59b2042e14f6188dd775b5c1059b7149b92 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,77 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/teaserv6.png filter=lfs diff=lfs merge=lfs -text
+assets/test/0.png filter=lfs diff=lfs merge=lfs -text
+assets/test/1.png filter=lfs diff=lfs merge=lfs -text
+assets/test/10.png filter=lfs diff=lfs merge=lfs -text
+assets/test/11.png filter=lfs diff=lfs merge=lfs -text
+assets/test/12.png filter=lfs diff=lfs merge=lfs -text
+assets/test/13.png filter=lfs diff=lfs merge=lfs -text
+assets/test/14.png filter=lfs diff=lfs merge=lfs -text
+assets/test/15.png filter=lfs diff=lfs merge=lfs -text
+assets/test/16.png filter=lfs diff=lfs merge=lfs -text
+assets/test/17.png filter=lfs diff=lfs merge=lfs -text
+assets/test/18.png filter=lfs diff=lfs merge=lfs -text
+assets/test/19.png filter=lfs diff=lfs merge=lfs -text
+assets/test/2.png filter=lfs diff=lfs merge=lfs -text
+assets/test/20.png filter=lfs diff=lfs merge=lfs -text
+assets/test/21.png filter=lfs diff=lfs merge=lfs -text
+assets/test/22.png filter=lfs diff=lfs merge=lfs -text
+assets/test/23.png filter=lfs diff=lfs merge=lfs -text
+assets/test/24.png filter=lfs diff=lfs merge=lfs -text
+assets/test/25.png filter=lfs diff=lfs merge=lfs -text
+assets/test/26.png filter=lfs diff=lfs merge=lfs -text
+assets/test/27.png filter=lfs diff=lfs merge=lfs -text
+assets/test/28.png filter=lfs diff=lfs merge=lfs -text
+assets/test/29.png filter=lfs diff=lfs merge=lfs -text
+assets/test/3.png filter=lfs diff=lfs merge=lfs -text
+assets/test/30.png filter=lfs diff=lfs merge=lfs -text
+assets/test/31.png filter=lfs diff=lfs merge=lfs -text
+assets/test/32.png filter=lfs diff=lfs merge=lfs -text
+assets/test/33.png filter=lfs diff=lfs merge=lfs -text
+assets/test/34.png filter=lfs diff=lfs merge=lfs -text
+assets/test/35.png filter=lfs diff=lfs merge=lfs -text
+assets/test/36.png filter=lfs diff=lfs merge=lfs -text
+assets/test/37.png filter=lfs diff=lfs merge=lfs -text
+assets/test/38.png filter=lfs diff=lfs merge=lfs -text
+assets/test/39.png filter=lfs diff=lfs merge=lfs -text
+assets/test/4.png filter=lfs diff=lfs merge=lfs -text
+assets/test/40.png filter=lfs diff=lfs merge=lfs -text
+assets/test/41.png filter=lfs diff=lfs merge=lfs -text
+assets/test/42.png filter=lfs diff=lfs merge=lfs -text
+assets/test/43.png filter=lfs diff=lfs merge=lfs -text
+assets/test/44.png filter=lfs diff=lfs merge=lfs -text
+assets/test/45.png filter=lfs diff=lfs merge=lfs -text
+assets/test/46.png filter=lfs diff=lfs merge=lfs -text
+assets/test/47.png filter=lfs diff=lfs merge=lfs -text
+assets/test/48.png filter=lfs diff=lfs merge=lfs -text
+assets/test/49.png filter=lfs diff=lfs merge=lfs -text
+assets/test/5.png filter=lfs diff=lfs merge=lfs -text
+assets/test/50.png filter=lfs diff=lfs merge=lfs -text
+assets/test/51.png filter=lfs diff=lfs merge=lfs -text
+assets/test/52.png filter=lfs diff=lfs merge=lfs -text
+assets/test/53.png filter=lfs diff=lfs merge=lfs -text
+assets/test/54.png filter=lfs diff=lfs merge=lfs -text
+assets/test/55.png filter=lfs diff=lfs merge=lfs -text
+assets/test/56.png filter=lfs diff=lfs merge=lfs -text
+assets/test/57.png filter=lfs diff=lfs merge=lfs -text
+assets/test/58.png filter=lfs diff=lfs merge=lfs -text
+assets/test/59.png filter=lfs diff=lfs merge=lfs -text
+assets/test/6.png filter=lfs diff=lfs merge=lfs -text
+assets/test/60.png filter=lfs diff=lfs merge=lfs -text
+assets/test/61.png filter=lfs diff=lfs merge=lfs -text
+assets/test/62.png filter=lfs diff=lfs merge=lfs -text
+assets/test/63.png filter=lfs diff=lfs merge=lfs -text
+assets/test/64.png filter=lfs diff=lfs merge=lfs -text
+assets/test/65.png filter=lfs diff=lfs merge=lfs -text
+assets/test/66.png filter=lfs diff=lfs merge=lfs -text
+assets/test/67.png filter=lfs diff=lfs merge=lfs -text
+assets/test/7.png filter=lfs diff=lfs merge=lfs -text
+assets/test/8.png filter=lfs diff=lfs merge=lfs -text
+assets/test/9.png filter=lfs diff=lfs merge=lfs -text
+third_party/voxelize/build/lib.linux-x86_64-3.10/udf_ext.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
+third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_deps filter=lfs diff=lfs merge=lfs -text
+third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o filter=lfs diff=lfs merge=lfs -text
+third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o filter=lfs diff=lfs merge=lfs -text
+third_party/voxelize/udf_ext.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..225066c568a99b8715cdabebc8cf90ff1cf7ca73
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 DreamTechAI
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..0229b5814a82398f83552e047575f9dd8be45325
--- /dev/null
+++ b/app.py
@@ -0,0 +1,207 @@
+
+import torch
+import trimesh
+import datetime
+import argparse
+import numpy as np
+from torchvision import transforms
+from direct3d_s2.utils.rembg import BiRefNet
+from direct3d_s2.pipeline import Direct3DS2Pipeline
+from direct3d_s2.utils.fill_hole import postprocess_mesh
+
+import os
+from PIL import Image
+from typing import Any
+
+import gradio as gr
+from gradio.themes.utils import colors, fonts, sizes
+
+# -----------------------------------------------------------------------------
+# THEME ▸ a soft glass-like dark theme with a vibrant primary accent
+# -----------------------------------------------------------------------------
+class Glass(gr.themes.Soft):
+ def __init__(self):
+ super().__init__(
+ primary_hue=colors.emerald,
+ secondary_hue=colors.indigo,
+ neutral_hue=colors.zinc,
+ text_size=sizes.text_md,
+ spacing_size=sizes.spacing_md,
+ radius_size=sizes.radius_lg,
+ font=fonts.GoogleFont("Inter"),
+ )
+
+ def style(self):
+ super().style()
+ self.set(
+ background_fill="var(--neutral-950)",
+ border_color_primary="rgba(255,255,255,.12)",
+ border_width="1px",
+ shadow_drop="0 10px 38px -10px rgba(0,0,0,.65)",
+ shadow_drop_lg="0 10px 38px -10px rgba(0,0,0,.65)",
+ )
+ return self
+
+def check_input_image(input_image):
+ if input_image is None:
+ raise gr.Error("No image uploaded!")
+
+# -----------------------------------------------------------------------------
+# PLACEHOLDER BACK-END HOOKS ▸ replace with your real logic
+# -----------------------------------------------------------------------------
+def image2mesh(
+ image: Any,
+ resolution: str = '1024',
+ simplify: bool = True,
+ simplify_ratio: float = 0.95,
+ output_path: str = 'outputs/web'
+):
+
+ torch.cuda.empty_cache()
+
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+
+ uid = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ image.save(os.path.join(output_path, uid + '.png'))
+
+ pipe = Direct3DS2Pipeline.from_pretrained('wushuang98/Direct3D-S2', subfolder="direct3d-s2-v-1-1")
+ pipe.to("cuda:0")
+
+ mesh = pipe(
+ image,
+ sdf_resolution=int(resolution),
+ mc_threshold=0.2,
+ remesh=simplify,
+ simplify_ratio=simplify_ratio,
+ )["mesh"]
+
+ mesh_path = os.path.join(output_path, f'{uid}.obj')
+ mesh.export(
+ mesh_path,
+ include_normals=True,
+ )
+ torch.cuda.empty_cache()
+
+ return mesh_path
+
+# -----------------------------------------------------------------------------
+# UI LAYOUT ▸ minimal glassmorphism, keyboard-first workflow
+# -----------------------------------------------------------------------------
+with gr.Blocks(theme=Glass(), css="""
+:root { --header-height:64px }
+body { background:linear-gradient(215deg,#101113 0%,#0b0c0d 60%,#0d1014 100%) }
+#header { height:var(--header-height);display:flex;align-items:center;justify-content:space-between;padding:0 1.5rem;backdrop-filter:blur(18px);background:rgba(17,17,17,.65);border-bottom:1px solid rgba(255,255,255,.08);position:sticky;top:0;z-index:999 }
+#header a { color:white;font-weight:500;text-decoration:none;margin-right:1.25rem;font-size:.925rem }
+#hero-title { font-size:1.35rem;font-weight:600;color:white;white-space:nowrap }
+#footer { text-align:center;font-size:.8rem;color:rgba(255,255,255,.55);margin-top:1.5rem }
+#mesh_viewport { aspect-ratio:1/1;width:100%;display:flex;align-items:center;justify-content:center;border:1px dashed rgba(255,255,255,.12);border-radius:12px;background:rgba(255,255,255,.03); }
+.gallery-item img { border-radius:10px }
+#examples_gallery { height:100%;flex:1;display:flex;flex-direction:column; }
+#examples_gallery img { width:800px;}
+#show_image img { height:260px;display:flex;align-items:center;justify-content:center; }
+#examples { height:100%;flex:1; }
+""") as demo:
+
+ # ▸ custom sticky header
+ with gr.Row(elem_id="header", variant="panel"):
+ gr.Markdown("Direct3D-S2 Studio", elem_id="hero-title")
+ gr.Markdown(
+ """
+ """,
+ elem_id="nav-links",
+ )
+
+ # ▸ main workspace
+ with gr.Row(equal_height=True):
+ # ---------- Controls ----------
+ with gr.Column(scale=3):
+ gr.Markdown("### Input", elem_classes="subtitle")
+ image_input = gr.Image(
+ label="Image Input",
+ image_mode="RGBA",
+ sources="upload",
+ type="pil",
+ height=260,
+ elem_id="show_image",
+ )
+ # gr.Markdown("
Drag & drop or click to upload
")
+ processed_image = gr.Image(
+ label="Processed Image",
+ image_mode="RGBA",
+ type="pil",
+ interactive=False,
+ height=260,
+ elem_id="show_image",
+ )
+ with gr.Accordion("Advanced Options", open=True):
+ resolution = gr.Radio(choices=["512", "1024"], label="SDF Resolution", value="1024")
+ simplify = gr.Checkbox(label="Simplify Mesh", value=True)
+ reduce_ratio = gr.Slider(0.1, 0.95, step=0.05, value=0.95, label="Faces Reduction Ratio")
+
+ gen_btn = gr.Button("Generate 3D ✨", variant="primary", interactive=True)
+
+ # ---------- Viewport ----------
+ with gr.Column(scale=6):
+ gr.Markdown("### Model Viewer", elem_classes="subtitle")
+ # mesh_html = gr.HTML("🌀 No mesh yet
")
+ output_model_obj = gr.Model3D(
+ label="Output Model (OBJ Format)",
+ camera_position=(90.0, 90.0, 3.5),
+ interactive=False,
+ elem_id="mesh_viewport",
+ )
+
+ # ---------- Gallery / Examples ----------
+ with gr.Column(scale=3):
+ gr.Markdown("### Examples", elem_classes="subtitle")
+ # gr.Examples(
+ # examples=[os.path.join("assets/test", i) for i in os.listdir("assets/test")],
+ # inputs=[image_input],
+ # examples_per_page=8,
+ # label="Gallery",
+ # elem_id="examples_gallery",
+ # )
+ with gr.Tabs(selected='tab_img_gallery') as gallery:
+ with gr.Tab('Image to 3D Gallery', id='tab_img_gallery') as tab_gi:
+ with gr.Row():
+ gr.Examples(
+ examples=[os.path.join("assets/test", i) for i in os.listdir("assets/test")],
+ inputs=[image_input],
+ label=None,
+ examples_per_page=24
+ )
+ # gallery = gr.Gallery(
+ # [os.path.join("assets/test", i) for i in os.listdir("assets/test")],
+ # columns=2,
+ # object_fit="contain",
+ # elem_id="examples_gallery",
+ # allow_preview=False,
+ # )
+
+
+ # ▸ callbacks
+ outputs = [output_model_obj]
+ rmbg = BiRefNet(device="cuda:0")
+
+ gen_btn.click(
+ fn=check_input_image,
+ inputs=[image_input]
+ ).success(
+ fn=rmbg.run,
+ inputs=[image_input],
+ outputs=[processed_image]
+ ).success(
+ fn=image2mesh,
+ inputs=[processed_image, resolution, simplify, reduce_ratio],
+ outputs=outputs,
+ api_name="generate_img2obj"
+ )
+
+# -----------------------------------------------------------------------------
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--cached_dir", type=str, default="outputs/web")
+ args = parser.parse_args()
+
+ demo.queue().launch(share=True, allowed_paths=[args.cached_dir], server_port=7860)
diff --git a/assets/teaserv6.png b/assets/teaserv6.png
new file mode 100644
index 0000000000000000000000000000000000000000..9e4af78c005e2ce1e489b19d4f73125b973d5bba
--- /dev/null
+++ b/assets/teaserv6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e58eefff81cee52dbc11d65949fe99c3e8b23e2a0a5e677505aca8799a22564
+size 3282197
diff --git a/assets/test/0.png b/assets/test/0.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce6173cb73e77c760c0812118fafeea74730f64b
--- /dev/null
+++ b/assets/test/0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11b98466291aaf59b56a5149a3acebd69d419876b8b5905041f92314889c90e4
+size 432420
diff --git a/assets/test/1.png b/assets/test/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..3863da51b726f88d7fa9fec08ca4cc63a686d72e
--- /dev/null
+++ b/assets/test/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c034e0e62e82a70194ecfd751c27b2f538cbf1eeb304029b40d069c92546033a
+size 428055
diff --git a/assets/test/10.png b/assets/test/10.png
new file mode 100644
index 0000000000000000000000000000000000000000..2d267096771f514299064ef873f9390c2cb41fe0
--- /dev/null
+++ b/assets/test/10.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e1121ea0f9bc6ac7469a0d3db15c445f9694fb4375a3a49032cef209179d1e2
+size 1083093
diff --git a/assets/test/11.png b/assets/test/11.png
new file mode 100644
index 0000000000000000000000000000000000000000..9b9ce9d0c9cd2c5b44412fe924c9c8533eacbe57
--- /dev/null
+++ b/assets/test/11.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2284a6d065826ad4a4293a662781a8b506c4ec61890b52ee6855af5b029f4e8f
+size 1810882
diff --git a/assets/test/12.png b/assets/test/12.png
new file mode 100644
index 0000000000000000000000000000000000000000..a50c9eb3fe1496f6a3448ccba1862035227fa1f2
--- /dev/null
+++ b/assets/test/12.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a904d6b0ac11c75d7d0409ca46beb3e80bbdf72fd8272e268230fcc5e8deab7c
+size 953366
diff --git a/assets/test/13.png b/assets/test/13.png
new file mode 100644
index 0000000000000000000000000000000000000000..8cdbc86b61339c1b16f9c255633697d68c3ec5e5
--- /dev/null
+++ b/assets/test/13.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:600a57dcad4c4a1f06b0d52aac68e762c42ee86f17c6ce0c9a2f59b62df00b54
+size 1768898
diff --git a/assets/test/14.png b/assets/test/14.png
new file mode 100644
index 0000000000000000000000000000000000000000..9fe23cc235caabe20470240b962557f71bf21a8a
--- /dev/null
+++ b/assets/test/14.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dccf9a1727df51d81d7dd7f4168f8db4df40123f453474bc84dee3cd97bdde82
+size 1989606
diff --git a/assets/test/15.png b/assets/test/15.png
new file mode 100644
index 0000000000000000000000000000000000000000..0731bb00e1a8e7ce7349feccfcad8f98e0096f9c
--- /dev/null
+++ b/assets/test/15.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6025c32c9a0bf8d88c2f66985234a6a283ebd3cecabe16be1be0e1f9d89147f9
+size 1113661
diff --git a/assets/test/16.png b/assets/test/16.png
new file mode 100644
index 0000000000000000000000000000000000000000..3f48cf6cc141087bb3f3e611342bb823d4f05114
--- /dev/null
+++ b/assets/test/16.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:817ca70e9af0ad89af747fb389e2be833aefc3becbe9d86506bd53452cea12c8
+size 2265160
diff --git a/assets/test/17.png b/assets/test/17.png
new file mode 100644
index 0000000000000000000000000000000000000000..41bf8216b6f87a4ed2543d6dbbfd31da9c37a8f8
--- /dev/null
+++ b/assets/test/17.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4de91821f1e3d7db6834ba679aa1b7f78c7a52f5ce9c2cd0c44ccbf6d649bb7d
+size 603365
diff --git a/assets/test/18.png b/assets/test/18.png
new file mode 100644
index 0000000000000000000000000000000000000000..8504961d3657175523ffc7733bb816825139a67d
--- /dev/null
+++ b/assets/test/18.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3cad5e839a8c54334c28aa08aec2d185b2ce3e0a96c0f8dca901d7bc8535079c
+size 1425119
diff --git a/assets/test/19.png b/assets/test/19.png
new file mode 100644
index 0000000000000000000000000000000000000000..2bc3df887c72ff0cc7ac92f223e2c80cd2d5d1ff
--- /dev/null
+++ b/assets/test/19.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f69ec7094bdb98bf69d8345c4cc39e00648cc74721879d43c784ad774786e6af
+size 1171141
diff --git a/assets/test/2.png b/assets/test/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..2c186cf08bf92cacfc5b54757bec340b5d3a0694
--- /dev/null
+++ b/assets/test/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c01e935783d7a81eb9024c6b0ca450fe3d69cdf92a816ae1e03ff57c78024875
+size 848588
diff --git a/assets/test/20.png b/assets/test/20.png
new file mode 100644
index 0000000000000000000000000000000000000000..4d7e40882784eb7bafeae6b9f1ffa9adf68a967b
--- /dev/null
+++ b/assets/test/20.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97ba5a7b42ec317aee9ee4d3cdbd62ac507c3554f4c1dca7ce0f635039cc16e3
+size 601441
diff --git a/assets/test/21.png b/assets/test/21.png
new file mode 100644
index 0000000000000000000000000000000000000000..a8b79f3a2a8ebabfad60821e6362b25bf035c395
--- /dev/null
+++ b/assets/test/21.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ae71e62d5893669426d94226fe376f4bc4269dfdb83a465e0510b6d8002deb5
+size 808986
diff --git a/assets/test/22.png b/assets/test/22.png
new file mode 100644
index 0000000000000000000000000000000000000000..99907ec3d36814eb5e58ca9a0c2c9dc98dec1998
--- /dev/null
+++ b/assets/test/22.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff910bc20816d8a55ff08635a0228abbe821f8cd8359c86ef3f044a16cebe6f7
+size 1175586
diff --git a/assets/test/23.png b/assets/test/23.png
new file mode 100644
index 0000000000000000000000000000000000000000..8da4acda1d51cc83eacb22c46a6fb9602940b51c
--- /dev/null
+++ b/assets/test/23.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f4c57797cea4d29e465d556b63c2d367fa0f4469fd3adf31ff686665cce38a1
+size 834642
diff --git a/assets/test/24.png b/assets/test/24.png
new file mode 100644
index 0000000000000000000000000000000000000000..561190e7fd82ffc201824a4832f85521e7101c15
--- /dev/null
+++ b/assets/test/24.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab1feb4c8b0e507f16248c65826961820b0367ea897705c8cc19a802824a2a06
+size 1805891
diff --git a/assets/test/25.png b/assets/test/25.png
new file mode 100644
index 0000000000000000000000000000000000000000..23ea0c33413ea8cab436d7d12c150782cdb243c1
--- /dev/null
+++ b/assets/test/25.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dddb4e689ef77d87c3849e021e322f9b30ce384ed83ba60c6b0e4b55ec5202f2
+size 1193255
diff --git a/assets/test/26.png b/assets/test/26.png
new file mode 100644
index 0000000000000000000000000000000000000000..ec3004593f76ff2c199756915c6a78cda03148ab
--- /dev/null
+++ b/assets/test/26.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96e0987b6d3c0e7a883667dfb2e52461fad04a621a43a95a2ec1544c0e5b8683
+size 1037838
diff --git a/assets/test/27.png b/assets/test/27.png
new file mode 100644
index 0000000000000000000000000000000000000000..974a3047b590d8d67b7a02eb3d74d9b2fc791d74
--- /dev/null
+++ b/assets/test/27.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:711cd06fec16402e73c01bd80e202ae91b3f18c3d9372b57447c00d2a4491d03
+size 1489023
diff --git a/assets/test/28.png b/assets/test/28.png
new file mode 100644
index 0000000000000000000000000000000000000000..99a214e54a7098ddb87b15acfd802d495aea7493
--- /dev/null
+++ b/assets/test/28.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c87191606ad374fc506857050f44e8d8af0e8f1227391cd98bab6eb5bf7dc1c2
+size 1637020
diff --git a/assets/test/29.png b/assets/test/29.png
new file mode 100644
index 0000000000000000000000000000000000000000..843926b25a4eba8be295be228caacf38725b8eb6
--- /dev/null
+++ b/assets/test/29.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9b0f9ba7069e600b470fa8ba3cfcff890a9bd5d481eaab7d25fe190e69fdf4ac
+size 1579767
diff --git a/assets/test/3.png b/assets/test/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..7f86106be5d6ae8f37f27d6d97e87a3651afba9b
--- /dev/null
+++ b/assets/test/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:16af3c8ee909ef2a5e13f996a9b79029bc385c1ccbd9a8cf0613081ead844e63
+size 1384607
diff --git a/assets/test/30.png b/assets/test/30.png
new file mode 100644
index 0000000000000000000000000000000000000000..d5b33df498d2355533acf6b7a2d1536562e9c855
--- /dev/null
+++ b/assets/test/30.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f7c925eefcf85714df9fbf4d602e9ad604d7efc2a87b57d72741049737dc9747
+size 2476255
diff --git a/assets/test/31.png b/assets/test/31.png
new file mode 100644
index 0000000000000000000000000000000000000000..16b8e97073d4f77bb6eec6f810389612b37535f3
--- /dev/null
+++ b/assets/test/31.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:873a12cd691b851939e1558d203021462543f0cb8167d65fa468ec93793192b6
+size 1480380
diff --git a/assets/test/32.png b/assets/test/32.png
new file mode 100644
index 0000000000000000000000000000000000000000..07810ae63439da49dcf6d3abb0bc9acb348ba864
--- /dev/null
+++ b/assets/test/32.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa8ff1fcc5b43d9382bb196665b624d19ce5a5204b618c9adb87520a460efee6
+size 1114250
diff --git a/assets/test/33.png b/assets/test/33.png
new file mode 100644
index 0000000000000000000000000000000000000000..44658a0bc404a3a2b0d42106b6b7e37c0f6dc42e
--- /dev/null
+++ b/assets/test/33.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7bac56a98e7ff510cfb628806c11abee90ea7f66dd6997a1e690a30aa2af7560
+size 1025378
diff --git a/assets/test/34.png b/assets/test/34.png
new file mode 100644
index 0000000000000000000000000000000000000000..d268b7332790b789c1efbc9603dd3b36801f77ef
--- /dev/null
+++ b/assets/test/34.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ed6537fe32aa665fe6ff180cbd08ccc3dca87da382e4165e4e8865e26d80bf3
+size 2005294
diff --git a/assets/test/35.png b/assets/test/35.png
new file mode 100644
index 0000000000000000000000000000000000000000..d4f189675bc888238896eab93f4aefa6e2e32a9b
--- /dev/null
+++ b/assets/test/35.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86a171e37a3d781e7215977f565cd63e813341c1f89e2c586fa61937e4ed6916
+size 481919
diff --git a/assets/test/36.png b/assets/test/36.png
new file mode 100644
index 0000000000000000000000000000000000000000..3dddee1020e052155d2e8f404d982f45786c5d07
--- /dev/null
+++ b/assets/test/36.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:27a418853eefa197f1e10ed944a7bb071413fd2bc1681804ee773a6ce3799c52
+size 712062
diff --git a/assets/test/37.png b/assets/test/37.png
new file mode 100644
index 0000000000000000000000000000000000000000..18e37dc2fa5cc9fd7cd4003828e85514f0b0f780
--- /dev/null
+++ b/assets/test/37.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aecbc5712f300ec67fb01d79cd758e7cb5da4c11c9e19d4a8e72d05275016766
+size 1890500
diff --git a/assets/test/38.png b/assets/test/38.png
new file mode 100644
index 0000000000000000000000000000000000000000..293dc7dc50f3fc6fb884c8a0ba64a86df178a1b8
--- /dev/null
+++ b/assets/test/38.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:98bab768078960656fab3755af69954055fa4c65561376085e7a80431dbc7b9d
+size 1792829
diff --git a/assets/test/39.png b/assets/test/39.png
new file mode 100644
index 0000000000000000000000000000000000000000..92dfce03092f39dda00ba86f9499c280bc26f506
--- /dev/null
+++ b/assets/test/39.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66b568e2951c2db7c59a73b4498cac011535b414f7d7d4327235039467f5429f
+size 2004638
diff --git a/assets/test/4.png b/assets/test/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..955991400960ab8c6ca8f6b876097823be39f940
--- /dev/null
+++ b/assets/test/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5546b5268a51f20dc15eaba836ba09954ec98fb347f5c0b4baefccac0f596757
+size 1505047
diff --git a/assets/test/40.png b/assets/test/40.png
new file mode 100644
index 0000000000000000000000000000000000000000..a1a72358f3aa124cf5835b0460205524d3c27d0b
--- /dev/null
+++ b/assets/test/40.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e75bebfbefd8f67818c5e051cac3cb4853b418974a364fff8a14c9a4b7e7eba8
+size 686837
diff --git a/assets/test/41.png b/assets/test/41.png
new file mode 100644
index 0000000000000000000000000000000000000000..bdf9dc4534a26b05e13e42f339f65b9caa76df28
--- /dev/null
+++ b/assets/test/41.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a95ac40d2ae1c978baa64c4e41c1c68ee5742c71236ae330de6e6dbb148dc84
+size 611258
diff --git a/assets/test/42.png b/assets/test/42.png
new file mode 100644
index 0000000000000000000000000000000000000000..3a0e2fe912c7f5818a6efa34e8951e37820d24f0
--- /dev/null
+++ b/assets/test/42.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0449cf9ab9dbe9b044585da4b940b9ad2d2fda158a6eea3a8abf3555f9e4bc9d
+size 1687128
diff --git a/assets/test/43.png b/assets/test/43.png
new file mode 100644
index 0000000000000000000000000000000000000000..d1e9fd9fa970998582cd58e3b6e0b7544547022e
--- /dev/null
+++ b/assets/test/43.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c6928cdeb8ee4dcef2d4c4e6b5dd0c69f2ecc3885daeb1cdf9bfbc40da0c01e
+size 1447356
diff --git a/assets/test/44.png b/assets/test/44.png
new file mode 100644
index 0000000000000000000000000000000000000000..d5b9eb33cb3aecab7070e13e6a543f52a8c3aef8
--- /dev/null
+++ b/assets/test/44.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:507abeb3f7698e5e929152b4878152626aa1270554bfa5a84ff74dc6114ea3c1
+size 1978541
diff --git a/assets/test/45.png b/assets/test/45.png
new file mode 100644
index 0000000000000000000000000000000000000000..a702921607bb16938c02c99766d49abf4fa880ab
--- /dev/null
+++ b/assets/test/45.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e89b245fbe25abe19d817866a2c38b3b2c7d130a6b21b1ec6391975671d4d06d
+size 1951359
diff --git a/assets/test/46.png b/assets/test/46.png
new file mode 100644
index 0000000000000000000000000000000000000000..a6f3bf5162ef0d3776be36a6d0a2ac061f56a75c
--- /dev/null
+++ b/assets/test/46.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ddddfdeb15b82d85b90684b2aba221705283894f2f69532fcf84512107b85f67
+size 1162177
diff --git a/assets/test/47.png b/assets/test/47.png
new file mode 100644
index 0000000000000000000000000000000000000000..aa56b73dbb206ba5dfc03b08f61d87110b64ee0b
--- /dev/null
+++ b/assets/test/47.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:060a286bffcf9ee71404b62f240253788ea4b15b026f3583b06111070027fa0d
+size 1896047
diff --git a/assets/test/48.png b/assets/test/48.png
new file mode 100644
index 0000000000000000000000000000000000000000..a599f773bcb21f82bccfa1a5d0fde5c4a324dff0
--- /dev/null
+++ b/assets/test/48.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b28719959fc116464429e1c677f6676f2b37e8ccd592c4f1d08400bc48f3adac
+size 2301834
diff --git a/assets/test/49.png b/assets/test/49.png
new file mode 100644
index 0000000000000000000000000000000000000000..66a1406029617fdf2557efb8a8011ed08a6a0810
--- /dev/null
+++ b/assets/test/49.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94f36bdae81c4c7cf6b4a402b704bb29aa1a56ca6cee7671a41ca041c0560203
+size 2206909
diff --git a/assets/test/5.png b/assets/test/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..61571a7ac4d2ce6e8e0d234cc6f0a8e13218ccab
--- /dev/null
+++ b/assets/test/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:412c5f398bab3dc08395a8140777d883b3291eda08e2d6346cb93b4aeb42407e
+size 1638078
diff --git a/assets/test/50.png b/assets/test/50.png
new file mode 100644
index 0000000000000000000000000000000000000000..88a413563ecb21ac5b83c541b89bb27d7772d5dd
--- /dev/null
+++ b/assets/test/50.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45e7810d3358274a0f32a4ecaf57a04e88d3438c46cd0a6bd76a3e4bcf149e6b
+size 2293203
diff --git a/assets/test/51.png b/assets/test/51.png
new file mode 100644
index 0000000000000000000000000000000000000000..95f0d876e8380f7560d0d4aaab7d3db44fae7e15
--- /dev/null
+++ b/assets/test/51.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9cea2e8947d4c57a95157836fc4a7c0d36335748f394af01a9f27f6b933ea82
+size 1359226
diff --git a/assets/test/52.png b/assets/test/52.png
new file mode 100644
index 0000000000000000000000000000000000000000..4bd172e1211c480e2e9b39613d897f0acfdc5678
--- /dev/null
+++ b/assets/test/52.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:402758b579b0d96043c613c93125a59e183b6a06acb9c0fad7dc30f1c1a3d40c
+size 862029
diff --git a/assets/test/53.png b/assets/test/53.png
new file mode 100644
index 0000000000000000000000000000000000000000..38fdc4c21ef6f7f5d8623948d43d81acfafd2708
--- /dev/null
+++ b/assets/test/53.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fdd82d60b7ec11e6d5699df29693d8ab538f9dab4b04e3f2abaa59ccd7b4709a
+size 226408
diff --git a/assets/test/54.png b/assets/test/54.png
new file mode 100644
index 0000000000000000000000000000000000000000..ff256723431c2729c3eff55f067fb6f7120d47ee
--- /dev/null
+++ b/assets/test/54.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5793b52a21507a5cdd661f8d87681b303e420057a8c65aaa16e0560409a7a34e
+size 724846
diff --git a/assets/test/55.png b/assets/test/55.png
new file mode 100644
index 0000000000000000000000000000000000000000..4e82000c43b029e236b85b135dbba66ea04d4871
--- /dev/null
+++ b/assets/test/55.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a6ff25b573a607605c607815f99b566002f47dc5f69cfe145248b8c0a4855f5
+size 883754
diff --git a/assets/test/56.png b/assets/test/56.png
new file mode 100644
index 0000000000000000000000000000000000000000..7cf47bc68a3197540c765c8f8e246848c00e0e49
--- /dev/null
+++ b/assets/test/56.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:43a71e868ab4a5b174ceb5ceaa0a7e76f6a0a635acc246eaaa7b83da47357b82
+size 1041398
diff --git a/assets/test/57.png b/assets/test/57.png
new file mode 100644
index 0000000000000000000000000000000000000000..fb69455a8960874fba38700b2c9a495704d0a283
--- /dev/null
+++ b/assets/test/57.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5677117de95f646acd0f4455d32bae8cabb2fef7567a64bab9c8638e8c98cbbb
+size 862063
diff --git a/assets/test/58.png b/assets/test/58.png
new file mode 100644
index 0000000000000000000000000000000000000000..35285ba5c6f01bbcbee4a085e8c3e2379446ef19
--- /dev/null
+++ b/assets/test/58.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b235ed827731813445fd2413aaddb7bef9523d20d23ce1dff7d39fd2bc6829da
+size 248153
diff --git a/assets/test/59.png b/assets/test/59.png
new file mode 100644
index 0000000000000000000000000000000000000000..0784bcdfd94be087b1fbc441a25c302cf2b5f361
--- /dev/null
+++ b/assets/test/59.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73d09b20fb4e7ab6513a3ac99f7c87166ec655a731bd1be9b14cec6aef36c4ef
+size 1509264
diff --git a/assets/test/6.png b/assets/test/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..9238b926417cf808022822562f148dcdc4213e2f
--- /dev/null
+++ b/assets/test/6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c5472bb12b775c18b37fcb889069915943c4b6122fda668d3a1df8003e9e5da
+size 1986433
diff --git a/assets/test/60.png b/assets/test/60.png
new file mode 100644
index 0000000000000000000000000000000000000000..4542fab15508116ea9c49d764a997263757f5f73
--- /dev/null
+++ b/assets/test/60.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:adba4a0efaad6e22b667b516a2724d62a40b09726578cd71dc3b2abcfcb7c5a2
+size 437697
diff --git a/assets/test/61.png b/assets/test/61.png
new file mode 100644
index 0000000000000000000000000000000000000000..1546f322acc31caeb524a518979afbffe8de197f
--- /dev/null
+++ b/assets/test/61.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d7e716abe8f8895080f562d1dc26b14fa0e20a05aa5beb2770c6fb3b87b3476a
+size 594232
diff --git a/assets/test/62.png b/assets/test/62.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c633d74618978d918461428d1d76df8805c7554
--- /dev/null
+++ b/assets/test/62.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:70f86ca18bf83b67e57d74ba01eb045047fe407353a69f9d1bc75015d1606346
+size 591829
diff --git a/assets/test/63.png b/assets/test/63.png
new file mode 100644
index 0000000000000000000000000000000000000000..c966e77ebf366c679ad749dba14a06616f3d2c89
--- /dev/null
+++ b/assets/test/63.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8799fd1b56df540f2c7471312be5959fbf49d33d2386e59b41cb7138ad4694b3
+size 646495
diff --git a/assets/test/64.png b/assets/test/64.png
new file mode 100644
index 0000000000000000000000000000000000000000..3536dd866f2332d60eb149cc1243ea5c5392e228
--- /dev/null
+++ b/assets/test/64.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97ba7df8adf0e956c5cbf85fafb127b9b21a7873c33629a476af40bac91d5655
+size 397994
diff --git a/assets/test/65.png b/assets/test/65.png
new file mode 100644
index 0000000000000000000000000000000000000000..9f29b4f1fa5537ec16bbf9835cae5bb9e75640de
--- /dev/null
+++ b/assets/test/65.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:907cacfd8e66fd8424389e6560e66ecdd98544f31076ad1e0c9b6434a47a9747
+size 980850
diff --git a/assets/test/66.png b/assets/test/66.png
new file mode 100644
index 0000000000000000000000000000000000000000..51b786a1d5cd52c9e7207961693172800ca336b8
--- /dev/null
+++ b/assets/test/66.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b8236aeb0cf0923df29e3d76f3cbe6cc61c9da85ea4b43a4e7cf81614fe750cd
+size 1182506
diff --git a/assets/test/67.png b/assets/test/67.png
new file mode 100644
index 0000000000000000000000000000000000000000..43173a46412aec0a4109bdc408171e383262042e
--- /dev/null
+++ b/assets/test/67.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:134136dd4086cfc1b887ab0a134c4a2b906223762a0d5959a8b90cc68f11f4f0
+size 1490139
diff --git a/assets/test/7.png b/assets/test/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..754a251ab29736f9cb564a2c43b7b4f47317f4e5
--- /dev/null
+++ b/assets/test/7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f820d9e9fe86192e6e1a70dc2430f3189bf620aff6e16aa817e6b99da286c424
+size 990354
diff --git a/assets/test/8.png b/assets/test/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..76541834fc6e5e4120c225238e412596e9199ec2
--- /dev/null
+++ b/assets/test/8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8bd5e913171ad456d607442af4ad89bd6e0da6cd819aedcd76d0fa1c2d4ff655
+size 1161291
diff --git a/assets/test/9.png b/assets/test/9.png
new file mode 100644
index 0000000000000000000000000000000000000000..0f16b1297ac470b4a444965eac7748723ee027dd
--- /dev/null
+++ b/assets/test/9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c22e9e601fead00398634e083e2ed1e90d0c60f459757ae19f5909f3ee5481aa
+size 1029084
diff --git a/direct3d_s2/__pycache__/pipeline.cpython-310.pyc b/direct3d_s2/__pycache__/pipeline.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7d119d458e52959e75cd802b7a58f1ee690fb8e
Binary files /dev/null and b/direct3d_s2/__pycache__/pipeline.cpython-310.pyc differ
diff --git a/direct3d_s2/models/__init__.py b/direct3d_s2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/direct3d_s2/models/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35967be27d8f00a7776ebecaa3586e0988a4fd1f
Binary files /dev/null and b/direct3d_s2/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/models/__pycache__/conditioner.cpython-310.pyc b/direct3d_s2/models/__pycache__/conditioner.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5208e75587b28c6621a1973cb5ce3d90e4bdd2cb
Binary files /dev/null and b/direct3d_s2/models/__pycache__/conditioner.cpython-310.pyc differ
diff --git a/direct3d_s2/models/autoencoders/__pycache__/base.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28713e652176cfb9ced950d4f71f8dabf564205b
Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/base.cpython-310.pyc differ
diff --git a/direct3d_s2/models/autoencoders/__pycache__/decoder.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8a4507a17cd98df6fb5f411b20d07c00f566f1b
Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/decoder.cpython-310.pyc differ
diff --git a/direct3d_s2/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce375b2d5118ca2ecce50444ff96216e9d9bb744
Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc differ
diff --git a/direct3d_s2/models/autoencoders/__pycache__/distributions.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/distributions.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc13439a4a922a3ae606562e28746aa0999a762f
Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/distributions.cpython-310.pyc differ
diff --git a/direct3d_s2/models/autoencoders/__pycache__/encoder.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c8314f083a3dd0b61a39b35a9153c029136e1e0
Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/encoder.cpython-310.pyc differ
diff --git a/direct3d_s2/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1377751eb6d3e498ed44126e3822f693447a8fc6
Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc differ
diff --git a/direct3d_s2/models/autoencoders/base.py b/direct3d_s2/models/autoencoders/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ffeea9be7a588d0de1de9aae07ba00dbca99e1
--- /dev/null
+++ b/direct3d_s2/models/autoencoders/base.py
@@ -0,0 +1,118 @@
+from typing import *
+import torch
+import torch.nn as nn
+
+from ...modules.utils import convert_module_to_f16, convert_module_to_f32
+from ...modules import sparse as sp
+from ...modules.transformer import AbsolutePositionEmbedder
+from ...modules.sparse.transformer import SparseTransformerBlock
+
+
+def block_attn_config(self):
+ """
+ Return the attention configuration of the model.
+ """
+ for i in range(self.num_blocks):
+ if self.attn_mode == "shift_window":
+ yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
+ elif self.attn_mode == "shift_sequence":
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
+ elif self.attn_mode == "shift_order":
+ yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
+ elif self.attn_mode == "full":
+ yield "full", None, None, None, None
+ elif self.attn_mode == "swin":
+ yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
+
+
+class SparseTransformerBase(nn.Module):
+ """
+ Sparse Transformer without output layers.
+ Serve as the base class for encoder and decoder.
+ """
+ def __init__(
+ self,
+ in_channels: int,
+ model_channels: int,
+ num_blocks: int,
+ num_heads: Optional[int] = None,
+ num_head_channels: Optional[int] = 64,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
+ window_size: Optional[int] = None,
+ pe_mode: Literal["ape", "rope"] = "ape",
+ use_fp16: bool = False,
+ use_checkpoint: bool = False,
+ qk_rms_norm: bool = False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.num_blocks = num_blocks
+ self.window_size = window_size
+ self.num_heads = num_heads or model_channels // num_head_channels
+ self.mlp_ratio = mlp_ratio
+ self.attn_mode = attn_mode
+ self.pe_mode = pe_mode
+ self.use_fp16 = use_fp16
+ self.use_checkpoint = use_checkpoint
+ self.qk_rms_norm = qk_rms_norm
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+
+ if pe_mode == "ape":
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
+ self.blocks = nn.ModuleList([
+ SparseTransformerBlock(
+ model_channels,
+ num_heads=self.num_heads,
+ mlp_ratio=self.mlp_ratio,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_sequence=shift_sequence,
+ shift_window=shift_window,
+ serialize_mode=serialize_mode,
+ use_checkpoint=self.use_checkpoint,
+ use_rope=(pe_mode == "rope"),
+ qk_rms_norm=self.qk_rms_norm,
+ )
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
+ ])
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ # self.blocks.apply(convert_module_to_f16)
+ self.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self) -> None:
+ """
+ Convert the torso of the model to float32.
+ """
+ self.blocks.apply(convert_module_to_f32)
+
+ def initialize_weights(self) -> None:
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ def forward(self, x: sp.SparseTensor, factor: float = None) -> sp.SparseTensor:
+ h = self.input_layer(x)
+ if self.pe_mode == "ape":
+ h = h + self.pos_embedder(x.coords[:, 1:], factor)
+ h = h.type(self.dtype)
+ for block in self.blocks:
+ h = block(h)
+ return h
\ No newline at end of file
diff --git a/direct3d_s2/models/autoencoders/decoder.py b/direct3d_s2/models/autoencoders/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..209cccf24f80dfcd962610c31811dadfeaf50812
--- /dev/null
+++ b/direct3d_s2/models/autoencoders/decoder.py
@@ -0,0 +1,353 @@
+from typing import *
+import random
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
+from ...modules import sparse as sp
+from .base import SparseTransformerBase
+
+
+class SparseSubdivideBlock3d(nn.Module):
+
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ use_checkpoint: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_checkpoint = use_checkpoint
+
+ self.act_layers = nn.Sequential(
+ sp.SparseConv3d(channels, self.out_channels, 3, padding=1),
+ sp.SparseSiLU()
+ )
+
+ self.sub = sp.SparseSubdivide()
+
+ self.out_layers = nn.Sequential(
+ sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1),
+ sp.SparseSiLU(),
+ )
+
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ h = self.act_layers(x)
+ h = self.sub(h)
+ h = self.out_layers(h)
+ return h
+
+ def forward(self, x: torch.Tensor):
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+ else:
+ return self._forward(x)
+
+
+class SparseSDFDecoder(SparseTransformerBase):
+ def __init__(
+ self,
+ resolution: int,
+ model_channels: int,
+ latent_channels: int,
+ num_blocks: int,
+ num_heads: Optional[int] = None,
+ num_head_channels: Optional[int] = 64,
+ mlp_ratio: float = 4,
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
+ window_size: int = 8,
+ pe_mode: Literal["ape", "rope"] = "ape",
+ use_fp16: bool = False,
+ use_checkpoint: bool = False,
+ qk_rms_norm: bool = False,
+ representation_config: dict = None,
+ out_channels: int = 1,
+ chunk_size: int = 1,
+ ):
+ super().__init__(
+ in_channels=latent_channels,
+ model_channels=model_channels,
+ num_blocks=num_blocks,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ mlp_ratio=mlp_ratio,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ pe_mode=pe_mode,
+ use_fp16=use_fp16,
+ use_checkpoint=use_checkpoint,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.resolution = resolution
+ self.rep_config = representation_config
+ self.out_channels = out_channels
+ self.chunk_size = chunk_size
+ self.upsample = nn.ModuleList([
+ SparseSubdivideBlock3d(
+ channels=model_channels,
+ out_channels=model_channels // 4,
+ use_checkpoint=use_checkpoint,
+ ),
+ SparseSubdivideBlock3d(
+ channels=model_channels // 4,
+ out_channels=model_channels // 8,
+ use_checkpoint=use_checkpoint,
+ ),
+ SparseSubdivideBlock3d(
+ channels=model_channels // 8,
+ out_channels=model_channels // 16,
+ use_checkpoint=use_checkpoint,
+ )
+ ])
+
+ self.out_layer = sp.SparseLinear(model_channels // 16, self.out_channels)
+ self.out_active = sp.SparseTanh()
+
+ self.initialize_weights()
+ if use_fp16:
+ self.convert_to_fp16()
+
+ def initialize_weights(self) -> None:
+ super().initialize_weights()
+ # Zero-out output layers:
+ nn.init.constant_(self.out_layer.weight, 0)
+ nn.init.constant_(self.out_layer.bias, 0)
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ super().convert_to_fp16()
+ self.upsample.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self) -> None:
+ """
+ Convert the torso of the model to float32.
+ """
+ super().convert_to_fp32()
+ self.upsample.apply(convert_module_to_f32)
+
+ @torch.no_grad()
+ def split_for_meshing(self, x: sp.SparseTensor, chunk_size=4, padding=4):
+
+ sub_resolution = self.resolution // chunk_size
+ upsample_ratio = 8 # hard-coded here
+ assert sub_resolution % padding == 0
+ out = []
+
+ for i in range(chunk_size):
+ for j in range(chunk_size):
+ for k in range(chunk_size):
+ # Calculate padded boundaries
+ start_x = max(0, i * sub_resolution - padding)
+ end_x = min((i + 1) * sub_resolution + padding, self.resolution)
+ start_y = max(0, j * sub_resolution - padding)
+ end_y = min((j + 1) * sub_resolution + padding, self.resolution)
+ start_z = max(0, k * sub_resolution - padding)
+ end_z = min((k + 1) * sub_resolution + padding, self.resolution)
+
+ # Store original (unpadded) boundaries for later cropping
+ orig_start_x = i * sub_resolution
+ orig_end_x = (i + 1) * sub_resolution
+ orig_start_y = j * sub_resolution
+ orig_end_y = (j + 1) * sub_resolution
+ orig_start_z = k * sub_resolution
+ orig_end_z = (k + 1) * sub_resolution
+
+ mask = torch.logical_and(
+ torch.logical_and(
+ torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x),
+ torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y)
+ ),
+ torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z)
+ )
+
+ if mask.sum() > 0:
+ # Get the coordinates and shift them to local space
+ coords = x.coords[mask].clone()
+ # Shift to local coordinates
+ coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z],
+ device=coords.device).view(1, 3)
+
+ chunk_tensor = sp.SparseTensor(x.feats[mask], coords)
+ # Store the boundaries and offsets as metadata for later reconstruction
+ chunk_tensor.bounds = {
+ 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)),
+ 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction
+ }
+ out.append(chunk_tensor)
+
+ del mask
+ torch.cuda.empty_cache()
+ return out
+
+ @torch.no_grad()
+ def split_single_chunk(self, x: sp.SparseTensor, chunk_size=4, padding=4):
+ sub_resolution = self.resolution // chunk_size
+ upsample_ratio = 8 # hard-coded here
+ assert sub_resolution % padding == 0
+
+ mask_sum = -1
+ while mask_sum < 1:
+ orig_start_x = random.randint(0, self.resolution - sub_resolution)
+ orig_end_x = orig_start_x + sub_resolution
+ orig_start_y = random.randint(0, self.resolution - sub_resolution)
+ orig_end_y = orig_start_y + sub_resolution
+ orig_start_z = random.randint(0, self.resolution - sub_resolution)
+ orig_end_z = orig_start_z + sub_resolution
+ start_x = max(0, orig_start_x - padding)
+ end_x = min(orig_end_x + padding, self.resolution)
+ start_y = max(0, orig_start_y - padding)
+ end_y = min(orig_end_y + padding, self.resolution)
+ start_z = max(0, orig_start_z - padding)
+ end_z = min(orig_end_z + padding, self.resolution)
+
+ mask_ori = torch.logical_and(
+ torch.logical_and(
+ torch.logical_and(x.coords[:, 1] >= orig_start_x, x.coords[:, 1] < orig_end_x),
+ torch.logical_and(x.coords[:, 2] >= orig_start_y, x.coords[:, 2] < orig_end_y)
+ ),
+ torch.logical_and(x.coords[:, 3] >= orig_start_z, x.coords[:, 3] < orig_end_z)
+ )
+ mask_sum = mask_ori.sum()
+
+ # Store the boundaries and offsets as metadata for later reconstruction
+ bounds = {
+ 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)),
+ 'start': (start_x, end_x, start_y, end_y, start_z, end_z),
+ 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction
+ }
+ return bounds
+
+ def forward_single_chunk(self, x: sp.SparseTensor, padding=4):
+
+ bounds = self.split_single_chunk(x, self.chunk_size, padding=padding)
+
+ start_x, end_x, start_y, end_y, start_z, end_z = bounds['start']
+ mask = torch.logical_and(
+ torch.logical_and(
+ torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x),
+ torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y)
+ ),
+ torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z)
+ )
+
+ # Shift to local coordinates
+ coords = x.coords.clone()
+ coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z],
+ device=coords.device).view(1, 3)
+
+ chunk = sp.SparseTensor(x.feats[mask], coords[mask])
+
+ chunk_result = self.upsamples(chunk)
+
+ coords = chunk_result.coords.clone()
+
+ # Restore global coordinates
+ offsets = torch.tensor(bounds['offsets'],
+ device=coords.device).view(1, 3)
+ coords[:, 1:] = coords[:, 1:] + offsets
+
+ # Filter points within original bounds
+ original = bounds['original']
+ within_bounds = torch.logical_and(
+ torch.logical_and(
+ torch.logical_and(
+ coords[:, 1] >= original[0],
+ coords[:, 1] < original[1]
+ ),
+ torch.logical_and(
+ coords[:, 2] >= original[2],
+ coords[:, 2] < original[3]
+ )
+ ),
+ torch.logical_and(
+ coords[:, 3] >= original[4],
+ coords[:, 3] < original[5]
+ )
+ )
+
+ final_coords = coords[within_bounds]
+ final_feats = chunk_result.feats[within_bounds]
+
+ return sp.SparseTensor(final_feats, final_coords)
+
+ def upsamples(self, x, return_feat: bool = False):
+ dtype = x.dtype
+ for block in self.upsample:
+ x = block(x)
+ x = x.type(dtype)
+
+ output = self.out_active(self.out_layer(x))
+
+ if return_feat:
+ return output, x
+ else:
+ return output
+
+ def forward(self, x: sp.SparseTensor, factor: float = None, return_feat: bool = False):
+ h = super().forward(x, factor)
+ if self.chunk_size <= 1:
+ for block in self.upsample:
+ h = block(h)
+ h = h.type(x.dtype)
+
+ if return_feat:
+ return self.out_active(self.out_layer(h)), h
+
+ h = self.out_layer(h)
+ h = self.out_active(h)
+ return h
+ else:
+ if self.training:
+ return self.forward_single_chunk(h)
+ else:
+ batch_size = x.shape[0]
+ chunks = self.split_for_meshing(h, chunk_size=self.chunk_size)
+ all_coords, all_feats = [], []
+ for chunk_idx, chunk in enumerate(chunks):
+ chunk_result = self.upsamples(chunk)
+
+ for b in range(batch_size):
+ mask = torch.nonzero(chunk_result.coords[:, 0] == b).squeeze(-1)
+ if mask.numel() > 0:
+ coords = chunk_result.coords[mask].clone()
+
+ # Restore global coordinates
+ offsets = torch.tensor(chunk.bounds['offsets'],
+ device=coords.device).view(1, 3)
+ coords[:, 1:] = coords[:, 1:] + offsets
+
+ # Filter points within original bounds
+ bounds = chunk.bounds['original']
+ within_bounds = torch.logical_and(
+ torch.logical_and(
+ torch.logical_and(
+ coords[:, 1] >= bounds[0],
+ coords[:, 1] < bounds[1]
+ ),
+ torch.logical_and(
+ coords[:, 2] >= bounds[2],
+ coords[:, 2] < bounds[3]
+ )
+ ),
+ torch.logical_and(
+ coords[:, 3] >= bounds[4],
+ coords[:, 3] < bounds[5]
+ )
+ )
+
+ if within_bounds.any():
+ all_coords.append(coords[within_bounds])
+ all_feats.append(chunk_result.feats[mask][within_bounds])
+
+ if not self.training:
+ torch.cuda.empty_cache()
+
+ final_coords = torch.cat(all_coords)
+ final_feats = torch.cat(all_feats)
+
+ return sp.SparseTensor(final_feats, final_coords)
+
\ No newline at end of file
diff --git a/direct3d_s2/models/autoencoders/dense_vae.py b/direct3d_s2/models/autoencoders/dense_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7aa0cfdc8a307b65c3ed2746ee2e4a5f4537c45
--- /dev/null
+++ b/direct3d_s2/models/autoencoders/dense_vae.py
@@ -0,0 +1,401 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import trimesh
+from skimage import measure
+from ...modules.norm import GroupNorm32, ChannelLayerNorm32
+from ...modules.spatial import pixel_shuffle_3d
+from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
+from .distributions import DiagonalGaussianDistribution
+
+
+def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
+ """
+ Return a normalization layer.
+ """
+ if norm_type == "group":
+ return GroupNorm32(32, *args, **kwargs)
+ elif norm_type == "layer":
+ return ChannelLayerNorm32(*args, **kwargs)
+ else:
+ raise ValueError(f"Invalid norm type {norm_type}")
+
+
+class ResBlock3d(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ norm_type: Literal["group", "layer"] = "layer",
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.norm1 = norm_layer(norm_type, channels)
+ self.norm2 = norm_layer(norm_type, self.out_channels)
+ self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
+ self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.norm1(x)
+ h = F.silu(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = F.silu(h)
+ h = self.conv2(h)
+ h = h + self.skip_connection(x)
+ return h
+
+
+class DownsampleBlock3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ mode: Literal["conv", "avgpool"] = "conv",
+ ):
+ assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
+
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ if mode == "conv":
+ self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
+ elif mode == "avgpool":
+ assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if hasattr(self, "conv"):
+ return self.conv(x)
+ else:
+ return F.avg_pool3d(x, 2)
+
+
+class UpsampleBlock3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ mode: Literal["conv", "nearest"] = "conv",
+ ):
+ assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
+
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ if mode == "conv":
+ self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
+ elif mode == "nearest":
+ assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if hasattr(self, "conv"):
+ x = self.conv(x)
+ return pixel_shuffle_3d(x, 2)
+ else:
+ return F.interpolate(x, scale_factor=2, mode="nearest")
+
+
+class SparseStructureEncoder(nn.Module):
+ """
+ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
+
+ Args:
+ in_channels (int): Channels of the input.
+ latent_channels (int): Channels of the latent representation.
+ num_res_blocks (int): Number of residual blocks at each resolution.
+ channels (List[int]): Channels of the encoder blocks.
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
+ use_fp16 (bool): Whether to use FP16.
+ """
+ def __init__(
+ self,
+ in_channels: int,
+ latent_channels: int,
+ num_res_blocks: int,
+ channels: List[int],
+ num_res_blocks_middle: int = 2,
+ norm_type: Literal["group", "layer"] = "layer",
+ use_fp16: bool = False,
+ use_checkpoint: bool = False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.latent_channels = latent_channels
+ self.num_res_blocks = num_res_blocks
+ self.channels = channels
+ self.num_res_blocks_middle = num_res_blocks_middle
+ self.norm_type = norm_type
+ self.use_fp16 = use_fp16
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ self.use_checkpoint = use_checkpoint
+
+ self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+ for i, ch in enumerate(channels):
+ self.blocks.extend([
+ ResBlock3d(ch, ch)
+ for _ in range(num_res_blocks)
+ ])
+ if i < len(channels) - 1:
+ self.blocks.append(
+ DownsampleBlock3d(ch, channels[i+1])
+ )
+
+ self.middle_block = nn.Sequential(*[
+ ResBlock3d(channels[-1], channels[-1])
+ for _ in range(num_res_blocks_middle)
+ ])
+
+ self.out_layer = nn.Sequential(
+ norm_layer(norm_type, channels[-1]),
+ nn.SiLU(),
+ nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
+ )
+
+ if use_fp16:
+ self.convert_to_fp16()
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ self.use_fp16 = True
+ self.dtype = torch.float16
+ self.blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self) -> None:
+ """
+ Convert the torso of the model to float32.
+ """
+ self.use_fp16 = False
+ self.dtype = torch.float32
+ self.blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.input_layer(x)
+
+ for block in self.blocks:
+ h = block(h)
+ h = self.middle_block(h)
+
+ h = self.out_layer(h)
+
+ return h
+
+
+class SparseStructureDecoder(nn.Module):
+ """
+ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
+
+ Args:
+ out_channels (int): Channels of the output.
+ latent_channels (int): Channels of the latent representation.
+ num_res_blocks (int): Number of residual blocks at each resolution.
+ channels (List[int]): Channels of the decoder blocks.
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
+ use_fp16 (bool): Whether to use FP16.
+ """
+ def __init__(
+ self,
+ out_channels: int,
+ latent_channels: int,
+ num_res_blocks: int,
+ channels: List[int],
+ num_res_blocks_middle: int = 2,
+ norm_type: Literal["group", "layer"] = "layer",
+ use_fp16: bool = False,
+ use_checkpoint: bool = False,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.latent_channels = latent_channels
+ self.num_res_blocks = num_res_blocks
+ self.channels = channels
+ self.num_res_blocks_middle = num_res_blocks_middle
+ self.norm_type = norm_type
+ self.use_fp16 = use_fp16
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ self.use_checkpoint = use_checkpoint
+
+ self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
+
+ self.middle_block = nn.Sequential(*[
+ ResBlock3d(channels[0], channels[0])
+ for _ in range(num_res_blocks_middle)
+ ])
+
+ self.blocks = nn.ModuleList([])
+ for i, ch in enumerate(channels):
+ self.blocks.extend([
+ ResBlock3d(ch, ch)
+ for _ in range(num_res_blocks)
+ ])
+ if i < len(channels) - 1:
+ self.blocks.append(
+ UpsampleBlock3d(ch, channels[i+1])
+ )
+
+ self.out_layer = nn.Sequential(
+ norm_layer(norm_type, channels[-1]),
+ nn.SiLU(),
+ nn.Conv3d(channels[-1], out_channels, 3, padding=1)
+ )
+
+ if use_fp16:
+ self.convert_to_fp16()
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ self.use_fp16 = True
+ self.dtype = torch.float16
+ # self.blocks.apply(convert_module_to_f16)
+ # self.middle_block.apply(convert_module_to_f16)
+ self.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self) -> None:
+ """
+ Convert the torso of the model to float32.
+ """
+ self.use_fp16 = False
+ self.dtype = torch.float32
+ self.blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.input_layer(x)
+
+ h = self.middle_block(h)
+ for block in self.blocks:
+ h = block(h)
+
+ h = self.out_layer(h)
+ return h
+
+
+class DenseShapeVAE(nn.Module):
+ def __init__(self,
+ embed_dim: int = 0,
+ model_channels_encoder: list = [32, 128, 512],
+ model_channels_decoder: list = [512, 128, 32],
+ num_res_blocks_encoder: int = 2,
+ num_res_blocks_middle_encoder: int = 2,
+ num_res_blocks_decoder: int = 2,
+ num_res_blocks_middle_decoder: int=2,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ use_fp16: bool = False,
+ use_checkpoint: bool = False,
+ latents_scale: float = 1.0,
+ latents_shift: float = 0.0):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+ self.latents_scale = latents_scale
+ self.latents_shift = latents_shift
+
+ self.encoder = SparseStructureEncoder(
+ in_channels=in_channels,
+ latent_channels=embed_dim,
+ num_res_blocks=num_res_blocks_encoder,
+ channels=model_channels_encoder,
+ num_res_blocks_middle=num_res_blocks_middle_encoder,
+ use_fp16=use_fp16,
+ use_checkpoint=use_checkpoint,
+ )
+
+ self.decoder = SparseStructureDecoder(
+ num_res_blocks=num_res_blocks_decoder,
+ num_res_blocks_middle=num_res_blocks_middle_decoder,
+ channels=model_channels_decoder,
+ latent_channels=embed_dim,
+ out_channels=out_channels,
+ use_fp16=use_fp16,
+ use_checkpoint=use_checkpoint,
+ )
+
+ self.embed_dim = embed_dim
+
+ def encode(self, batch, sample_posterior: bool = True):
+
+ x = batch['dense_index'] * 2.0 - 1.0
+ h = self.encoder(x)
+ posterior = DiagonalGaussianDistribution(h, feat_dim=1)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+
+ return z, posterior
+
+ def forward(self, batch):
+
+ z, posterior = self.encode(batch)
+ reconst_x = self.decoder(z)
+ outputs = {'reconst_x': reconst_x, 'posterior': posterior}
+
+ return outputs
+
+ def decode_mesh(self,
+ latents,
+ voxel_resolution: int = 64,
+ mc_threshold: float = 0.5,
+ return_index: bool = False):
+ x = self.decoder(latents)
+ if return_index:
+ outputs = []
+ for i in range(len(x)):
+ occ = x[i].sigmoid()
+ occ = (occ >= mc_threshold).float().squeeze(0)
+ index = occ.unsqueeze(0).nonzero()
+ outputs.append(index)
+ else:
+ outputs = self.dense2mesh(x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold)
+
+ return outputs
+
+ def dense2mesh(self,
+ x: torch.FloatTensor,
+ voxel_resolution: int = 64,
+ mc_threshold: float = 0.5):
+
+ meshes = []
+ for i in range(len(x)):
+ occ = x[i].sigmoid()
+ occ = (occ >= 0.1).float().squeeze(0).cpu().detach().numpy()
+ vertices, faces, _, _ = measure.marching_cubes(
+ occ,
+ mc_threshold,
+ method="lewiner",
+ )
+ vertices = vertices / voxel_resolution * 2 - 1
+ meshes.append(trimesh.Trimesh(vertices, faces))
+
+ return meshes
diff --git a/direct3d_s2/models/autoencoders/distributions.py b/direct3d_s2/models/autoencoders/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..a702359a8d2dedcb28c2e654bb60221af7f72c8f
--- /dev/null
+++ b/direct3d_s2/models/autoencoders/distributions.py
@@ -0,0 +1,51 @@
+import torch
+import numpy as np
+from typing import Union, List
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
+ self.feat_dim = feat_dim
+ self.parameters = parameters
+
+ if isinstance(parameters, list):
+ self.mean = parameters[0]
+ self.logvar = parameters[1]
+ else:
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
+
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn_like(self.mean)
+ return x
+
+ def kl(self, other=None, dims=(1, 2, 3)):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.mean(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=dims)
+ else:
+ return 0.5 * torch.mean(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=dims)
+
+ def nll(self, sample, dims=(1, 2, 3)):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
diff --git a/direct3d_s2/models/autoencoders/encoder.py b/direct3d_s2/models/autoencoders/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..49fe567ff4dd1b62bae84c8627b27502ec4cdd09
--- /dev/null
+++ b/direct3d_s2/models/autoencoders/encoder.py
@@ -0,0 +1,133 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ...modules import sparse as sp
+from .base import SparseTransformerBase
+
+
+class SparseDownBlock3d(nn.Module):
+
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ num_groups: int = 32,
+ use_checkpoint: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.act_layers = nn.Sequential(
+ sp.SparseGroupNorm32(num_groups, channels),
+ sp.SparseSiLU()
+ )
+
+ self.down = sp.SparseDownsample(2)
+ self.out_layers = nn.Sequential(
+ sp.SparseConv3d(channels, self.out_channels, 3, padding=1),
+ sp.SparseGroupNorm32(num_groups, self.out_channels),
+ sp.SparseSiLU(),
+ sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ else:
+ self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1)
+
+ self.use_checkpoint = use_checkpoint
+
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ h = self.act_layers(x)
+ h = self.down(h)
+ x = self.down(x)
+ h = self.out_layers(h)
+ h = h + self.skip_connection(x)
+ return h
+
+ def forward(self, x: torch.Tensor):
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+ else:
+ return self._forward(x)
+
+
+class SparseSDFEncoder(SparseTransformerBase):
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ model_channels: int,
+ latent_channels: int,
+ num_blocks: int,
+ num_heads: Optional[int] = None,
+ num_head_channels: Optional[int] = 64,
+ mlp_ratio: float = 4,
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
+ window_size: int = 8,
+ pe_mode: Literal["ape", "rope"] = "ape",
+ use_fp16: bool = False,
+ use_checkpoint: bool = False,
+ qk_rms_norm: bool = False,
+ ):
+ super().__init__(
+ in_channels=in_channels,
+ model_channels=model_channels,
+ num_blocks=num_blocks,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ mlp_ratio=mlp_ratio,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ pe_mode=pe_mode,
+ use_fp16=use_fp16,
+ use_checkpoint=use_checkpoint,
+ qk_rms_norm=qk_rms_norm,
+ )
+
+ self.input_layer1 = sp.SparseLinear(1, model_channels // 16)
+
+ self.downsample = nn.ModuleList([
+ SparseDownBlock3d(
+ channels=model_channels//16,
+ out_channels=model_channels // 8,
+ use_checkpoint=use_checkpoint,
+ ),
+ SparseDownBlock3d(
+ channels=model_channels // 8,
+ out_channels=model_channels // 4,
+ use_checkpoint=use_checkpoint,
+ ),
+ SparseDownBlock3d(
+ channels=model_channels // 4,
+ out_channels=model_channels,
+ use_checkpoint=use_checkpoint,
+ )
+ ])
+
+ self.resolution = resolution
+ self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
+
+ self.initialize_weights()
+ if use_fp16:
+ self.convert_to_fp16()
+
+ def initialize_weights(self) -> None:
+ super().initialize_weights()
+ # Zero-out output layers:
+ nn.init.constant_(self.out_layer.weight, 0)
+ nn.init.constant_(self.out_layer.bias, 0)
+
+ def forward(self, x: sp.SparseTensor, factor: float = None):
+
+ x = self.input_layer1(x)
+ for block in self.downsample:
+ x = block(x)
+ h = super().forward(x, factor)
+ h = h.type(x.dtype)
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
+ h = self.out_layer(h)
+
+ return h
\ No newline at end of file
diff --git a/direct3d_s2/models/autoencoders/ss_vae.py b/direct3d_s2/models/autoencoders/ss_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..b32fd20767df4f90d7d5bc55b2b393676e8e47a3
--- /dev/null
+++ b/direct3d_s2/models/autoencoders/ss_vae.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import trimesh
+from skimage import measure
+
+from ...modules import sparse as sp
+from .encoder import SparseSDFEncoder
+from .decoder import SparseSDFDecoder
+from .distributions import DiagonalGaussianDistribution
+
+
+class SparseSDFVAE(nn.Module):
+ def __init__(self, *,
+ embed_dim: int = 0,
+ resolution: int = 64,
+ model_channels_encoder: int = 512,
+ num_blocks_encoder: int = 4,
+ num_heads_encoder: int = 8,
+ num_head_channels_encoder: int = 64,
+ model_channels_decoder: int = 512,
+ num_blocks_decoder: int = 4,
+ num_heads_decoder: int = 8,
+ num_head_channels_decoder: int = 64,
+ out_channels: int = 1,
+ use_fp16: bool = False,
+ use_checkpoint: bool = False,
+ chunk_size: int = 1,
+ latents_scale: float = 1.0,
+ latents_shift: float = 0.0):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+ self.resolution = resolution
+ self.latents_scale = latents_scale
+ self.latents_shift = latents_shift
+
+ self.encoder = SparseSDFEncoder(
+ resolution=resolution,
+ in_channels=model_channels_encoder,
+ model_channels=model_channels_encoder,
+ latent_channels=embed_dim,
+ num_blocks=num_blocks_encoder,
+ num_heads=num_heads_encoder,
+ num_head_channels=num_head_channels_encoder,
+ use_fp16=use_fp16,
+ use_checkpoint=use_checkpoint,
+ )
+
+ self.decoder = SparseSDFDecoder(
+ resolution=resolution,
+ model_channels=model_channels_decoder,
+ latent_channels=embed_dim,
+ num_blocks=num_blocks_decoder,
+ num_heads=num_heads_decoder,
+ num_head_channels=num_head_channels_decoder,
+ out_channels=out_channels,
+ use_fp16=use_fp16,
+ use_checkpoint=use_checkpoint,
+ chunk_size=chunk_size,
+ )
+ self.embed_dim = embed_dim
+
+ def forward(self, batch):
+
+ z, posterior = self.encode(batch)
+
+ reconst_x = self.decoder(z)
+ outputs = {'reconst_x': reconst_x, 'posterior': posterior}
+ return outputs
+
+ def encode(self, batch, sample_posterior: bool = True):
+
+ feat, xyz, batch_idx = batch['sparse_sdf'], batch['sparse_index'], batch['batch_idx']
+ if feat.ndim == 1:
+ feat = feat.unsqueeze(-1)
+ coords = torch.cat([batch_idx.unsqueeze(-1), xyz], dim=-1).int()
+
+ x = sp.SparseTensor(feat, coords)
+ h = self.encoder(x, batch.get('factor', None))
+ posterior = DiagonalGaussianDistribution(h.feats, feat_dim=1)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ z = h.replace(z)
+
+ return z, posterior
+
+ def decode_mesh(self,
+ latents,
+ voxel_resolution: int = 512,
+ mc_threshold: float = 0.2,
+ return_feat: bool = False,
+ factor: float = 1.0):
+ voxel_resolution = int(voxel_resolution / factor)
+ reconst_x = self.decoder(latents, factor=factor, return_feat=return_feat)
+ if return_feat:
+ return reconst_x
+ outputs = self.sparse2mesh(reconst_x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold)
+
+ return outputs
+
+ def sparse2mesh(self,
+ reconst_x: torch.FloatTensor,
+ voxel_resolution: int = 512,
+ mc_threshold: float = 0.0):
+
+ sparse_sdf, sparse_index = reconst_x.feats.float(), reconst_x.coords
+ batch_size = int(sparse_index[..., 0].max().cpu().numpy() + 1)
+
+ meshes = []
+ for i in range(batch_size):
+ idx = sparse_index[..., 0] == i
+ sparse_sdf_i, sparse_index_i = sparse_sdf[idx].squeeze(-1).cpu(), sparse_index[idx][..., 1:].detach().cpu()
+ sdf = torch.ones((voxel_resolution, voxel_resolution, voxel_resolution))
+ sdf[sparse_index_i[..., 0], sparse_index_i[..., 1], sparse_index_i[..., 2]] = sparse_sdf_i
+ vertices, faces, _, _ = measure.marching_cubes(
+ sdf.numpy(),
+ mc_threshold,
+ method="lewiner",
+ )
+ vertices = vertices / voxel_resolution * 2 - 1
+ meshes.append(trimesh.Trimesh(vertices, faces))
+
+ return meshes
diff --git a/direct3d_s2/models/conditioner.py b/direct3d_s2/models/conditioner.py
new file mode 100644
index 0000000000000000000000000000000000000000..55bcee7c5a737c8a5bfc54b34eb84b0da83b54b0
--- /dev/null
+++ b/direct3d_s2/models/conditioner.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import transforms
+
+
+class DinoEncoder(nn.Module):
+
+ def __init__(
+ self,
+ model="facebookresearch/dinov2",
+ version="dinov2_vitl14_reg",
+ size=518,
+ ):
+ super().__init__()
+
+ dino_model = torch.hub.load(model, version, pretrained=True)
+ dino_model = dino_model.eval()
+ self.encoder = dino_model
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize(size, transforms.InterpolationMode.BILINEAR, antialias=True),
+ transforms.CenterCrop(size),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225],
+ ),
+ ]
+ )
+
+ def forward(self, image, image_mask=None):
+
+ z = self.encoder(self.transform(image), is_training=True)['x_prenorm']
+ z = F.layer_norm(z, z.shape[-1:])
+
+ if image_mask is not None:
+ image_mask_patch = F.max_pool2d(image_mask, kernel_size=14, stride=14).squeeze(1) > 0
+ return z, image_mask_patch
+
+ return z
diff --git a/direct3d_s2/models/refiner/__pycache__/unet3d.cpython-310.pyc b/direct3d_s2/models/refiner/__pycache__/unet3d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e2fe6c6d663e72c4d7838040802d2f39100969b
Binary files /dev/null and b/direct3d_s2/models/refiner/__pycache__/unet3d.cpython-310.pyc differ
diff --git a/direct3d_s2/models/refiner/__pycache__/unet_refiner.cpython-310.pyc b/direct3d_s2/models/refiner/__pycache__/unet_refiner.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7f2a773caa62efda28f97306284bb4536aa4bbb
Binary files /dev/null and b/direct3d_s2/models/refiner/__pycache__/unet_refiner.cpython-310.pyc differ
diff --git a/direct3d_s2/models/refiner/unet3d.py b/direct3d_s2/models/refiner/unet3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c7228576fec4fc1cf8f0688b44cef22ab16fb4
--- /dev/null
+++ b/direct3d_s2/models/refiner/unet3d.py
@@ -0,0 +1,640 @@
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+ACTIVATION_FUNCTIONS = {
+ "swish": nn.SiLU(),
+ "silu": nn.SiLU(),
+ "mish": nn.Mish(),
+ "gelu": nn.GELU(),
+ "relu": nn.ReLU(),
+}
+
+def get_activation(act_fn: str) -> nn.Module:
+
+ act_fn = act_fn.lower()
+ if act_fn in ACTIVATION_FUNCTIONS:
+ return ACTIVATION_FUNCTIONS[act_fn]
+ else:
+ raise ValueError(f"Unsupported activation function: {act_fn}")
+
+def get_down_block(
+ down_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ add_downsample: bool,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ resnet_groups: Optional[int] = None,
+ downsample_padding: Optional[int] = None,
+ dropout: float = 0.0,
+) -> Union[
+ "DownBlock3D",
+]:
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ dropout=dropout,
+ )
+
+ raise ValueError(f"{down_block_type} does not exist.")
+
+def get_up_block(
+ up_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ add_upsample: bool,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ resnet_groups: Optional[int] = None,
+ dropout: float = 0.0,
+) -> Union[
+ "UpBlock3D",
+]:
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ dropout=dropout,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+class Downsample3D(nn.Module):
+
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ kernel_size=2,
+ bias=True,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ stride = 2
+
+ self.conv = nn.Conv3d(
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
+ )
+
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ assert hidden_states.shape[1] == self.channels
+
+ assert hidden_states.shape[1] == self.channels
+
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+class Upsample3D(nn.Module):
+
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ use_conv_transpose: bool = True,
+ out_channels: Optional[int] = None,
+ name: str = "conv",
+ kernel_size: Optional[int] = None,
+ padding=1,
+ bias=True,
+ interpolate=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+ self.interpolate = interpolate
+
+ conv = None
+ if use_conv_transpose:
+ conv = nn.ConvTranspose3d(
+ channels, self.out_channels, kernel_size=2, stride=2, padding=0, bias=bias
+ )
+ elif use_conv:
+ if kernel_size is None:
+ kernel_size = 3
+ conv = nn.Conv3d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
+
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None) -> torch.Tensor:
+
+ assert hidden_states.shape[1] == self.channels
+
+
+ if self.use_conv_transpose:
+ return self.conv(hidden_states)
+
+ if hidden_states.shape[0] >= 64 or hidden_states.shape[-1] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ if self.interpolate:
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+class ResnetBlock3D(nn.Module):
+
+ def __init__(
+ self,
+ *,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ conv_shortcut: bool = False,
+ dropout: float = 0.0,
+ groups: int = 32,
+ groups_out: Optional[int] = None,
+ eps: float = 1e-6,
+ non_linearity: str = "swish",
+ output_scale_factor: float = 1.0,
+ use_in_shortcut: Optional[bool] = None,
+ up: bool = False,
+ down: bool = False,
+ conv_shortcut_bias: bool = True,
+ conv_2d_out_channels: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
+ self.conv2 = nn.Conv3d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.upsample = self.downsample = None
+ if self.up:
+ self.upsample = Upsample3D(in_channels, use_conv=False)
+ elif self.down:
+ self.downsample = Downsample3D(in_channels)
+
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = nn.Conv3d(
+ in_channels,
+ conv_2d_out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=conv_shortcut_bias,
+ )
+
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = input_tensor
+ dtype = hidden_states.dtype
+ hidden_states = self.norm1(hidden_states.float()).to(dtype)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ if hidden_states.shape[0] >= 64:
+ input_tensor = input_tensor.contiguous()
+ hidden_states = hidden_states.contiguous()
+ input_tensor = self.upsample(input_tensor)
+ hidden_states = self.upsample(hidden_states)
+ elif self.downsample is not None:
+ input_tensor = self.downsample(input_tensor)
+ hidden_states = self.downsample(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states.float()).to(dtype)
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels,
+ out_channels=out_channels
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
+ upsample_size: Optional[int] = None,
+ ) -> torch.Tensor:
+ for resnet in self.resnets:
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UNetMidBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 2,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ output_scale_factor: float = 1.0,
+ use_linear_projection: bool = True,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ ]
+
+ for _ in range(num_layers):
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states
+
+def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+class UNet3DModel(nn.Module):
+
+ def __init__(
+ self,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ use_conv_out: bool=True,
+ down_block_types: Tuple[str, ...] = (
+ "DownBlock3D",
+ "DownBlock3D",
+ "DownBlock3D",
+ "DownBlock3D",
+ ),
+ up_block_types: Tuple[str, ...] = (
+ "UpBlock3D",
+ "UpBlock3D",
+ "UpBlock3D",
+ "UpBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = (4, 16, 64,256),
+ layers_per_block: int = 4,
+ layers_mid_block: int=4,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 4,
+ norm_eps: float = 1e-5,
+ use_checkpoint: bool = True,
+ ):
+ super().__init__()
+
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ conv_in_kernel = 3
+ conv_out_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv3d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ downsample_padding=downsample_padding,
+ )
+ self.down_blocks.append(down_block)
+
+ self.mid_block = UNetMidBlock3D(
+ in_channels=block_out_channels[-1],
+ num_layers=layers_mid_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_groups=norm_num_groups,
+ )
+
+ self.num_upsamplers = 0
+
+ reversed_block_out_channels = list(reversed(block_out_channels))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = get_activation("silu")
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ if use_conv_out:
+ self.conv_out = nn.Conv3d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+ else:
+ self.conv_out = None
+
+ self.use_checkpoint = use_checkpoint
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ ) :
+
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ forward_upsample_size = True
+
+ sample = self.conv_in(sample)
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if self.use_checkpoint:
+ sample, res_samples = torch.utils.checkpoint.checkpoint(downsample_block, sample, use_reentrant=False)
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample)
+
+ down_block_res_samples += res_samples
+
+ if self.mid_block is not None:
+ if self.use_checkpoint:
+ sample = torch.utils.checkpoint.checkpoint(self.mid_block, sample, use_reentrant=False)
+ else:
+ sample = self.mid_block(sample)
+
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+ if self.use_checkpoint:
+ sample = torch.utils.checkpoint.checkpoint(upsample_block, (sample, res_samples, upsample_size), use_reentrant=False)
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ )
+
+ if self.conv_norm_out:
+ dtype = sample.dtype
+ sample = self.conv_norm_out(sample.float()).to(dtype)
+ sample = self.conv_act(sample)
+ if self.conv_out!=None:
+ sample = self.conv_out(sample)
+
+ return F.tanh(sample)*2
+ else:
+ return sample
diff --git a/direct3d_s2/models/refiner/unet_refiner.py b/direct3d_s2/models/refiner/unet_refiner.py
new file mode 100644
index 0000000000000000000000000000000000000000..056ddfc7d4a987e02160a73060f60d27893a1cd3
--- /dev/null
+++ b/direct3d_s2/models/refiner/unet_refiner.py
@@ -0,0 +1,203 @@
+# -*- coding: utf-8 -*-
+import itertools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .unet3d import UNet3DModel
+import trimesh
+from tqdm import tqdm
+from skimage import measure
+from ...modules.utils import convert_module_to_f16, convert_module_to_f32
+
+
+def adaptive_conv(inputs,weights):
+ padding = (1, 1, 1, 1, 1, 1)
+ padded_input = F.pad(inputs, padding, mode="constant", value=0)
+ output = torch.zeros_like(inputs)
+ size=inputs.shape[-1]
+ for i in range(3):
+ for j in range(3):
+ for k in range(3):
+ output=output+padded_input[:,:,i:i+size,j:j+size,k:k+size]*weights[:,i*9+j*3+k:i*9+j*3+k+1]
+ return output
+
+def adaptive_block(inputs,conv,weights_=None):
+ if weights_ != None:
+ weights = conv(weights_)
+ else:
+ weights = conv(inputs)
+ weights = F.normalize(weights, dim=1, p=1)
+ for i in range(3):
+ inputs = adaptive_conv(inputs, weights)
+ return inputs
+
+class GeoDecoder(nn.Module):
+
+ def __init__(self,
+ n_features: int,
+ hidden_dim: int = 32,
+ num_layers: int = 4,
+ use_sdf: bool = False,
+ activation: nn.Module = nn.ReLU):
+ super().__init__()
+ self.use_sdf=use_sdf
+ self.net = nn.Sequential(
+ nn.Linear(n_features, hidden_dim),
+ activation(),
+ *itertools.chain(*[[
+ nn.Linear(hidden_dim, hidden_dim),
+ activation(),
+ ] for _ in range(num_layers - 2)]),
+ nn.Linear(hidden_dim, 8),
+ )
+
+ # init all bias to zero
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ x = self.net(x)
+ return x
+
+
+class Voxel_RefinerXL(nn.Module):
+ def __init__(self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ layers_per_block: int = 2,
+ layers_mid_block: int = 2,
+ patch_size: int = 192,
+ res: int = 512,
+ use_checkpoint: bool=False,
+ use_fp16: bool = False):
+
+ super().__init__()
+
+ self.unet3d1 = UNet3DModel(in_channels=16, out_channels=8, use_conv_out=False,
+ layers_per_block=layers_per_block, layers_mid_block=layers_mid_block,
+ block_out_channels=(8, 32, 128,512), norm_num_groups=4, use_checkpoint=use_checkpoint)
+ self.conv_in = nn.Conv3d(in_channels, 8, kernel_size=3, padding=1)
+ self.latent_mlp = GeoDecoder(32)
+ self.adaptive_conv1 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1),
+ nn.ReLU(),
+ nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False))
+ self.adaptive_conv2 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1),
+ nn.ReLU(),
+ nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False))
+ self.adaptive_conv3 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1),
+ nn.ReLU(),
+ nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False))
+ self.mid_conv = nn.Conv3d(8, 8, kernel_size=3, padding=1)
+ self.conv_out = nn.Conv3d(8, out_channels, kernel_size=3, padding=1)
+ self.patch_size = patch_size
+ self.res = res
+
+ self.use_fp16 = use_fp16
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ if use_fp16:
+ self.convert_to_fp16()
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ # self.blocks.apply(convert_module_to_f16)
+ self.apply(convert_module_to_f16)
+
+ def run(self,
+ reconst_x,
+ feat,
+ mc_threshold=0,
+ ):
+ batch_size = int(reconst_x.coords[..., 0].max()) + 1
+ sparse_sdf, sparse_index = reconst_x.feats, reconst_x.coords
+ sparse_feat = feat.feats
+ device = sparse_sdf.device
+ dtype = sparse_sdf.dtype
+ res = self.res
+
+ sdfs = []
+ for i in range(batch_size):
+ idx = sparse_index[..., 0] == i
+ sparse_sdf_i, sparse_index_i = sparse_sdf[idx].squeeze(-1), sparse_index[idx][..., 1:]
+ sdf = torch.ones((res, res, res)).to(device).to(dtype)
+ sdf[sparse_index_i[..., 0], sparse_index_i[..., 1], sparse_index_i[..., 2]] = sparse_sdf_i
+ sdfs.append(sdf.unsqueeze(0))
+
+ sdfs = torch.stack(sdfs, dim=0)
+ feats = torch.zeros((batch_size, sparse_feat.shape[-1], res, res, res),
+ device=device, dtype=dtype)
+ feats[sparse_index[...,0],:,sparse_index[...,1],sparse_index[...,2],sparse_index[...,3]] = sparse_feat
+
+ N = sdfs.shape[0]
+ outputs = torch.ones([N,1,res,res,res], dtype=dtype, device=device)
+ stride = 160
+ patch_size = self.patch_size
+ step = 3
+ sdfs = sdfs.to(dtype)
+ feats = feats.to(dtype)
+ patchs=[]
+ for i in range(step):
+ for j in range(step):
+ for k in tqdm(range(step)):
+ sdf = sdfs[:, :, stride * i: stride * i + patch_size,
+ stride * j: stride * j + patch_size,
+ stride * k: stride * k + patch_size]
+ crop_feats = feats[:, :, stride * i: stride * i + patch_size,
+ stride * j: stride * j + patch_size,
+ stride * k: stride * k + patch_size]
+ inputs = self.conv_in(sdf)
+ crop_feats = self.latent_mlp(crop_feats.permute(0,2,3,4,1)).permute(0,4,1,2,3)
+ inputs = torch.cat([inputs, crop_feats],dim=1)
+ mid_feat = self.unet3d1(inputs)
+ mid_feat = adaptive_block(mid_feat, self.adaptive_conv1)
+ mid_feat = self.mid_conv(mid_feat)
+ mid_feat = adaptive_block(mid_feat, self.adaptive_conv2)
+ final_feat = self.conv_out(mid_feat)
+ final_feat = adaptive_block(final_feat, self.adaptive_conv3, weights_=mid_feat)
+ output = F.tanh(final_feat)
+ patchs.append(output)
+ weights = torch.linspace(0, 1, steps=32, device=device, dtype=dtype)
+ lines=[]
+ for i in range(9):
+ out1 = patchs[i * 3]
+ out2 = patchs[i * 3 + 1]
+ out3 = patchs[i * 3 + 2]
+ line = torch.ones([N, 1, 192, 192,res], dtype=dtype, device=device) * 2
+ line[:, :, :, :, :160] = out1[:, :, :, :, :160]
+ line[:, :, :, :, 192:320] = out2[:, :, :, :, 32:160]
+ line[:, :, :, :, 352:] = out3[:, :, :, :, 32:]
+
+ line[:,:,:,:,160:192] = out1[:,:,:,:,160:] * (1-weights.reshape(1,1,1,1,-1)) + out2[:,:,:,:,:32] * weights.reshape(1,1,1,1,-1)
+ line[:,:,:,:,320:352] = out2[:,:,:,:,160:] * (1-weights.reshape(1,1,1,1,-1)) + out3[:,:,:,:,:32] * weights.reshape(1,1,1,1,-1)
+ lines.append(line)
+ layers=[]
+ for i in range(3):
+ line1 = lines[i*3]
+ line2 = lines[i*3+1]
+ line3 = lines[i*3+2]
+ layer = torch.ones([N,1,192,res,res], device=device, dtype=dtype) * 2
+ layer[:,:,:,:160] = line1[:,:,:,:160]
+ layer[:,:,:,192:320] = line2[:,:,:,32:160]
+ layer[:,:,:,352:] = line3[:,:,:,32:]
+ layer[:,:,:,160:192] = line1[:,:,:,160:]*(1-weights.reshape(1,1,1,-1,1))+line2[:,:,:,:32]*weights.reshape(1,1,1,-1,1)
+ layer[:,:,:,320:352] = line2[:,:,:,160:]*(1-weights.reshape(1,1,1,-1,1))+line3[:,:,:,:32]*weights.reshape(1,1,1,-1,1)
+ layers.append(layer)
+ outputs[:,:,:160] = layers[0][:,:,:160]
+ outputs[:,:,192:320] = layers[1][:,:,32:160]
+ outputs[:,:,352:] = layers[2][:,:,32:]
+ outputs[:,:,160:192] = layers[0][:,:,160:]*(1-weights.reshape(1,1,-1,1,1))+layers[1][:,:,:32]*weights.reshape(1,1,-1,1,1)
+ outputs[:,:,320:352] = layers[1][:,:,160:]*(1-weights.reshape(1,1,-1,1,1))+layers[2][:,:,:32]*weights.reshape(1,1,-1,1,1)
+ # outputs = -outputs
+
+ meshes = []
+ for i in range(outputs.shape[0]):
+ vertices, faces, _, _ = measure.marching_cubes(outputs[i, 0].cpu().numpy(), level=mc_threshold, method='lewiner')
+ vertices = vertices / res * 2 - 1
+ meshes.append(trimesh.Trimesh(vertices, faces))
+
+ return meshes
+
+
diff --git a/direct3d_s2/models/transformers/__pycache__/dense_dit.cpython-310.pyc b/direct3d_s2/models/transformers/__pycache__/dense_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0299464b4e9143f416b13a762160da2a37543b00
Binary files /dev/null and b/direct3d_s2/models/transformers/__pycache__/dense_dit.cpython-310.pyc differ
diff --git a/direct3d_s2/models/transformers/__pycache__/sparse_dit.cpython-310.pyc b/direct3d_s2/models/transformers/__pycache__/sparse_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d2eb4b0aab5a6dc2b6e5ba116fe4f0a5810d771
Binary files /dev/null and b/direct3d_s2/models/transformers/__pycache__/sparse_dit.cpython-310.pyc differ
diff --git a/direct3d_s2/models/transformers/dense_dit.py b/direct3d_s2/models/transformers/dense_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..f19ef8511dc0d3fe5c561cffe01376e9eed713fc
--- /dev/null
+++ b/direct3d_s2/models/transformers/dense_dit.py
@@ -0,0 +1,203 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from ...modules.utils import convert_module_to_f16, convert_module_to_f32
+from ...modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
+from ...modules.spatial import patchify, unpatchify
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ Args:
+ t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ dim: the dimension of the output.
+ max_period: controls the minimum frequency of the embeddings.
+
+ Returns:
+ an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_freq = t_freq.to(self.mlp[0].weight.dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class DenseDiT(nn.Module):
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ model_channels: int,
+ cond_channels: int,
+ out_channels: int,
+ num_blocks: int,
+ num_heads: Optional[int] = None,
+ num_head_channels: Optional[int] = 64,
+ mlp_ratio: float = 4,
+ patch_size: int = 2,
+ pe_mode: Literal["ape", "rope"] = "ape",
+ use_fp16: bool = False,
+ use_checkpoint: bool = False,
+ share_mod: bool = False,
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ latent_shape: list = [8, 16, 16, 16],
+ ):
+ super().__init__()
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.cond_channels = cond_channels
+ self.out_channels = out_channels
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads or model_channels // num_head_channels
+ self.mlp_ratio = mlp_ratio
+ self.patch_size = patch_size
+ self.pe_mode = pe_mode
+ self.use_fp16 = use_fp16
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.qk_rms_norm = qk_rms_norm
+ self.qk_rms_norm_cross = qk_rms_norm_cross
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ self.latent_shape = latent_shape
+
+ self.t_embedder = TimestepEmbedder(model_channels)
+ if share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
+ )
+
+ if pe_mode == "ape":
+ pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
+ pos_emb = pos_embedder(coords)
+ self.register_buffer("pos_emb", pos_emb)
+
+ self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
+
+ self.blocks = nn.ModuleList([
+ ModulatedTransformerCrossBlock(
+ model_channels,
+ cond_channels,
+ num_heads=self.num_heads,
+ mlp_ratio=self.mlp_ratio,
+ attn_mode='full',
+ use_checkpoint=self.use_checkpoint,
+ use_rope=(pe_mode == "rope"),
+ share_mod=share_mod,
+ qk_rms_norm=self.qk_rms_norm,
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
+ )
+ for _ in range(num_blocks)
+ ])
+
+ self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
+
+ self.initialize_weights()
+ if use_fp16:
+ self.convert_to_fp16()
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ # self.blocks.apply(convert_module_to_f16)
+ self.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self) -> None:
+ """
+ Convert the torso of the model to float32.
+ """
+ self.blocks.apply(convert_module_to_f32)
+
+ def initialize_weights(self) -> None:
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ if self.share_mod:
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
+ else:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.out_layer.weight, 0)
+ nn.init.constant_(self.out_layer.bias, 0)
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
+
+ h = patchify(x, self.patch_size)
+ h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
+ h = self.input_layer(h)
+ h = h + self.pos_emb[None]
+ t_emb = self.t_embedder(t)
+ if self.share_mod:
+ t_emb = self.adaLN_modulation(t_emb)
+ t_emb = t_emb.type(self.dtype)
+ h = h.type(self.dtype)
+ cond = cond.type(self.dtype)
+ for block in self.blocks:
+ h = block(h, t_emb, cond)
+ h = h.type(x.dtype)
+ h = F.layer_norm(h, h.shape[-1:])
+ h = self.out_layer(h)
+
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
+ h = unpatchify(h, self.patch_size).contiguous()
+
+ return h
\ No newline at end of file
diff --git a/direct3d_s2/models/transformers/sparse_dit.py b/direct3d_s2/models/transformers/sparse_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bb6c2e32b54214b621581a268273c0960a9c940
--- /dev/null
+++ b/direct3d_s2/models/transformers/sparse_dit.py
@@ -0,0 +1,171 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
+from ...modules.transformer import AbsolutePositionEmbedder
+from ...modules import sparse as sp
+from ...modules.sparse.transformer.modulated import ModulatedSparseTransformerCrossBlock
+from .dense_dit import TimestepEmbedder
+
+
+class SparseDiT(nn.Module):
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ model_channels: int,
+ cond_channels: int,
+ out_channels: int,
+ num_blocks: int,
+ num_heads: Optional[int] = None,
+ num_head_channels: Optional[int] = 64,
+ num_kv_heads: Optional[int] = 2,
+ compression_block_size: int = 4,
+ selection_block_size: int = 8,
+ topk: int = 8,
+ compression_version: str = 'v2',
+ mlp_ratio: float = 4,
+ pe_mode: Literal["ape", "rope"] = "ape",
+ use_fp16: bool = False,
+ use_checkpoint: bool = False,
+ share_mod: bool = False,
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ sparse_conditions: bool = False,
+ factor: float = 1.0,
+ window_size: Optional[int] = 8,
+ use_shift: bool = True,
+ ):
+ super().__init__()
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.cond_channels = cond_channels
+ self.out_channels = out_channels
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads or model_channels // num_head_channels
+ self.mlp_ratio = mlp_ratio
+ self.pe_mode = pe_mode
+ self.use_fp16 = use_fp16
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.qk_rms_norm = qk_rms_norm
+ self.qk_rms_norm_cross = qk_rms_norm_cross
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ self.sparse_conditions = sparse_conditions
+ self.factor = factor
+ self.compression_block_size = compression_block_size
+ self.selection_block_size = selection_block_size
+
+ self.t_embedder = TimestepEmbedder(model_channels)
+ if share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
+ )
+
+ if sparse_conditions:
+ self.cond_proj = sp.SparseLinear(cond_channels, cond_channels)
+ self.pos_embedder_cond = AbsolutePositionEmbedder(model_channels, in_channels=3)
+
+ if pe_mode == "ape":
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
+
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
+
+ self.blocks = nn.ModuleList([
+ ModulatedSparseTransformerCrossBlock(
+ model_channels,
+ cond_channels,
+ num_heads=self.num_heads,
+ num_kv_heads=num_kv_heads,
+ compression_block_size=compression_block_size,
+ selection_block_size=selection_block_size,
+ topk=topk,
+ mlp_ratio=self.mlp_ratio,
+ attn_mode='full',
+ compression_version=compression_version,
+ use_checkpoint=self.use_checkpoint,
+ use_rope=(pe_mode == "rope"),
+ share_mod=self.share_mod,
+ qk_rms_norm=self.qk_rms_norm,
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
+ resolution=resolution,
+ window_size=window_size,
+ shift_window=window_size // 2 * (_ % 2) if use_shift else window_size // 2,
+ )
+ for _ in range(num_blocks)
+ ])
+
+ self.out_layer = sp.SparseLinear(model_channels, out_channels)
+
+ self.initialize_weights()
+ if use_fp16:
+ self.convert_to_fp16()
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ # self.blocks.apply(convert_module_to_f16)
+ self.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self) -> None:
+ """
+ Convert the torso of the model to float32.
+ """
+ self.blocks.apply(convert_module_to_f32)
+
+ def initialize_weights(self) -> None:
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ if self.share_mod:
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
+ else:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.out_layer.weight, 0)
+ nn.init.constant_(self.out_layer.bias, 0)
+
+ def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: Union[torch.Tensor, sp.SparseTensor]) -> sp.SparseTensor:
+ h = self.input_layer(x).type(self.dtype)
+ t_emb = self.t_embedder(t)
+ if self.share_mod:
+ t_emb = self.adaLN_modulation(t_emb)
+ t_emb = t_emb.type(self.dtype)
+ cond = cond.type(self.dtype)
+
+ if self.sparse_conditions:
+ cond = self.cond_proj(cond)
+ cond = cond + self.pos_embedder_cond(cond.coords[:, 1:]).type(self.dtype)
+ if self.pe_mode == "ape":
+ h = h + self.pos_embedder(h.coords[:, 1:], factor=self.factor).type(self.dtype)
+ for block in self.blocks:
+ h = block(h, t_emb, cond)
+
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
+ h = self.out_layer(h.type(x.dtype))
+ return h
diff --git a/direct3d_s2/modules/__pycache__/norm.cpython-310.pyc b/direct3d_s2/modules/__pycache__/norm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7499560d85c1c87edbe8fd1ebe8a2c81dfbb17f7
Binary files /dev/null and b/direct3d_s2/modules/__pycache__/norm.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/__pycache__/spatial.cpython-310.pyc b/direct3d_s2/modules/__pycache__/spatial.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49064a2098551353971c619198fedf94458ad625
Binary files /dev/null and b/direct3d_s2/modules/__pycache__/spatial.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/__pycache__/utils.cpython-310.pyc b/direct3d_s2/modules/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2961e1cbf78cb5c2511fd69340e6be1928ef3cd6
Binary files /dev/null and b/direct3d_s2/modules/__pycache__/utils.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/attention/__init__.py b/direct3d_s2/modules/attention/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d500ecd8ce72fd4e072ecdd9c008d2ae030e0629
--- /dev/null
+++ b/direct3d_s2/modules/attention/__init__.py
@@ -0,0 +1,35 @@
+from typing import *
+BACKEND = 'flash_attn'
+DEBUG = False
+
+def __from_env():
+ import os
+
+ global BACKEND
+ global DEBUG
+
+ env_attn_backend = os.environ.get('ATTN_BACKEND')
+ env_sttn_debug = os.environ.get('ATTN_DEBUG')
+
+ if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
+ BACKEND = env_attn_backend
+ if env_sttn_debug is not None:
+ DEBUG = env_sttn_debug == '1'
+
+ print(f"[ATTENTION] Using backend: {BACKEND}")
+
+
+__from_env()
+
+
+def set_backend(backend: Literal['xformers', 'flash_attn']):
+ global BACKEND
+ BACKEND = backend
+
+def set_debug(debug: bool):
+ global DEBUG
+ DEBUG = debug
+
+
+from .full_attn import *
+from .modules import *
diff --git a/direct3d_s2/modules/attention/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/attention/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b78deb562dc65c50d91cbd6d09e74f5a4f390dac
Binary files /dev/null and b/direct3d_s2/modules/attention/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/attention/__pycache__/full_attn.cpython-310.pyc b/direct3d_s2/modules/attention/__pycache__/full_attn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a0450d41a7c226b03c5e7c9c476bea7c2b47b28
Binary files /dev/null and b/direct3d_s2/modules/attention/__pycache__/full_attn.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/attention/__pycache__/modules.cpython-310.pyc b/direct3d_s2/modules/attention/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82529b982c3e709e0822fe0f2b592edae451777d
Binary files /dev/null and b/direct3d_s2/modules/attention/__pycache__/modules.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/attention/full_attn.py b/direct3d_s2/modules/attention/full_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f94cf46843412c4d349d2d5dcd7277fac938e507
--- /dev/null
+++ b/direct3d_s2/modules/attention/full_attn.py
@@ -0,0 +1,140 @@
+from typing import *
+import torch
+import math
+from . import DEBUG, BACKEND
+
+if BACKEND == 'xformers':
+ import xformers.ops as xops
+elif BACKEND == 'flash_attn':
+ import flash_attn
+elif BACKEND == 'sdpa':
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
+elif BACKEND == 'naive':
+ pass
+else:
+ raise ValueError(f"Unknown attention backend: {BACKEND}")
+
+
+__all__ = [
+ 'scaled_dot_product_attention',
+]
+
+
+def _naive_sdpa(q, k, v):
+ """
+ Naive implementation of scaled dot product attention.
+ """
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
+ scale_factor = 1 / math.sqrt(q.size(-1))
+ attn_weight = q @ k.transpose(-2, -1) * scale_factor
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ out = attn_weight @ v
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
+ return out
+
+
+@overload
+def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
+ """
+ Apply scaled dot product attention.
+
+ Args:
+ qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
+ """
+ ...
+
+@overload
+def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
+ """
+ Apply scaled dot product attention.
+
+ Args:
+ q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
+ kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
+ """
+ ...
+
+@overload
+def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+ """
+ Apply scaled dot product attention.
+
+ Args:
+ q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
+ k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
+ v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
+
+ Note:
+ k and v are assumed to have the same coordinate map.
+ """
+ ...
+
+def scaled_dot_product_attention(*args, **kwargs):
+ arg_names_dict = {
+ 1: ['qkv'],
+ 2: ['q', 'kv'],
+ 3: ['q', 'k', 'v']
+ }
+ num_all_args = len(args) + len(kwargs)
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
+ for key in arg_names_dict[num_all_args][len(args):]:
+ assert key in kwargs, f"Missing argument {key}"
+
+ if num_all_args == 1:
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
+ assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
+ device = qkv.device
+
+ elif num_all_args == 2:
+ q = args[0] if len(args) > 0 else kwargs['q']
+ kv = args[1] if len(args) > 1 else kwargs['kv']
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
+ device = q.device
+
+ elif num_all_args == 3:
+ q = args[0] if len(args) > 0 else kwargs['q']
+ k = args[1] if len(args) > 1 else kwargs['k']
+ v = args[2] if len(args) > 2 else kwargs['v']
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
+ device = q.device
+
+ if BACKEND == 'xformers':
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=2)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=2)
+ out = xops.memory_efficient_attention(q, k, v)
+ elif BACKEND == 'flash_attn':
+ if num_all_args == 1:
+ out = flash_attn.flash_attn_qkvpacked_func(qkv)
+ elif num_all_args == 2:
+ out = flash_attn.flash_attn_kvpacked_func(q, kv)
+ elif num_all_args == 3:
+ out = flash_attn.flash_attn_func(q, k, v)
+ elif BACKEND == 'sdpa':
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=2)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=2)
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
+ out = sdpa(q, k, v) # [N, H, L, C]
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
+ elif BACKEND == 'naive':
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=2)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=2)
+ out = _naive_sdpa(q, k, v)
+ else:
+ raise ValueError(f"Unknown attention module: {BACKEND}")
+
+ return out
diff --git a/direct3d_s2/modules/attention/modules.py b/direct3d_s2/modules/attention/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbe6235c27134f0477e48d3e12de3068c6a500ef
--- /dev/null
+++ b/direct3d_s2/modules/attention/modules.py
@@ -0,0 +1,146 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .full_attn import scaled_dot_product_attention
+
+
+class MultiHeadRMSNorm(nn.Module):
+ def __init__(self, dim: int, heads: int):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
+
+
+class RotaryPositionEmbedder(nn.Module):
+ def __init__(self, hidden_size: int, in_channels: int = 3):
+ super().__init__()
+ assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
+ self.hidden_size = hidden_size
+ self.in_channels = in_channels
+ self.freq_dim = hidden_size // in_channels // 2
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
+ self.freqs = 1.0 / (10000 ** self.freqs)
+
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
+ self.freqs = self.freqs.to(indices.device)
+ phases = torch.outer(indices, self.freqs)
+ phases = torch.polar(torch.ones_like(phases), phases)
+ return phases
+
+ def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+ x_rotated = x_complex * phases
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
+ return x_embed
+
+ def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ q (sp.SparseTensor): [..., N, D] tensor of queries
+ k (sp.SparseTensor): [..., N, D] tensor of keys
+ indices (torch.Tensor): [..., N, C] tensor of spatial positions
+ """
+ if indices is None:
+ indices = torch.arange(q.shape[-2], device=q.device)
+ if len(q.shape) > 2:
+ indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
+
+ phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
+ if phases.shape[1] < self.hidden_size // 2:
+ phases = torch.cat([phases, torch.polar(
+ torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
+ torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
+ )], dim=-1)
+ q_embed = self._rotary_embedding(q, phases)
+ k_embed = self._rotary_embedding(k, phases)
+ return q_embed, k_embed
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ ctx_channels: Optional[int]=None,
+ type: Literal["self", "cross"] = "self",
+ attn_mode: Literal["full", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ qkv_bias: bool = True,
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ ):
+ super().__init__()
+ assert channels % num_heads == 0
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
+ assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
+
+ if attn_mode == "windowed":
+ raise NotImplementedError("Windowed attention is not yet implemented")
+
+ self.channels = channels
+ self.head_dim = channels // num_heads
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
+ self.num_heads = num_heads
+ self._type = type
+ self.attn_mode = attn_mode
+ self.window_size = window_size
+ self.shift_window = shift_window
+ self.use_rope = use_rope
+ self.qk_rms_norm = qk_rms_norm
+
+ if self._type == "self":
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
+ else:
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
+
+ if self.qk_rms_norm:
+ self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
+ self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
+
+ self.to_out = nn.Linear(channels, channels)
+
+ if use_rope:
+ self.rope = RotaryPositionEmbedder(channels)
+
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
+ B, L, C = x.shape
+ if self._type == "self":
+ qkv = self.to_qkv(x)
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
+ if self.use_rope:
+ q, k, v = qkv.unbind(dim=2)
+ q, k = self.rope(q, k, indices)
+ qkv = torch.stack([q, k, v], dim=2)
+ if self.attn_mode == "full":
+ if self.qk_rms_norm:
+ q, k, v = qkv.unbind(dim=2)
+ q = self.q_rms_norm(q)
+ k = self.k_rms_norm(k)
+ h = scaled_dot_product_attention(q, k, v)
+ else:
+ h = scaled_dot_product_attention(qkv)
+ elif self.attn_mode == "windowed":
+ raise NotImplementedError("Windowed attention is not yet implemented")
+ else:
+ Lkv = context.shape[1]
+ q = self.to_q(x)
+ kv = self.to_kv(context)
+ q = q.reshape(B, L, self.num_heads, -1)
+ kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
+ if self.qk_rms_norm:
+ q = self.q_rms_norm(q)
+ k, v = kv.unbind(dim=2)
+ k = self.k_rms_norm(k)
+ h = scaled_dot_product_attention(q, k, v)
+ else:
+ h = scaled_dot_product_attention(q, kv)
+ h = h.reshape(B, L, -1)
+ h = self.to_out(h)
+ return h
diff --git a/direct3d_s2/modules/norm.py b/direct3d_s2/modules/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..09035726081fb7afda2c62504d5474cfa483c58f
--- /dev/null
+++ b/direct3d_s2/modules/norm.py
@@ -0,0 +1,25 @@
+import torch
+import torch.nn as nn
+
+
+class LayerNorm32(nn.LayerNorm):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return super().forward(x.float()).type(x.dtype)
+
+
+class GroupNorm32(nn.GroupNorm):
+ """
+ A GroupNorm layer that converts to float32 before the forward pass.
+ """
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return super().forward(x.float()).type(x.dtype)
+
+
+class ChannelLayerNorm32(LayerNorm32):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ DIM = x.dim()
+ x = x.permute(0, *range(2, DIM), 1).contiguous()
+ x = super().forward(x)
+ x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
+ return x
+
\ No newline at end of file
diff --git a/direct3d_s2/modules/sparse/.ipynb_checkpoints/basic-checkpoint.py b/direct3d_s2/modules/sparse/.ipynb_checkpoints/basic-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..8837f44052f6d573d09e3bfb897e659e10516bb5
--- /dev/null
+++ b/direct3d_s2/modules/sparse/.ipynb_checkpoints/basic-checkpoint.py
@@ -0,0 +1,459 @@
+from typing import *
+import torch
+import torch.nn as nn
+from . import BACKEND, DEBUG
+SparseTensorData = None # Lazy import
+
+
+__all__ = [
+ 'SparseTensor',
+ 'sparse_batch_broadcast',
+ 'sparse_batch_op',
+ 'sparse_cat',
+ 'sparse_unbind',
+]
+
+
+class SparseTensor:
+ """
+ Sparse tensor with support for both torchsparse and spconv backends.
+
+ Parameters:
+ - feats (torch.Tensor): Features of the sparse tensor.
+ - coords (torch.Tensor): Coordinates of the sparse tensor.
+ - shape (torch.Size): Shape of the sparse tensor.
+ - layout (List[slice]): Layout of the sparse tensor for each batch
+ - data (SparseTensorData): Sparse tensor data used for convolusion
+
+ NOTE:
+ - Data corresponding to a same batch should be contiguous.
+ - Coords should be in [0, 1023]
+ """
+ @overload
+ def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
+
+ @overload
+ def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
+
+ def __init__(self, *args, **kwargs):
+ # Lazy import of sparse tensor backend
+ global SparseTensorData
+ if SparseTensorData is None:
+ import importlib
+ if BACKEND == 'torchsparse':
+ SparseTensorData = importlib.import_module('torchsparse').SparseTensor
+ elif BACKEND == 'spconv':
+ SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
+
+ method_id = 0
+ if len(args) != 0:
+ method_id = 0 if isinstance(args[0], torch.Tensor) else 1
+ else:
+ method_id = 1 if 'data' in kwargs else 0
+
+ if method_id == 0:
+ feats, coords, shape, layout = args + (None,) * (4 - len(args))
+ if 'feats' in kwargs:
+ feats = kwargs['feats']
+ del kwargs['feats']
+ if 'coords' in kwargs:
+ coords = kwargs['coords']
+ del kwargs['coords']
+ if 'shape' in kwargs:
+ shape = kwargs['shape']
+ del kwargs['shape']
+ if 'layout' in kwargs:
+ layout = kwargs['layout']
+ del kwargs['layout']
+
+ if shape is None:
+ shape = self.__cal_shape(feats, coords)
+ if layout is None:
+ layout = self.__cal_layout(coords, shape[0])
+ if BACKEND == 'torchsparse':
+ self.data = SparseTensorData(feats, coords, **kwargs)
+ elif BACKEND == 'spconv':
+ spatial_shape = list(coords.max(0)[0] + 1)[1:]
+ self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
+ self.data._features = feats
+ elif method_id == 1:
+ data, shape, layout = args + (None,) * (3 - len(args))
+ if 'data' in kwargs:
+ data = kwargs['data']
+ del kwargs['data']
+ if 'shape' in kwargs:
+ shape = kwargs['shape']
+ del kwargs['shape']
+ if 'layout' in kwargs:
+ layout = kwargs['layout']
+ del kwargs['layout']
+
+ self.data = data
+ if shape is None:
+ shape = self.__cal_shape(self.feats, self.coords)
+ if layout is None:
+ layout = self.__cal_layout(self.coords, shape[0])
+
+ self._shape = shape
+ self._layout = layout
+ self._scale = kwargs.get('scale', (1, 1, 1))
+ self._spatial_cache = kwargs.get('spatial_cache', {})
+
+ if DEBUG:
+ try:
+ assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
+ assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
+ assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
+ for i in range(self.shape[0]):
+ assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
+ except Exception as e:
+ print('Debugging information:')
+ print(f"- Shape: {self.shape}")
+ print(f"- Layout: {self.layout}")
+ print(f"- Scale: {self._scale}")
+ print(f"- Coords: {self.coords}")
+ raise e
+
+ def __cal_shape(self, feats, coords):
+ shape = []
+ shape.append(coords[:, 0].max().item() + 1)
+ shape.extend([*feats.shape[1:]])
+ return torch.Size(shape)
+
+ def __cal_layout(self, coords, batch_size):
+ seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
+ offset = torch.cumsum(seq_len, dim=0)
+ layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
+ return layout
+
+ @property
+ def shape(self) -> torch.Size:
+ return self._shape
+
+ def dim(self) -> int:
+ return len(self.shape)
+
+ @property
+ def layout(self) -> List[slice]:
+ return self._layout
+
+ @property
+ def feats(self) -> torch.Tensor:
+ if BACKEND == 'torchsparse':
+ return self.data.F
+ elif BACKEND == 'spconv':
+ return self.data.features
+
+ @feats.setter
+ def feats(self, value: torch.Tensor):
+ if BACKEND == 'torchsparse':
+ self.data.F = value
+ elif BACKEND == 'spconv':
+ self.data.features = value
+
+ @property
+ def coords(self) -> torch.Tensor:
+ if BACKEND == 'torchsparse':
+ return self.data.C
+ elif BACKEND == 'spconv':
+ return self.data.indices
+
+ @coords.setter
+ def coords(self, value: torch.Tensor):
+ if BACKEND == 'torchsparse':
+ self.data.C = value
+ elif BACKEND == 'spconv':
+ self.data.indices = value
+
+ @property
+ def dtype(self):
+ return self.feats.dtype
+
+ @property
+ def device(self):
+ return self.feats.device
+
+ @overload
+ def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
+
+ @overload
+ def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
+
+ def to(self, *args, **kwargs) -> 'SparseTensor':
+ device = None
+ dtype = None
+ if len(args) == 2:
+ device, dtype = args
+ elif len(args) == 1:
+ if isinstance(args[0], torch.dtype):
+ dtype = args[0]
+ else:
+ device = args[0]
+ if 'dtype' in kwargs:
+ assert dtype is None, "to() received multiple values for argument 'dtype'"
+ dtype = kwargs['dtype']
+ if 'device' in kwargs:
+ assert device is None, "to() received multiple values for argument 'device'"
+ device = kwargs['device']
+
+ new_feats = self.feats.to(device=device, dtype=dtype)
+ new_coords = self.coords.to(device=device)
+ return self.replace(new_feats, new_coords)
+
+ def type(self, dtype):
+ new_feats = self.feats.type(dtype)
+ return self.replace(new_feats)
+
+ def cpu(self) -> 'SparseTensor':
+ new_feats = self.feats.cpu()
+ new_coords = self.coords.cpu()
+ return self.replace(new_feats, new_coords)
+
+ def cuda(self) -> 'SparseTensor':
+ new_feats = self.feats.cuda()
+ new_coords = self.coords.cuda()
+ return self.replace(new_feats, new_coords)
+
+ def half(self) -> 'SparseTensor':
+ new_feats = self.feats.half()
+ return self.replace(new_feats)
+
+ def float(self) -> 'SparseTensor':
+ new_feats = self.feats.float()
+ return self.replace(new_feats)
+
+ def detach(self) -> 'SparseTensor':
+ new_coords = self.coords.detach()
+ new_feats = self.feats.detach()
+ return self.replace(new_feats, new_coords)
+
+ def dense(self) -> torch.Tensor:
+ if BACKEND == 'torchsparse':
+ return self.data.dense()
+ elif BACKEND == 'spconv':
+ return self.data.dense()
+
+ def reshape(self, *shape) -> 'SparseTensor':
+ new_feats = self.feats.reshape(self.feats.shape[0], *shape)
+ return self.replace(new_feats)
+
+ def unbind(self, dim: int) -> List['SparseTensor']:
+ return sparse_unbind(self, dim)
+
+ def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
+ new_shape = [self.shape[0]]
+ new_shape.extend(feats.shape[1:])
+ if BACKEND == 'torchsparse':
+ new_data = SparseTensorData(
+ feats=feats,
+ coords=self.data.coords if coords is None else coords,
+ stride=self.data.stride,
+ spatial_range=self.data.spatial_range,
+ )
+ new_data._caches = self.data._caches
+ elif BACKEND == 'spconv':
+ new_data = SparseTensorData(
+ self.data.features.reshape(self.data.features.shape[0], -1),
+ self.data.indices,
+ self.data.spatial_shape,
+ self.data.batch_size,
+ self.data.grid,
+ self.data.voxel_num,
+ self.data.indice_dict
+ )
+ new_data._features = feats
+ new_data.benchmark = self.data.benchmark
+ new_data.benchmark_record = self.data.benchmark_record
+ new_data.thrust_allocator = self.data.thrust_allocator
+ new_data._timer = self.data._timer
+ new_data.force_algo = self.data.force_algo
+ new_data.int8_scale = self.data.int8_scale
+ if coords is not None:
+ new_data.indices = coords
+ new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
+ return new_tensor
+
+ @staticmethod
+ def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
+ N, C = dim
+ x = torch.arange(aabb[0], aabb[3] + 1)
+ y = torch.arange(aabb[1], aabb[4] + 1)
+ z = torch.arange(aabb[2], aabb[5] + 1)
+ coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
+ coords = torch.cat([
+ torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
+ coords.repeat(N, 1),
+ ], dim=1).to(dtype=torch.int32, device=device)
+ feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
+ return SparseTensor(feats=feats, coords=coords)
+
+ def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
+ new_cache = {}
+ for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
+ if k in self._spatial_cache:
+ new_cache[k] = self._spatial_cache[k]
+ if k in other._spatial_cache:
+ if k not in new_cache:
+ new_cache[k] = other._spatial_cache[k]
+ else:
+ new_cache[k].update(other._spatial_cache[k])
+ return new_cache
+
+ def __neg__(self) -> 'SparseTensor':
+ return self.replace(-self.feats)
+
+ def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
+ if isinstance(other, torch.Tensor):
+ try:
+ other = torch.broadcast_to(other, self.shape)
+ other = sparse_batch_broadcast(self, other)
+ except:
+ pass
+ if isinstance(other, SparseTensor):
+ other = other.feats
+ new_feats = op(self.feats, other)
+ new_tensor = self.replace(new_feats)
+ if isinstance(other, SparseTensor):
+ new_tensor._spatial_cache = self.__merge_sparse_cache(other)
+ return new_tensor
+
+ def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.add)
+
+ def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.add)
+
+ def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.sub)
+
+ def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
+
+ def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.mul)
+
+ def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.mul)
+
+ def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.div)
+
+ def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, lambda x, y: torch.div(y, x))
+
+ def __getitem__(self, idx):
+ if isinstance(idx, int):
+ idx = [idx]
+ elif isinstance(idx, slice):
+ idx = range(*idx.indices(self.shape[0]))
+ elif isinstance(idx, torch.Tensor):
+ if idx.dtype == torch.bool:
+ assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
+ idx = idx.nonzero().squeeze(1)
+ elif idx.dtype in [torch.int32, torch.int64]:
+ assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
+ else:
+ raise ValueError(f"Unknown index type: {idx.dtype}")
+ else:
+ raise ValueError(f"Unknown index type: {type(idx)}")
+
+ coords = []
+ feats = []
+ for new_idx, old_idx in enumerate(idx):
+ coords.append(self.coords[self.layout[old_idx]].clone())
+ coords[-1][:, 0] = new_idx
+ feats.append(self.feats[self.layout[old_idx]])
+ coords = torch.cat(coords, dim=0).contiguous()
+ feats = torch.cat(feats, dim=0).contiguous()
+ return SparseTensor(feats=feats, coords=coords)
+
+ def register_spatial_cache(self, key, value) -> None:
+ """
+ Register a spatial cache.
+ The spatial cache can be any thing you want to cache.
+ The registery and retrieval of the cache is based on current scale.
+ """
+ scale_key = str(self._scale)
+ if scale_key not in self._spatial_cache:
+ self._spatial_cache[scale_key] = {}
+ self._spatial_cache[scale_key][key] = value
+
+ def get_spatial_cache(self, key=None):
+ """
+ Get a spatial cache.
+ """
+ scale_key = str(self._scale)
+ cur_scale_cache = self._spatial_cache.get(scale_key, {})
+ if key is None:
+ return cur_scale_cache
+ return cur_scale_cache.get(key, None)
+
+
+def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
+ """
+ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
+
+ Args:
+ input (torch.Tensor): 1D tensor to broadcast.
+ target (SparseTensor): Sparse tensor to broadcast to.
+ op (callable): Operation to perform after broadcasting. Defaults to torch.add.
+ """
+ coords, feats = input.coords, input.feats
+ broadcasted = torch.zeros_like(feats)
+ for k in range(input.shape[0]):
+ broadcasted[input.layout[k]] = other[k]
+ return broadcasted
+
+
+def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
+ """
+ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
+
+ Args:
+ input (torch.Tensor): 1D tensor to broadcast.
+ target (SparseTensor): Sparse tensor to broadcast to.
+ op (callable): Operation to perform after broadcasting. Defaults to torch.add.
+ """
+ return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
+
+
+def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
+ """
+ Concatenate a list of sparse tensors.
+
+ Args:
+ inputs (List[SparseTensor]): List of sparse tensors to concatenate.
+ """
+ if dim == 0:
+ start = 0
+ coords = []
+ for input in inputs:
+ coords.append(input.coords.clone())
+ coords[-1][:, 0] += start
+ start += input.shape[0]
+ coords = torch.cat(coords, dim=0)
+ feats = torch.cat([input.feats for input in inputs], dim=0)
+ output = SparseTensor(
+ coords=coords,
+ feats=feats,
+ )
+ else:
+ feats = torch.cat([input.feats for input in inputs], dim=dim)
+ output = inputs[0].replace(feats)
+
+ return output
+
+
+def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
+ """
+ Unbind a sparse tensor along a dimension.
+
+ Args:
+ input (SparseTensor): Sparse tensor to unbind.
+ dim (int): Dimension to unbind.
+ """
+ if dim == 0:
+ return [input[i] for i in range(input.shape[0])]
+ else:
+ feats = input.feats.unbind(dim)
+ return [input.replace(f) for f in feats]
diff --git a/direct3d_s2/modules/sparse/__init__.py b/direct3d_s2/modules/sparse/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..036ff1998f29424abb720642cd83a83c6abf750f
--- /dev/null
+++ b/direct3d_s2/modules/sparse/__init__.py
@@ -0,0 +1,105 @@
+from typing import *
+
+BACKEND = 'torchsparse'
+DEBUG = False
+ATTN = 'flash_attn'
+
+def __from_env():
+ import os
+
+ global BACKEND
+ global DEBUG
+ global ATTN
+
+ env_sparse_backend = os.environ.get('SPARSE_BACKEND')
+ env_sparse_debug = os.environ.get('SPARSE_DEBUG')
+ env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
+ if env_sparse_attn is None:
+ env_sparse_attn = os.environ.get('ATTN_BACKEND')
+
+ if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
+ BACKEND = env_sparse_backend
+ if env_sparse_debug is not None:
+ DEBUG = env_sparse_debug == '1'
+ if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
+ ATTN = env_sparse_attn
+
+ print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
+
+
+__from_env()
+
+
+def set_backend(backend: Literal['spconv', 'torchsparse']):
+ global BACKEND
+ BACKEND = backend
+
+def set_debug(debug: bool):
+ global DEBUG
+ DEBUG = debug
+
+def set_attn(attn: Literal['xformers', 'flash_attn']):
+ global ATTN
+ ATTN = attn
+
+
+import importlib
+
+__attributes = {
+ 'SparseTensor': 'basic',
+ 'sparse_batch_broadcast': 'basic',
+ 'sparse_batch_op': 'basic',
+ 'sparse_cat': 'basic',
+ 'sparse_unbind': 'basic',
+ 'SparseGroupNorm': 'norm',
+ 'SparseLayerNorm': 'norm',
+ 'SparseGroupNorm32': 'norm',
+ 'SparseLayerNorm32': 'norm',
+ 'SparseSigmoid': 'nonlinearity',
+ 'SparseReLU': 'nonlinearity',
+ 'SparseSiLU': 'nonlinearity',
+ 'SparseGELU': 'nonlinearity',
+ 'SparseTanh': 'nonlinearity',
+ 'SparseActivation': 'nonlinearity',
+ 'SparseLinear': 'linear',
+ 'sparse_scaled_dot_product_attention': 'attention',
+ 'SerializeMode': 'attention',
+ 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
+ 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
+ 'SparseMultiHeadAttention': 'attention',
+ 'SparseConv3d': 'conv',
+ 'SparseInverseConv3d': 'conv',
+ 'sparseconv3d_func': 'conv',
+ 'SparseDownsample': 'spatial',
+ 'SparseUpsample': 'spatial',
+ 'SparseSubdivide' : 'spatial'
+}
+
+__submodules = ['transformer']
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+ if name not in globals():
+ if name in __attributes:
+ module_name = __attributes[name]
+ module = importlib.import_module(f".{module_name}", __name__)
+ globals()[name] = getattr(module, name)
+ elif name in __submodules:
+ module = importlib.import_module(f".{name}", __name__)
+ globals()[name] = module
+ else:
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+ return globals()[name]
+
+
+# For Pylance
+if __name__ == '__main__':
+ from .basic import *
+ from .norm import *
+ from .nonlinearity import *
+ from .linear import *
+ from .attention import *
+ from .conv import *
+ from .spatial import *
+ import transformer
diff --git a/direct3d_s2/modules/sparse/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8f04ec54c69aa7df0aeeaaef71775535bb39ec3
Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/__pycache__/basic.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/basic.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f0ddea519a69f53a825d46e59a1160f63c83ff9
Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/basic.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/__pycache__/linear.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/linear.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..907fa3835a412ca530196c7fd18fa983f120a391
Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/linear.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b74cae998582e393318ab29540a9978aa1db82a
Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/__pycache__/norm.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/norm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9be3e9bb12a48fcb2ed9fc1bd077ae091d22815a
Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/norm.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/__pycache__/spatial.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/spatial.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e962d7fe5d46ec6b52ef1b440b257e0e09bbd7de
Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/spatial.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/__init__.py b/direct3d_s2/modules/sparse/attention/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e2732f12d139e579fb27f224c523e27f1e8cefb
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/__init__.py
@@ -0,0 +1,5 @@
+from .full_attn import *
+from .serialized_attn import *
+from .windowed_attn import *
+from .modules import *
+from .spatial_sparse_attention.module.spatial_sparse_attention import SpatialSparseAttention
diff --git a/direct3d_s2/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7bb9bf49687ac68fa570e049d22f87b21a5d0b9c
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..636e11586f86da80ac873a5538be2bb474098cdb
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/__pycache__/modules.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c6687e1746733deb74c38271108c42655893d5c
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/__pycache__/modules.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d102abe0ef2316b2492f990a9c20195485042eb5
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f697615079d97b18827305e1d37421078e963bd0
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/full_attn.py b/direct3d_s2/modules/sparse/attention/full_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9e27aeb98419621f3f9999fd3b11eebf2b90a40
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/full_attn.py
@@ -0,0 +1,215 @@
+from typing import *
+import torch
+from .. import SparseTensor
+from .. import DEBUG, ATTN
+
+if ATTN == 'xformers':
+ import xformers.ops as xops
+elif ATTN == 'flash_attn':
+ import flash_attn
+else:
+ raise ValueError(f"Unknown attention module: {ATTN}")
+
+
+__all__ = [
+ 'sparse_scaled_dot_product_attention',
+]
+
+
+@overload
+def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
+ """
+ ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
+ """
+ ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
+ """
+ ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
+
+ Note:
+ k and v are assumed to have the same coordinate map.
+ """
+ ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
+ k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
+ v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
+ """
+ ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
+ """
+ ...
+
+def sparse_scaled_dot_product_attention(*args, **kwargs):
+ arg_names_dict = {
+ 1: ['qkv'],
+ 2: ['q', 'kv'],
+ 3: ['q', 'k', 'v']
+ }
+ num_all_args = len(args) + len(kwargs)
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
+ for key in arg_names_dict[num_all_args][len(args):]:
+ assert key in kwargs, f"Missing argument {key}"
+
+ if num_all_args == 1:
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
+ assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
+ device = qkv.device
+
+ s = qkv
+ q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
+ kv_seqlen = q_seqlen
+ qkv = qkv.feats # [T, 3, H, C]
+
+ elif num_all_args == 2:
+ q = args[0] if len(args) > 0 else kwargs['q']
+ kv = args[1] if len(args) > 1 else kwargs['kv']
+ assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
+ isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
+ f"Invalid types, got {type(q)} and {type(kv)}"
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
+ device = q.device
+
+ if isinstance(q, SparseTensor):
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
+ s = q
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
+ q = q.feats # [T_Q, H, C]
+ else:
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
+ s = None
+ N, L, H, C = q.shape
+ q_seqlen = [L] * N
+ q = q.reshape(N * L, H, C) # [T_Q, H, C]
+
+ if isinstance(kv, SparseTensor):
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
+ kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
+ kv = kv.feats # [T_KV, 2, H, C]
+ else:
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
+ N, L, _, H, C = kv.shape
+ kv_seqlen = [L] * N
+ kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
+
+ elif num_all_args == 3:
+ q = args[0] if len(args) > 0 else kwargs['q']
+ k = args[1] if len(args) > 1 else kwargs['k']
+ v = args[2] if len(args) > 2 else kwargs['v']
+ assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
+ isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
+ f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
+ device = q.device
+
+ if isinstance(q, SparseTensor):
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
+ s = q
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
+ q = q.feats # [T_Q, H, Ci]
+ else:
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
+ s = None
+ N, L, H, CI = q.shape
+ q_seqlen = [L] * N
+ q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
+
+ if isinstance(k, SparseTensor):
+ assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
+ assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
+ kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
+ k = k.feats # [T_KV, H, Ci]
+ v = v.feats # [T_KV, H, Co]
+ else:
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
+ N, L, H, CI, CO = *k.shape, v.shape[-1]
+ kv_seqlen = [L] * N
+ k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
+ v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
+
+ if DEBUG:
+ if s is not None:
+ for i in range(s.shape[0]):
+ assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
+ if num_all_args in [2, 3]:
+ assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
+ if num_all_args == 3:
+ assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
+ assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
+
+ if ATTN == 'xformers':
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=1)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=1)
+ q = q.unsqueeze(0)
+ k = k.unsqueeze(0)
+ v = v.unsqueeze(0)
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
+ out = xops.memory_efficient_attention(q, k, v, mask)[0]
+ elif ATTN == 'flash_attn':
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
+ if num_all_args in [2, 3]:
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
+ if num_all_args == 1:
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
+ elif num_all_args == 2:
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
+ elif num_all_args == 3:
+ out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
+ else:
+ raise ValueError(f"Unknown attention module: {ATTN}")
+
+ if s is not None:
+ return s.replace(out)
+ else:
+ return out.reshape(N, L, H, -1)
diff --git a/direct3d_s2/modules/sparse/attention/modules.py b/direct3d_s2/modules/sparse/attention/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d2fe782b0947700e308e9ec0325e7e91c84e3c2
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/modules.py
@@ -0,0 +1,139 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .. import SparseTensor
+from .full_attn import sparse_scaled_dot_product_attention
+from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
+from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
+from ...attention import RotaryPositionEmbedder
+
+
+class SparseMultiHeadRMSNorm(nn.Module):
+ def __init__(self, dim: int, heads: int):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
+
+ def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
+ x_type = x.dtype
+ x = x.float()
+ if isinstance(x, SparseTensor):
+ x = x.replace(F.normalize(x.feats, dim=-1))
+ else:
+ x = F.normalize(x, dim=-1)
+ return (x * self.gamma * self.scale).to(x_type)
+
+
+class SparseMultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ ctx_channels: Optional[int] = None,
+ type: Literal["self", "cross"] = "self",
+ attn_mode: Literal["full", "serialized", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_sequence: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ serialize_mode: Optional[SerializeMode] = None,
+ qkv_bias: bool = True,
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ ):
+ super().__init__()
+ assert channels % num_heads == 0
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
+ assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
+ assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
+ self.channels = channels
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
+ self.num_heads = num_heads
+ self._type = type
+ self.attn_mode = attn_mode
+ self.window_size = window_size
+ self.shift_sequence = shift_sequence
+ self.shift_window = shift_window
+ self.serialize_mode = serialize_mode
+ self.use_rope = use_rope
+ self.qk_rms_norm = qk_rms_norm
+
+ if self._type == "self":
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
+ else:
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
+
+ if self.qk_rms_norm:
+ self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
+ self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
+
+ self.to_out = nn.Linear(channels, channels)
+
+ if use_rope:
+ self.rope = RotaryPositionEmbedder(channels)
+
+ @staticmethod
+ def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
+ if isinstance(x, SparseTensor):
+ return x.replace(module(x.feats))
+ else:
+ return module(x)
+
+ @staticmethod
+ def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
+ if isinstance(x, SparseTensor):
+ return x.reshape(*shape)
+ else:
+ return x.reshape(*x.shape[:2], *shape)
+
+ def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
+ if isinstance(x, SparseTensor):
+ x_feats = x.feats.unsqueeze(0)
+ else:
+ x_feats = x
+ x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
+ return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
+
+ def _rope(self, qkv: SparseTensor) -> SparseTensor:
+ q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
+ q, k = self.rope(q, k, qkv.coords[:, 1:])
+ qkv = qkv.replace(torch.stack([q, k, v], dim=1))
+ return qkv
+
+ def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
+ if self._type == "self":
+ qkv = self._linear(self.to_qkv, x)
+ qkv = self._fused_pre(qkv, num_fused=3)
+ if self.use_rope:
+ qkv = self._rope(qkv)
+ if self.qk_rms_norm:
+ q, k, v = qkv.unbind(dim=1)
+ q = self.q_rms_norm(q)
+ k = self.k_rms_norm(k)
+ qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
+ if self.attn_mode == "full":
+ h = sparse_scaled_dot_product_attention(qkv)
+ elif self.attn_mode == "serialized":
+ h = sparse_serialized_scaled_dot_product_self_attention(
+ qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
+ )
+ elif self.attn_mode == "windowed":
+ h = sparse_windowed_scaled_dot_product_self_attention(
+ qkv, self.window_size, shift_window=self.shift_window
+ )
+ else:
+ q = self._linear(self.to_q, x)
+ q = self._reshape_chs(q, (self.num_heads, -1))
+ kv = self._linear(self.to_kv, context)
+ kv = self._fused_pre(kv, num_fused=2)
+ if self.qk_rms_norm:
+ q = self.q_rms_norm(q)
+ k, v = kv.unbind(dim=1)
+ k = self.k_rms_norm(k)
+ kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
+ h = sparse_scaled_dot_product_attention(q, kv)
+ h = self._reshape_chs(h, (-1,))
+ h = self._linear(self.to_out, h)
+ return h
diff --git a/direct3d_s2/modules/sparse/attention/serialized_attn.py b/direct3d_s2/modules/sparse/attention/serialized_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5950b75b2f5a6d6e79ab6d472b8501aaa5ec4a26
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/serialized_attn.py
@@ -0,0 +1,193 @@
+from typing import *
+from enum import Enum
+import torch
+import math
+from .. import SparseTensor
+from .. import DEBUG, ATTN
+
+if ATTN == 'xformers':
+ import xformers.ops as xops
+elif ATTN == 'flash_attn':
+ import flash_attn
+else:
+ raise ValueError(f"Unknown attention module: {ATTN}")
+
+
+__all__ = [
+ 'sparse_serialized_scaled_dot_product_self_attention',
+]
+
+
+class SerializeMode(Enum):
+ Z_ORDER = 0
+ Z_ORDER_TRANSPOSED = 1
+ HILBERT = 2
+ HILBERT_TRANSPOSED = 3
+
+
+SerializeModes = [
+ SerializeMode.Z_ORDER,
+ SerializeMode.Z_ORDER_TRANSPOSED,
+ SerializeMode.HILBERT,
+ SerializeMode.HILBERT_TRANSPOSED
+]
+
+
+def calc_serialization(
+ tensor: SparseTensor,
+ window_size: int,
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
+ shift_sequence: int = 0,
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
+) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
+ """
+ Calculate serialization and partitioning for a set of coordinates.
+
+ Args:
+ tensor (SparseTensor): The input tensor.
+ window_size (int): The window size to use.
+ serialize_mode (SerializeMode): The serialization mode to use.
+ shift_sequence (int): The shift of serialized sequence.
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
+
+ Returns:
+ (torch.Tensor, torch.Tensor): Forwards and backwards indices.
+ """
+ fwd_indices = []
+ bwd_indices = []
+ seq_lens = []
+ seq_batch_indices = []
+ offsets = [0]
+
+ if 'vox2seq' not in globals():
+ import vox2seq
+
+ # Serialize the input
+ serialize_coords = tensor.coords[:, 1:].clone()
+ serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
+ if serialize_mode == SerializeMode.Z_ORDER:
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
+ elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
+ elif serialize_mode == SerializeMode.HILBERT:
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
+ elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
+ else:
+ raise ValueError(f"Unknown serialize mode: {serialize_mode}")
+
+ for bi, s in enumerate(tensor.layout):
+ num_points = s.stop - s.start
+ num_windows = (num_points + window_size - 1) // window_size
+ valid_window_size = num_points / num_windows
+ to_ordered = torch.argsort(code[s.start:s.stop])
+ if num_windows == 1:
+ fwd_indices.append(to_ordered)
+ bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
+ fwd_indices[-1] += s.start
+ bwd_indices[-1] += offsets[-1]
+ seq_lens.append(num_points)
+ seq_batch_indices.append(bi)
+ offsets.append(offsets[-1] + seq_lens[-1])
+ else:
+ # Partition the input
+ offset = 0
+ mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
+ split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
+ bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
+ for i in range(num_windows):
+ mid = mids[i]
+ valid_start = split[i]
+ valid_end = split[i + 1]
+ padded_start = math.floor(mid - 0.5 * window_size)
+ padded_end = padded_start + window_size
+ fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
+ offset += valid_start - padded_start
+ bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
+ offset += padded_end - valid_start
+ fwd_indices[-1] += s.start
+ seq_lens.extend([window_size] * num_windows)
+ seq_batch_indices.extend([bi] * num_windows)
+ bwd_indices.append(bwd_index + offsets[-1])
+ offsets.append(offsets[-1] + num_windows * window_size)
+
+ fwd_indices = torch.cat(fwd_indices)
+ bwd_indices = torch.cat(bwd_indices)
+
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
+
+
+def sparse_serialized_scaled_dot_product_self_attention(
+ qkv: SparseTensor,
+ window_size: int,
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
+ shift_sequence: int = 0,
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
+) -> SparseTensor:
+ """
+ Apply serialized scaled dot product self attention to a sparse tensor.
+
+ Args:
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
+ window_size (int): The window size to use.
+ serialize_mode (SerializeMode): The serialization mode to use.
+ shift_sequence (int): The shift of serialized sequence.
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
+ shift (int): The shift to use.
+ """
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
+
+ serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
+ if serialization_spatial_cache is None:
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
+ else:
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
+
+ M = fwd_indices.shape[0]
+ T = qkv.feats.shape[0]
+ H = qkv.feats.shape[2]
+ C = qkv.feats.shape[3]
+
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
+
+ if DEBUG:
+ start = 0
+ qkv_coords = qkv.coords[fwd_indices]
+ for i in range(len(seq_lens)):
+ assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
+ start += seq_lens[i]
+
+ if all([seq_len == window_size for seq_len in seq_lens]):
+ B = len(seq_lens)
+ N = window_size
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
+ if ATTN == 'xformers':
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
+ elif ATTN == 'flash_attn':
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
+ else:
+ raise ValueError(f"Unknown attention module: {ATTN}")
+ out = out.reshape(B * N, H, C) # [M, H, C]
+ else:
+ if ATTN == 'xformers':
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
+ q = q.unsqueeze(0) # [1, M, H, C]
+ k = k.unsqueeze(0) # [1, M, H, C]
+ v = v.unsqueeze(0) # [1, M, H, C]
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
+ elif ATTN == 'flash_attn':
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
+ .to(qkv.device).int()
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
+
+ out = out[bwd_indices] # [T, H, C]
+
+ if DEBUG:
+ qkv_coords = qkv_coords[bwd_indices]
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
+
+ return qkv.replace(out)
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/__init__.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53c0cd4d3e95a487a67f2c4a79227c5d2a852078
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__init__.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__init__.py
@@ -0,0 +1 @@
+
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a157b61e6024f1d194e5f5eea8deb290f86dbc8d
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/compression_block.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/compression_block.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a742c89f7169dfd9747021ca3ceceda3ef353862
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/compression_block.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/spatial_sparse_attention.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/spatial_sparse_attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56b02b9d9dbfb5ed384762e4a59ffd509e1318d4
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/spatial_sparse_attention.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/compression_block.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/compression_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..579a611862995abfdd3478dad22ba912fdb8c1bb
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/compression_block.py
@@ -0,0 +1,65 @@
+import torch.nn as nn
+import direct3d_s2.modules.sparse as sp
+
+
+class SparseDownBlock3d_v1(nn.Module):
+
+ def __init__(
+ self,
+ channels: int,
+ out_channels: int = None,
+ factor: int = 2,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.act_layers = nn.Sequential(
+ sp.SparseConv3d(self.out_channels, self.out_channels, 1, padding=0),
+ sp.SparseSiLU()
+ )
+ self.down = sp.SparseDownsample(factor)
+
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ h = self.act_layers(x)
+ h = self.down(h)
+ return h
+
+class SparseDownBlock3d_v2(nn.Module):
+
+ def __init__(
+ self,
+ channels: int,
+ out_channels: int = None,
+ num_groups: int = 32,
+ factor: int = 2,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.act_layers = nn.Sequential(
+ sp.SparseGroupNorm32(num_groups, channels),
+ sp.SparseSiLU()
+ )
+
+ self.down = sp.SparseDownsample(factor)
+ self.out_layers = nn.Sequential(
+ sp.SparseConv3d(channels, self.out_channels, 3, padding=1),
+ sp.SparseGroupNorm32(num_groups, self.out_channels),
+ sp.SparseSiLU(),
+ sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ else:
+ self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1)
+
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ h = self.act_layers(x)
+ h = self.down(h)
+ x = self.down(x)
+ h = self.out_layers(h)
+ h = h + self.skip_connection(x)
+ return h
\ No newline at end of file
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/spatial_sparse_attention.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/spatial_sparse_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cb8977d4b7c4256d6e95cfb3d308449badfa8b9
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/spatial_sparse_attention.py
@@ -0,0 +1,155 @@
+import torch
+import torch.nn as nn
+from einops import rearrange
+from flash_attn import flash_attn_varlen_func
+from ..ops import (
+ spatial_selection_attention,
+ get_block_score,
+ sparse_window_attention,
+)
+from .compression_block import SparseDownBlock3d_v1, SparseDownBlock3d_v2
+import direct3d_s2.modules.sparse as sp
+
+
+class SpatialSparseAttention(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_q_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ compression_block_size: int,
+ selection_block_size: int,
+ topk: int,
+ window_size: int,
+ shift_window: int,
+ resolution: int = 64,
+ compression_version: str = 'v2',
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.num_q_heads = num_q_heads
+ self.num_kv_heads = num_kv_heads
+ self.head_dim = head_dim
+ self.compression_block_size = compression_block_size
+ self.selection_block_size = selection_block_size
+ self.topk = topk
+ self.window_size = window_size
+ self.shift_window = shift_window
+ self.resolution = resolution
+
+ # qkv proj and o proj
+ self.proj_q = sp.SparseLinear(
+ hidden_size, num_q_heads * head_dim, bias=False
+ )
+ self.proj_k = sp.SparseLinear(
+ hidden_size, num_kv_heads * head_dim, bias=False
+ )
+ self.proj_v = sp.SparseLinear(
+ hidden_size, num_kv_heads * head_dim, bias=False
+ )
+ self.proj_o = torch.nn.Linear(
+ num_q_heads * head_dim, hidden_size, bias=False
+ )
+
+ # ssa parameteres
+ if compression_version == 'v1':
+ compression_block = SparseDownBlock3d_v1
+ elif compression_version == 'v2':
+ compression_block = SparseDownBlock3d_v2
+ else:
+ raise NotImplementedError('only support v1 or v2 compression block')
+ self.compression_key = compression_block(
+ num_kv_heads * head_dim, num_kv_heads * head_dim, factor=compression_block_size
+ )
+ self.compression_value = compression_block(
+ num_kv_heads * head_dim, num_kv_heads * head_dim, factor=compression_block_size
+ )
+ self.intra_block_pe = torch.nn.Parameter(
+ torch.zeros(compression_block_size,
+ compression_block_size,
+ compression_block_size,
+ num_kv_heads * head_dim,
+ )
+ )
+
+ # gate function
+ self.gate = torch.nn.Sequential(
+ sp.SparseLinear(hidden_size, 3, bias=False), sp.SparseSigmoid(),
+ )
+
+ def sparse3d_compression(self, x, key=True):
+ _, num_heads, num_dim = x.feats.shape
+ x = x.replace(x.feats.view(-1, num_heads * num_dim))
+ if key:
+ coords = x.coords
+ intra_block_coords = coords[..., 1:] % self.compression_block_size
+ intra_block_pos = self.intra_block_pe[intra_block_coords[:, 0], intra_block_coords[:, 1], intra_block_coords[:, 2]].to(x.dtype)
+ x = x.replace(x.feats + intra_block_pos)
+ y = self.compression_key(x)
+ else:
+ y = self.compression_value(x)
+ y = y.replace(y.feats.view(-1, num_heads, num_dim))
+ return y
+
+ def forward(self, x: sp.SparseTensor):
+ # dtype and shape check
+ assert x.shape[-1] == self.hidden_size
+ assert self.selection_block_size % self.compression_block_size == 0
+ # qkv proj
+ q = x.replace(self.proj_q(x).feats.view(-1, self.num_q_heads, self.head_dim))
+ k = x.replace(self.proj_k(x).feats.view(-1, self.num_kv_heads, self.head_dim))
+ v = x.replace(self.proj_v(x).feats.view(-1, self.num_kv_heads, self.head_dim))
+
+ # compression attention
+ compressed_k = self.sparse3d_compression(k, key=True)
+ compressed_v = self.sparse3d_compression(v, key=False)
+
+ compressed_cu_seqlens = torch.tensor([s.start for s in compressed_v.layout] + [s.stop for s in compressed_v.layout if s.stop not in [s.start for s in compressed_v.layout]]).to(compressed_v.device).to(torch.int32)
+ compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
+
+ cu_seqlens = torch.tensor([s.start for s in x.layout] + [s.stop for s in x.layout if s.stop not in [s.start for s in x.layout]]).to(x.device).to(torch.int32)
+ seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
+
+ compressed_attn_output, lse, _ = flash_attn_varlen_func(
+ q.feats,
+ compressed_k.feats,
+ compressed_v.feats,
+ cu_seqlens,
+ compressed_cu_seqlens,
+ seqlens.max().item(),
+ compressed_seqlens.max().item(),
+ causal=False,
+ return_attn_probs=True,
+ )
+
+ with torch.no_grad():
+ block_topk, cu_seqblocks, cu_block_include_tokens = get_block_score(
+ q, compressed_k, lse, self.resolution, self.compression_block_size,
+ self.selection_block_size, self.topk, cu_seqlens, compressed_cu_seqlens,
+ seqlens, compressed_seqlens, None)
+
+ # spatial selection attention
+ selection_attn_output = spatial_selection_attention(
+ q.feats, k.feats, v.feats, block_topk, cu_seqblocks,
+ cu_block_include_tokens, self.selection_block_size, cu_seqlens, None,
+ )
+
+ # window attention
+ window_attn_output = sparse_window_attention(
+ q, k, v, window_size=self.window_size, shift_window=self.shift_window,
+ ).feats
+
+ # gate average
+ gate = self.gate(x).feats
+ attn_output = (
+ gate[:, 0:1, None] * compressed_attn_output
+ + gate[:, 1:2, None] * selection_attn_output
+ + gate[:, 2:3, None] * window_attn_output
+ )
+
+ # rearrange and output proj
+ attn_output = rearrange(attn_output, "n h d -> n (h d)")
+ attn_output = self.proj_o(attn_output)
+
+ return x.replace(attn_output)
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__init__.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef3bfe015052a8b72b5c9d1d7337621656bc7ce7
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__init__.py
@@ -0,0 +1,3 @@
+from .compressed_attention import get_block_score
+from .selection_attention import spatial_selection_attention
+from .window_attention import sparse_window_attention
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbef7c9a876454135c3ef88993cca9f0f93dac51
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/compressed_attention.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/compressed_attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f93816c494979e04f17d8e3fb8423d82891122bd
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/compressed_attention.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/selection_attention.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/selection_attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bff8490d9552b66ad7d745146a9ec17ac594b44
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/selection_attention.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/window_attention.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/window_attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df310ed5db6d8d209b9b77b0d3464ed484881e44
Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/window_attention.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/compressed_attention.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/compressed_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..5edc6bfb55511f4f9376a27aac91d2eb6d04f7dd
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/compressed_attention.py
@@ -0,0 +1,275 @@
+# Copyright 2025 Xunhao Lai.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ---------------------------------------------------------------------
+# Copyright 2025 Shuang Wu
+# adapted from https://github.com/XunhaoLai/native-sparse-attention-triton/blob/main/native_sparse_attention/ops/triton/compressed_attention.py
+
+import math
+import torch
+from copy import deepcopy
+import triton
+import triton.language as tl
+import direct3d_s2.modules.sparse as sp
+
+
+@triton.jit
+def score_kernel(
+ q_ptr,
+ k_ptr,
+ lse_ptr,
+ s_ptr,
+ # seqlens
+ cu_seqlens_q,
+ cu_seqlens_k,
+ # shape
+ NUM_KV_HEADS,
+ NUM_SHARE_Q_HEADS,
+ HEAD_DIM,
+ # sm_scale
+ sm_scale,
+ # stride
+ stride_qn,
+ stride_qh,
+ stride_qd,
+ stride_kn,
+ stride_kh,
+ stride_kd,
+ stride_lh,
+ stride_ln,
+ stride_sh,
+ stride_sq,
+ stride_sk,
+ # META parameters
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
+ BLOCK_SIZE_K: tl.constexpr, # k block size
+ BLOCK_SIZE_D: tl.constexpr,
+):
+ qk_scale = sm_scale * 1.44269504
+ # get batch id and head id
+ pid_bkh = tl.program_id(0)
+ pid_b = pid_bkh // NUM_KV_HEADS
+ pid_kh = pid_bkh % NUM_KV_HEADS
+ pid_q = tl.program_id(1)
+ pid_k = tl.program_id(2)
+ # get q k start and len after rmpad
+ q_start = tl.load(cu_seqlens_q + pid_b)
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
+ k_start = tl.load(cu_seqlens_k + pid_b)
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
+ if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:
+ return
+ # init k pointer and load k
+ k_ptrs = tl.make_block_ptr(
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
+ shape=(HEAD_DIM, k_len),
+ strides=(stride_kd, stride_kn),
+ offsets=(0, pid_k * BLOCK_SIZE_K),
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
+ order=(0, 1),
+ )
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
+ # init score
+ s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
+ # loop over gqa heads
+ for h in range(NUM_SHARE_Q_HEADS):
+ pid_h = pid_kh * NUM_SHARE_Q_HEADS + h
+ q_ptrs = tl.make_block_ptr(
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
+ shape=(q_len, HEAD_DIM),
+ strides=(stride_qn, stride_qd),
+ offsets=(pid_q * BLOCK_SIZE_Q, 0),
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ lse_ptrs = tl.make_block_ptr(
+ base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
+ shape=(q_len, 1),
+ strides=(stride_ln, stride_lh),
+ offsets=(pid_q * BLOCK_SIZE_Q, 0),
+ block_shape=(BLOCK_SIZE_Q, 1),
+ order=(0, 1),
+ )
+ # load q and lse
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
+ # compute qk
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
+ qk += tl.dot(q, k) * qk_scale
+ # compute score
+ s += tl.exp2(qk - lse)
+ # save output
+ s_ptrs = tl.make_block_ptr(
+ base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,
+ shape=(q_len, k_len),
+ strides=(stride_sq, stride_sk),
+ offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),
+ order=(1, 0),
+ )
+ tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))
+
+
+def _get_attention_score(
+ q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
+ k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
+ lse: torch.Tensor, # [num_q_heads, total_query_len]
+ cu_seqlens_q: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_seqlen_q: int,
+ max_seqlen_k: int,
+ sm_scale: float,
+) -> torch.Tensor:
+ # dtype check
+ assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
+ assert q.dtype == k.dtype
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
+ assert lse.dtype == torch.float32
+ # shape
+ q_len, num_q_heads, head_dim = q.shape
+ k_len, num_k_heads, head_dim = k.shape
+ batch_size = cu_seqlens_q.shape[0] - 1
+ assert q_len > k_len
+ if sm_scale is None:
+ sm_scale = 1 / math.sqrt(head_dim)
+ # gqa
+ assert num_q_heads % num_k_heads == 0
+ num_share_q_heads = num_q_heads // num_k_heads
+ # init score
+ score = torch.zeros(
+ num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device
+ )
+ # launch kernel
+ grid = lambda META: (
+ batch_size * num_k_heads,
+ triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
+ triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
+ )
+ num_warps = 4 if head_dim <= 64 else 8
+ num_stages = 3
+ BLOCK_SIZE_Q = 128
+ BLOCK_SIZE_K = 128
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
+ score_kernel[grid](
+ q,
+ k,
+ lse,
+ score,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ num_k_heads,
+ num_share_q_heads,
+ head_dim,
+ sm_scale,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ lse.stride(0),
+ lse.stride(1),
+ score.stride(0),
+ score.stride(1),
+ score.stride(2),
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return score
+
+
+def get_block_score(
+ q: sp.SparseTensor,
+ compressed_k: sp.SparseTensor,
+ lse: sp.SparseTensor,
+ resolution: int,
+ kernel_stride: int,
+ block_size: int,
+ topk: int,
+ cu_seqlens: torch.Tensor,
+ compressed_cu_seqlens: torch.Tensor,
+ seqlens: torch.Tensor,
+ compressed_seqlens: torch.Tensor,
+ sm_scale: float = None,
+) -> torch.Tensor:
+ attn_score = _get_attention_score(
+ q.feats,
+ compressed_k.feats,
+ lse.exp().log2(),
+ cu_seqlens,
+ compressed_cu_seqlens,
+ seqlens.max().item(),
+ compressed_seqlens.max().item(),
+ sm_scale,
+ )
+
+ batch_size = len(cu_seqlens) - 1
+ num_kv_head = attn_score.shape[0]
+ block_res = resolution // block_size
+ seqblocks, block_include_tokens = [], []
+ block_topk = torch.ones((num_kv_head, cu_seqlens[-1], topk), device=q.device, dtype=torch.int32) * -1
+
+ q_coords = deepcopy(q.coords)
+ for b in range(batch_size):
+ q_start, q_end, q_len = cu_seqlens[b], cu_seqlens[b + 1], seqlens[b]
+
+ compressed_k_start, compressed_k_end = compressed_cu_seqlens[b], compressed_cu_seqlens[b + 1]
+ attn_score_b = attn_score[:, q_start: q_end, :(compressed_k_end-compressed_k_start)]
+ compressed_block_coords_b = deepcopy(compressed_k.coords[compressed_k_start: compressed_k_end])
+ if block_size == kernel_stride:
+ score_block_b = attn_score_b
+ real_topk = min(topk, compressed_k_end - compressed_k_start)
+ block_topk_b = score_block_b.topk(real_topk, dim=-1).indices.sort(-1).values
+ block_topk[:, q_start: q_end, :real_topk] = block_topk_b
+ else:
+ compressed_block_coords_b[:, 1:] = compressed_block_coords_b[:, 1:] // (block_size//kernel_stride)
+ compressed_block_coords_flatten_b = compressed_block_coords_b[:, 1] * block_res**2 + compressed_block_coords_b[:, 2] * block_res + compressed_block_coords_b[:, 3]
+ score_block_b = torch.scatter_reduce(
+ torch.zeros((num_kv_head, q_len, block_res**3), device=attn_score_b.device, dtype=attn_score_b.dtype),
+ index=compressed_block_coords_flatten_b.long().unsqueeze(0).unsqueeze(0).expand_as(attn_score_b),
+ src=attn_score_b,
+ reduce="sum",
+ dim=2,
+ )
+ compressed_block_coords_flatten_unique_b = compressed_block_coords_flatten_b.unique()
+ score_block_b = score_block_b[..., compressed_block_coords_flatten_unique_b]
+ real_topk = min(topk, len(compressed_block_coords_flatten_unique_b))
+ block_topk_b = score_block_b.topk(real_topk, dim=-1).indices.sort(-1).values
+ block_topk[:, q_start: q_end, :real_topk] = block_topk_b
+
+ block_coords_b = q_coords[q_start: q_end]
+ block_coords_b[:, 1:] = block_coords_b[:, 1:] // block_size
+ block_coords_flatten_b = block_coords_b[:, 1] * block_res**2 + block_coords_b[:, 2] * block_res + block_coords_b[:, 3]
+ block_bins_b = torch.histc(block_coords_flatten_b, bins=block_res**3, min=0, max=block_res**3-1)
+ block_include_tokens.append(block_bins_b[block_bins_b > 0])
+ seqblocks.append(len(block_include_tokens[-1]))
+ seqblocks = torch.Tensor(seqblocks).to(attn_score.device)
+ cu_seqblocks = torch.cat(
+ [
+ torch.zeros(1, dtype=torch.int32, device="cuda"),
+ torch.cumsum(seqblocks, dim=0),
+ ],
+ dim=0,
+ ).to(torch.int32)
+ block_include_tokens = torch.cat(block_include_tokens)
+ cu_block_include_tokens = torch.cat(
+ [
+ torch.zeros(1, dtype=torch.int32, device="cuda"),
+ torch.cumsum(block_include_tokens, dim=0),
+ ],
+ dim=0,
+ ).to(torch.int32)
+ return block_topk.to(torch.int32), cu_seqblocks, cu_block_include_tokens
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/selection_attention.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/selection_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..62607d6f5a0a1c2b18180b5495a6413d9867402f
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/selection_attention.py
@@ -0,0 +1,1256 @@
+# Copyright 2025 Xunhao Lai & Jianqiao Lu.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ---------------------------------------------------------------------
+# Copyright 2025 Shuang Wu
+# adapted from https://github.com/XunhaoLai/native-sparse-attention-triton/blob/main/native_sparse_attention/ops/triton/topk_sparse_attention.py
+
+import math
+from typing import Any, Optional
+
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def forward_kernel(
+ q_ptr, # Q: n x h x d
+ k_ptr, # K: n x kh x d
+ v_ptr, # V: n x kh x d
+ t_ptr, # topk_idx: kh x n x k
+ o_ptr, # O: n x h x d
+ lse_ptr, # LSE: h x n
+ # seqlens
+ cu_seqlens_q,
+ cu_seqlens_k,
+ cu_seqblocks,
+ cu_block_include_tokens,
+ # shape
+ NUM_KV_HEADS,
+ NUM_SHARE_Q_HEADS,
+ HEAD_DIM,
+ TOPK,
+ # q loop num
+ num_q_loop,
+ # sm_scale
+ sm_scale,
+ # stride
+ stride_qn,
+ stride_qh,
+ stride_qd,
+ stride_kn,
+ stride_kh,
+ stride_kd,
+ stride_vn,
+ stride_vh,
+ stride_vd,
+ stride_th,
+ stride_tn,
+ stride_tk,
+ stride_on,
+ stride_oh,
+ stride_od,
+ stride_lh,
+ stride_ln,
+ # META parameters
+ BLOCK_SIZE_K: tl.constexpr, # k block size
+ BLOCK_SIZE_D: tl.constexpr,
+ BLOCK_SIZE_H: tl.constexpr,
+ BLOCK_SIZE_T: tl.constexpr,
+):
+ qk_scale = sm_scale * 1.44269504
+ # get batch id and head id
+ pid_b = tl.program_id(0)
+ pid_kh = tl.program_id(1)
+ pid_h = pid_kh * NUM_SHARE_Q_HEADS
+ pid_q = tl.program_id(2)
+ # get q k start and len after rmpad
+ q_start = tl.load(cu_seqlens_q + pid_b)
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
+ k_start = tl.load(cu_seqlens_k + pid_b)
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
+ block_start = tl.load(cu_seqblocks + pid_b)
+ block_len = tl.load(cu_seqblocks + pid_b + 1) - block_start
+ if pid_q * num_q_loop >= q_len:
+ return
+ num_q_loop_ = min(num_q_loop, q_len - pid_q * num_q_loop)
+ for j in range(num_q_loop_):
+ pid_q_j = pid_q * num_q_loop + j
+ # init topk idx pointer
+ off_t = tl.arange(0, BLOCK_SIZE_T)
+ t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th
+ topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1)
+ real_topk = tl.sum(
+ tl.where((topk_idx >= 0) & (topk_idx < block_len), 1, 0),
+ axis=0,
+ )
+ # init qkv pointer
+ q_ptrs = tl.make_block_ptr(
+ base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh,
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
+ strides=(stride_qh, stride_qd),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ k_ptrs = tl.make_block_ptr(
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
+ shape=(HEAD_DIM, k_len),
+ strides=(stride_kd, stride_kn),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
+ order=(0, 1),
+ )
+ v_ptrs = tl.make_block_ptr(
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
+ shape=(k_len, HEAD_DIM),
+ strides=(stride_vn, stride_vd),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ # load q
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
+ # init statistics
+ off_h = tl.arange(0, BLOCK_SIZE_H)
+ off_k = tl.arange(0, BLOCK_SIZE_K)
+ m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32)
+ lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32)
+ acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32)
+ # sparse attention
+ for i in range(real_topk):
+ # get current block start index
+ cur_block_idx = tl.load(t_ptr_j).to(tl.int32)
+ cur_token_start = tl.load(cu_block_include_tokens + block_start + cur_block_idx).to(tl.int32)
+ cur_block_size = tl.load(cu_block_include_tokens + block_start + cur_block_idx + 1).to(tl.int32) - cur_token_start
+ c = cur_token_start - k_start
+ t_ptr_j = t_ptr_j + stride_tk
+ for b_j in range(0, cur_block_size, BLOCK_SIZE_K):
+ # load k
+ k = tl.load(
+ tl.advance(k_ptrs, (0, c + b_j)),
+ boundary_check=(1, 0), padding_option="zero",
+ )
+ # compute qk
+ qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)
+ qk += tl.where(((c + b_j + off_k < k_len) & (b_j + off_k < cur_block_size))[None, :], 0, float("-inf"))
+ # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K]
+ qk += tl.dot(q, k) * qk_scale
+ # compute m_ij and l_ij
+ m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
+ p = tl.exp2(qk - m_ij[:, None])
+ l_ij = tl.sum(p, axis=1)
+ # scale acc_o
+ acc_o_scale = tl.exp2(m_i - m_ij)
+ acc_o = acc_o * acc_o_scale[:, None]
+ # load v and update acc_o
+ v = tl.load(
+ tl.advance(v_ptrs, (c + b_j, 0)),
+ boundary_check=(0, 1), padding_option="zero"
+ )
+ p = p.to(v.dtype)
+ acc_o += tl.dot(p, v)
+ # update statistics
+ m_i = m_ij
+ lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
+ # final scale
+ acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
+ # save output
+ o_ptrs = tl.make_block_ptr(
+ base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh,
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
+ strides=(stride_oh, stride_od),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
+ # save lse
+ lse_ptrs = (
+ lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh
+ )
+ tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS)
+
+
+@triton.jit
+def backward_sum_o_do(
+ o_ptr, # O: n x h x d
+ do_ptr, # dO: n x h x d
+ delta_ptr, # D: h x n
+ o_len,
+ HEAD_DIM,
+ stride_on,
+ stride_oh,
+ stride_od,
+ stride_don,
+ stride_doh,
+ stride_dod,
+ stride_dh,
+ stride_dn,
+ BLOCK_SIZE_O: tl.constexpr,
+ BLOCK_SIZE_D: tl.constexpr,
+):
+ pid_n = tl.program_id(0)
+ pid_h = tl.program_id(1)
+ off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
+ off_d = tl.arange(0, BLOCK_SIZE_D)
+ o = tl.load(
+ o_ptr
+ + off_o[:, None] * stride_on
+ + pid_h * stride_oh
+ + off_d[None, :] * stride_od,
+ mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
+ other=0,
+ ).to(tl.float32)
+ do = tl.load(
+ do_ptr
+ + off_o[:, None] * stride_don
+ + pid_h * stride_doh
+ + off_d[None, :] * stride_dod,
+ mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
+ other=0,
+ ).to(tl.float32)
+ delta = tl.sum(o * do, axis=1)
+ tl.store(
+ delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len
+ )
+
+
+@triton.jit
+def count_kernel(
+ x_ptr, # [num_kv_heads, total_len, topk]
+ y_ptr, # [num_kv_heads, total_blocks]
+ cu_seqlens, # [batch_size + 1]
+ cu_seqblocks, # [batch_size + 1]
+ topk,
+ stride_xh,
+ stride_xn,
+ stride_xk,
+ stride_yh,
+ stride_yn,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ BLOCK_SIZE_R: tl.constexpr,
+):
+ pid_h = tl.program_id(0)
+ pid_b = tl.program_id(1)
+ # get start and len after rmpad
+ seq_start = tl.load(cu_seqlens + pid_b)
+ seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start
+ blocks_start = tl.load(cu_seqblocks + pid_b)
+ num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start
+ # load x
+ off_k = tl.arange(0, BLOCK_SIZE_K)
+ off_n = tl.arange(0, BLOCK_SIZE_N)
+ x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn
+ x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk
+ # init y
+ y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32)
+ # loop
+ for i in range(0, seq_len, BLOCK_SIZE_N):
+ x = tl.load(
+ x_ptrs,
+ mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :],
+ other=-1,
+ )
+ x = tl.ravel(x)
+ y += tl.histogram(x, BLOCK_SIZE_R)
+ x_ptrs += BLOCK_SIZE_N * stride_xn
+ # store result
+ off_r = tl.arange(0, BLOCK_SIZE_R)
+ y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn
+ y_ptrs = y_ptr + off_r * stride_yn
+ tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks)
+
+
+def count_query(
+ topk_idx: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ cu_seqblocks: torch.Tensor,
+ block_size: int,
+):
+ num_kv_heads, total_len, topk = topk_idx.shape
+ seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
+ seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1]
+ batch_size = seqlens.shape[0]
+ BLOCK_SIZE_K = triton.next_power_of_2(topk)
+ BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K)
+ BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2)
+ active_query_count = torch.zeros(
+ num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device
+ )
+ grid = (num_kv_heads, batch_size)
+ count_kernel[grid](
+ topk_idx,
+ active_query_count,
+ cu_seqlens,
+ cu_seqblocks,
+ topk,
+ topk_idx.stride(0),
+ topk_idx.stride(1),
+ topk_idx.stride(2),
+ active_query_count.stride(0),
+ active_query_count.stride(1),
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
+ BLOCK_SIZE_R=BLOCK_SIZE_R,
+ num_warps=4,
+ num_stages=3,
+ )
+ return active_query_count
+
+
+@triton.jit
+def pad_topk_idx_kernel(
+ t_ptr,
+ p_ptr,
+ cu_seqlens,
+ topk,
+ stride_th,
+ stride_tn,
+ stride_tk,
+ stride_pb,
+ stride_ph,
+ stride_pn,
+ stride_pk,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_T: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ pid_h = tl.program_id(1)
+ pid_n = tl.program_id(2)
+ # get q start and len after rmpad
+ q_start = tl.load(cu_seqlens + pid_b)
+ q_len = tl.load(cu_seqlens + pid_b + 1) - q_start
+ if BLOCK_SIZE_N * pid_n >= q_len:
+ return
+ # init prts
+ t_ptrs = tl.make_block_ptr(
+ base=t_ptr + pid_h * stride_th + q_start * stride_tn,
+ shape=(q_len, topk),
+ strides=(stride_tn, stride_tk),
+ offsets=(pid_n * BLOCK_SIZE_N, 0),
+ block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T),
+ order=(1, 0),
+ )
+ p_ptrs = tl.make_block_ptr(
+ base=p_ptr + pid_b * stride_pb + pid_h * stride_ph,
+ shape=(q_len, topk),
+ strides=(stride_pn, stride_pk),
+ offsets=(pid_n * BLOCK_SIZE_N, 0),
+ block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T),
+ order=(1, 0),
+ )
+ # load and save
+ idxs = tl.load(t_ptrs, boundary_check=(0, 1))
+ tl.store(p_ptrs, idxs, boundary_check=(0, 1))
+
+
+@triton.jit
+def save_topk_idx_kernel(
+ p_ptr,
+ t_ptr,
+ cu_seqblocks,
+ cu_topk_q_count,
+ n_len,
+ stride_pb,
+ stride_ph,
+ stride_pn,
+ stride_th,
+ stride_tn,
+ stride_ch,
+ stride_cn,
+ BLOCK_SIZE_N: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ pid_h = tl.program_id(1)
+ pid_n = tl.program_id(2)
+ # get q start and len after rmpad
+ q_block_start = tl.load(cu_seqblocks + pid_b)
+ q_block_end = tl.load(cu_seqblocks + pid_b + 1)
+ c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn)
+ c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn)
+ c_len = c_end - c_start
+ if c_len <= 0:
+ return
+ if pid_n * BLOCK_SIZE_N >= c_len:
+ return
+ # init ptrs
+ p_ptrs = tl.make_block_ptr(
+ base=p_ptr
+ + pid_b * stride_pb
+ + pid_h * stride_ph
+ + (n_len - c_len) * stride_pn,
+ shape=(c_len,),
+ strides=(stride_pn,),
+ offsets=(pid_n * BLOCK_SIZE_N,),
+ block_shape=(BLOCK_SIZE_N,),
+ order=(0,),
+ )
+ t_ptrs = tl.make_block_ptr(
+ base=t_ptr + pid_h * stride_th + c_start * stride_tn,
+ shape=(c_len,),
+ strides=(stride_tn,),
+ offsets=(pid_n * BLOCK_SIZE_N,),
+ block_shape=(BLOCK_SIZE_N,),
+ order=(0,),
+ )
+ # load and save
+ idxs = tl.load(p_ptrs, boundary_check=(0,))
+ tl.store(t_ptrs, idxs, boundary_check=(0,))
+
+
+def reorder_topk_idx(
+ topk_idx: torch.Tensor,
+ cu_topk_q_count: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ cu_seqblocks: torch.Tensor,
+ block_size: int,
+):
+ num_kv_heads, total_len, topk = topk_idx.shape
+ batch_size = cu_seqlens.shape[0] - 1
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
+ max_seqlen = seq_lens.max().item()
+ # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk]
+ pad_topk_idx = torch.full(
+ (batch_size, num_kv_heads, max_seqlen, topk),
+ fill_value=-1,
+ device=topk_idx.device,
+ dtype=torch.int32,
+ )
+ BLOCK_SIZE_T = triton.next_power_of_2(topk)
+ BLOCK_SIZE_N = min(
+ triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T)
+ )
+ grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N))
+ pad_topk_idx_kernel[grid](
+ topk_idx,
+ pad_topk_idx,
+ cu_seqlens,
+ topk,
+ topk_idx.stride(0),
+ topk_idx.stride(1),
+ topk_idx.stride(2),
+ pad_topk_idx.stride(0),
+ pad_topk_idx.stride(1),
+ pad_topk_idx.stride(2),
+ pad_topk_idx.stride(3),
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
+ BLOCK_SIZE_T=BLOCK_SIZE_T,
+ )
+ # argsort
+ pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk
+ pad_topk_q_idx = pad_topk_q_idx.to(torch.int32)
+ # save as remove pad version
+ topk_q_idx = torch.full(
+ (num_kv_heads, cu_topk_q_count[:, -1].max().item()),
+ fill_value=-1,
+ device=topk_idx.device,
+ dtype=torch.int32,
+ )
+ max_len = (
+ (
+ cu_topk_q_count[:, cu_seqblocks][:, 1:]
+ - cu_topk_q_count[:, cu_seqblocks][:, :-1]
+ )
+ .max()
+ .item()
+ )
+
+ BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192)
+ grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N))
+ save_topk_idx_kernel[grid](
+ pad_topk_q_idx,
+ topk_q_idx,
+ cu_seqblocks,
+ cu_topk_q_count,
+ pad_topk_q_idx.shape[-1],
+ pad_topk_q_idx.stride(0),
+ pad_topk_q_idx.stride(1),
+ pad_topk_q_idx.stride(2),
+ topk_q_idx.stride(0),
+ topk_q_idx.stride(1),
+ cu_topk_q_count.stride(0),
+ cu_topk_q_count.stride(1),
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
+ )
+ return topk_q_idx
+
+
+@triton.jit
+def backward_dkdv(
+ q_ptr, # Q: n x qh x d
+ k_ptr, # K: n x kh x d
+ v_ptr, # V: n x kh x d
+ tq_ptr, # topk_q_idx: kh x N
+ lse_ptr, # LSE: qh x n
+ d_ptr, # Delta: qh x n
+ do_ptr,
+ dk_ptr, # DK: sh x n x kh x d
+ dv_ptr, # DK: sh x n x kh x d
+ # seqlens
+ cu_seqlens_q, # [batch_size + 1]
+ cu_seqlens_k, # [batch_size + 1]
+ cu_seqblocks, # [batch_size + 1]
+ cu_topk_q_count, # [kh, total_blocks]
+ cu_block_include_tokens, # [total_blocks + 1]
+ # shape
+ NUM_KV_HEADS,
+ NUM_SHARE_Q_HEADS,
+ HEAD_DIM,
+ TOPK,
+ max_seqblocks,
+ # sm_scale
+ sm_scale,
+ # stride
+ stride_qn,
+ stride_qh,
+ stride_qd,
+ stride_kn,
+ stride_kh,
+ stride_kd,
+ stride_vn,
+ stride_vh,
+ stride_vd,
+ stride_tqh,
+ stride_tqn,
+ stride_ctqh,
+ stride_ctqn,
+ stride_lh,
+ stride_ln,
+ stride_dh,
+ stride_dn,
+ stride_don,
+ stride_doh,
+ stride_dod,
+ stride_dks,
+ stride_dkn,
+ stride_dkh,
+ stride_dkd,
+ stride_dvs,
+ stride_dvn,
+ stride_dvh,
+ stride_dvd,
+ # META parameters
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
+ BLOCK_SIZE_K: tl.constexpr, # k block size
+ BLOCK_SIZE_D: tl.constexpr,
+):
+ qk_scale = sm_scale * 1.44269504
+ # get batch id and head id
+ pid_b = tl.program_id(0)
+ pid_h = tl.program_id(1)
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
+ pid_sh = pid_h % NUM_SHARE_Q_HEADS
+ pid_kb = tl.program_id(2)
+ pid_k = pid_kb % max_seqblocks
+ pid_block = pid_kb // max_seqblocks
+
+ # get q k start and len after rmpad
+ q_start = tl.load(cu_seqlens_q + pid_b)
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
+ k_start = tl.load(cu_seqlens_k + pid_b)
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
+ # get topk_q_idx
+ b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence
+ b_len = tl.load(cu_seqblocks + pid_b + 1) - b_start
+
+ if pid_k >= b_len:
+ return
+
+ cur_token_start = tl.load(cu_block_include_tokens + b_start + pid_k).to(tl.int32)
+ cur_block_size = tl.load(cu_block_include_tokens + b_start + pid_k + 1).to(tl.int32) - cur_token_start
+ cur_token_start_in_seq = cur_token_start - k_start
+
+ if pid_block * BLOCK_SIZE_K >= cur_block_size:
+ return
+
+ act_q_start = tl.load(
+ cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn
+ )
+ act_q_end = tl.load(
+ cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn
+ )
+ act_q_len = act_q_end - act_q_start
+ tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn
+ # offsets
+ off_q = tl.arange(0, BLOCK_SIZE_Q)
+ off_k = tl.arange(0, BLOCK_SIZE_K) #+ pid_k * BLOCK_SIZE_K
+ off_d = tl.arange(0, BLOCK_SIZE_D)
+ # init pointers
+ k_ptrs = tl.make_block_ptr(
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
+ shape=(k_len, HEAD_DIM),
+ strides=(stride_kn, stride_kd),
+ offsets=(cur_token_start_in_seq + pid_block * BLOCK_SIZE_K, 0), #(pid_k * BLOCK_SIZE_K, 0),
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ dk_ptrs = dk_ptr + (cur_token_start + pid_block * BLOCK_SIZE_K + off_k[:, None]) * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks + off_d[None, :] * stride_dkd
+ v_ptrs = tl.make_block_ptr(
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
+ shape=(k_len, HEAD_DIM),
+ strides=(stride_vn, stride_vd),
+ offsets=(cur_token_start_in_seq + pid_block * BLOCK_SIZE_K, 0),
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ dv_ptrs = dv_ptr + (cur_token_start + pid_block * BLOCK_SIZE_K + off_k[:, None]) * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs + off_d[None, :] * stride_dvd
+ # load k v and keep in SRAM
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
+ # init dk dv
+ dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
+ dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
+ # init ptrs
+ q_ptrs = (
+ q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd
+ )
+ do_ptrs = (
+ do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod
+ )
+ d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh
+ lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh
+ # loop for q blocks
+ for i in range(0, act_q_len, BLOCK_SIZE_Q):
+ # load
+ idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(
+ tl.int32
+ )
+ q = tl.load(
+ q_ptrs + idx_q[:, None] * stride_qn,
+ mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],
+ other=0,
+ )
+ do = tl.load(
+ do_ptrs + idx_q[:, None] * stride_don,
+ mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],
+ other=0,
+ )
+ lse = tl.load(
+ lse_ptrs + idx_q[:, None] * stride_ln,
+ mask=(off_q < act_q_len - i)[:, None],
+ other=0,
+ )
+ d = tl.load(
+ d_ptrs + idx_q[:, None] * stride_dn,
+ mask=(off_q < act_q_len - i)[:, None],
+ other=0,
+ )
+ # compute qk
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
+ qk += tl.where(((pid_block * BLOCK_SIZE_K + off_k < cur_block_size)[None, :] & (off_q < act_q_len - i)[:, None]), float(0.0), float("-inf"))
+ qk += tl.dot(q, k.T) * qk_scale
+ # compute p, ds
+ p = tl.exp2(qk - lse)
+ dp = tl.dot(do, v.T)
+ ds = sm_scale * p * (dp - d)
+ # cast dtype
+ p = p.to(do.dtype)
+ ds = ds.to(q.dtype)
+ # update dk and dv
+ dk += tl.dot(ds.T, q)
+ dv += tl.dot(p.T, do)
+ # save dk dv
+ tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), mask=(pid_block * BLOCK_SIZE_K + off_k < cur_block_size)[:, None])
+ tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), mask=(pid_block * BLOCK_SIZE_K + off_k < cur_block_size)[:, None])
+
+
+@triton.jit
+def backward_dq(
+ q_ptr, # Q: n x qh x d
+ k_ptr, # K: n x kh x d
+ v_ptr, # V: n x kh x d
+ t_ptr, # topk_idx: kh x n x k
+ lse_ptr, # LSE: qh x n
+ d_ptr, # Delta: qh x n
+ do_ptr,
+ dq_ptr,
+ # seqlens
+ cu_seqlens_q,
+ cu_seqlens_k,
+ cu_seqblocks,
+ cu_block_include_tokens,
+ # shape
+ NUM_KV_HEADS,
+ NUM_SHARE_Q_HEADS,
+ HEAD_DIM,
+ TOPK,
+ # q loop num
+ num_q_loop,
+ # sm_scale
+ sm_scale,
+ # stride
+ stride_qn,
+ stride_qh,
+ stride_qd,
+ stride_kn,
+ stride_kh,
+ stride_kd,
+ stride_vn,
+ stride_vh,
+ stride_vd,
+ stride_th,
+ stride_tn,
+ stride_tk,
+ stride_lh,
+ stride_ln,
+ stride_dh,
+ stride_dn,
+ stride_don,
+ stride_doh,
+ stride_dod,
+ stride_dqn,
+ stride_dqh,
+ stride_dqd,
+ # META parameters
+ BLOCK_SIZE_K: tl.constexpr, # k block size
+ BLOCK_SIZE_D: tl.constexpr,
+ BLOCK_SIZE_H: tl.constexpr,
+ BLOCK_SIZE_T: tl.constexpr,
+):
+ qk_scale = sm_scale * 1.44269504
+ # get batch id and head id
+ pid_b = tl.program_id(0)
+ pid_kh = tl.program_id(1)
+ pid_q = tl.program_id(2)
+ pid_h = pid_kh * NUM_SHARE_Q_HEADS
+ # get q k start and len after rmpad
+ q_start = tl.load(cu_seqlens_q + pid_b)
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
+ k_start = tl.load(cu_seqlens_k + pid_b)
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
+ block_start = tl.load(cu_seqblocks + pid_b)
+ block_len = tl.load(cu_seqblocks + pid_b + 1) - block_start
+ if pid_q * num_q_loop >= q_len:
+ return
+ num_q_loop_ = min(num_q_loop, q_len - pid_q * num_q_loop)
+ for j in range(num_q_loop_):
+ pid_q_j = pid_q * num_q_loop + j
+ # init topk idx pointer
+ off_t = tl.arange(0, BLOCK_SIZE_T)
+ t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th
+ topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1)
+ real_topk = tl.sum(
+ tl.where((topk_idx >= 0) & (topk_idx < block_len), 1, 0),
+ axis=0,
+ )
+ # init pointers
+ q_ptrs = tl.make_block_ptr(
+ base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh,
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
+ strides=(stride_qh, stride_qd),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ dq_ptrs = tl.make_block_ptr(
+ base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh,
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
+ strides=(stride_dqh, stride_dqd),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ k_ptrs = tl.make_block_ptr(
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
+ shape=(k_len, HEAD_DIM),
+ strides=(stride_kn, stride_kd),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ v_ptrs = tl.make_block_ptr(
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
+ shape=(k_len, HEAD_DIM),
+ strides=(stride_vn, stride_vd),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ do_ptrs = tl.make_block_ptr(
+ base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh,
+ shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
+ strides=(stride_doh, stride_dod),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ d_ptrs = tl.make_block_ptr(
+ base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh,
+ shape=(NUM_SHARE_Q_HEADS, 1),
+ strides=(stride_dh, stride_dn),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_H, 1),
+ order=(1, 0),
+ )
+ lse_ptrs = tl.make_block_ptr(
+ base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh,
+ shape=(NUM_SHARE_Q_HEADS, 1),
+ strides=(stride_lh, stride_ln),
+ offsets=(0, 0),
+ block_shape=(BLOCK_SIZE_H, 1),
+ order=(1, 0),
+ )
+ # offsets
+ off_k = tl.arange(0, BLOCK_SIZE_K)
+ # load q, do, lse, delta, and keep in SRAM
+ q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
+ do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
+ d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
+ # init dq
+ dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32)
+ # sparse
+ for i in range(real_topk):
+ # get current block start index
+ cur_block_idx = tl.load(t_ptr_j).to(tl.int32)
+ cur_token_start = tl.load(cu_block_include_tokens + block_start + cur_block_idx).to(tl.int32)
+ cur_block_size = tl.load(cu_block_include_tokens + block_start + cur_block_idx + 1).to(tl.int32) - cur_token_start
+ c = cur_token_start - k_start
+ t_ptr_j = t_ptr_j + stride_tk
+
+ for b_j in range(0, cur_block_size, BLOCK_SIZE_K):
+ # load kv
+ k = tl.load(
+ tl.advance(k_ptrs, (c + b_j, 0)), boundary_check=(1, 0), padding_option="zero"
+ )
+ v = tl.load(
+ tl.advance(v_ptrs, (c + b_j, 0)),boundary_check=(0, 1),padding_option="zero"
+ )
+ # compute qk
+ qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)
+ qk += tl.where((off_k + b_j < cur_block_size)[None, :], 0, float("-inf"))
+ # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K]
+ qk += tl.dot(q, tl.trans(k)) * qk_scale
+ # compute p, ds
+ p = tl.exp2(qk - lse)
+ dp = tl.dot(do, tl.trans(v))
+ ds = sm_scale * p * (dp - d)
+ # cast dtype
+ ds = ds.to(q.dtype)
+ # update dq
+ dq += tl.dot(ds, k)
+
+ # save dq
+ tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
+
+
+def _spatial_selection_attention_fwd(
+ q: torch.Tensor, # [total_len, num_q_heads, head_dim]
+ k: torch.Tensor, # [total_len, num_k_heads, head_dim]
+ v: torch.Tensor, # [total_len, num_k_heads, head_dim]
+ topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk]
+ cu_seqblocks: torch.Tensor,
+ cu_block_include_tokens: torch.Tensor,
+ block_size: int,
+ cu_seqlens_q: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_seqlen_q: int,
+ max_seqlen_k: int,
+ sm_scale: float,
+):
+ # dtype check
+ assert k.dtype == q.dtype and v.dtype == q.dtype
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
+ # shape
+ q_len, num_q_heads, head_dim = q.shape
+ k_len, num_k_heads, head_dim = k.shape
+ v_len, num_v_heads, head_dim = v.shape
+ batch_size = cu_seqlens_q.shape[0] - 1
+ assert q_len == k_len and k_len == v_len
+ topk = topk_idx.shape[-1]
+ assert topk_idx.shape[0] == num_k_heads
+ assert topk_idx.shape[1] == q_len
+ # gqa
+ assert num_k_heads == num_v_heads
+ assert num_q_heads % num_k_heads == 0
+ num_share_q_heads = num_q_heads // num_k_heads
+ # output tensor
+ o = torch.zeros_like(q)
+ lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device)
+ # launch kernel
+ num_q_loop = (
+ cu_seqlens_q[-1].item() // 32768 + 1
+ ) # calculate multiple querys in one kernel if seqlence length is too long
+ grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop))
+ num_warps = 4 if head_dim <= 64 else 8
+ num_stages = 3
+ BLOCK_SIZE_K = 64
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
+ BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))
+ BLOCK_SIZE_T = triton.next_power_of_2(topk)
+ forward_kernel[grid](
+ q,
+ k,
+ v,
+ topk_idx,
+ o,
+ lse,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ cu_seqblocks,
+ cu_block_include_tokens,
+ num_k_heads,
+ num_share_q_heads,
+ head_dim,
+ topk,
+ num_q_loop,
+ sm_scale,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ topk_idx.stride(0),
+ topk_idx.stride(1),
+ topk_idx.stride(2),
+ o.stride(0),
+ o.stride(1),
+ o.stride(2),
+ lse.stride(0),
+ lse.stride(1),
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
+ BLOCK_SIZE_H=BLOCK_SIZE_H,
+ BLOCK_SIZE_T=BLOCK_SIZE_T,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return o, lse
+
+
+def _spatial_selection_attention_bwd(
+ o: torch.Tensor,
+ do: torch.Tensor,
+ lse: torch.Tensor,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ topk_idx: torch.Tensor,
+ cu_seqblocks: torch.Tensor,
+ cu_block_include_tokens: torch.Tensor,
+ block_size: int,
+ cu_seqlens_q: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_seqlen_q: int,
+ max_seqlen_k: int,
+ sm_scale: float,
+):
+ q_len, num_q_heads, head_dim = q.shape
+ k_len, num_k_heads, head_dim = k.shape
+ v_len, num_v_heads, head_dim = v.shape
+ o_len, num_o_heads, head_dim = o.shape
+ num_share_q_heads = num_q_heads // num_k_heads
+ topk = topk_idx.shape[-1]
+ # compute D
+ delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
+ BLOCK_SIZE_O = 256
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
+ num_warps = 4 if head_dim <= 64 else 8
+ num_stages = 3
+ grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads)
+ backward_sum_o_do[grid](
+ o,
+ do,
+ delta,
+ o_len,
+ head_dim,
+ o.stride(0),
+ o.stride(1),
+ o.stride(2),
+ do.stride(0),
+ do.stride(1),
+ do.stride(2),
+ delta.stride(0),
+ delta.stride(1),
+ BLOCK_SIZE_O=BLOCK_SIZE_O,
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ # count active querys for each key block, shape: (num_k_heads, total_k_blocks)
+ topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size) # [num_kv_head, total_block]
+ cu_topk_q_count = torch.cat(
+ [
+ torch.zeros(
+ topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device
+ ),
+ torch.cumsum(topk_q_count, dim=-1),
+ ],
+ dim=-1,
+ ).to(torch.int32) # [num_kv_head, cu_total_block + 1]
+ # active query idx for each key block
+ # how to get active query idx for sequence b, head h, kv block i?
+ # topk_q_idx[h, cu_topk_q_count[h, cu_seqblocks[b] + i]:cu_topk_q_count[h, cu_seqblocks[b] + i + 1]]
+ topk_q_idx = reorder_topk_idx(
+ topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size
+ )
+ # compute dk dv
+ dk = torch.zeros(
+ num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
+ )
+ dv = torch.zeros(
+ num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
+ )
+ batch_size = cu_seqlens_q.shape[0] - 1
+ num_warps = 4 if head_dim <= 64 else 8
+ num_stages = 3
+ max_include_block = (cu_block_include_tokens[..., 1:] - cu_block_include_tokens[..., :-1]).max().item()
+ BLOCK_SIZE_K = 64
+ BLOCK_SIZE_Q = 128 if BLOCK_SIZE_K <= 64 else 64
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
+ max_seqblocks = (cu_seqblocks[1:] - cu_seqblocks[:-1]).max().item()
+ grid = (batch_size, num_q_heads, max_seqblocks * triton.cdiv(max_include_block, BLOCK_SIZE_K))
+
+ backward_dkdv[grid](
+ q,
+ k,
+ v,
+ topk_q_idx,
+ lse,
+ delta,
+ do,
+ dk,
+ dv,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ cu_seqblocks,
+ cu_topk_q_count,
+ cu_block_include_tokens,
+ num_k_heads,
+ num_share_q_heads,
+ head_dim,
+ topk,
+ max_seqblocks,
+ sm_scale,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ topk_q_idx.stride(0),
+ topk_q_idx.stride(1),
+ cu_topk_q_count.stride(0),
+ cu_topk_q_count.stride(1),
+ lse.stride(0),
+ lse.stride(1),
+ delta.stride(0),
+ delta.stride(1),
+ do.stride(0),
+ do.stride(1),
+ do.stride(2),
+ dk.stride(0),
+ dk.stride(1),
+ dk.stride(2),
+ dk.stride(3),
+ dv.stride(0),
+ dv.stride(1),
+ dv.stride(2),
+ dv.stride(3),
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ dk = dk.sum(0)
+ dv = dv.sum(0)
+ # compute dq
+ dq = torch.zeros_like(q)
+ num_q_loop = (
+ cu_seqlens_q[-1].item() // 32768 + 1
+ ) # calculate multiple querys in one kernel if seqlence length is too long
+ grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop))
+ num_warps = 4 if head_dim <= 64 else 8
+ num_stages = 3
+ BLOCK_SIZE_K = 64
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
+ BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))
+ BLOCK_SIZE_T = triton.next_power_of_2(topk)
+ backward_dq[grid](
+ q,
+ k,
+ v,
+ topk_idx,
+ lse,
+ delta,
+ do,
+ dq,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ cu_seqblocks,
+ cu_block_include_tokens,
+ num_k_heads,
+ num_share_q_heads,
+ head_dim,
+ topk,
+ num_q_loop,
+ sm_scale,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ topk_idx.stride(0),
+ topk_idx.stride(1),
+ topk_idx.stride(2),
+ lse.stride(0),
+ lse.stride(1),
+ delta.stride(0),
+ delta.stride(1),
+ do.stride(0),
+ do.stride(1),
+ do.stride(2),
+ dq.stride(0),
+ dq.stride(1),
+ dq.stride(2),
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
+ BLOCK_SIZE_H=BLOCK_SIZE_H,
+ BLOCK_SIZE_T=BLOCK_SIZE_T,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return dq, dk, dv
+
+
+class SpatialSelectionAttention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ q: torch.Tensor, # [total_len, num_q_heads, head_dim]
+ k: torch.Tensor, # [total_len, num_k_heads, head_dim]
+ v: torch.Tensor, # [total_len, num_k_heads, head_dim]
+ topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk]
+ cu_seqblocks: torch.Tensor, # [batch_size + 1]
+ cu_block_include_tokens: torch.Tensor, # [total_block_len]
+ block_size: int,
+ cu_seqlens_q: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_seqlen_q: torch.Tensor,
+ max_seqlen_k: torch.Tensor,
+ sm_scale=None,
+ ):
+ # dtype check
+ assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
+ assert q.dtype == k.dtype and k.dtype == v.dtype
+ assert topk_idx.dtype == torch.int32
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
+ # softmax scale
+ if sm_scale is None:
+ sm_scale = 1 / math.sqrt(q.shape[-1])
+ o, lse = _spatial_selection_attention_fwd(
+ q,
+ k,
+ v,
+ topk_idx,
+ cu_seqblocks,
+ cu_block_include_tokens,
+ block_size,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ sm_scale,
+ )
+ ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx, cu_seqblocks, cu_block_include_tokens)
+ ctx.sm_scale = sm_scale
+ ctx.max_seqlen_q = max_seqlen_q
+ ctx.max_seqlen_k = max_seqlen_k
+ ctx.block_size = block_size
+ # return
+ return o
+
+ @staticmethod
+ def backward(ctx, do: torch.Tensor, *args) -> Any:
+ q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx, cu_seqblocks, cu_block_include_tokens = ctx.saved_tensors
+ max_seqlen_q = ctx.max_seqlen_q
+ max_seqlen_k = ctx.max_seqlen_k
+ sm_scale = ctx.sm_scale
+ block_size = ctx.block_size
+
+ dq, dk, dv = _spatial_selection_attention_bwd(
+ o,
+ do,
+ lse,
+ q,
+ k,
+ v,
+ topk_idx,
+ cu_seqblocks,
+ cu_block_include_tokens,
+ block_size,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ sm_scale,
+ )
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None
+
+
+def spatial_selection_attention(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ block_topk: torch.Tensor,
+ cu_seqblocks: torch.Tensor,
+ cu_block_include_tokens: torch.Tensor,
+ block_size: int,
+ cu_seqlens: torch.Tensor,
+ softmax_scale: Optional[float] = None,
+) -> torch.Tensor:
+ """Spatial selection attention implemented in triton.
+
+ Args:
+ q (torch.Tensor): shape [total_len, num_q_heads, head_dim]
+ k (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
+ v (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
+ block_topk (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding.
+ block_size (int): key value block size.
+ cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.
+ cu_block_include_tokens (torch.Tensor) shape [total_block_len]: number of tokens within each block
+ softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).
+
+ Returns:
+ torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim]
+ """
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ return SpatialSelectionAttention.apply(
+ q,
+ k,
+ v,
+ block_topk,
+ cu_seqblocks,
+ cu_block_include_tokens,
+ block_size,
+ cu_seqlens,
+ cu_seqlens,
+ max_seqlen,
+ max_seqlen,
+ softmax_scale,
+ )
diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/window_attention.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/window_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c8f8b420d6f3d533d1f00bcacebf37805089826
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/window_attention.py
@@ -0,0 +1,59 @@
+from typing import *
+import torch
+import flash_attn
+from direct3d_s2.modules.sparse import SparseTensor
+from direct3d_s2.modules.sparse.attention.windowed_attn import calc_window_partition
+
+
+def sparse_window_attention(
+ q: SparseTensor,
+ k: SparseTensor,
+ v: SparseTensor,
+ window_size: int,
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
+) -> SparseTensor:
+ """
+ Apply windowed scaled dot product self attention to a sparse tensor.
+
+ Args:
+ q (SparseTensor): [N, *, H_q, C] sparse tensor containing query.
+ k (SparseTensor): [N, *, H_kv, C] sparse tensor containing key.
+ v (SparseTensor): [N, *, H_kv, C] sparse tensor containing value.
+ window_size (int): The window size to use.
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
+ shift (int): The shift to use.
+ """
+
+ serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
+ serialization_spatial_cache = q.get_spatial_cache(serialization_spatial_cache_name)
+ if serialization_spatial_cache is None:
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(q, window_size, shift_window)
+ q.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
+ else:
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
+
+ M = fwd_indices.shape[0]
+ T = q.feats.shape[0]
+ H = q.feats.shape[1]
+ H_kv = k.feats.shape[1]
+ C = q.feats.shape[2]
+ q_feats = q.feats[fwd_indices] # [M, H, C]
+ k_feats = k.feats[fwd_indices]
+ v_feats = v.feats[fwd_indices]
+
+ if all([seq_len == window_size for seq_len in seq_lens]):
+ B = len(seq_lens)
+ N = window_size
+ q_feats = q_feats.reshape(B, N, H, C)
+ k_feats = k_feats.reshape(B, N, H_kv, C)
+ v_feats = v_feats.reshape(B, N, H_kv, C)
+ out = flash_attn.flash_attn_func(q_feats, k_feats, v_feats)
+ out = out.reshape(B * N, H, C) # [M, H, C]
+ else:
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
+ .to(q.device).int()
+ out = flash_attn.flash_attn_varlen_func(q_feats, k_feats, v_feats, cu_seqlens, cu_seqlens, max(seq_lens), max(seq_lens))
+
+ out = out[bwd_indices] # [T, H, C]
+
+ return q.replace(out)
\ No newline at end of file
diff --git a/direct3d_s2/modules/sparse/attention/windowed_attn.py b/direct3d_s2/modules/sparse/attention/windowed_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c450398c634ba314c0f2100bf4949207a40847b
--- /dev/null
+++ b/direct3d_s2/modules/sparse/attention/windowed_attn.py
@@ -0,0 +1,133 @@
+from typing import *
+import torch
+import math
+from .. import SparseTensor
+from .. import DEBUG, ATTN
+
+if ATTN == 'xformers':
+ import xformers.ops as xops
+elif ATTN == 'flash_attn':
+ import flash_attn
+else:
+ raise ValueError(f"Unknown attention module: {ATTN}")
+
+
+__all__ = [
+ 'sparse_windowed_scaled_dot_product_self_attention',
+]
+
+
+def calc_window_partition(
+ tensor: SparseTensor,
+ window_size: Union[int, Tuple[int, ...]],
+ shift_window: Union[int, Tuple[int, ...]] = 0
+) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
+ """
+ Calculate serialization and partitioning for a set of coordinates.
+
+ Args:
+ tensor (SparseTensor): The input tensor.
+ window_size (int): The window size to use.
+ shift_window (Tuple[int, ...]): The shift of serialized coordinates.
+
+ Returns:
+ (torch.Tensor): Forwards indices.
+ (torch.Tensor): Backwards indices.
+ (List[int]): Sequence lengths.
+ (List[int]): Sequence batch indices.
+ """
+ DIM = tensor.coords.shape[1] - 1
+ shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
+ window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
+ shifted_coords = tensor.coords.clone().detach()
+ shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
+
+ MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
+ NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
+ OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
+
+ shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
+ shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
+ fwd_indices = torch.argsort(shifted_indices)
+ bwd_indices = torch.empty_like(fwd_indices)
+ bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
+ seq_lens = torch.bincount(shifted_indices)
+ seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
+ mask = seq_lens != 0
+ seq_lens = seq_lens[mask].tolist()
+ seq_batch_indices = seq_batch_indices[mask].tolist()
+
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
+
+
+def sparse_windowed_scaled_dot_product_self_attention(
+ qkv: SparseTensor,
+ window_size: int,
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
+) -> SparseTensor:
+ """
+ Apply windowed scaled dot product self attention to a sparse tensor.
+
+ Args:
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
+ window_size (int): The window size to use.
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
+ shift (int): The shift to use.
+ """
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
+
+ serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
+ if serialization_spatial_cache is None:
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
+ else:
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
+
+ M = fwd_indices.shape[0]
+ T = qkv.feats.shape[0]
+ H = qkv.feats.shape[2]
+ C = qkv.feats.shape[3]
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
+ if DEBUG:
+ start = 0
+ qkv_coords = qkv.coords[fwd_indices]
+ for i in range(len(seq_lens)):
+ seq_coords = qkv_coords[start:start+seq_lens[i]]
+ assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
+ assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
+ f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
+ start += seq_lens[i]
+
+ if all([seq_len == window_size for seq_len in seq_lens]):
+ B = len(seq_lens)
+ N = window_size
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
+ if ATTN == 'xformers':
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
+ elif ATTN == 'flash_attn':
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
+ else:
+ raise ValueError(f"Unknown attention module: {ATTN}")
+ out = out.reshape(B * N, H, C) # [M, H, C]
+ else:
+ if ATTN == 'xformers':
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
+ q = q.unsqueeze(0) # [1, M, H, C]
+ k = k.unsqueeze(0) # [1, M, H, C]
+ v = v.unsqueeze(0) # [1, M, H, C]
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
+ elif ATTN == 'flash_attn':
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
+ .to(qkv.device).int()
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
+
+ out = out[bwd_indices] # [T, H, C]
+
+ if DEBUG:
+ qkv_coords = qkv_coords[bwd_indices]
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
+
+ return qkv.replace(out)
diff --git a/direct3d_s2/modules/sparse/basic.py b/direct3d_s2/modules/sparse/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..8837f44052f6d573d09e3bfb897e659e10516bb5
--- /dev/null
+++ b/direct3d_s2/modules/sparse/basic.py
@@ -0,0 +1,459 @@
+from typing import *
+import torch
+import torch.nn as nn
+from . import BACKEND, DEBUG
+SparseTensorData = None # Lazy import
+
+
+__all__ = [
+ 'SparseTensor',
+ 'sparse_batch_broadcast',
+ 'sparse_batch_op',
+ 'sparse_cat',
+ 'sparse_unbind',
+]
+
+
+class SparseTensor:
+ """
+ Sparse tensor with support for both torchsparse and spconv backends.
+
+ Parameters:
+ - feats (torch.Tensor): Features of the sparse tensor.
+ - coords (torch.Tensor): Coordinates of the sparse tensor.
+ - shape (torch.Size): Shape of the sparse tensor.
+ - layout (List[slice]): Layout of the sparse tensor for each batch
+ - data (SparseTensorData): Sparse tensor data used for convolusion
+
+ NOTE:
+ - Data corresponding to a same batch should be contiguous.
+ - Coords should be in [0, 1023]
+ """
+ @overload
+ def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
+
+ @overload
+ def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
+
+ def __init__(self, *args, **kwargs):
+ # Lazy import of sparse tensor backend
+ global SparseTensorData
+ if SparseTensorData is None:
+ import importlib
+ if BACKEND == 'torchsparse':
+ SparseTensorData = importlib.import_module('torchsparse').SparseTensor
+ elif BACKEND == 'spconv':
+ SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
+
+ method_id = 0
+ if len(args) != 0:
+ method_id = 0 if isinstance(args[0], torch.Tensor) else 1
+ else:
+ method_id = 1 if 'data' in kwargs else 0
+
+ if method_id == 0:
+ feats, coords, shape, layout = args + (None,) * (4 - len(args))
+ if 'feats' in kwargs:
+ feats = kwargs['feats']
+ del kwargs['feats']
+ if 'coords' in kwargs:
+ coords = kwargs['coords']
+ del kwargs['coords']
+ if 'shape' in kwargs:
+ shape = kwargs['shape']
+ del kwargs['shape']
+ if 'layout' in kwargs:
+ layout = kwargs['layout']
+ del kwargs['layout']
+
+ if shape is None:
+ shape = self.__cal_shape(feats, coords)
+ if layout is None:
+ layout = self.__cal_layout(coords, shape[0])
+ if BACKEND == 'torchsparse':
+ self.data = SparseTensorData(feats, coords, **kwargs)
+ elif BACKEND == 'spconv':
+ spatial_shape = list(coords.max(0)[0] + 1)[1:]
+ self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
+ self.data._features = feats
+ elif method_id == 1:
+ data, shape, layout = args + (None,) * (3 - len(args))
+ if 'data' in kwargs:
+ data = kwargs['data']
+ del kwargs['data']
+ if 'shape' in kwargs:
+ shape = kwargs['shape']
+ del kwargs['shape']
+ if 'layout' in kwargs:
+ layout = kwargs['layout']
+ del kwargs['layout']
+
+ self.data = data
+ if shape is None:
+ shape = self.__cal_shape(self.feats, self.coords)
+ if layout is None:
+ layout = self.__cal_layout(self.coords, shape[0])
+
+ self._shape = shape
+ self._layout = layout
+ self._scale = kwargs.get('scale', (1, 1, 1))
+ self._spatial_cache = kwargs.get('spatial_cache', {})
+
+ if DEBUG:
+ try:
+ assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
+ assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
+ assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
+ for i in range(self.shape[0]):
+ assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
+ except Exception as e:
+ print('Debugging information:')
+ print(f"- Shape: {self.shape}")
+ print(f"- Layout: {self.layout}")
+ print(f"- Scale: {self._scale}")
+ print(f"- Coords: {self.coords}")
+ raise e
+
+ def __cal_shape(self, feats, coords):
+ shape = []
+ shape.append(coords[:, 0].max().item() + 1)
+ shape.extend([*feats.shape[1:]])
+ return torch.Size(shape)
+
+ def __cal_layout(self, coords, batch_size):
+ seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
+ offset = torch.cumsum(seq_len, dim=0)
+ layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
+ return layout
+
+ @property
+ def shape(self) -> torch.Size:
+ return self._shape
+
+ def dim(self) -> int:
+ return len(self.shape)
+
+ @property
+ def layout(self) -> List[slice]:
+ return self._layout
+
+ @property
+ def feats(self) -> torch.Tensor:
+ if BACKEND == 'torchsparse':
+ return self.data.F
+ elif BACKEND == 'spconv':
+ return self.data.features
+
+ @feats.setter
+ def feats(self, value: torch.Tensor):
+ if BACKEND == 'torchsparse':
+ self.data.F = value
+ elif BACKEND == 'spconv':
+ self.data.features = value
+
+ @property
+ def coords(self) -> torch.Tensor:
+ if BACKEND == 'torchsparse':
+ return self.data.C
+ elif BACKEND == 'spconv':
+ return self.data.indices
+
+ @coords.setter
+ def coords(self, value: torch.Tensor):
+ if BACKEND == 'torchsparse':
+ self.data.C = value
+ elif BACKEND == 'spconv':
+ self.data.indices = value
+
+ @property
+ def dtype(self):
+ return self.feats.dtype
+
+ @property
+ def device(self):
+ return self.feats.device
+
+ @overload
+ def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
+
+ @overload
+ def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
+
+ def to(self, *args, **kwargs) -> 'SparseTensor':
+ device = None
+ dtype = None
+ if len(args) == 2:
+ device, dtype = args
+ elif len(args) == 1:
+ if isinstance(args[0], torch.dtype):
+ dtype = args[0]
+ else:
+ device = args[0]
+ if 'dtype' in kwargs:
+ assert dtype is None, "to() received multiple values for argument 'dtype'"
+ dtype = kwargs['dtype']
+ if 'device' in kwargs:
+ assert device is None, "to() received multiple values for argument 'device'"
+ device = kwargs['device']
+
+ new_feats = self.feats.to(device=device, dtype=dtype)
+ new_coords = self.coords.to(device=device)
+ return self.replace(new_feats, new_coords)
+
+ def type(self, dtype):
+ new_feats = self.feats.type(dtype)
+ return self.replace(new_feats)
+
+ def cpu(self) -> 'SparseTensor':
+ new_feats = self.feats.cpu()
+ new_coords = self.coords.cpu()
+ return self.replace(new_feats, new_coords)
+
+ def cuda(self) -> 'SparseTensor':
+ new_feats = self.feats.cuda()
+ new_coords = self.coords.cuda()
+ return self.replace(new_feats, new_coords)
+
+ def half(self) -> 'SparseTensor':
+ new_feats = self.feats.half()
+ return self.replace(new_feats)
+
+ def float(self) -> 'SparseTensor':
+ new_feats = self.feats.float()
+ return self.replace(new_feats)
+
+ def detach(self) -> 'SparseTensor':
+ new_coords = self.coords.detach()
+ new_feats = self.feats.detach()
+ return self.replace(new_feats, new_coords)
+
+ def dense(self) -> torch.Tensor:
+ if BACKEND == 'torchsparse':
+ return self.data.dense()
+ elif BACKEND == 'spconv':
+ return self.data.dense()
+
+ def reshape(self, *shape) -> 'SparseTensor':
+ new_feats = self.feats.reshape(self.feats.shape[0], *shape)
+ return self.replace(new_feats)
+
+ def unbind(self, dim: int) -> List['SparseTensor']:
+ return sparse_unbind(self, dim)
+
+ def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
+ new_shape = [self.shape[0]]
+ new_shape.extend(feats.shape[1:])
+ if BACKEND == 'torchsparse':
+ new_data = SparseTensorData(
+ feats=feats,
+ coords=self.data.coords if coords is None else coords,
+ stride=self.data.stride,
+ spatial_range=self.data.spatial_range,
+ )
+ new_data._caches = self.data._caches
+ elif BACKEND == 'spconv':
+ new_data = SparseTensorData(
+ self.data.features.reshape(self.data.features.shape[0], -1),
+ self.data.indices,
+ self.data.spatial_shape,
+ self.data.batch_size,
+ self.data.grid,
+ self.data.voxel_num,
+ self.data.indice_dict
+ )
+ new_data._features = feats
+ new_data.benchmark = self.data.benchmark
+ new_data.benchmark_record = self.data.benchmark_record
+ new_data.thrust_allocator = self.data.thrust_allocator
+ new_data._timer = self.data._timer
+ new_data.force_algo = self.data.force_algo
+ new_data.int8_scale = self.data.int8_scale
+ if coords is not None:
+ new_data.indices = coords
+ new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
+ return new_tensor
+
+ @staticmethod
+ def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
+ N, C = dim
+ x = torch.arange(aabb[0], aabb[3] + 1)
+ y = torch.arange(aabb[1], aabb[4] + 1)
+ z = torch.arange(aabb[2], aabb[5] + 1)
+ coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
+ coords = torch.cat([
+ torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
+ coords.repeat(N, 1),
+ ], dim=1).to(dtype=torch.int32, device=device)
+ feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
+ return SparseTensor(feats=feats, coords=coords)
+
+ def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
+ new_cache = {}
+ for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
+ if k in self._spatial_cache:
+ new_cache[k] = self._spatial_cache[k]
+ if k in other._spatial_cache:
+ if k not in new_cache:
+ new_cache[k] = other._spatial_cache[k]
+ else:
+ new_cache[k].update(other._spatial_cache[k])
+ return new_cache
+
+ def __neg__(self) -> 'SparseTensor':
+ return self.replace(-self.feats)
+
+ def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
+ if isinstance(other, torch.Tensor):
+ try:
+ other = torch.broadcast_to(other, self.shape)
+ other = sparse_batch_broadcast(self, other)
+ except:
+ pass
+ if isinstance(other, SparseTensor):
+ other = other.feats
+ new_feats = op(self.feats, other)
+ new_tensor = self.replace(new_feats)
+ if isinstance(other, SparseTensor):
+ new_tensor._spatial_cache = self.__merge_sparse_cache(other)
+ return new_tensor
+
+ def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.add)
+
+ def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.add)
+
+ def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.sub)
+
+ def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
+
+ def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.mul)
+
+ def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.mul)
+
+ def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, torch.div)
+
+ def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+ return self.__elemwise__(other, lambda x, y: torch.div(y, x))
+
+ def __getitem__(self, idx):
+ if isinstance(idx, int):
+ idx = [idx]
+ elif isinstance(idx, slice):
+ idx = range(*idx.indices(self.shape[0]))
+ elif isinstance(idx, torch.Tensor):
+ if idx.dtype == torch.bool:
+ assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
+ idx = idx.nonzero().squeeze(1)
+ elif idx.dtype in [torch.int32, torch.int64]:
+ assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
+ else:
+ raise ValueError(f"Unknown index type: {idx.dtype}")
+ else:
+ raise ValueError(f"Unknown index type: {type(idx)}")
+
+ coords = []
+ feats = []
+ for new_idx, old_idx in enumerate(idx):
+ coords.append(self.coords[self.layout[old_idx]].clone())
+ coords[-1][:, 0] = new_idx
+ feats.append(self.feats[self.layout[old_idx]])
+ coords = torch.cat(coords, dim=0).contiguous()
+ feats = torch.cat(feats, dim=0).contiguous()
+ return SparseTensor(feats=feats, coords=coords)
+
+ def register_spatial_cache(self, key, value) -> None:
+ """
+ Register a spatial cache.
+ The spatial cache can be any thing you want to cache.
+ The registery and retrieval of the cache is based on current scale.
+ """
+ scale_key = str(self._scale)
+ if scale_key not in self._spatial_cache:
+ self._spatial_cache[scale_key] = {}
+ self._spatial_cache[scale_key][key] = value
+
+ def get_spatial_cache(self, key=None):
+ """
+ Get a spatial cache.
+ """
+ scale_key = str(self._scale)
+ cur_scale_cache = self._spatial_cache.get(scale_key, {})
+ if key is None:
+ return cur_scale_cache
+ return cur_scale_cache.get(key, None)
+
+
+def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
+ """
+ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
+
+ Args:
+ input (torch.Tensor): 1D tensor to broadcast.
+ target (SparseTensor): Sparse tensor to broadcast to.
+ op (callable): Operation to perform after broadcasting. Defaults to torch.add.
+ """
+ coords, feats = input.coords, input.feats
+ broadcasted = torch.zeros_like(feats)
+ for k in range(input.shape[0]):
+ broadcasted[input.layout[k]] = other[k]
+ return broadcasted
+
+
+def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
+ """
+ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
+
+ Args:
+ input (torch.Tensor): 1D tensor to broadcast.
+ target (SparseTensor): Sparse tensor to broadcast to.
+ op (callable): Operation to perform after broadcasting. Defaults to torch.add.
+ """
+ return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
+
+
+def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
+ """
+ Concatenate a list of sparse tensors.
+
+ Args:
+ inputs (List[SparseTensor]): List of sparse tensors to concatenate.
+ """
+ if dim == 0:
+ start = 0
+ coords = []
+ for input in inputs:
+ coords.append(input.coords.clone())
+ coords[-1][:, 0] += start
+ start += input.shape[0]
+ coords = torch.cat(coords, dim=0)
+ feats = torch.cat([input.feats for input in inputs], dim=0)
+ output = SparseTensor(
+ coords=coords,
+ feats=feats,
+ )
+ else:
+ feats = torch.cat([input.feats for input in inputs], dim=dim)
+ output = inputs[0].replace(feats)
+
+ return output
+
+
+def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
+ """
+ Unbind a sparse tensor along a dimension.
+
+ Args:
+ input (SparseTensor): Sparse tensor to unbind.
+ dim (int): Dimension to unbind.
+ """
+ if dim == 0:
+ return [input[i] for i in range(input.shape[0])]
+ else:
+ feats = input.feats.unbind(dim)
+ return [input.replace(f) for f in feats]
diff --git a/direct3d_s2/modules/sparse/conv/__init__.py b/direct3d_s2/modules/sparse/conv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..340a87126a8de574ee0276feb96b49824a2ce234
--- /dev/null
+++ b/direct3d_s2/modules/sparse/conv/__init__.py
@@ -0,0 +1,21 @@
+from .. import BACKEND
+
+
+SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
+
+def __from_env():
+ import os
+
+ global SPCONV_ALGO
+ env_spconv_algo = os.environ.get('SPCONV_ALGO')
+ if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
+ SPCONV_ALGO = env_spconv_algo
+ print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
+
+
+__from_env()
+
+if BACKEND == 'torchsparse':
+ from .conv_torchsparse import *
+elif BACKEND == 'spconv':
+ from .conv_spconv import *
diff --git a/direct3d_s2/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70d75bb2732eb260f2c78ebe5d62e70335c2cfbc
Binary files /dev/null and b/direct3d_s2/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc b/direct3d_s2/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cbeac5a85eecb6501557b46c516dfba80fb6d4ae
Binary files /dev/null and b/direct3d_s2/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/conv/__pycache__/conv_torchsparse.cpython-310.pyc b/direct3d_s2/modules/sparse/conv/__pycache__/conv_torchsparse.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae86bfa9fa6dfe94672df2191848c09d88d17c5c
Binary files /dev/null and b/direct3d_s2/modules/sparse/conv/__pycache__/conv_torchsparse.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/conv/conv_spconv.py b/direct3d_s2/modules/sparse/conv/conv_spconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff058302f033f8f340c9f75efc869006cbf5b993
--- /dev/null
+++ b/direct3d_s2/modules/sparse/conv/conv_spconv.py
@@ -0,0 +1,86 @@
+import torch
+import torch.nn as nn
+from .. import SparseTensor
+from .. import DEBUG
+from . import SPCONV_ALGO
+
+class SparseConv3d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
+ super(SparseConv3d, self).__init__()
+ if 'spconv' not in globals():
+ import spconv.pytorch as spconv
+ algo = None
+ if SPCONV_ALGO == 'native':
+ algo = spconv.ConvAlgo.Native
+ elif SPCONV_ALGO == 'implicit_gemm':
+ algo = spconv.ConvAlgo.MaskImplicitGemm
+ if stride == 1 and (padding is None):
+ self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
+ else:
+ self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
+ self.padding = padding
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+
+ spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
+ new_data = self.conv(x.data)
+ new_shape = [x.shape[0], self.conv.out_channels]
+ new_layout = None if spatial_changed else x.layout
+
+ if spatial_changed and (x.shape[0] != 1):
+ # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
+ fwd = new_data.indices[:, 0].argsort()
+ bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
+ sorted_feats = new_data.features #[fwd]
+ sorted_coords = new_data.indices #[fwd]
+ unsorted_data = new_data
+
+ indice_dict = new_data.indice_dict
+
+ if 'spconv' not in globals():
+ import spconv.pytorch as spconv
+ new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size, indice_dict=indice_dict) # type: ignore
+
+ out = SparseTensor(
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
+ scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
+ spatial_cache=x._spatial_cache,
+ )
+
+ if spatial_changed and (x.shape[0] != 1):
+ out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
+ out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
+
+ return out
+
+
+class SparseInverseConv3d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
+ super(SparseInverseConv3d, self).__init__()
+ if 'spconv' not in globals():
+ import spconv.pytorch as spconv
+ self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ spatial_changed = any(s != 1 for s in self.stride)
+ if spatial_changed:
+ # recover the original spconv order
+ data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
+ bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
+ data = data.replace_feature(x.feats[bwd])
+ if DEBUG:
+ assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed'
+ else:
+ data = x.data
+
+ new_data = self.conv(data)
+ new_shape = [x.shape[0], self.conv.out_channels]
+ new_layout = None if spatial_changed else x.layout
+ out = SparseTensor(
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
+ scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
+ spatial_cache=x._spatial_cache,
+ )
+ return out
diff --git a/direct3d_s2/modules/sparse/conv/conv_torchsparse.py b/direct3d_s2/modules/sparse/conv/conv_torchsparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdc6561a2734d5c71794e24575741209b279de89
--- /dev/null
+++ b/direct3d_s2/modules/sparse/conv/conv_torchsparse.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+from .. import SparseTensor
+from torchsparse.utils import make_ntuple
+
+
+def sparseconv3d_func(input: SparseTensor, weight: torch.Tensor, kernel_size: int, stride: int = 1, dilation: int = 1, padding: int = 0, bias: torch.Tensor = None, training: bool = True):
+ if 'torchsparse' not in globals():
+ import torchsparse
+ stride = make_ntuple(stride, ndim=3)
+ kernel_size = make_ntuple(kernel_size, ndim=3)
+ _padding = make_ntuple(padding, 3)
+ padding = ()
+ for i in range(3):
+ if kernel_size[i] % 2 == 1 and stride[i] == 1:
+ padding += ((kernel_size[i] - 1) // 2,)
+ else:
+ padding += (_padding[i],)
+ out = torchsparse.nn.functional.conv3d(input.data, weight, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, training=training)
+ spatial_range = out.spatial_range
+ new_shape = [input.shape[0], weight.shape[1]]
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=input.layout if all(s == 1 for s in stride) else None)
+ out._spatial_cache = input._spatial_cache
+ out._scale = tuple([s * stride for s, stride in zip(input._scale, stride)])
+ out.data.spatial_range = spatial_range
+ return out
+
+class SparseConv3d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, bias=True, indice_key=None):
+ super(SparseConv3d, self).__init__()
+ if 'torchsparse' not in globals():
+ import torchsparse
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias)
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ out = self.conv(x.data)
+
+ spatial_range = out.spatial_range
+
+ new_shape = [x.shape[0], self.conv.out_channels]
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
+ out._spatial_cache = x._spatial_cache
+ out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
+
+ out.data.spatial_range = spatial_range
+
+ return out
+
+
+class SparseInverseConv3d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
+ super(SparseInverseConv3d, self).__init__()
+ if 'torchsparse' not in globals():
+ import torchsparse
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ out = self.conv(x.data)
+
+ new_shape = [x.shape[0], self.conv.out_channels]
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
+ out._spatial_cache = x._spatial_cache
+ out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)])
+
+ return out
+
+
+
diff --git a/direct3d_s2/modules/sparse/linear.py b/direct3d_s2/modules/sparse/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..a854e77ce87d1a190b9730d91f363a821ff250bd
--- /dev/null
+++ b/direct3d_s2/modules/sparse/linear.py
@@ -0,0 +1,15 @@
+import torch
+import torch.nn as nn
+from . import SparseTensor
+
+__all__ = [
+ 'SparseLinear'
+]
+
+
+class SparseLinear(nn.Linear):
+ def __init__(self, in_features, out_features, bias=True):
+ super(SparseLinear, self).__init__(in_features, out_features, bias)
+
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ return input.replace(super().forward(input.feats))
diff --git a/direct3d_s2/modules/sparse/nonlinearity.py b/direct3d_s2/modules/sparse/nonlinearity.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e6bfd855271d238fe34ab8bec2744bf9db58b94
--- /dev/null
+++ b/direct3d_s2/modules/sparse/nonlinearity.py
@@ -0,0 +1,42 @@
+import torch
+import torch.nn as nn
+from . import SparseTensor
+
+__all__ = [
+ 'SparseReLU',
+ 'SparseSiLU',
+ 'SparseGELU',
+ 'SparseActivation'
+ 'SparseTanh',
+ 'SparseSigmoid',
+]
+
+class SparseSigmoid(nn.Sigmoid):
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ return input.replace(super().forward(input.feats))
+
+class SparseReLU(nn.ReLU):
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ return input.replace(super().forward(input.feats))
+
+
+class SparseSiLU(nn.SiLU):
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ return input.replace(super().forward(input.feats))
+
+class SparseGELU(nn.GELU):
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ return input.replace(super().forward(input.feats))
+
+class SparseTanh(nn.Tanh):
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ return input.replace(super().forward(input.feats))
+
+class SparseActivation(nn.Module):
+ def __init__(self, activation: nn.Module):
+ super().__init__()
+ self.activation = activation
+
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ return input.replace(self.activation(input.feats))
+
diff --git a/direct3d_s2/modules/sparse/norm.py b/direct3d_s2/modules/sparse/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b38a36682c098210000dc31d68ddc31ccd2929d
--- /dev/null
+++ b/direct3d_s2/modules/sparse/norm.py
@@ -0,0 +1,58 @@
+import torch
+import torch.nn as nn
+from . import SparseTensor
+from . import DEBUG
+
+__all__ = [
+ 'SparseGroupNorm',
+ 'SparseLayerNorm',
+ 'SparseGroupNorm32',
+ 'SparseLayerNorm32',
+]
+
+
+class SparseGroupNorm(nn.GroupNorm):
+ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
+ super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
+
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ nfeats = torch.zeros_like(input.feats)
+ for k in range(input.shape[0]):
+ if DEBUG:
+ assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
+ bfeats = input.feats[input.layout[k]]
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
+ bfeats = super().forward(bfeats)
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
+ nfeats[input.layout[k]] = bfeats
+ return input.replace(nfeats)
+
+
+class SparseLayerNorm(nn.LayerNorm):
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
+ super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
+
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ nfeats = torch.zeros_like(input.feats)
+ for k in range(input.shape[0]):
+ bfeats = input.feats[input.layout[k]]
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
+ bfeats = super().forward(bfeats)
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
+ nfeats[input.layout[k]] = bfeats
+ return input.replace(nfeats)
+
+
+class SparseGroupNorm32(SparseGroupNorm):
+ """
+ A GroupNorm layer that converts to float32 before the forward pass.
+ """
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ return super().forward(x.float()).type(x.dtype)
+
+class SparseLayerNorm32(SparseLayerNorm):
+ """
+ A LayerNorm layer that converts to float32 before the forward pass.
+ """
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ return super().forward(x.float()).type(x.dtype)
diff --git a/direct3d_s2/modules/sparse/spatial.py b/direct3d_s2/modules/sparse/spatial.py
new file mode 100644
index 0000000000000000000000000000000000000000..557438b50d9aaa99dee5d36c5f8a8a042ac70017
--- /dev/null
+++ b/direct3d_s2/modules/sparse/spatial.py
@@ -0,0 +1,115 @@
+from typing import *
+import torch
+import torch.nn as nn
+from . import SparseTensor
+
+__all__ = [
+ 'SparseDownsample',
+ 'SparseUpsample',
+ 'SparseSubdivide'
+]
+
+
+class SparseDownsample(nn.Module):
+ """
+ Downsample a sparse tensor by a factor of `factor`.
+ Implemented as average pooling.
+ """
+ def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], mode="mean"):
+ super(SparseDownsample, self).__init__()
+ self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
+ self.mode = mode
+
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ DIM = input.coords.shape[-1] - 1
+ factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
+ assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
+
+ coord = list(input.coords.unbind(dim=-1))
+ for i, f in enumerate(factor):
+ coord[i+1] = coord[i+1] // f
+
+ MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
+ OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
+ code = sum([c * o for c, o in zip(coord, OFFSET)])
+ code, idx = code.unique(return_inverse=True)
+
+ #### using fp16 could cause overflow when factor is large ######
+ dtype = input.feats.dtype
+ new_feats = torch.scatter_reduce(
+ torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=torch.float64),
+ dim=0,
+ index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
+ src=input.feats.double(),
+ reduce=self.mode,
+ )
+ new_feats = new_feats.to(dtype)
+
+ new_coords = torch.stack(
+ [code // OFFSET[0]] +
+ [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
+ dim=-1
+ )
+ out = SparseTensor(new_feats, new_coords, input.shape,)
+ out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
+ out._spatial_cache = input._spatial_cache
+
+ out.register_spatial_cache(f'upsample_{factor}_coords', input.coords)
+ out.register_spatial_cache(f'upsample_{factor}_layout', input.layout)
+ out.register_spatial_cache(f'upsample_{factor}_idx', idx)
+
+ return out
+
+
+class SparseUpsample(nn.Module):
+ """
+ Upsample a sparse tensor by a factor of `factor`.
+ Implemented as nearest neighbor interpolation.
+ """
+ def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
+ super(SparseUpsample, self).__init__()
+ self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
+
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ DIM = input.coords.shape[-1] - 1
+ factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
+ assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
+
+ new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
+ new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
+ idx = input.get_spatial_cache(f'upsample_{factor}_idx')
+ if any([x is None for x in [new_coords, new_layout, idx]]):
+ raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
+ new_feats = input.feats[idx]
+ out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
+ out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
+ out._spatial_cache = input._spatial_cache
+ return out
+
+class SparseSubdivide(nn.Module):
+ """
+ Upsample a sparse tensor by a factor of `factor`.
+ Implemented as nearest neighbor interpolation.
+ """
+ def __init__(self):
+ super(SparseSubdivide, self).__init__()
+
+ def forward(self, input: SparseTensor) -> SparseTensor:
+ DIM = input.coords.shape[-1] - 1
+ # upsample scale=2^DIM
+ n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int)
+ n_coords = torch.nonzero(n_cube)
+ n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
+ factor = n_coords.shape[0]
+ assert factor == 2 ** DIM
+ # print(n_coords.shape)
+ new_coords = input.coords.clone()
+ new_coords[:, 1:] *= 2
+ new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
+
+ new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
+ out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
+ out._scale = input._scale * 2
+ out._spatial_cache = input._spatial_cache
+ return out
+
diff --git a/direct3d_s2/modules/sparse/transformer/__init__.py b/direct3d_s2/modules/sparse/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd
--- /dev/null
+++ b/direct3d_s2/modules/sparse/transformer/__init__.py
@@ -0,0 +1,2 @@
+from .blocks import *
+from .modulated import *
\ No newline at end of file
diff --git a/direct3d_s2/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..969a1cb28304a96ecc20b610829b2e20e79d992b
Binary files /dev/null and b/direct3d_s2/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc b/direct3d_s2/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ee096ac0e59a2a4530cfc7dd8014fc678962cf8
Binary files /dev/null and b/direct3d_s2/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc b/direct3d_s2/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b14f225257c4670a232ee5f135b4598065182d63
Binary files /dev/null and b/direct3d_s2/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/sparse/transformer/blocks.py b/direct3d_s2/modules/sparse/transformer/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d037a49bf83e1c2dfb2f8c4b23d2e9d6c51e9f0
--- /dev/null
+++ b/direct3d_s2/modules/sparse/transformer/blocks.py
@@ -0,0 +1,151 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..basic import SparseTensor
+from ..linear import SparseLinear
+from ..nonlinearity import SparseGELU
+from ..attention import SparseMultiHeadAttention, SerializeMode
+from ...norm import LayerNorm32
+
+
+class SparseFeedForwardNet(nn.Module):
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ SparseLinear(channels, int(channels * mlp_ratio)),
+ SparseGELU(approximate="tanh"),
+ SparseLinear(int(channels * mlp_ratio), channels),
+ )
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ return self.mlp(x)
+
+
+class SparseTransformerBlock(nn.Module):
+ """
+ Sparse Transformer block (MSA + FFN).
+ """
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
+ window_size: Optional[int] = None,
+ shift_sequence: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ serialize_mode: Optional[SerializeMode] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ qkv_bias: bool = True,
+ ln_affine: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.attn = SparseMultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_sequence=shift_sequence,
+ shift_window=shift_window,
+ serialize_mode=serialize_mode,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.mlp = SparseFeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+
+ def _forward(self, x: SparseTensor) -> SparseTensor:
+ h = x.replace(self.norm1(x.feats))
+ h = self.attn(h)
+ x = x + h
+ h = x.replace(self.norm2(x.feats))
+ h = self.mlp(h)
+ x = x + h
+ return x
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+ else:
+ return self._forward(x)
+
+
+class SparseTransformerCrossBlock(nn.Module):
+ """
+ Sparse Transformer cross-attention block (MSA + MCA + FFN).
+ """
+ def __init__(
+ self,
+ channels: int,
+ ctx_channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
+ window_size: Optional[int] = None,
+ shift_sequence: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ serialize_mode: Optional[SerializeMode] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ qkv_bias: bool = True,
+ ln_affine: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.self_attn = SparseMultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ type="self",
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_sequence=shift_sequence,
+ shift_window=shift_window,
+ serialize_mode=serialize_mode,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.cross_attn = SparseMultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ self.mlp = SparseFeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+
+ def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
+ h = x.replace(self.norm1(x.feats))
+ h = self.self_attn(h)
+ x = x + h
+ h = x.replace(self.norm2(x.feats))
+ h = self.cross_attn(h, context)
+ x = x + h
+ h = x.replace(self.norm3(x.feats))
+ h = self.mlp(h)
+ x = x + h
+ return x
+
+ def forward(self, x: SparseTensor, context: torch.Tensor):
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
+ else:
+ return self._forward(x, context)
diff --git a/direct3d_s2/modules/sparse/transformer/modulated.py b/direct3d_s2/modules/sparse/transformer/modulated.py
new file mode 100644
index 0000000000000000000000000000000000000000..617af7338d7c5414a0666e9cf91f63cd71cc3447
--- /dev/null
+++ b/direct3d_s2/modules/sparse/transformer/modulated.py
@@ -0,0 +1,213 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..basic import SparseTensor
+from ..attention import SparseMultiHeadAttention, SerializeMode, SpatialSparseAttention
+from ...norm import LayerNorm32
+from .blocks import SparseFeedForwardNet
+
+
+class ModulatedSparseTransformerBlock(nn.Module):
+ """
+ Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
+ """
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
+ window_size: Optional[int] = None,
+ shift_sequence: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ serialize_mode: Optional[SerializeMode] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ qkv_bias: bool = True,
+ share_mod: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.attn = SparseMultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_sequence=shift_sequence,
+ shift_window=shift_window,
+ serialize_mode=serialize_mode,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.mlp = SparseFeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+ if not share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 6 * channels, bias=True)
+ )
+
+ def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
+ if self.share_mod:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
+ else:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+ h = x.replace(self.norm1(x.feats))
+ h = h * (1 + scale_msa) + shift_msa
+ h = self.attn(h)
+ h = h * gate_msa
+ x = x + h
+ h = x.replace(self.norm2(x.feats))
+ h = h * (1 + scale_mlp) + shift_mlp
+ h = self.mlp(h)
+ h = h * gate_mlp
+ x = x + h
+ return x
+
+ def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
+ else:
+ return self._forward(x, mod)
+
+
+class ModulatedSparseTransformerCrossBlock(nn.Module):
+ """
+ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
+ """
+ def __init__(
+ self,
+ channels: int,
+ ctx_channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
+ window_size: Optional[int] = None,
+ shift_sequence: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ serialize_mode: Optional[SerializeMode] = None,
+ use_checkpoint: bool = False,
+ compression_version: str = "v2",
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ qkv_bias: bool = True,
+ share_mod: bool = False,
+ use_ssa: bool = True,
+ num_kv_heads: int = 2,
+ compression_block_size: int = 4,
+ selection_block_size: int = 8,
+ topk: int = 8,
+ resolution: int = 64,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ if use_ssa:
+ self.self_attn = SpatialSparseAttention(
+ channels,
+ num_q_heads=num_heads,
+ num_kv_heads=num_kv_heads,
+ head_dim=channels//num_heads,
+ compression_block_size=compression_block_size,
+ compression_version=compression_version,
+ selection_block_size=selection_block_size,
+ topk=topk,
+ window_size=window_size,
+ shift_window=shift_window,
+ resolution=resolution,
+ )
+ else:
+ self.self_attn = SparseMultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ type="self",
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_sequence=shift_sequence,
+ shift_window=shift_window,
+ serialize_mode=serialize_mode,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.cross_attn = SparseMultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ self.mlp = SparseFeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+ if not share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 6 * channels, bias=True)
+ )
+
+ def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
+ if self.share_mod:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
+ else:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+ h = x.replace(self.norm1(x.feats))
+
+ feats_h = h.feats
+ layouts = h.layout
+ ada_r1 = []
+ for i in range(len(layouts)):
+ ada_r1.append(feats_h[layouts[i]] * (1 + scale_msa[i:i+1]) + shift_msa[i:i+1])
+ h = h.replace(torch.cat(ada_r1, dim=0))
+ h = self.self_attn(h)
+
+ feats_h = h.feats
+ layouts = h.layout
+ ada_r2 = []
+ for i in range(len(layouts)):
+ ada_r2.append(feats_h[layouts[i]] * gate_msa[i:i+1])
+ h = h.replace(torch.cat(ada_r2, dim=0))
+
+ x = x + h
+ h = x.replace(self.norm2(x.feats))
+ h = self.cross_attn(h, context)
+ x = x + h
+ h = x.replace(self.norm3(x.feats))
+
+ feats_h = h.feats
+ layouts = h.layout
+ ada_r3 = []
+ for i in range(len(layouts)):
+ ada_r3.append(feats_h[layouts[i]] * (1 + scale_mlp[i:i+1]) + shift_mlp[i:i+1])
+ h = h.replace(torch.cat(ada_r3, dim=0))
+ h = self.mlp(h)
+
+ feats_h = h.feats
+ layouts = h.layout
+ ada_r4 = []
+ for i in range(len(layouts)):
+ ada_r4.append(feats_h[layouts[i]] * gate_mlp[i:i+1])
+ h = h.replace(torch.cat(ada_r4, dim=0))
+
+ x = x + h
+ return x
+
+ def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
+ else:
+ return self._forward(x, mod, context)
diff --git a/direct3d_s2/modules/spatial.py b/direct3d_s2/modules/spatial.py
new file mode 100644
index 0000000000000000000000000000000000000000..79e268d36c2ba49b0275744022a1a1e19983dae3
--- /dev/null
+++ b/direct3d_s2/modules/spatial.py
@@ -0,0 +1,48 @@
+import torch
+
+
+def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
+ """
+ 3D pixel shuffle.
+ """
+ B, C, H, W, D = x.shape
+ C_ = C // scale_factor**3
+ x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
+ x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
+ return x
+
+
+def patchify(x: torch.Tensor, patch_size: int):
+ """
+ Patchify a tensor.
+
+ Args:
+ x (torch.Tensor): (N, C, *spatial) tensor
+ patch_size (int): Patch size
+ """
+ DIM = x.dim() - 2
+ for d in range(2, DIM + 2):
+ assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
+
+ x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
+ x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
+ x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
+ return x
+
+
+def unpatchify(x: torch.Tensor, patch_size: int):
+ """
+ Unpatchify a tensor.
+
+ Args:
+ x (torch.Tensor): (N, C, *spatial) tensor
+ patch_size (int): Patch size
+ """
+ DIM = x.dim() - 2
+ assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
+
+ x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
+ x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
+ x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
+ return x
diff --git a/direct3d_s2/modules/transformer/__init__.py b/direct3d_s2/modules/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd
--- /dev/null
+++ b/direct3d_s2/modules/transformer/__init__.py
@@ -0,0 +1,2 @@
+from .blocks import *
+from .modulated import *
\ No newline at end of file
diff --git a/direct3d_s2/modules/transformer/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/transformer/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8f2a4e028c3083fbd5787b393d7f1027d27b9f2
Binary files /dev/null and b/direct3d_s2/modules/transformer/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/transformer/__pycache__/blocks.cpython-310.pyc b/direct3d_s2/modules/transformer/__pycache__/blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38ad84b709f92193cecb88fba9300b54d6d2848c
Binary files /dev/null and b/direct3d_s2/modules/transformer/__pycache__/blocks.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/transformer/__pycache__/modulated.cpython-310.pyc b/direct3d_s2/modules/transformer/__pycache__/modulated.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38d12b1952b57c97234e214a5385a947f7ed0a3d
Binary files /dev/null and b/direct3d_s2/modules/transformer/__pycache__/modulated.cpython-310.pyc differ
diff --git a/direct3d_s2/modules/transformer/blocks.py b/direct3d_s2/modules/transformer/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..605ae33bc44276f45f73789f547dc756ac3999da
--- /dev/null
+++ b/direct3d_s2/modules/transformer/blocks.py
@@ -0,0 +1,184 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..attention import MultiHeadAttention
+from ..norm import LayerNorm32
+
+
+class AbsolutePositionEmbedder(nn.Module):
+ """
+ Embeds spatial positions into vector representations.
+ """
+ def __init__(self, channels: int, in_channels: int = 3):
+ super().__init__()
+ self.channels = channels
+ self.in_channels = in_channels
+ self.freq_dim = channels // in_channels // 2
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
+ self.freqs = 1.0 / (10000 ** self.freqs)
+
+ def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Create sinusoidal position embeddings.
+
+ Args:
+ x: a 1-D Tensor of N indices
+
+ Returns:
+ an (N, D) Tensor of positional embeddings.
+ """
+ self.freqs = self.freqs.to(x.device)
+ out = torch.outer(x, self.freqs)
+ out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
+ return out
+
+ def forward(self, x: torch.Tensor, factor: float = None) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): (N, D) tensor of spatial positions
+ """
+ if factor is not None:
+ x = x * factor
+ N, D = x.shape
+ assert D == self.in_channels, "Input dimension must match number of input channels"
+ embed = self._sin_cos_embedding(x.reshape(-1))
+ embed = embed.reshape(N, -1)
+ if embed.shape[1] < self.channels:
+ embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
+ return embed
+
+
+class FeedForwardNet(nn.Module):
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(channels, int(channels * mlp_ratio)),
+ nn.GELU(approximate="tanh"),
+ nn.Linear(int(channels * mlp_ratio), channels),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.mlp(x)
+
+
+class TransformerBlock(nn.Module):
+ """
+ Transformer block (MSA + FFN).
+ """
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[int] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ qkv_bias: bool = True,
+ ln_affine: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.attn = MultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.mlp = FeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+
+ def _forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.norm1(x)
+ h = self.attn(h)
+ x = x + h
+ h = self.norm2(x)
+ h = self.mlp(h)
+ x = x + h
+ return x
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+ else:
+ return self._forward(x)
+
+
+class TransformerCrossBlock(nn.Module):
+ """
+ Transformer cross-attention block (MSA + MCA + FFN).
+ """
+ def __init__(
+ self,
+ channels: int,
+ ctx_channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ qkv_bias: bool = True,
+ ln_affine: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.self_attn = MultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ type="self",
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.cross_attn = MultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ self.mlp = FeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+
+ def _forward(self, x: torch.Tensor, context: torch.Tensor):
+ h = self.norm1(x)
+ h = self.self_attn(h)
+ x = x + h
+ h = self.norm2(x)
+ h = self.cross_attn(h, context)
+ x = x + h
+ h = self.norm3(x)
+ h = self.mlp(h)
+ x = x + h
+ return x
+
+ def forward(self, x: torch.Tensor, context: torch.Tensor):
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
+ else:
+ return self._forward(x, context)
+
\ No newline at end of file
diff --git a/direct3d_s2/modules/transformer/modulated.py b/direct3d_s2/modules/transformer/modulated.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4aeca0689e68f656b08f7aa822b7be839aa727d
--- /dev/null
+++ b/direct3d_s2/modules/transformer/modulated.py
@@ -0,0 +1,157 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..attention import MultiHeadAttention
+from ..norm import LayerNorm32
+from .blocks import FeedForwardNet
+
+
+class ModulatedTransformerBlock(nn.Module):
+ """
+ Transformer block (MSA + FFN) with adaptive layer norm conditioning.
+ """
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ qkv_bias: bool = True,
+ share_mod: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.attn = MultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.mlp = FeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+ if not share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 6 * channels, bias=True)
+ )
+
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
+ if self.share_mod:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
+ else:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+ h = self.norm1(x)
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
+ h = self.attn(h)
+ h = h * gate_msa.unsqueeze(1)
+ x = x + h
+ h = self.norm2(x)
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
+ h = self.mlp(h)
+ h = h * gate_mlp.unsqueeze(1)
+ x = x + h
+ return x
+
+ def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
+ else:
+ return self._forward(x, mod)
+
+
+class ModulatedTransformerCrossBlock(nn.Module):
+ """
+ Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
+ """
+ def __init__(
+ self,
+ channels: int,
+ ctx_channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ qkv_bias: bool = True,
+ share_mod: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.self_attn = MultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ type="self",
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.cross_attn = MultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ self.mlp = FeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+ if not share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 6 * channels, bias=True)
+ )
+
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
+ if self.share_mod:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
+ else:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+ h = self.norm1(x)
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
+ h = self.self_attn(h)
+ h = h * gate_msa.unsqueeze(1)
+ x = x + h
+ h = self.norm2(x)
+ h = self.cross_attn(h, context)
+ x = x + h
+ h = self.norm3(x)
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
+ h = self.mlp(h)
+ h = h * gate_mlp.unsqueeze(1)
+ x = x + h
+ return x
+
+ def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
+ else:
+ return self._forward(x, mod, context)
+
\ No newline at end of file
diff --git a/direct3d_s2/modules/utils.py b/direct3d_s2/modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0afb1b6c767aa2ad00bad96649fb30315e696ea
--- /dev/null
+++ b/direct3d_s2/modules/utils.py
@@ -0,0 +1,54 @@
+import torch.nn as nn
+from ..modules import sparse as sp
+
+FP16_MODULES = (
+ nn.Conv1d,
+ nn.Conv2d,
+ nn.Conv3d,
+ nn.ConvTranspose1d,
+ nn.ConvTranspose2d,
+ nn.ConvTranspose3d,
+ nn.Linear,
+ sp.SparseConv3d,
+ sp.SparseInverseConv3d,
+ sp.SparseLinear,
+)
+
+def convert_module_to_f16(l):
+ """
+ Convert primitive modules to float16.
+ """
+ if isinstance(l, FP16_MODULES):
+ for p in l.parameters():
+ p.data = p.data.half()
+
+
+def convert_module_to_f32(l):
+ """
+ Convert primitive modules to float32, undoing convert_module_to_f16().
+ """
+ if isinstance(l, FP16_MODULES):
+ for p in l.parameters():
+ p.data = p.data.float()
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
diff --git a/direct3d_s2/pipeline.py b/direct3d_s2/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce0e9e7c64329b4092654254235f38f011b05c29
--- /dev/null
+++ b/direct3d_s2/pipeline.py
@@ -0,0 +1,357 @@
+import os
+import torch
+import numpy as np
+from typing import Any
+from PIL import Image
+from tqdm import tqdm
+from omegaconf import OmegaConf
+from huggingface_hub import hf_hub_download
+from typing import Union, List, Optional
+from direct3d_s2.modules import sparse as sp
+from direct3d_s2.utils import (
+ instantiate_from_config,
+ preprocess_image,
+ sort_block,
+ extract_tokens_and_coords,
+ normalize_mesh,
+ mesh2index,
+)
+
+
+class Direct3DS2Pipeline(object):
+
+ def __init__(self,
+ dense_vae,
+ dense_dit,
+ sparse_vae_512,
+ sparse_dit_512,
+ sparse_vae_1024,
+ sparse_dit_1024,
+ refiner,
+ dense_image_encoder,
+ sparse_image_encoder,
+ dense_scheduler,
+ sparse_scheduler_512,
+ sparse_scheduler_1024,
+ dtype=torch.float16,
+ ):
+ self.dense_vae = dense_vae
+ self.dense_dit = dense_dit
+ self.sparse_vae_512 = sparse_vae_512
+ self.sparse_dit_512 = sparse_dit_512
+ self.sparse_vae_1024 = sparse_vae_1024
+ self.sparse_dit_1024 = sparse_dit_1024
+ self.refiner = refiner
+ self.dense_image_encoder = dense_image_encoder
+ self.sparse_image_encoder = sparse_image_encoder
+ self.dense_scheduler = dense_scheduler
+ self.sparse_scheduler_512 = sparse_scheduler_512
+ self.sparse_scheduler_1024 = sparse_scheduler_1024
+ self.dtype = dtype
+
+ def to(self, device):
+ self.device = torch.device(device)
+ self.dense_vae.to(device)
+ self.dense_dit.to(device)
+ self.sparse_vae_512.to(device)
+ self.sparse_dit_512.to(device)
+ self.sparse_vae_1024.to(device)
+ self.sparse_dit_1024.to(device)
+ self.refiner.to(device)
+ self.dense_image_encoder.to(device)
+ self.sparse_image_encoder.to(device)
+
+ @classmethod
+ def from_pretrained(cls, pipeline_path, subfolder="direct3d-s2-v-1-1"):
+
+ if os.path.isdir(pipeline_path):
+ config_path = os.path.join(pipeline_path, 'config.yaml')
+ model_dense_path = os.path.join(pipeline_path, 'model_dense.ckpt')
+ model_sparse_512_path = os.path.join(pipeline_path, 'model_sparse_512.ckpt')
+ model_sparse_1024_path = os.path.join(pipeline_path, 'model_sparse_1024.ckpt')
+ model_refiner_path = os.path.join(pipeline_path, 'model_refiner.ckpt')
+ else:
+ config_path = hf_hub_download(
+ repo_id=pipeline_path,
+ subfolder=subfolder,
+ filename="config.yaml",
+ repo_type="model"
+ )
+ model_dense_path = hf_hub_download(
+ repo_id=pipeline_path,
+ subfolder=subfolder,
+ filename="model_dense.ckpt",
+ repo_type="model"
+ )
+ model_sparse_512_path = hf_hub_download(
+ repo_id=pipeline_path,
+ subfolder=subfolder,
+ filename="model_sparse_512.ckpt",
+ repo_type="model"
+ )
+ model_sparse_1024_path = hf_hub_download(
+ repo_id=pipeline_path,
+ subfolder=subfolder,
+ filename="model_sparse_1024.ckpt",
+ repo_type="model"
+ )
+ model_refiner_path = hf_hub_download(
+ repo_id=pipeline_path,
+ subfolder=subfolder,
+ filename="model_refiner.ckpt",
+ repo_type="model"
+ )
+
+ cfg = OmegaConf.load(config_path)
+
+ state_dict_dense = torch.load(model_dense_path, map_location='cpu', weights_only=True)
+ dense_vae = instantiate_from_config(cfg.dense_vae)
+ dense_vae.load_state_dict(state_dict_dense["vae"], strict=True)
+ dense_vae.eval()
+ dense_dit = instantiate_from_config(cfg.dense_dit)
+ dense_dit.load_state_dict(state_dict_dense["dit"], strict=True)
+ dense_dit.eval()
+
+ state_dict_sparse_512 = torch.load(model_sparse_512_path, map_location='cpu', weights_only=True)
+ sparse_vae_512 = instantiate_from_config(cfg.sparse_vae_512)
+ sparse_vae_512.load_state_dict(state_dict_sparse_512["vae"], strict=True)
+ sparse_vae_512.eval()
+ sparse_dit_512 = instantiate_from_config(cfg.sparse_dit_512)
+ sparse_dit_512.load_state_dict(state_dict_sparse_512["dit"], strict=True)
+ sparse_dit_512.eval()
+
+ state_dict_sparse_1024 = torch.load(model_sparse_1024_path, map_location='cpu', weights_only=True)
+ sparse_vae_1024 = instantiate_from_config(cfg.sparse_vae_1024)
+ sparse_vae_1024.load_state_dict(state_dict_sparse_1024["vae"], strict=True)
+ sparse_vae_1024.eval()
+ sparse_dit_1024 = instantiate_from_config(cfg.sparse_dit_1024)
+ sparse_dit_1024.load_state_dict(state_dict_sparse_1024["dit"], strict=True)
+ sparse_dit_1024.eval()
+
+ state_dict_refiner = torch.load(model_refiner_path, map_location='cpu', weights_only=True)
+ refiner = instantiate_from_config(cfg.refiner)
+ refiner.load_state_dict(state_dict_refiner["refiner"], strict=True)
+ refiner.eval()
+
+ dense_image_encoder = instantiate_from_config(cfg.dense_image_encoder)
+ sparse_image_encoder = instantiate_from_config(cfg.sparse_image_encoder)
+
+ dense_scheduler = instantiate_from_config(cfg.dense_scheduler)
+ sparse_scheduler_512 = instantiate_from_config(cfg.sparse_scheduler_512)
+ sparse_scheduler_1024 = instantiate_from_config(cfg.sparse_scheduler_1024)
+
+ return cls(
+ dense_vae=dense_vae,
+ dense_dit=dense_dit,
+ sparse_vae_512=sparse_vae_512,
+ sparse_dit_512=sparse_dit_512,
+ sparse_vae_1024=sparse_vae_1024,
+ sparse_dit_1024=sparse_dit_1024,
+ dense_image_encoder=dense_image_encoder,
+ sparse_image_encoder=sparse_image_encoder,
+ dense_scheduler=dense_scheduler,
+ sparse_scheduler_512=sparse_scheduler_512,
+ sparse_scheduler_1024=sparse_scheduler_1024,
+ refiner=refiner,
+ )
+
+ def preprocess(self, image):
+ if image.mode == 'RGBA':
+ image = np.array(image)
+ else:
+ if getattr(self, 'birefnet_model', None) is None:
+ from direct3d_s2.utils import BiRefNet
+ self.birefnet_model = BiRefNet(self.device)
+ image = self.birefnet_model.run(image)
+ image = preprocess_image(image)
+ return image
+
+ def prepare_image(self, image: Union[str, List[str], Image.Image, List[Image.Image]]):
+ if not isinstance(image, list):
+ image = [image]
+ if isinstance(image[0], str):
+ image = [Image.open(img) for img in image]
+ image = [self.preprocess(img) for img in image]
+ image = torch.stack([img for img in image]).to(self.device)
+ return image
+
+ def encode_image(self, image: torch.Tensor, conditioner: Any,
+ do_classifier_free_guidance: bool = True, use_mask: bool = False):
+ if use_mask:
+ cond = conditioner(image[:, :3], image[:, 3:])
+ else:
+ cond = conditioner(image[:, :3])
+
+ if isinstance(cond, tuple):
+ cond, cond_mask = cond
+ cond, cond_coords = extract_tokens_and_coords(cond, cond_mask)
+ else:
+ cond_mask, cond_coords = None, None
+
+ if do_classifier_free_guidance:
+ uncond = torch.zeros_like(cond)
+ else:
+ uncond = None
+
+ if cond_coords is not None:
+ cond = sp.SparseTensor(cond, cond_coords.int())
+ if uncond is not None:
+ uncond = sp.SparseTensor(uncond, cond_coords.int())
+
+ return cond, uncond
+
+ def inference(
+ self,
+ image,
+ vae,
+ dit,
+ conditioner,
+ scheduler,
+ num_inference_steps: int = 30,
+ guidance_scale: int = 7.0,
+ generator: Optional[torch.Generator] = None,
+ latent_index: torch.Tensor = None,
+ mode: str = 'dense', # 'dense', 'sparse512' or 'sparse1024
+ remove_interior: bool = False,
+ mc_threshold: float = 0.02):
+
+ do_classifier_free_guidance = guidance_scale > 0
+ if mode == 'dense':
+ sparse_conditions = False
+ else:
+ sparse_conditions = dit.sparse_conditions
+ cond, uncond = self.encode_image(image, conditioner,
+ do_classifier_free_guidance, sparse_conditions)
+ batch_size = cond.shape[0]
+
+ if mode == 'dense':
+ latent_shape = (batch_size, *dit.latent_shape)
+ else:
+ latent_shape = (len(latent_index), dit.out_channels)
+ latents = torch.randn(latent_shape, dtype=self.dtype, device=self.device, generator=generator)
+
+ scheduler.set_timesteps(num_inference_steps, device=self.device)
+ timesteps = scheduler.timesteps
+
+ extra_step_kwargs = {
+ "generator": generator
+ }
+
+ for i, t in enumerate(tqdm(timesteps, desc=f"{mode} Sampling:")):
+ latent_model_input = latents
+ timestep_tensor = torch.tensor([t], dtype=latent_model_input.dtype, device=self.device)
+
+ if mode == 'dense':
+ x_input = latent_model_input
+ elif mode in ['sparse512', 'sparse1024']:
+ x_input = sp.SparseTensor(latent_model_input, latent_index.int())
+
+ diffusion_inputs = {
+ "x": x_input,
+ "t": timestep_tensor,
+ "cond": cond,
+ }
+
+ noise_pred_cond = dit(**diffusion_inputs)
+ if mode != 'dense':
+ noise_pred_cond = noise_pred_cond.feats
+
+ if do_classifier_free_guidance:
+ diffusion_inputs["cond"] = uncond
+ noise_pred_uncond = dit(**diffusion_inputs)
+ if mode != 'dense':
+ noise_pred_uncond = noise_pred_uncond.feats
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ else:
+ noise_pred = noise_pred_cond
+
+ latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ latents = 1. / vae.latents_scale * latents + vae.latents_shift
+
+ if mode != 'dense':
+ latents = sp.SparseTensor(latents, latent_index.int())
+
+ decoder_inputs = {
+ "latents": latents,
+ "mc_threshold": mc_threshold,
+ }
+ if mode == 'dense':
+ decoder_inputs['return_index'] = True
+ elif remove_interior:
+ decoder_inputs['return_feat'] = True
+ if mode == 'sparse1024':
+ decoder_inputs['voxel_resolution'] = 1024
+
+ outputs = vae.decode_mesh(**decoder_inputs)
+
+ if remove_interior:
+ del latents, noise_pred, noise_pred_cond, noise_pred_uncond, x_input, cond, uncond
+ torch.cuda.empty_cache()
+ outputs = self.refiner.run(*outputs, mc_threshold=mc_threshold*2.0)
+
+ return outputs
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[str, List[str], Image.Image, List[Image.Image]] = None,
+ sdf_resolution: int = 1024,
+ dense_sampler_params: dict = {'num_inference_steps': 50, 'guidance_scale': 7.0},
+ sparse_512_sampler_params: dict = {'num_inference_steps': 30, 'guidance_scale': 7.0},
+ sparse_1024_sampler_params: dict = {'num_inference_steps': 15, 'guidance_scale': 7.0},
+ generator: Optional[torch.Generator] = None,
+ remesh: bool = False,
+ simplify_ratio: float = 0.95,
+ mc_threshold: float = 0.2):
+
+ image = self.prepare_image(image)
+
+ latent_index = self.inference(image, self.dense_vae, self.dense_dit, self.dense_image_encoder,
+ self.dense_scheduler, generator=generator, mode='dense', mc_threshold=0.1, **dense_sampler_params)[0]
+
+ latent_index = sort_block(latent_index, self.sparse_dit_512.selection_block_size)
+
+ torch.cuda.empty_cache()
+
+ if sdf_resolution == 512:
+ remove_interior = False
+ else:
+ remove_interior = True
+
+ mesh = self.inference(image, self.sparse_vae_512, self.sparse_dit_512,
+ self.sparse_image_encoder, self.sparse_scheduler_512,
+ generator=generator, mode='sparse512',
+ mc_threshold=mc_threshold, latent_index=latent_index,
+ remove_interior=remove_interior, **sparse_512_sampler_params)[0]
+
+ if sdf_resolution == 1024:
+ del latent_index
+ torch.cuda.empty_cache()
+ mesh = normalize_mesh(mesh)
+ latent_index = mesh2index(mesh, size=1024, factor=8)
+ latent_index = sort_block(latent_index, self.sparse_dit_1024.selection_block_size)
+ print(f"number of latent tokens: {len(latent_index)}")
+
+ mesh = self.inference(image, self.sparse_vae_1024, self.sparse_dit_1024,
+ self.sparse_image_encoder, self.sparse_scheduler_1024,
+ generator=generator, mode='sparse1024',
+ mc_threshold=mc_threshold, latent_index=latent_index,
+ **sparse_1024_sampler_params)[0]
+
+ if remesh:
+ import trimesh
+ from direct3d_s2.utils import postprocess_mesh
+ filled_mesh = postprocess_mesh(
+ vertices=mesh.vertices,
+ faces=mesh.faces,
+ simplify=True,
+ simplify_ratio=simplify_ratio,
+ verbose=True,
+ )
+ mesh = trimesh.Trimesh(filled_mesh[0], filled_mesh[1])
+
+ outputs = {"mesh": mesh}
+
+ return outputs
+
\ No newline at end of file
diff --git a/direct3d_s2/utils/__init__.py b/direct3d_s2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e46f0ec24c6f4da95eb3a9b8df5a28ae1377c036
--- /dev/null
+++ b/direct3d_s2/utils/__init__.py
@@ -0,0 +1,6 @@
+from .util import instantiate_from_config, get_obj_from_str
+from .image import preprocess_image
+from .rembg import BiRefNet
+from .sparse import sort_block, extract_tokens_and_coords
+from .mesh import mesh2index, normalize_mesh
+from .fill_hole import postprocess_mesh
\ No newline at end of file
diff --git a/direct3d_s2/utils/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ce8ae171e4f700ec661f50952122d1fc6241a85
Binary files /dev/null and b/direct3d_s2/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/direct3d_s2/utils/__pycache__/fill_hole.cpython-310.pyc b/direct3d_s2/utils/__pycache__/fill_hole.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf3af72c632aac5aa27083aba8fbd70d58482cd6
Binary files /dev/null and b/direct3d_s2/utils/__pycache__/fill_hole.cpython-310.pyc differ
diff --git a/direct3d_s2/utils/__pycache__/fix_hole.cpython-310.pyc b/direct3d_s2/utils/__pycache__/fix_hole.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e96698412e82122026ca3c9a9d9a0c935ed3b30c
Binary files /dev/null and b/direct3d_s2/utils/__pycache__/fix_hole.cpython-310.pyc differ
diff --git a/direct3d_s2/utils/__pycache__/image.cpython-310.pyc b/direct3d_s2/utils/__pycache__/image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..053a83a090808081b97dd683ac8a7a3be2766a8c
Binary files /dev/null and b/direct3d_s2/utils/__pycache__/image.cpython-310.pyc differ
diff --git a/direct3d_s2/utils/__pycache__/mesh.cpython-310.pyc b/direct3d_s2/utils/__pycache__/mesh.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3f4aa988d21ca631a4b0587892bd50afae7c448
Binary files /dev/null and b/direct3d_s2/utils/__pycache__/mesh.cpython-310.pyc differ
diff --git a/direct3d_s2/utils/__pycache__/rembg.cpython-310.pyc b/direct3d_s2/utils/__pycache__/rembg.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c89319caf9b74ccacd6ec6cdd352b85efc07ab8
Binary files /dev/null and b/direct3d_s2/utils/__pycache__/rembg.cpython-310.pyc differ
diff --git a/direct3d_s2/utils/__pycache__/sparse.cpython-310.pyc b/direct3d_s2/utils/__pycache__/sparse.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..390eca94ae38c2c4a39b84bb1b2d977b55eb8a50
Binary files /dev/null and b/direct3d_s2/utils/__pycache__/sparse.cpython-310.pyc differ
diff --git a/direct3d_s2/utils/__pycache__/util.cpython-310.pyc b/direct3d_s2/utils/__pycache__/util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66a7bb9d4838016184c6fe6b7aef44615d442415
Binary files /dev/null and b/direct3d_s2/utils/__pycache__/util.cpython-310.pyc differ
diff --git a/direct3d_s2/utils/fill_hole.py b/direct3d_s2/utils/fill_hole.py
new file mode 100644
index 0000000000000000000000000000000000000000..48a56d979a4b367ba1cddc834c7214e947b98adc
--- /dev/null
+++ b/direct3d_s2/utils/fill_hole.py
@@ -0,0 +1,272 @@
+import torch
+import numpy as np
+from tqdm import tqdm
+import utils3d
+from pymeshfix import _meshfix
+import igraph
+import pyvista as pv
+
+PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
+
+def radical_inverse(base, n):
+ val = 0
+ inv_base = 1.0 / base
+ inv_base_n = inv_base
+ while n > 0:
+ digit = n % base
+ val += digit * inv_base_n
+ n //= base
+ inv_base_n *= inv_base
+ return val
+
+def halton_sequence(dim, n):
+ return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
+
+def hammersley_sequence(dim, n, num_samples):
+ return [n / num_samples] + halton_sequence(dim - 1, n)
+
+def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
+ u, v = hammersley_sequence(2, n, num_samples)
+ u += offset[0] / num_samples
+ v += offset[1]
+ if remap:
+ u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
+ theta = np.arccos(1 - 2 * u) - np.pi / 2
+ phi = v * 2 * np.pi
+ return [phi, theta]
+
+@torch.no_grad()
+def _fill_holes(
+ verts,
+ faces,
+ max_hole_size=0.04,
+ max_hole_nbe=32,
+ resolution=128,
+ num_views=500,
+ debug=False,
+ verbose=False
+):
+ """
+ Rasterize a mesh from multiple views and remove invisible faces.
+ Also includes postprocessing to:
+ 1. Remove connected components that are have low visibility.
+ 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole.
+
+ Args:
+ verts (torch.Tensor): Vertices of the mesh. Shape (V, 3).
+ faces (torch.Tensor): Faces of the mesh. Shape (F, 3).
+ max_hole_size (float): Maximum area of a hole to fill.
+ resolution (int): Resolution of the rasterization.
+ num_views (int): Number of views to rasterize the mesh.
+ verbose (bool): Whether to print progress.
+ """
+ # Construct cameras
+ yaws = []
+ pitchs = []
+ for i in range(num_views):
+ y, p = sphere_hammersley_sequence(i, num_views)
+ yaws.append(y)
+ pitchs.append(p)
+ yaws = torch.tensor(yaws).cuda()
+ pitchs = torch.tensor(pitchs).cuda()
+ radius = 2.0
+ fov = torch.deg2rad(torch.tensor(40)).cuda()
+ projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
+ views = []
+ for (yaw, pitch) in zip(yaws, pitchs):
+ orig = torch.tensor([
+ torch.sin(yaw) * torch.cos(pitch),
+ torch.cos(yaw) * torch.cos(pitch),
+ torch.sin(pitch),
+ ]).cuda().float() * radius
+ view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
+ views.append(view)
+ views = torch.stack(views, dim=0)
+
+ # Rasterize
+ visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device)
+ rastctx = utils3d.torch.RastContext(backend='cuda')
+ for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'):
+ view = views[i]
+ buffers = utils3d.torch.rasterize_triangle_faces(
+ rastctx, verts[None].float(), faces, resolution, resolution, view=view, projection=projection
+ )
+ face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1
+ face_id = torch.unique(face_id).long()
+ visblity[face_id] += 1
+ visblity = visblity.float() / num_views
+
+ # Mincut
+ ## construct outer faces
+ edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
+ boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
+ connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge)
+ outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device)
+ for i in range(len(connected_components)):
+ outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5)
+ outer_face_indices = outer_face_indices.nonzero().reshape(-1)
+
+ ## construct inner faces
+ inner_face_indices = torch.nonzero(visblity == 0).reshape(-1)
+ if verbose:
+ tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces')
+ if inner_face_indices.shape[0] == 0:
+ return verts, faces
+
+ ## Construct dual graph (faces as nodes, edges as edges)
+ dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge)
+ dual_edge2edge = edges[dual_edge2edge]
+ dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1)
+ if verbose:
+ tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges')
+
+ ## solve mincut problem
+ ### construct main graph
+ g = igraph.Graph()
+ g.add_vertices(faces.shape[0])
+ g.add_edges(dual_edges.cpu().numpy())
+ g.es['weight'] = dual_edges_weights.cpu().numpy()
+
+ ### source and target
+ g.add_vertex('s')
+ g.add_vertex('t')
+
+ ### connect invisible faces to source
+ g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
+
+ ### connect outer faces to target
+ g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
+
+ ### solve mincut
+ cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist())
+ remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device)
+ if verbose:
+ tqdm.write(f'Mincut solved, start checking the cut')
+
+ ### check if the cut is valid with each connected component
+ to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices])
+ if debug:
+ tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}')
+ valid_remove_cc = []
+ cutting_edges = []
+ for cc in to_remove_cc:
+ #### check if the connected component has low visibility
+ visblity_median = visblity[remove_face_indices[cc]].median()
+ if debug:
+ tqdm.write(f'visblity_median: {visblity_median}')
+ if visblity_median > 0.25:
+ continue
+
+ #### check if the cuting loop is small enough
+ cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True)
+ cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
+ cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)]
+ if len(cc_new_boundary_edge_indices) > 0:
+ cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices])
+ cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc]
+ cc_new_boundary_edges_cc_area = []
+ for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
+ _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i]
+ _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i]
+ cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5)
+ if debug:
+ cutting_edges.append(cc_new_boundary_edge_indices)
+ tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}')
+ if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]):
+ continue
+
+ valid_remove_cc.append(cc)
+
+ if debug:
+ face_v = verts[faces].mean(dim=1).cpu().numpy()
+ vis_dual_edges = dual_edges.cpu().numpy()
+ vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8)
+ vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255]
+ vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0]
+ vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255]
+ if len(valid_remove_cc) > 0:
+ vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0]
+ utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors)
+
+ vis_verts = verts.cpu().numpy()
+ vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy()
+ utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges)
+
+
+ if len(valid_remove_cc) > 0:
+ remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)]
+ mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device)
+ mask[remove_face_indices] = 0
+ faces = faces[mask]
+ faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts)
+ if verbose:
+ tqdm.write(f'Removed {(~mask).sum()} faces by mincut')
+ else:
+ if verbose:
+ tqdm.write(f'Removed 0 faces by mincut')
+
+ mesh = _meshfix.PyTMesh()
+ mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy())
+ mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
+ verts, faces = mesh.return_arrays()
+ verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32)
+
+ return verts, faces
+
+def postprocess_mesh(
+ vertices: np.array,
+ faces: np.array,
+ simplify: bool = False,
+ simplify_ratio: float = 0.9,
+ fill_holes: bool = False,
+ fill_holes_max_hole_size: float = 0.04,
+ fill_holes_max_hole_nbe: int = 32,
+ fill_holes_resolution: int = 1024,
+ fill_holes_num_views: int = 1000,
+ debug: bool = False,
+ verbose: bool = False,
+):
+ """
+ Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces.
+
+ Args:
+ vertices (np.array): Vertices of the mesh. Shape (V, 3).
+ faces (np.array): Faces of the mesh. Shape (F, 3).
+ simplify (bool): Whether to simplify the mesh, using quadric edge collapse.
+ simplify_ratio (float): Ratio of faces to keep after simplification.
+ fill_holes (bool): Whether to fill holes in the mesh.
+ fill_holes_max_hole_size (float): Maximum area of a hole to fill.
+ fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill.
+ fill_holes_resolution (int): Resolution of the rasterization.
+ fill_holes_num_views (int): Number of views to rasterize the mesh.
+ verbose (bool): Whether to print progress.
+ """
+
+ if verbose:
+ tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
+
+ # Simplify
+ if simplify and simplify_ratio > 0:
+ mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1))
+ mesh = mesh.decimate(simplify_ratio, progress_bar=verbose)
+ vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:]
+ if verbose:
+ tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
+
+ # Remove invisible faces
+ if fill_holes:
+ vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda()
+ vertices, faces = _fill_holes(
+ vertices, faces,
+ max_hole_size=fill_holes_max_hole_size,
+ max_hole_nbe=fill_holes_max_hole_nbe,
+ resolution=fill_holes_resolution,
+ num_views=fill_holes_num_views,
+ debug=debug,
+ verbose=verbose,
+ )
+ vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
+ if verbose:
+ tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
+
+ return vertices, faces
\ No newline at end of file
diff --git a/direct3d_s2/utils/image.py b/direct3d_s2/utils/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2094ffa0f71d3c8d44df26cfde91371bffbad74
--- /dev/null
+++ b/direct3d_s2/utils/image.py
@@ -0,0 +1,109 @@
+import numpy as np
+from PIL import Image
+import torch
+import random
+from torchvision import transforms
+import torchvision.transforms.functional as TF
+
+
+def apply_joint_transforms(rgb, mask, img_size, img_aug=True, test=True):
+ if test:
+ extra_pad = 16
+ else:
+ extra_pad = random.randint(0, 32)
+ W_img, H_img = rgb.size[:2]
+ max_HW = max(H_img, W_img)
+ top_pad = (max_HW - H_img) // 2
+ bottom_pad = max_HW - H_img - top_pad
+ left_pad = (max_HW - W_img) // 2
+ right_pad = max_HW - W_img - left_pad
+
+ # 1. padding
+ rgb = TF.pad(rgb, (left_pad, top_pad, right_pad, bottom_pad), fill=255)
+ mask = TF.pad(mask, (left_pad, top_pad, right_pad, bottom_pad), fill=0)
+
+ if img_aug and (not test):
+ # 2. random rotate
+ if random.random() < 0.1:
+ angle = random.uniform(-10, 10)
+ rgb = TF.rotate(rgb, angle, fill=255)
+ mask = TF.rotate(mask, angle, fill=0)
+
+ # 3. random crop
+ if random.random() < 0.1:
+ crop_ratio = random.uniform(0.9, 1.0)
+ crop_size = int(max_HW * crop_ratio)
+ i, j, h, w = transforms.RandomCrop.get_params(rgb, (crop_size, crop_size))
+ rgb = TF.crop(rgb, i, j, h, w)
+ mask = TF.crop(mask, i, j, h, w)
+
+ # 4. resize
+ target_size = (img_size, img_size)
+ rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR)
+ mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST)
+
+ # 5. extra padding
+ rgb = TF.pad(rgb, extra_pad, fill=255)
+ mask = TF.pad(mask, extra_pad, fill=0)
+ rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR)
+ mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST)
+
+ # to tensor
+ rgb_tensor = TF.to_tensor(rgb)
+ mask_tensor = TF.to_tensor(mask)
+
+ return rgb_tensor, mask_tensor
+
+def crop_recenter(image_no_bg, thereshold=100):
+ image_no_bg_np = np.array(image_no_bg)
+ mask = (image_no_bg_np[..., -1]).astype(np.uint8)
+ mask_bin = mask > thereshold
+
+ H, W = image_no_bg_np.shape[:2]
+
+ valid_pixels = mask_bin.astype(np.float32).nonzero() # [N, 2]
+ if np.sum(mask_bin) < (H*W) * 0.001:
+ min_h =0
+ max_h = H - 1
+ min_w = 0
+ max_w = W -1
+ else:
+ min_h, max_h = valid_pixels[0].min(), valid_pixels[0].max()
+ min_w, max_w = valid_pixels[1].min(), valid_pixels[1].max()
+
+ if min_h < 0:
+ min_h = 0
+ if min_w < 0:
+ min_w = 0
+ if max_h > H:
+ max_h = H
+ if max_w > W:
+ max_w = W
+
+ image_no_bg_np = image_no_bg_np[min_h:max_h+1, min_w:max_w+1]
+ return image_no_bg_np
+
+def preprocess_image(img):
+
+ if isinstance(img, str):
+ img = Image.open(img)
+ img = np.array(img)
+ elif isinstance(img, Image.Image):
+ img = np.array(img)
+
+ if img.shape[-1] == 3:
+ mask = np.ones_like(img[..., 0:1])
+ img = np.concatenate([img, mask], axis=-1)
+
+ img = crop_recenter(img, thereshold=0) / 255.
+
+ mask = img[..., 3]
+ img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:])
+ img = Image.fromarray((img * 255).astype(np.uint8))
+ mask = Image.fromarray((mask * 255).astype(np.uint8))
+
+ img, mask = apply_joint_transforms(img, mask, img_size=518,
+ img_aug=False, test=True)
+ img = torch.cat([img, mask], dim=0)
+ return img
+
\ No newline at end of file
diff --git a/direct3d_s2/utils/mesh.py b/direct3d_s2/utils/mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..321172edfb7814e63222d2b44ff9f9ba2365323f
--- /dev/null
+++ b/direct3d_s2/utils/mesh.py
@@ -0,0 +1,34 @@
+import torch
+import numpy as np
+import udf_ext
+
+
+def compute_valid_udf(vertices, faces, dim=512, threshold=8.0):
+ if not faces.is_cuda or not vertices.is_cuda:
+ raise ValueError("Both maze and visited tensors must be CUDA tensors")
+ udf = torch.zeros(dim**3,device=vertices.device).int() + 10000000
+ n_faces = faces.shape[0]
+ udf_ext.compute_valid_udf(vertices, faces, udf, n_faces, dim, threshold)
+ return udf.float()/10000000.
+
+def normalize_mesh(mesh, scale=0.95):
+ vertices = mesh.vertices
+ min_coords, max_coords = vertices.min(axis=0), vertices.max(axis=0)
+ dxyz = max_coords - min_coords
+ dist = max(dxyz)
+ mesh_scale = 2.0 * scale / dist
+ mesh_offset = -(min_coords + max_coords) / 2
+ vertices = (vertices + mesh_offset) * mesh_scale
+ mesh.vertices = vertices
+ return mesh
+
+def mesh2index(mesh, size=1024, factor=8):
+ vertices = torch.Tensor(mesh.vertices).float().cuda() * 0.5
+ faces = torch.Tensor(mesh.faces).int().cuda()
+ sdf = compute_valid_udf(vertices, faces, dim=size, threshold=4.0)
+ sdf = sdf.reshape(size, size, size).unsqueeze(0)
+
+ sparse_index = (sdf < 4/size).nonzero()
+ sparse_index[..., 1:] = sparse_index[..., 1:] // factor
+ latent_index = torch.unique(sparse_index, dim=0)
+ return latent_index
\ No newline at end of file
diff --git a/direct3d_s2/utils/rembg.py b/direct3d_s2/utils/rembg.py
new file mode 100644
index 0000000000000000000000000000000000000000..af64dccbf186afd4bfdf51e3a56cbcbae6ceef21
--- /dev/null
+++ b/direct3d_s2/utils/rembg.py
@@ -0,0 +1,35 @@
+import numpy as np
+import torch
+from torchvision import transforms
+
+
+class BiRefNet(object):
+ def __init__(self, device):
+ from transformers import AutoModelForImageSegmentation
+ self.birefnet_model = AutoModelForImageSegmentation.from_pretrained(
+ 'ZhengPeng7/BiRefNet',
+ trust_remote_code=True,
+ ).to(device)
+ self.birefnet_model.eval()
+ self.device = device
+
+ def run(self, image):
+ image = image.convert('RGB')
+ image_size = (1024, 1024)
+ transform_image = transforms.Compose([
+ transforms.Resize(image_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ ])
+
+ input_images = transform_image(image).unsqueeze(0).to(self.device)
+
+ with torch.no_grad():
+ preds = self.birefnet_model(input_images)[-1].sigmoid().cpu()
+
+ pred = preds[0].squeeze()
+ pred_pil = transforms.ToPILImage()(pred)
+ mask = pred_pil.resize(image.size)
+ mask = np.array(mask)
+ image = np.concatenate([np.array(image), mask[..., None]], axis=-1)
+ return image
\ No newline at end of file
diff --git a/direct3d_s2/utils/sparse.py b/direct3d_s2/utils/sparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..9682ed73c7f281c6717185a1067ec0e05daafdaa
--- /dev/null
+++ b/direct3d_s2/utils/sparse.py
@@ -0,0 +1,68 @@
+import torch
+import numpy as np
+
+def sort_block(latent_index, block_size):
+ device = latent_index.device
+ latent_index_block = latent_index.cpu().numpy()
+ latent_index_block[..., 1:] = latent_index_block[..., 1:] // block_size
+ latent_index_inblock = latent_index.cpu().numpy()
+ latent_index_inblock[..., 1:] = latent_index_inblock[..., 1:] % block_size
+ sort_index = np.lexsort((
+ latent_index_inblock[..., 3],
+ latent_index_inblock[..., 2],
+ latent_index_inblock[..., 1],
+ latent_index_block[..., 3],
+ latent_index_block[..., 2],
+ latent_index_block[..., 1])
+ )
+ sort_index = torch.from_numpy(sort_index).to(device)
+ return latent_index[sort_index]
+
+def extract_tokens_and_coords(conditions, token_mask, num_cls=1, num_reg=4):
+ device = conditions.device
+ B = conditions.size(0)
+ patch_size = token_mask.size(1)
+
+ class_tokens = conditions[:, 0:num_cls, :] # [B, 1, 1024]
+ register_tokens = conditions[:, num_cls:num_cls+num_reg, :] # [B, 4, 1024]
+ patch_tokens = conditions[:, num_cls+num_reg:, :] # [B, 1369, 1024]
+
+ selected_tokens_list = []
+ coords_list = []
+
+ for batch_idx in range(B):
+ cls_tokens = class_tokens[batch_idx] # [1, 1024]
+ reg_tokens = register_tokens[batch_idx] # [4, 1024]
+ cls_reg_tokens = torch.cat([cls_tokens, reg_tokens], dim=0) # [5, 1024]
+
+ cls_coord = torch.tensor([[batch_idx, 0, 0, 1]] * num_cls, device=device)
+ reg_coords = torch.tensor([[batch_idx, 0, 0, 1]] * num_reg, device=device)
+ cls_reg_coords = torch.cat([cls_coord, reg_coords], dim=0)
+
+ mask = token_mask[batch_idx]
+ pos = mask.nonzero(as_tuple=False)
+ K = pos.size(0)
+
+ if K > 0:
+ h, w = pos[:, 0], pos[:, 1]
+ indices = h * patch_size + w #
+ patches = patch_tokens[batch_idx][indices]
+
+ batch_ids = torch.full((K, 1), batch_idx, device=device)
+ x = w.unsqueeze(1)
+ y = h.unsqueeze(1)
+ patch_coords = torch.cat([batch_ids, x, y, torch.zeros((K, 1), device=device)], dim=1)
+
+ combined_tokens = torch.cat([cls_reg_tokens, patches], dim=0)
+ combined_coords = torch.cat([cls_reg_coords, patch_coords], dim=0)
+ else:
+ combined_tokens = cls_reg_tokens
+ combined_coords = cls_reg_coords
+
+ selected_tokens_list.append(combined_tokens)
+ coords_list.append(combined_coords)
+
+ selected_tokens = torch.cat(selected_tokens_list, dim=0)
+ coords = torch.cat(coords_list, dim=0)
+
+ return selected_tokens, coords
\ No newline at end of file
diff --git a/direct3d_s2/utils/util.py b/direct3d_s2/utils/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..51bd0eaa45994f0a3afe18040854288064535e7a
--- /dev/null
+++ b/direct3d_s2/utils/util.py
@@ -0,0 +1,19 @@
+import importlib
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3ad4f7d80bfe6190ac69ec3585246a9360587c04
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,17 @@
+scikit-image
+trimesh
+omegaconf
+tqdm
+huggingface_hub
+einops
+numpy
+transformers==4.40.2
+diffusers
+triton==3.1.0
+flash-attn --no-build-isolation
+pymeshfix
+git+https://github.com/EasternJournalist/utils3d.git#egg=utils3d
+pyvista
+igraph
+git+https://github.com/mit-han-lab/torchsparse.git
+third_party/voxelize/
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8db6251e3ed2b257ea34ef3c4ad056e4b14e515
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,18 @@
+from setuptools import setup, find_packages
+
+
+setup(
+ name="direct3d_s2",
+ version="1.0.0",
+ description="Direct3D-S2: Gigascale 3D Generation Made Easy with Spatial Sparse Attention",
+ packages=find_packages(),
+ python_requires=">=3.10",
+ install_requires=[
+ "torch",
+ "numpy",
+ "cython",
+ "trimesh",
+ "diffusers",
+ "triton",
+ ],
+)
\ No newline at end of file
diff --git a/third_party/voxelize/build/lib.linux-x86_64-3.10/udf_ext.cpython-310-x86_64-linux-gnu.so b/third_party/voxelize/build/lib.linux-x86_64-3.10/udf_ext.cpython-310-x86_64-linux-gnu.so
new file mode 100644
index 0000000000000000000000000000000000000000..02b864c83d160429ce5c8caab893df304a2a7316
--- /dev/null
+++ b/third_party/voxelize/build/lib.linux-x86_64-3.10/udf_ext.cpython-310-x86_64-linux-gnu.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dba5cb992519f31d01d20161f44fed88ec5bb7f4000a0efdf2f344222e54095d
+size 9334072
diff --git a/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_deps b/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_deps
new file mode 100644
index 0000000000000000000000000000000000000000..e7ebba6788c4ab705a00bb27ea9c9b5c4707ad2c
--- /dev/null
+++ b/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_deps
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfaad9b34926f6b0320aa5208865a236d2fabab01194c00662fa1f8c5c123474
+size 557356
diff --git a/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_log b/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_log
new file mode 100644
index 0000000000000000000000000000000000000000..5d2d9d37878743c6735597258366c8cedd790ec2
--- /dev/null
+++ b/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_log
@@ -0,0 +1,7 @@
+# ninja log v5
+1 10231 1748533567318943487 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_part/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o 974b8952d6695070
+1 23455 1748533580544026989 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_part/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o 451cfe46e7f34448
+94 9916 1748614488097090921 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_part/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o 974b8952d6695070
+94 22534 1748614500786335832 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_part/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o 451cfe46e7f34448
+8 9872 1748636870244409619 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o 52db95a09a5c3658
+8 21856 1748636882230425289 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o e9d1e10a200931c1
diff --git a/third_party/voxelize/build/temp.linux-x86_64-3.10/build.ninja b/third_party/voxelize/build/temp.linux-x86_64-3.10/build.ninja
new file mode 100644
index 0000000000000000000000000000000000000000..64c09e77d7ad90229414c2242f6813329c9f3e56
--- /dev/null
+++ b/third_party/voxelize/build/temp.linux-x86_64-3.10/build.ninja
@@ -0,0 +1,33 @@
+ninja_required_version = 1.3
+cxx = c++
+nvcc = /usr/local/cuda/bin/nvcc
+
+cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.10 -c
+post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=udf_ext -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17
+cuda_cflags = -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.10 -c
+cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=udf_ext -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_52,code=sm_52 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_90,code=sm_90 -std=c++17
+cuda_dlink_post_cflags =
+ldflags =
+
+rule compile
+ command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
+ depfile = $out.d
+ deps = gcc
+
+rule cuda_compile
+ depfile = $out.d
+ deps = gcc
+ command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
+
+
+
+
+
+build /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o: compile /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/src/udf_cuda.cpp
+build /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o: cuda_compile /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/src/udf_kernel.cu
+
+
+
+
+
+
diff --git a/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o b/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o
new file mode 100644
index 0000000000000000000000000000000000000000..e233bb97c7ecbccb7f9a9b86e1afd267089e6e51
--- /dev/null
+++ b/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:36e81102ff162bc0cbbb09814def88484f0d4ff68fd6aec52c5e5ee10895a3c3
+size 13639720
diff --git a/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o b/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o
new file mode 100644
index 0000000000000000000000000000000000000000..46bfc650dc8c076f035aed2a36d4901b555954fe
--- /dev/null
+++ b/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d90cd327d8506bd0db2654081633bb519f8367a80c8841a685e79f127018a2ca
+size 234112
diff --git a/third_party/voxelize/setup.py b/third_party/voxelize/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c68d1aa3b6feb7dae5b468b690c2d2f5a6a9a20
--- /dev/null
+++ b/third_party/voxelize/setup.py
@@ -0,0 +1,16 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+ name='udf_ext',
+ ext_modules=[
+ CUDAExtension('udf_ext', [
+ 'src/udf_kernel.cu',
+ 'src/udf_cuda.cpp'
+ ]),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ }
+)
+
diff --git a/third_party/voxelize/src/udf_cuda.cpp b/third_party/voxelize/src/udf_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..877772fd6372be13e04c7ddf757175c0213e56d4
--- /dev/null
+++ b/third_party/voxelize/src/udf_cuda.cpp
@@ -0,0 +1,13 @@
+#include
+
+void compute_valid_udf_cuda(float* vertices, int* faces, int* udf, const int numTriangles, const int DIM=512, const float threshold=8);
+
+extern "C"
+void compute_valid_udf_wrapper(torch::Tensor vertices, torch::Tensor faces, torch::Tensor udf, const int numTriangles, const int DIM=512, const float threshold=8.0) {
+ compute_valid_udf_cuda(vertices.data_ptr(), faces.data_ptr(), udf.data_ptr(), numTriangles, DIM, threshold);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("compute_valid_udf", &compute_valid_udf_wrapper, "Compute UDF using CUDA");
+}
+
diff --git a/third_party/voxelize/src/udf_kernel.cu b/third_party/voxelize/src/udf_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..2488606605b8185ec2ffdacbc59152d32767aa66
--- /dev/null
+++ b/third_party/voxelize/src/udf_kernel.cu
@@ -0,0 +1,214 @@
+#include
+#include
+#include
+
+
+struct Point3D {
+ float x, y, z;
+};
+
+struct Triangle {
+ Point3D v0, v1, v2;
+};
+__device__ Point3D cross(const Point3D& v1, const Point3D& v2) {
+ Point3D result;
+ result.x = v1.y * v2.z - v1.z * v2.y;
+ result.y = v1.z * v2.x - v1.x * v2.z;
+ result.z = v1.x * v2.y - v1.y * v2.x;
+ return result;
+}
+
+// Compute the dot product of two vectors
+__device__ float dot(const Point3D& v1, const Point3D& v2) {
+ return v1.x * v2.x + v1.y * v2.y + v1.z * v2.z;
+}
+
+// Subtract two 3D points (vector subtraction)
+__device__ Point3D subtract(const Point3D& p1, const Point3D& p2) {
+ Point3D result = {p1.x - p2.x, p1.y - p2.y, p1.z - p2.z};
+ return result;
+}
+__device__ float magnitude(const Point3D &v) {
+ return sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
+}
+__device__ bool is_identical(const Point3D & p1, const Point3D & p2){
+ Point3D check = subtract(p1, p2);
+ if(check.x==0 && check.y == 0 && check.z == 0)
+ return true;
+ return false;
+}
+
+// Compute the squared distance between two points
+__device__ float squaredDistance(const Point3D& p1, const Point3D& p2) {
+ return (p1.x - p2.x) * (p1.x - p2.x) +
+ (p1.y - p2.y) * (p1.y - p2.y) +
+ (p1.z - p2.z) * (p1.z - p2.z);
+}
+__device__ Point3D normalize(Point3D v){
+ float len = sqrtf(dot(v, v));
+ if (len ==0)
+ return v;
+ float scale = 1 / len;
+ Point3D result = {v.x * scale, v.y * scale, v.z * scale};
+ return result;
+}
+
+__device__ float point_to_line_distance(const Point3D &p, const Point3D &v0, const Point3D &v1) {
+ // Direction vector of the line
+ Point3D d = subtract(v1, v0);
+
+ // Vector from v0 to point p
+ Point3D v0_to_p = subtract(p, v0);
+
+ // Scalar projection of v0_to_p onto the direction vector d
+ float t = dot(v0_to_p, d) / dot(d, d);
+
+ Point3D closest_point;
+
+ // Check where the projection falls
+ if (t < 0) {
+ // Projection falls before v0, so the closest point is v0
+ closest_point = v0;
+ } else if (t > 1) {
+ // Projection falls beyond v1, so the closest point is v1
+ closest_point = v1;
+ } else {
+ // Projection falls within the segment, compute the projection point
+ closest_point.x = v0.x + t * d.x;
+ closest_point.y = v0.y + t * d.y;
+ closest_point.z = v0.z + t * d.z;
+ }
+
+ // Calculate the distance between p and the closest point
+ Point3D closest_to_p = subtract(p, closest_point);
+ return magnitude(closest_to_p);
+}
+
+// Compute the distance between a point and a triangle face
+__device__ float pointToTriangleDistance(const Point3D& queryPoint, const Point3D& v0, const Point3D& v1, const Point3D& v2, bool inverse=false) {
+ // Edge vectors
+ Point3D edge0 = subtract(v1, v0);
+ Point3D edge1 = subtract(v2, v0);
+ if (is_identical(v0, v1) && is_identical(v0, v2))
+ return sqrtf(squaredDistance(queryPoint, v0));
+ if (is_identical(v0, v1))
+ return point_to_line_distance(queryPoint, v0, v2);
+ if (is_identical(v0, v2))
+ return point_to_line_distance(queryPoint, v0, v1);
+ if (is_identical(v1, v2))
+ return point_to_line_distance(queryPoint, v0, v1);
+ // Normal vector to the triangle plane
+ Point3D normal = cross(edge0, edge1);
+ if (inverse)
+ normal = cross(edge1, edge0);
+
+ // Vector from v0 to queryPoint
+ Point3D queryVec = subtract(queryPoint, v0);
+ if (dot(normal, normal)==0)
+ return sqrtf(dot(queryVec, queryVec));
+ normal = normalize(normal);
+ //return 1.0;
+
+ // Project the query point onto the triangle's plane
+ float distanceToPlane = dot(normal, queryVec); // / sqrtf(dot(normal, normal));
+
+// return fabsf(distanceToPlane);
+ Point3D projectionPoint = {
+ queryPoint.x - distanceToPlane * normal.x,
+ queryPoint.y - distanceToPlane * normal.y,
+ queryPoint.z - distanceToPlane * normal.z
+ };
+ // Check if the projection point is inside the triangle using barycentric coordinates
+ edge0 = subtract(v0, v1);
+ edge1 = subtract(v1, v2);
+ Point3D edge2 = subtract(v2, v0);
+ Point3D projVec0 = subtract(v0, projectionPoint);
+ Point3D projVec1 = subtract(v1, projectionPoint);
+ Point3D projVec2 = subtract(v2, projectionPoint);
+ Point3D c0 = cross(edge0, projVec0);
+ Point3D c1 = cross(edge1, projVec1);
+ Point3D c2 = cross(edge2, projVec2);
+ if (dot(c0, c1) > 0 && dot(c1, c2) > 0 && dot(c0, c2) > 0)
+ return fabsf(distanceToPlane);
+
+ // Otherwise, return the minimum distance to the triangle's edges
+ float minEdgeDistance = 1e6f;
+ minEdgeDistance = fmin(minEdgeDistance, point_to_line_distance(queryPoint, v0, v1));
+ minEdgeDistance = fmin(minEdgeDistance, point_to_line_distance(queryPoint, v0, v2));
+ minEdgeDistance = fmin(minEdgeDistance, point_to_line_distance(queryPoint, v1, v2));
+
+
+ return minEdgeDistance;
+}
+
+
+__device__ void updateUDF(Triangle t, int* udf, const int DIM, const float threshold) {
+ // Compute the bounding box of the triangle
+ float minX = fminf(fminf(t.v0.x, t.v1.x), t.v2.x);
+ float minY = fminf(fminf(t.v0.y, t.v1.y), t.v2.y);
+ float minZ = fminf(fminf(t.v0.z, t.v1.z), t.v2.z);
+ float maxX = fmaxf(fmaxf(t.v0.x, t.v1.x), t.v2.x);
+ float maxY = fmaxf(fmaxf(t.v0.y, t.v1.y), t.v2.y);
+ float maxZ = fmaxf(fmaxf(t.v0.z, t.v1.z), t.v2.z);
+
+ // Convert bounding box to grid coordinates
+ int iMin = max(0, (int)floorf((minX + 0.5) * (DIM-1)));
+ int jMin = max(0, (int)floorf((minY + 0.5) * (DIM-1)));
+ int kMin = max(0, (int)floorf((minZ + 0.5) * (DIM-1)));
+ int iMax = min(DIM - 1, (int)floorf((maxX + 0.5) * (DIM-1)));
+ int jMax = min(DIM - 1, (int)floorf((maxY + 0.5) * (DIM-1)));
+ int kMax = min(DIM - 1, (int)floorf((maxZ + 0.5) * (DIM-1)));
+
+ int range = (int)(threshold + 1);
+
+ // Make the bounding box larger than the original
+ iMax = min(DIM - 1, iMax + range);
+ iMin = max(0, iMin - range);
+ jMax = min(DIM - 1, jMax + range);
+ jMin = max(0, jMin - range);
+ kMax = min(DIM - 1, kMax + range);
+ kMin = max(0, kMin - range);
+
+ // Update the valid grids within the bounding box
+ for (int i = iMin; i <= iMax; ++i) {
+ for (int j = jMin; j <= jMax; ++j) {
+ for (int k = kMin; k <= kMax; ++k) {
+ int idx = i * DIM * DIM + j * DIM + k;
+
+ // Compute the distance from the query point to the triangle
+ Point3D queryPoint = {(float)i/(DIM-1) - 0.5, (float)j/(DIM-1) - 0.5, (float)k/(DIM-1) -0.5};
+ float distance = pointToTriangleDistance(queryPoint, t.v0, t.v1, t.v2);
+ float distance2 = pointToTriangleDistance(queryPoint, t.v0, t.v1, t.v2, true);
+ if (distance < threshold / DIM or distance2 < threshold / DIM){
+ //distance = distance2;
+ int int_dist = (int)(distance * 10000000);
+ atomicMin(&udf[idx], int_dist);
+ }
+ }
+
+ }
+ }
+}
+
+__global__ void compute_udf_kernel(float* vertices, int* faces, int * udf, int numTriangles, const int DIM, const float threshold) {
+ int t = blockIdx.x * blockDim.x + threadIdx.x;
+ if (t < numTriangles) {
+ int f0 = faces[t * 3 + 0];
+ int f1 = faces[t * 3 + 1];
+ int f2 = faces[t * 3 + 2];
+ Point3D v0 = {vertices[f0 * 3 + 0], vertices[f0 * 3 + 1], vertices[f0 * 3 + 2]};
+ Point3D v1 = {vertices[f1 * 3 + 0], vertices[f1 * 3 + 1], vertices[f1 * 3 + 2]};
+ Point3D v2 = {vertices[f2 * 3 + 0], vertices[f2 * 3 + 1], vertices[f2 * 3 + 2]};
+ Triangle triangle = {v0, v1, v2};
+ updateUDF(triangle, udf, DIM, threshold);
+ }
+}
+
+void compute_valid_udf_cuda(float* vertices, int* faces, int* udf, int numTriangles, const int DIM=512, const float threshold=8) {
+ int blockSize = 256;
+ int gridSize = (numTriangles + blockSize - 1) / blockSize;
+
+ // Launch the kernel
+ compute_udf_kernel<<>>(vertices, faces, udf, numTriangles, DIM, threshold);
+}
+
diff --git a/third_party/voxelize/udf_ext.cpython-310-x86_64-linux-gnu.so b/third_party/voxelize/udf_ext.cpython-310-x86_64-linux-gnu.so
new file mode 100644
index 0000000000000000000000000000000000000000..02b864c83d160429ce5c8caab893df304a2a7316
--- /dev/null
+++ b/third_party/voxelize/udf_ext.cpython-310-x86_64-linux-gnu.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dba5cb992519f31d01d20161f44fed88ec5bb7f4000a0efdf2f344222e54095d
+size 9334072
diff --git a/third_party/voxelize/udf_ext.egg-info/PKG-INFO b/third_party/voxelize/udf_ext.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..21cdbe2d582ff6f92783c01ac46ffd62f8c1a480
--- /dev/null
+++ b/third_party/voxelize/udf_ext.egg-info/PKG-INFO
@@ -0,0 +1,8 @@
+Metadata-Version: 2.1
+Name: udf-ext
+Version: 0.0.0
+Summary: UNKNOWN
+License: UNKNOWN
+Platform: UNKNOWN
+
+UNKNOWN
diff --git a/third_party/voxelize/udf_ext.egg-info/SOURCES.txt b/third_party/voxelize/udf_ext.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c8f3f0e01b4cc56421b16b394c3ad796ef613412
--- /dev/null
+++ b/third_party/voxelize/udf_ext.egg-info/SOURCES.txt
@@ -0,0 +1,7 @@
+setup.py
+src/udf_cuda.cpp
+src/udf_kernel.cu
+udf_ext.egg-info/PKG-INFO
+udf_ext.egg-info/SOURCES.txt
+udf_ext.egg-info/dependency_links.txt
+udf_ext.egg-info/top_level.txt
\ No newline at end of file
diff --git a/third_party/voxelize/udf_ext.egg-info/dependency_links.txt b/third_party/voxelize/udf_ext.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/third_party/voxelize/udf_ext.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/third_party/voxelize/udf_ext.egg-info/top_level.txt b/third_party/voxelize/udf_ext.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..07a52f59f2f35df1f9db761ec17412a7406b9802
--- /dev/null
+++ b/third_party/voxelize/udf_ext.egg-info/top_level.txt
@@ -0,0 +1 @@
+udf_ext