diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..34cac59b2042e14f6188dd775b5c1059b7149b92 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,77 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/teaserv6.png filter=lfs diff=lfs merge=lfs -text +assets/test/0.png filter=lfs diff=lfs merge=lfs -text +assets/test/1.png filter=lfs diff=lfs merge=lfs -text +assets/test/10.png filter=lfs diff=lfs merge=lfs -text +assets/test/11.png filter=lfs diff=lfs merge=lfs -text +assets/test/12.png filter=lfs diff=lfs merge=lfs -text +assets/test/13.png filter=lfs diff=lfs merge=lfs -text +assets/test/14.png filter=lfs diff=lfs merge=lfs -text +assets/test/15.png filter=lfs diff=lfs merge=lfs -text +assets/test/16.png filter=lfs diff=lfs merge=lfs -text +assets/test/17.png filter=lfs diff=lfs merge=lfs -text +assets/test/18.png filter=lfs diff=lfs merge=lfs -text +assets/test/19.png filter=lfs diff=lfs merge=lfs -text +assets/test/2.png filter=lfs diff=lfs merge=lfs -text +assets/test/20.png filter=lfs diff=lfs merge=lfs -text +assets/test/21.png filter=lfs diff=lfs merge=lfs -text +assets/test/22.png filter=lfs diff=lfs merge=lfs -text +assets/test/23.png filter=lfs diff=lfs merge=lfs -text +assets/test/24.png filter=lfs diff=lfs merge=lfs -text +assets/test/25.png filter=lfs diff=lfs merge=lfs -text +assets/test/26.png filter=lfs diff=lfs merge=lfs -text +assets/test/27.png filter=lfs diff=lfs merge=lfs -text +assets/test/28.png filter=lfs diff=lfs merge=lfs -text +assets/test/29.png filter=lfs diff=lfs merge=lfs -text +assets/test/3.png filter=lfs diff=lfs merge=lfs -text +assets/test/30.png filter=lfs diff=lfs merge=lfs -text +assets/test/31.png filter=lfs diff=lfs merge=lfs -text +assets/test/32.png filter=lfs diff=lfs merge=lfs -text +assets/test/33.png filter=lfs diff=lfs merge=lfs -text +assets/test/34.png filter=lfs diff=lfs merge=lfs -text +assets/test/35.png filter=lfs diff=lfs merge=lfs -text +assets/test/36.png filter=lfs diff=lfs merge=lfs -text +assets/test/37.png filter=lfs diff=lfs merge=lfs -text +assets/test/38.png filter=lfs diff=lfs merge=lfs -text +assets/test/39.png filter=lfs diff=lfs merge=lfs -text +assets/test/4.png filter=lfs diff=lfs merge=lfs -text +assets/test/40.png filter=lfs diff=lfs merge=lfs -text +assets/test/41.png filter=lfs diff=lfs merge=lfs -text +assets/test/42.png filter=lfs diff=lfs merge=lfs -text +assets/test/43.png filter=lfs diff=lfs merge=lfs -text +assets/test/44.png filter=lfs diff=lfs merge=lfs -text +assets/test/45.png filter=lfs diff=lfs merge=lfs -text +assets/test/46.png filter=lfs diff=lfs merge=lfs -text +assets/test/47.png filter=lfs diff=lfs merge=lfs -text +assets/test/48.png filter=lfs diff=lfs merge=lfs -text +assets/test/49.png filter=lfs diff=lfs merge=lfs -text +assets/test/5.png filter=lfs diff=lfs merge=lfs -text +assets/test/50.png filter=lfs diff=lfs merge=lfs -text +assets/test/51.png filter=lfs diff=lfs merge=lfs -text +assets/test/52.png filter=lfs diff=lfs merge=lfs -text +assets/test/53.png filter=lfs diff=lfs merge=lfs -text +assets/test/54.png filter=lfs diff=lfs merge=lfs -text +assets/test/55.png filter=lfs diff=lfs merge=lfs -text +assets/test/56.png filter=lfs diff=lfs merge=lfs -text +assets/test/57.png filter=lfs diff=lfs merge=lfs -text +assets/test/58.png filter=lfs diff=lfs merge=lfs -text +assets/test/59.png filter=lfs diff=lfs merge=lfs -text +assets/test/6.png filter=lfs diff=lfs merge=lfs -text +assets/test/60.png filter=lfs diff=lfs merge=lfs -text +assets/test/61.png filter=lfs diff=lfs merge=lfs -text +assets/test/62.png filter=lfs diff=lfs merge=lfs -text +assets/test/63.png filter=lfs diff=lfs merge=lfs -text +assets/test/64.png filter=lfs diff=lfs merge=lfs -text +assets/test/65.png filter=lfs diff=lfs merge=lfs -text +assets/test/66.png filter=lfs diff=lfs merge=lfs -text +assets/test/67.png filter=lfs diff=lfs merge=lfs -text +assets/test/7.png filter=lfs diff=lfs merge=lfs -text +assets/test/8.png filter=lfs diff=lfs merge=lfs -text +assets/test/9.png filter=lfs diff=lfs merge=lfs -text +third_party/voxelize/build/lib.linux-x86_64-3.10/udf_ext.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_deps filter=lfs diff=lfs merge=lfs -text +third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o filter=lfs diff=lfs merge=lfs -text +third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o filter=lfs diff=lfs merge=lfs -text +third_party/voxelize/udf_ext.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..225066c568a99b8715cdabebc8cf90ff1cf7ca73 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 DreamTechAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0229b5814a82398f83552e047575f9dd8be45325 --- /dev/null +++ b/app.py @@ -0,0 +1,207 @@ + +import torch +import trimesh +import datetime +import argparse +import numpy as np +from torchvision import transforms +from direct3d_s2.utils.rembg import BiRefNet +from direct3d_s2.pipeline import Direct3DS2Pipeline +from direct3d_s2.utils.fill_hole import postprocess_mesh + +import os +from PIL import Image +from typing import Any + +import gradio as gr +from gradio.themes.utils import colors, fonts, sizes + +# ----------------------------------------------------------------------------- +# THEME ▸ a soft glass-like dark theme with a vibrant primary accent +# ----------------------------------------------------------------------------- +class Glass(gr.themes.Soft): + def __init__(self): + super().__init__( + primary_hue=colors.emerald, + secondary_hue=colors.indigo, + neutral_hue=colors.zinc, + text_size=sizes.text_md, + spacing_size=sizes.spacing_md, + radius_size=sizes.radius_lg, + font=fonts.GoogleFont("Inter"), + ) + + def style(self): + super().style() + self.set( + background_fill="var(--neutral-950)", + border_color_primary="rgba(255,255,255,.12)", + border_width="1px", + shadow_drop="0 10px 38px -10px rgba(0,0,0,.65)", + shadow_drop_lg="0 10px 38px -10px rgba(0,0,0,.65)", + ) + return self + +def check_input_image(input_image): + if input_image is None: + raise gr.Error("No image uploaded!") + +# ----------------------------------------------------------------------------- +# PLACEHOLDER BACK-END HOOKS ▸ replace with your real logic +# ----------------------------------------------------------------------------- +def image2mesh( + image: Any, + resolution: str = '1024', + simplify: bool = True, + simplify_ratio: float = 0.95, + output_path: str = 'outputs/web' +): + + torch.cuda.empty_cache() + + if not os.path.exists(output_path): + os.makedirs(output_path) + + uid = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + image.save(os.path.join(output_path, uid + '.png')) + + pipe = Direct3DS2Pipeline.from_pretrained('wushuang98/Direct3D-S2', subfolder="direct3d-s2-v-1-1") + pipe.to("cuda:0") + + mesh = pipe( + image, + sdf_resolution=int(resolution), + mc_threshold=0.2, + remesh=simplify, + simplify_ratio=simplify_ratio, + )["mesh"] + + mesh_path = os.path.join(output_path, f'{uid}.obj') + mesh.export( + mesh_path, + include_normals=True, + ) + torch.cuda.empty_cache() + + return mesh_path + +# ----------------------------------------------------------------------------- +# UI LAYOUT ▸ minimal glassmorphism, keyboard-first workflow +# ----------------------------------------------------------------------------- +with gr.Blocks(theme=Glass(), css=""" +:root { --header-height:64px } +body { background:linear-gradient(215deg,#101113 0%,#0b0c0d 60%,#0d1014 100%) } +#header { height:var(--header-height);display:flex;align-items:center;justify-content:space-between;padding:0 1.5rem;backdrop-filter:blur(18px);background:rgba(17,17,17,.65);border-bottom:1px solid rgba(255,255,255,.08);position:sticky;top:0;z-index:999 } +#header a { color:white;font-weight:500;text-decoration:none;margin-right:1.25rem;font-size:.925rem } +#hero-title { font-size:1.35rem;font-weight:600;color:white;white-space:nowrap } +#footer { text-align:center;font-size:.8rem;color:rgba(255,255,255,.55);margin-top:1.5rem } +#mesh_viewport { aspect-ratio:1/1;width:100%;display:flex;align-items:center;justify-content:center;border:1px dashed rgba(255,255,255,.12);border-radius:12px;background:rgba(255,255,255,.03); } +.gallery-item img { border-radius:10px } +#examples_gallery { height:100%;flex:1;display:flex;flex-direction:column; } +#examples_gallery img { width:800px;} +#show_image img { height:260px;display:flex;align-items:center;justify-content:center; } +#examples { height:100%;flex:1; } +""") as demo: + + # ▸ custom sticky header + with gr.Row(elem_id="header", variant="panel"): + gr.Markdown("Direct3D-S2 Studio", elem_id="hero-title") + gr.Markdown( + """ + """, + elem_id="nav-links", + ) + + # ▸ main workspace + with gr.Row(equal_height=True): + # ---------- Controls ---------- + with gr.Column(scale=3): + gr.Markdown("### Input", elem_classes="subtitle") + image_input = gr.Image( + label="Image Input", + image_mode="RGBA", + sources="upload", + type="pil", + height=260, + elem_id="show_image", + ) + # gr.Markdown("
Drag & drop or click to upload
") + processed_image = gr.Image( + label="Processed Image", + image_mode="RGBA", + type="pil", + interactive=False, + height=260, + elem_id="show_image", + ) + with gr.Accordion("Advanced Options", open=True): + resolution = gr.Radio(choices=["512", "1024"], label="SDF Resolution", value="1024") + simplify = gr.Checkbox(label="Simplify Mesh", value=True) + reduce_ratio = gr.Slider(0.1, 0.95, step=0.05, value=0.95, label="Faces Reduction Ratio") + + gen_btn = gr.Button("Generate 3D ✨", variant="primary", interactive=True) + + # ---------- Viewport ---------- + with gr.Column(scale=6): + gr.Markdown("### Model Viewer", elem_classes="subtitle") + # mesh_html = gr.HTML("
🌀 No mesh yet
") + output_model_obj = gr.Model3D( + label="Output Model (OBJ Format)", + camera_position=(90.0, 90.0, 3.5), + interactive=False, + elem_id="mesh_viewport", + ) + + # ---------- Gallery / Examples ---------- + with gr.Column(scale=3): + gr.Markdown("### Examples", elem_classes="subtitle") + # gr.Examples( + # examples=[os.path.join("assets/test", i) for i in os.listdir("assets/test")], + # inputs=[image_input], + # examples_per_page=8, + # label="Gallery", + # elem_id="examples_gallery", + # ) + with gr.Tabs(selected='tab_img_gallery') as gallery: + with gr.Tab('Image to 3D Gallery', id='tab_img_gallery') as tab_gi: + with gr.Row(): + gr.Examples( + examples=[os.path.join("assets/test", i) for i in os.listdir("assets/test")], + inputs=[image_input], + label=None, + examples_per_page=24 + ) + # gallery = gr.Gallery( + # [os.path.join("assets/test", i) for i in os.listdir("assets/test")], + # columns=2, + # object_fit="contain", + # elem_id="examples_gallery", + # allow_preview=False, + # ) + + + # ▸ callbacks + outputs = [output_model_obj] + rmbg = BiRefNet(device="cuda:0") + + gen_btn.click( + fn=check_input_image, + inputs=[image_input] + ).success( + fn=rmbg.run, + inputs=[image_input], + outputs=[processed_image] + ).success( + fn=image2mesh, + inputs=[processed_image, resolution, simplify, reduce_ratio], + outputs=outputs, + api_name="generate_img2obj" + ) + +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--cached_dir", type=str, default="outputs/web") + args = parser.parse_args() + + demo.queue().launch(share=True, allowed_paths=[args.cached_dir], server_port=7860) diff --git a/assets/teaserv6.png b/assets/teaserv6.png new file mode 100644 index 0000000000000000000000000000000000000000..9e4af78c005e2ce1e489b19d4f73125b973d5bba --- /dev/null +++ b/assets/teaserv6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e58eefff81cee52dbc11d65949fe99c3e8b23e2a0a5e677505aca8799a22564 +size 3282197 diff --git a/assets/test/0.png b/assets/test/0.png new file mode 100644 index 0000000000000000000000000000000000000000..ce6173cb73e77c760c0812118fafeea74730f64b --- /dev/null +++ b/assets/test/0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11b98466291aaf59b56a5149a3acebd69d419876b8b5905041f92314889c90e4 +size 432420 diff --git a/assets/test/1.png b/assets/test/1.png new file mode 100644 index 0000000000000000000000000000000000000000..3863da51b726f88d7fa9fec08ca4cc63a686d72e --- /dev/null +++ b/assets/test/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c034e0e62e82a70194ecfd751c27b2f538cbf1eeb304029b40d069c92546033a +size 428055 diff --git a/assets/test/10.png b/assets/test/10.png new file mode 100644 index 0000000000000000000000000000000000000000..2d267096771f514299064ef873f9390c2cb41fe0 --- /dev/null +++ b/assets/test/10.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e1121ea0f9bc6ac7469a0d3db15c445f9694fb4375a3a49032cef209179d1e2 +size 1083093 diff --git a/assets/test/11.png b/assets/test/11.png new file mode 100644 index 0000000000000000000000000000000000000000..9b9ce9d0c9cd2c5b44412fe924c9c8533eacbe57 --- /dev/null +++ b/assets/test/11.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2284a6d065826ad4a4293a662781a8b506c4ec61890b52ee6855af5b029f4e8f +size 1810882 diff --git a/assets/test/12.png b/assets/test/12.png new file mode 100644 index 0000000000000000000000000000000000000000..a50c9eb3fe1496f6a3448ccba1862035227fa1f2 --- /dev/null +++ b/assets/test/12.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a904d6b0ac11c75d7d0409ca46beb3e80bbdf72fd8272e268230fcc5e8deab7c +size 953366 diff --git a/assets/test/13.png b/assets/test/13.png new file mode 100644 index 0000000000000000000000000000000000000000..8cdbc86b61339c1b16f9c255633697d68c3ec5e5 --- /dev/null +++ b/assets/test/13.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:600a57dcad4c4a1f06b0d52aac68e762c42ee86f17c6ce0c9a2f59b62df00b54 +size 1768898 diff --git a/assets/test/14.png b/assets/test/14.png new file mode 100644 index 0000000000000000000000000000000000000000..9fe23cc235caabe20470240b962557f71bf21a8a --- /dev/null +++ b/assets/test/14.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dccf9a1727df51d81d7dd7f4168f8db4df40123f453474bc84dee3cd97bdde82 +size 1989606 diff --git a/assets/test/15.png b/assets/test/15.png new file mode 100644 index 0000000000000000000000000000000000000000..0731bb00e1a8e7ce7349feccfcad8f98e0096f9c --- /dev/null +++ b/assets/test/15.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6025c32c9a0bf8d88c2f66985234a6a283ebd3cecabe16be1be0e1f9d89147f9 +size 1113661 diff --git a/assets/test/16.png b/assets/test/16.png new file mode 100644 index 0000000000000000000000000000000000000000..3f48cf6cc141087bb3f3e611342bb823d4f05114 --- /dev/null +++ b/assets/test/16.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:817ca70e9af0ad89af747fb389e2be833aefc3becbe9d86506bd53452cea12c8 +size 2265160 diff --git a/assets/test/17.png b/assets/test/17.png new file mode 100644 index 0000000000000000000000000000000000000000..41bf8216b6f87a4ed2543d6dbbfd31da9c37a8f8 --- /dev/null +++ b/assets/test/17.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4de91821f1e3d7db6834ba679aa1b7f78c7a52f5ce9c2cd0c44ccbf6d649bb7d +size 603365 diff --git a/assets/test/18.png b/assets/test/18.png new file mode 100644 index 0000000000000000000000000000000000000000..8504961d3657175523ffc7733bb816825139a67d --- /dev/null +++ b/assets/test/18.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3cad5e839a8c54334c28aa08aec2d185b2ce3e0a96c0f8dca901d7bc8535079c +size 1425119 diff --git a/assets/test/19.png b/assets/test/19.png new file mode 100644 index 0000000000000000000000000000000000000000..2bc3df887c72ff0cc7ac92f223e2c80cd2d5d1ff --- /dev/null +++ b/assets/test/19.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f69ec7094bdb98bf69d8345c4cc39e00648cc74721879d43c784ad774786e6af +size 1171141 diff --git a/assets/test/2.png b/assets/test/2.png new file mode 100644 index 0000000000000000000000000000000000000000..2c186cf08bf92cacfc5b54757bec340b5d3a0694 --- /dev/null +++ b/assets/test/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c01e935783d7a81eb9024c6b0ca450fe3d69cdf92a816ae1e03ff57c78024875 +size 848588 diff --git a/assets/test/20.png b/assets/test/20.png new file mode 100644 index 0000000000000000000000000000000000000000..4d7e40882784eb7bafeae6b9f1ffa9adf68a967b --- /dev/null +++ b/assets/test/20.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97ba5a7b42ec317aee9ee4d3cdbd62ac507c3554f4c1dca7ce0f635039cc16e3 +size 601441 diff --git a/assets/test/21.png b/assets/test/21.png new file mode 100644 index 0000000000000000000000000000000000000000..a8b79f3a2a8ebabfad60821e6362b25bf035c395 --- /dev/null +++ b/assets/test/21.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ae71e62d5893669426d94226fe376f4bc4269dfdb83a465e0510b6d8002deb5 +size 808986 diff --git a/assets/test/22.png b/assets/test/22.png new file mode 100644 index 0000000000000000000000000000000000000000..99907ec3d36814eb5e58ca9a0c2c9dc98dec1998 --- /dev/null +++ b/assets/test/22.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff910bc20816d8a55ff08635a0228abbe821f8cd8359c86ef3f044a16cebe6f7 +size 1175586 diff --git a/assets/test/23.png b/assets/test/23.png new file mode 100644 index 0000000000000000000000000000000000000000..8da4acda1d51cc83eacb22c46a6fb9602940b51c --- /dev/null +++ b/assets/test/23.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f4c57797cea4d29e465d556b63c2d367fa0f4469fd3adf31ff686665cce38a1 +size 834642 diff --git a/assets/test/24.png b/assets/test/24.png new file mode 100644 index 0000000000000000000000000000000000000000..561190e7fd82ffc201824a4832f85521e7101c15 --- /dev/null +++ b/assets/test/24.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab1feb4c8b0e507f16248c65826961820b0367ea897705c8cc19a802824a2a06 +size 1805891 diff --git a/assets/test/25.png b/assets/test/25.png new file mode 100644 index 0000000000000000000000000000000000000000..23ea0c33413ea8cab436d7d12c150782cdb243c1 --- /dev/null +++ b/assets/test/25.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dddb4e689ef77d87c3849e021e322f9b30ce384ed83ba60c6b0e4b55ec5202f2 +size 1193255 diff --git a/assets/test/26.png b/assets/test/26.png new file mode 100644 index 0000000000000000000000000000000000000000..ec3004593f76ff2c199756915c6a78cda03148ab --- /dev/null +++ b/assets/test/26.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96e0987b6d3c0e7a883667dfb2e52461fad04a621a43a95a2ec1544c0e5b8683 +size 1037838 diff --git a/assets/test/27.png b/assets/test/27.png new file mode 100644 index 0000000000000000000000000000000000000000..974a3047b590d8d67b7a02eb3d74d9b2fc791d74 --- /dev/null +++ b/assets/test/27.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:711cd06fec16402e73c01bd80e202ae91b3f18c3d9372b57447c00d2a4491d03 +size 1489023 diff --git a/assets/test/28.png b/assets/test/28.png new file mode 100644 index 0000000000000000000000000000000000000000..99a214e54a7098ddb87b15acfd802d495aea7493 --- /dev/null +++ b/assets/test/28.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c87191606ad374fc506857050f44e8d8af0e8f1227391cd98bab6eb5bf7dc1c2 +size 1637020 diff --git a/assets/test/29.png b/assets/test/29.png new file mode 100644 index 0000000000000000000000000000000000000000..843926b25a4eba8be295be228caacf38725b8eb6 --- /dev/null +++ b/assets/test/29.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b0f9ba7069e600b470fa8ba3cfcff890a9bd5d481eaab7d25fe190e69fdf4ac +size 1579767 diff --git a/assets/test/3.png b/assets/test/3.png new file mode 100644 index 0000000000000000000000000000000000000000..7f86106be5d6ae8f37f27d6d97e87a3651afba9b --- /dev/null +++ b/assets/test/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16af3c8ee909ef2a5e13f996a9b79029bc385c1ccbd9a8cf0613081ead844e63 +size 1384607 diff --git a/assets/test/30.png b/assets/test/30.png new file mode 100644 index 0000000000000000000000000000000000000000..d5b33df498d2355533acf6b7a2d1536562e9c855 --- /dev/null +++ b/assets/test/30.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7c925eefcf85714df9fbf4d602e9ad604d7efc2a87b57d72741049737dc9747 +size 2476255 diff --git a/assets/test/31.png b/assets/test/31.png new file mode 100644 index 0000000000000000000000000000000000000000..16b8e97073d4f77bb6eec6f810389612b37535f3 --- /dev/null +++ b/assets/test/31.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:873a12cd691b851939e1558d203021462543f0cb8167d65fa468ec93793192b6 +size 1480380 diff --git a/assets/test/32.png b/assets/test/32.png new file mode 100644 index 0000000000000000000000000000000000000000..07810ae63439da49dcf6d3abb0bc9acb348ba864 --- /dev/null +++ b/assets/test/32.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa8ff1fcc5b43d9382bb196665b624d19ce5a5204b618c9adb87520a460efee6 +size 1114250 diff --git a/assets/test/33.png b/assets/test/33.png new file mode 100644 index 0000000000000000000000000000000000000000..44658a0bc404a3a2b0d42106b6b7e37c0f6dc42e --- /dev/null +++ b/assets/test/33.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bac56a98e7ff510cfb628806c11abee90ea7f66dd6997a1e690a30aa2af7560 +size 1025378 diff --git a/assets/test/34.png b/assets/test/34.png new file mode 100644 index 0000000000000000000000000000000000000000..d268b7332790b789c1efbc9603dd3b36801f77ef --- /dev/null +++ b/assets/test/34.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ed6537fe32aa665fe6ff180cbd08ccc3dca87da382e4165e4e8865e26d80bf3 +size 2005294 diff --git a/assets/test/35.png b/assets/test/35.png new file mode 100644 index 0000000000000000000000000000000000000000..d4f189675bc888238896eab93f4aefa6e2e32a9b --- /dev/null +++ b/assets/test/35.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86a171e37a3d781e7215977f565cd63e813341c1f89e2c586fa61937e4ed6916 +size 481919 diff --git a/assets/test/36.png b/assets/test/36.png new file mode 100644 index 0000000000000000000000000000000000000000..3dddee1020e052155d2e8f404d982f45786c5d07 --- /dev/null +++ b/assets/test/36.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27a418853eefa197f1e10ed944a7bb071413fd2bc1681804ee773a6ce3799c52 +size 712062 diff --git a/assets/test/37.png b/assets/test/37.png new file mode 100644 index 0000000000000000000000000000000000000000..18e37dc2fa5cc9fd7cd4003828e85514f0b0f780 --- /dev/null +++ b/assets/test/37.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aecbc5712f300ec67fb01d79cd758e7cb5da4c11c9e19d4a8e72d05275016766 +size 1890500 diff --git a/assets/test/38.png b/assets/test/38.png new file mode 100644 index 0000000000000000000000000000000000000000..293dc7dc50f3fc6fb884c8a0ba64a86df178a1b8 --- /dev/null +++ b/assets/test/38.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98bab768078960656fab3755af69954055fa4c65561376085e7a80431dbc7b9d +size 1792829 diff --git a/assets/test/39.png b/assets/test/39.png new file mode 100644 index 0000000000000000000000000000000000000000..92dfce03092f39dda00ba86f9499c280bc26f506 --- /dev/null +++ b/assets/test/39.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66b568e2951c2db7c59a73b4498cac011535b414f7d7d4327235039467f5429f +size 2004638 diff --git a/assets/test/4.png b/assets/test/4.png new file mode 100644 index 0000000000000000000000000000000000000000..955991400960ab8c6ca8f6b876097823be39f940 --- /dev/null +++ b/assets/test/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5546b5268a51f20dc15eaba836ba09954ec98fb347f5c0b4baefccac0f596757 +size 1505047 diff --git a/assets/test/40.png b/assets/test/40.png new file mode 100644 index 0000000000000000000000000000000000000000..a1a72358f3aa124cf5835b0460205524d3c27d0b --- /dev/null +++ b/assets/test/40.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e75bebfbefd8f67818c5e051cac3cb4853b418974a364fff8a14c9a4b7e7eba8 +size 686837 diff --git a/assets/test/41.png b/assets/test/41.png new file mode 100644 index 0000000000000000000000000000000000000000..bdf9dc4534a26b05e13e42f339f65b9caa76df28 --- /dev/null +++ b/assets/test/41.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a95ac40d2ae1c978baa64c4e41c1c68ee5742c71236ae330de6e6dbb148dc84 +size 611258 diff --git a/assets/test/42.png b/assets/test/42.png new file mode 100644 index 0000000000000000000000000000000000000000..3a0e2fe912c7f5818a6efa34e8951e37820d24f0 --- /dev/null +++ b/assets/test/42.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0449cf9ab9dbe9b044585da4b940b9ad2d2fda158a6eea3a8abf3555f9e4bc9d +size 1687128 diff --git a/assets/test/43.png b/assets/test/43.png new file mode 100644 index 0000000000000000000000000000000000000000..d1e9fd9fa970998582cd58e3b6e0b7544547022e --- /dev/null +++ b/assets/test/43.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c6928cdeb8ee4dcef2d4c4e6b5dd0c69f2ecc3885daeb1cdf9bfbc40da0c01e +size 1447356 diff --git a/assets/test/44.png b/assets/test/44.png new file mode 100644 index 0000000000000000000000000000000000000000..d5b9eb33cb3aecab7070e13e6a543f52a8c3aef8 --- /dev/null +++ b/assets/test/44.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:507abeb3f7698e5e929152b4878152626aa1270554bfa5a84ff74dc6114ea3c1 +size 1978541 diff --git a/assets/test/45.png b/assets/test/45.png new file mode 100644 index 0000000000000000000000000000000000000000..a702921607bb16938c02c99766d49abf4fa880ab --- /dev/null +++ b/assets/test/45.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e89b245fbe25abe19d817866a2c38b3b2c7d130a6b21b1ec6391975671d4d06d +size 1951359 diff --git a/assets/test/46.png b/assets/test/46.png new file mode 100644 index 0000000000000000000000000000000000000000..a6f3bf5162ef0d3776be36a6d0a2ac061f56a75c --- /dev/null +++ b/assets/test/46.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddddfdeb15b82d85b90684b2aba221705283894f2f69532fcf84512107b85f67 +size 1162177 diff --git a/assets/test/47.png b/assets/test/47.png new file mode 100644 index 0000000000000000000000000000000000000000..aa56b73dbb206ba5dfc03b08f61d87110b64ee0b --- /dev/null +++ b/assets/test/47.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:060a286bffcf9ee71404b62f240253788ea4b15b026f3583b06111070027fa0d +size 1896047 diff --git a/assets/test/48.png b/assets/test/48.png new file mode 100644 index 0000000000000000000000000000000000000000..a599f773bcb21f82bccfa1a5d0fde5c4a324dff0 --- /dev/null +++ b/assets/test/48.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b28719959fc116464429e1c677f6676f2b37e8ccd592c4f1d08400bc48f3adac +size 2301834 diff --git a/assets/test/49.png b/assets/test/49.png new file mode 100644 index 0000000000000000000000000000000000000000..66a1406029617fdf2557efb8a8011ed08a6a0810 --- /dev/null +++ b/assets/test/49.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94f36bdae81c4c7cf6b4a402b704bb29aa1a56ca6cee7671a41ca041c0560203 +size 2206909 diff --git a/assets/test/5.png b/assets/test/5.png new file mode 100644 index 0000000000000000000000000000000000000000..61571a7ac4d2ce6e8e0d234cc6f0a8e13218ccab --- /dev/null +++ b/assets/test/5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:412c5f398bab3dc08395a8140777d883b3291eda08e2d6346cb93b4aeb42407e +size 1638078 diff --git a/assets/test/50.png b/assets/test/50.png new file mode 100644 index 0000000000000000000000000000000000000000..88a413563ecb21ac5b83c541b89bb27d7772d5dd --- /dev/null +++ b/assets/test/50.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45e7810d3358274a0f32a4ecaf57a04e88d3438c46cd0a6bd76a3e4bcf149e6b +size 2293203 diff --git a/assets/test/51.png b/assets/test/51.png new file mode 100644 index 0000000000000000000000000000000000000000..95f0d876e8380f7560d0d4aaab7d3db44fae7e15 --- /dev/null +++ b/assets/test/51.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9cea2e8947d4c57a95157836fc4a7c0d36335748f394af01a9f27f6b933ea82 +size 1359226 diff --git a/assets/test/52.png b/assets/test/52.png new file mode 100644 index 0000000000000000000000000000000000000000..4bd172e1211c480e2e9b39613d897f0acfdc5678 --- /dev/null +++ b/assets/test/52.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:402758b579b0d96043c613c93125a59e183b6a06acb9c0fad7dc30f1c1a3d40c +size 862029 diff --git a/assets/test/53.png b/assets/test/53.png new file mode 100644 index 0000000000000000000000000000000000000000..38fdc4c21ef6f7f5d8623948d43d81acfafd2708 --- /dev/null +++ b/assets/test/53.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fdd82d60b7ec11e6d5699df29693d8ab538f9dab4b04e3f2abaa59ccd7b4709a +size 226408 diff --git a/assets/test/54.png b/assets/test/54.png new file mode 100644 index 0000000000000000000000000000000000000000..ff256723431c2729c3eff55f067fb6f7120d47ee --- /dev/null +++ b/assets/test/54.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5793b52a21507a5cdd661f8d87681b303e420057a8c65aaa16e0560409a7a34e +size 724846 diff --git a/assets/test/55.png b/assets/test/55.png new file mode 100644 index 0000000000000000000000000000000000000000..4e82000c43b029e236b85b135dbba66ea04d4871 --- /dev/null +++ b/assets/test/55.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a6ff25b573a607605c607815f99b566002f47dc5f69cfe145248b8c0a4855f5 +size 883754 diff --git a/assets/test/56.png b/assets/test/56.png new file mode 100644 index 0000000000000000000000000000000000000000..7cf47bc68a3197540c765c8f8e246848c00e0e49 --- /dev/null +++ b/assets/test/56.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43a71e868ab4a5b174ceb5ceaa0a7e76f6a0a635acc246eaaa7b83da47357b82 +size 1041398 diff --git a/assets/test/57.png b/assets/test/57.png new file mode 100644 index 0000000000000000000000000000000000000000..fb69455a8960874fba38700b2c9a495704d0a283 --- /dev/null +++ b/assets/test/57.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5677117de95f646acd0f4455d32bae8cabb2fef7567a64bab9c8638e8c98cbbb +size 862063 diff --git a/assets/test/58.png b/assets/test/58.png new file mode 100644 index 0000000000000000000000000000000000000000..35285ba5c6f01bbcbee4a085e8c3e2379446ef19 --- /dev/null +++ b/assets/test/58.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b235ed827731813445fd2413aaddb7bef9523d20d23ce1dff7d39fd2bc6829da +size 248153 diff --git a/assets/test/59.png b/assets/test/59.png new file mode 100644 index 0000000000000000000000000000000000000000..0784bcdfd94be087b1fbc441a25c302cf2b5f361 --- /dev/null +++ b/assets/test/59.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73d09b20fb4e7ab6513a3ac99f7c87166ec655a731bd1be9b14cec6aef36c4ef +size 1509264 diff --git a/assets/test/6.png b/assets/test/6.png new file mode 100644 index 0000000000000000000000000000000000000000..9238b926417cf808022822562f148dcdc4213e2f --- /dev/null +++ b/assets/test/6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c5472bb12b775c18b37fcb889069915943c4b6122fda668d3a1df8003e9e5da +size 1986433 diff --git a/assets/test/60.png b/assets/test/60.png new file mode 100644 index 0000000000000000000000000000000000000000..4542fab15508116ea9c49d764a997263757f5f73 --- /dev/null +++ b/assets/test/60.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:adba4a0efaad6e22b667b516a2724d62a40b09726578cd71dc3b2abcfcb7c5a2 +size 437697 diff --git a/assets/test/61.png b/assets/test/61.png new file mode 100644 index 0000000000000000000000000000000000000000..1546f322acc31caeb524a518979afbffe8de197f --- /dev/null +++ b/assets/test/61.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7e716abe8f8895080f562d1dc26b14fa0e20a05aa5beb2770c6fb3b87b3476a +size 594232 diff --git a/assets/test/62.png b/assets/test/62.png new file mode 100644 index 0000000000000000000000000000000000000000..4c633d74618978d918461428d1d76df8805c7554 --- /dev/null +++ b/assets/test/62.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70f86ca18bf83b67e57d74ba01eb045047fe407353a69f9d1bc75015d1606346 +size 591829 diff --git a/assets/test/63.png b/assets/test/63.png new file mode 100644 index 0000000000000000000000000000000000000000..c966e77ebf366c679ad749dba14a06616f3d2c89 --- /dev/null +++ b/assets/test/63.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8799fd1b56df540f2c7471312be5959fbf49d33d2386e59b41cb7138ad4694b3 +size 646495 diff --git a/assets/test/64.png b/assets/test/64.png new file mode 100644 index 0000000000000000000000000000000000000000..3536dd866f2332d60eb149cc1243ea5c5392e228 --- /dev/null +++ b/assets/test/64.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97ba7df8adf0e956c5cbf85fafb127b9b21a7873c33629a476af40bac91d5655 +size 397994 diff --git a/assets/test/65.png b/assets/test/65.png new file mode 100644 index 0000000000000000000000000000000000000000..9f29b4f1fa5537ec16bbf9835cae5bb9e75640de --- /dev/null +++ b/assets/test/65.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:907cacfd8e66fd8424389e6560e66ecdd98544f31076ad1e0c9b6434a47a9747 +size 980850 diff --git a/assets/test/66.png b/assets/test/66.png new file mode 100644 index 0000000000000000000000000000000000000000..51b786a1d5cd52c9e7207961693172800ca336b8 --- /dev/null +++ b/assets/test/66.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8236aeb0cf0923df29e3d76f3cbe6cc61c9da85ea4b43a4e7cf81614fe750cd +size 1182506 diff --git a/assets/test/67.png b/assets/test/67.png new file mode 100644 index 0000000000000000000000000000000000000000..43173a46412aec0a4109bdc408171e383262042e --- /dev/null +++ b/assets/test/67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:134136dd4086cfc1b887ab0a134c4a2b906223762a0d5959a8b90cc68f11f4f0 +size 1490139 diff --git a/assets/test/7.png b/assets/test/7.png new file mode 100644 index 0000000000000000000000000000000000000000..754a251ab29736f9cb564a2c43b7b4f47317f4e5 --- /dev/null +++ b/assets/test/7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f820d9e9fe86192e6e1a70dc2430f3189bf620aff6e16aa817e6b99da286c424 +size 990354 diff --git a/assets/test/8.png b/assets/test/8.png new file mode 100644 index 0000000000000000000000000000000000000000..76541834fc6e5e4120c225238e412596e9199ec2 --- /dev/null +++ b/assets/test/8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8bd5e913171ad456d607442af4ad89bd6e0da6cd819aedcd76d0fa1c2d4ff655 +size 1161291 diff --git a/assets/test/9.png b/assets/test/9.png new file mode 100644 index 0000000000000000000000000000000000000000..0f16b1297ac470b4a444965eac7748723ee027dd --- /dev/null +++ b/assets/test/9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c22e9e601fead00398634e083e2ed1e90d0c60f459757ae19f5909f3ee5481aa +size 1029084 diff --git a/direct3d_s2/__pycache__/pipeline.cpython-310.pyc b/direct3d_s2/__pycache__/pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7d119d458e52959e75cd802b7a58f1ee690fb8e Binary files /dev/null and b/direct3d_s2/__pycache__/pipeline.cpython-310.pyc differ diff --git a/direct3d_s2/models/__init__.py b/direct3d_s2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/direct3d_s2/models/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35967be27d8f00a7776ebecaa3586e0988a4fd1f Binary files /dev/null and b/direct3d_s2/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/models/__pycache__/conditioner.cpython-310.pyc b/direct3d_s2/models/__pycache__/conditioner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5208e75587b28c6621a1973cb5ce3d90e4bdd2cb Binary files /dev/null and b/direct3d_s2/models/__pycache__/conditioner.cpython-310.pyc differ diff --git a/direct3d_s2/models/autoencoders/__pycache__/base.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28713e652176cfb9ced950d4f71f8dabf564205b Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/base.cpython-310.pyc differ diff --git a/direct3d_s2/models/autoencoders/__pycache__/decoder.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8a4507a17cd98df6fb5f411b20d07c00f566f1b Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/decoder.cpython-310.pyc differ diff --git a/direct3d_s2/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce375b2d5118ca2ecce50444ff96216e9d9bb744 Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc differ diff --git a/direct3d_s2/models/autoencoders/__pycache__/distributions.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/distributions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc13439a4a922a3ae606562e28746aa0999a762f Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/distributions.cpython-310.pyc differ diff --git a/direct3d_s2/models/autoencoders/__pycache__/encoder.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c8314f083a3dd0b61a39b35a9153c029136e1e0 Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/encoder.cpython-310.pyc differ diff --git a/direct3d_s2/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc b/direct3d_s2/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1377751eb6d3e498ed44126e3822f693447a8fc6 Binary files /dev/null and b/direct3d_s2/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc differ diff --git a/direct3d_s2/models/autoencoders/base.py b/direct3d_s2/models/autoencoders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ffeea9be7a588d0de1de9aae07ba00dbca99e1 --- /dev/null +++ b/direct3d_s2/models/autoencoders/base.py @@ -0,0 +1,118 @@ +from typing import * +import torch +import torch.nn as nn + +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from ...modules.transformer import AbsolutePositionEmbedder +from ...modules.sparse.transformer import SparseTransformerBlock + + +def block_attn_config(self): + """ + Return the attention configuration of the model. + """ + for i in range(self.num_blocks): + if self.attn_mode == "shift_window": + yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_sequence": + yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_order": + yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] + elif self.attn_mode == "full": + yield "full", None, None, None, None + elif self.attn_mode == "swin": + yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None + + +class SparseTransformerBase(nn.Module): + """ + Sparse Transformer without output layers. + Serve as the base class for encoder and decoder. + """ + def __init__( + self, + in_channels: int, + model_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.window_size = window_size + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.attn_mode = attn_mode + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.qk_rms_norm = qk_rms_norm + self.dtype = torch.float16 if use_fp16 else torch.float32 + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + self.input_layer = sp.SparseLinear(in_channels, model_channels) + self.blocks = nn.ModuleList([ + SparseTransformerBlock( + model_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + qk_rms_norm=self.qk_rms_norm, + ) + for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) + ]) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + # self.blocks.apply(convert_module_to_f16) + self.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor, factor: float = None) -> sp.SparseTensor: + h = self.input_layer(x) + if self.pe_mode == "ape": + h = h + self.pos_embedder(x.coords[:, 1:], factor) + h = h.type(self.dtype) + for block in self.blocks: + h = block(h) + return h \ No newline at end of file diff --git a/direct3d_s2/models/autoencoders/decoder.py b/direct3d_s2/models/autoencoders/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..209cccf24f80dfcd962610c31811dadfeaf50812 --- /dev/null +++ b/direct3d_s2/models/autoencoders/decoder.py @@ -0,0 +1,353 @@ +from typing import * +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from .base import SparseTransformerBase + + +class SparseSubdivideBlock3d(nn.Module): + + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.act_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, padding=1), + sp.SparseSiLU() + ) + + self.sub = sp.SparseSubdivide() + + self.out_layers = nn.Sequential( + sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1), + sp.SparseSiLU(), + ) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.act_layers(x) + h = self.sub(h) + h = self.out_layers(h) + return h + + def forward(self, x: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseSDFDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + out_channels: int = 1, + chunk_size: int = 1, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self.out_channels = out_channels + self.chunk_size = chunk_size + self.upsample = nn.ModuleList([ + SparseSubdivideBlock3d( + channels=model_channels, + out_channels=model_channels // 4, + use_checkpoint=use_checkpoint, + ), + SparseSubdivideBlock3d( + channels=model_channels // 4, + out_channels=model_channels // 8, + use_checkpoint=use_checkpoint, + ), + SparseSubdivideBlock3d( + channels=model_channels // 8, + out_channels=model_channels // 16, + use_checkpoint=use_checkpoint, + ) + ]) + + self.out_layer = sp.SparseLinear(model_channels // 16, self.out_channels) + self.out_active = sp.SparseTanh() + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + super().convert_to_fp16() + self.upsample.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + super().convert_to_fp32() + self.upsample.apply(convert_module_to_f32) + + @torch.no_grad() + def split_for_meshing(self, x: sp.SparseTensor, chunk_size=4, padding=4): + + sub_resolution = self.resolution // chunk_size + upsample_ratio = 8 # hard-coded here + assert sub_resolution % padding == 0 + out = [] + + for i in range(chunk_size): + for j in range(chunk_size): + for k in range(chunk_size): + # Calculate padded boundaries + start_x = max(0, i * sub_resolution - padding) + end_x = min((i + 1) * sub_resolution + padding, self.resolution) + start_y = max(0, j * sub_resolution - padding) + end_y = min((j + 1) * sub_resolution + padding, self.resolution) + start_z = max(0, k * sub_resolution - padding) + end_z = min((k + 1) * sub_resolution + padding, self.resolution) + + # Store original (unpadded) boundaries for later cropping + orig_start_x = i * sub_resolution + orig_end_x = (i + 1) * sub_resolution + orig_start_y = j * sub_resolution + orig_end_y = (j + 1) * sub_resolution + orig_start_z = k * sub_resolution + orig_end_z = (k + 1) * sub_resolution + + mask = torch.logical_and( + torch.logical_and( + torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x), + torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y) + ), + torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z) + ) + + if mask.sum() > 0: + # Get the coordinates and shift them to local space + coords = x.coords[mask].clone() + # Shift to local coordinates + coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z], + device=coords.device).view(1, 3) + + chunk_tensor = sp.SparseTensor(x.feats[mask], coords) + # Store the boundaries and offsets as metadata for later reconstruction + chunk_tensor.bounds = { + 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)), + 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction + } + out.append(chunk_tensor) + + del mask + torch.cuda.empty_cache() + return out + + @torch.no_grad() + def split_single_chunk(self, x: sp.SparseTensor, chunk_size=4, padding=4): + sub_resolution = self.resolution // chunk_size + upsample_ratio = 8 # hard-coded here + assert sub_resolution % padding == 0 + + mask_sum = -1 + while mask_sum < 1: + orig_start_x = random.randint(0, self.resolution - sub_resolution) + orig_end_x = orig_start_x + sub_resolution + orig_start_y = random.randint(0, self.resolution - sub_resolution) + orig_end_y = orig_start_y + sub_resolution + orig_start_z = random.randint(0, self.resolution - sub_resolution) + orig_end_z = orig_start_z + sub_resolution + start_x = max(0, orig_start_x - padding) + end_x = min(orig_end_x + padding, self.resolution) + start_y = max(0, orig_start_y - padding) + end_y = min(orig_end_y + padding, self.resolution) + start_z = max(0, orig_start_z - padding) + end_z = min(orig_end_z + padding, self.resolution) + + mask_ori = torch.logical_and( + torch.logical_and( + torch.logical_and(x.coords[:, 1] >= orig_start_x, x.coords[:, 1] < orig_end_x), + torch.logical_and(x.coords[:, 2] >= orig_start_y, x.coords[:, 2] < orig_end_y) + ), + torch.logical_and(x.coords[:, 3] >= orig_start_z, x.coords[:, 3] < orig_end_z) + ) + mask_sum = mask_ori.sum() + + # Store the boundaries and offsets as metadata for later reconstruction + bounds = { + 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)), + 'start': (start_x, end_x, start_y, end_y, start_z, end_z), + 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction + } + return bounds + + def forward_single_chunk(self, x: sp.SparseTensor, padding=4): + + bounds = self.split_single_chunk(x, self.chunk_size, padding=padding) + + start_x, end_x, start_y, end_y, start_z, end_z = bounds['start'] + mask = torch.logical_and( + torch.logical_and( + torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x), + torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y) + ), + torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z) + ) + + # Shift to local coordinates + coords = x.coords.clone() + coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z], + device=coords.device).view(1, 3) + + chunk = sp.SparseTensor(x.feats[mask], coords[mask]) + + chunk_result = self.upsamples(chunk) + + coords = chunk_result.coords.clone() + + # Restore global coordinates + offsets = torch.tensor(bounds['offsets'], + device=coords.device).view(1, 3) + coords[:, 1:] = coords[:, 1:] + offsets + + # Filter points within original bounds + original = bounds['original'] + within_bounds = torch.logical_and( + torch.logical_and( + torch.logical_and( + coords[:, 1] >= original[0], + coords[:, 1] < original[1] + ), + torch.logical_and( + coords[:, 2] >= original[2], + coords[:, 2] < original[3] + ) + ), + torch.logical_and( + coords[:, 3] >= original[4], + coords[:, 3] < original[5] + ) + ) + + final_coords = coords[within_bounds] + final_feats = chunk_result.feats[within_bounds] + + return sp.SparseTensor(final_feats, final_coords) + + def upsamples(self, x, return_feat: bool = False): + dtype = x.dtype + for block in self.upsample: + x = block(x) + x = x.type(dtype) + + output = self.out_active(self.out_layer(x)) + + if return_feat: + return output, x + else: + return output + + def forward(self, x: sp.SparseTensor, factor: float = None, return_feat: bool = False): + h = super().forward(x, factor) + if self.chunk_size <= 1: + for block in self.upsample: + h = block(h) + h = h.type(x.dtype) + + if return_feat: + return self.out_active(self.out_layer(h)), h + + h = self.out_layer(h) + h = self.out_active(h) + return h + else: + if self.training: + return self.forward_single_chunk(h) + else: + batch_size = x.shape[0] + chunks = self.split_for_meshing(h, chunk_size=self.chunk_size) + all_coords, all_feats = [], [] + for chunk_idx, chunk in enumerate(chunks): + chunk_result = self.upsamples(chunk) + + for b in range(batch_size): + mask = torch.nonzero(chunk_result.coords[:, 0] == b).squeeze(-1) + if mask.numel() > 0: + coords = chunk_result.coords[mask].clone() + + # Restore global coordinates + offsets = torch.tensor(chunk.bounds['offsets'], + device=coords.device).view(1, 3) + coords[:, 1:] = coords[:, 1:] + offsets + + # Filter points within original bounds + bounds = chunk.bounds['original'] + within_bounds = torch.logical_and( + torch.logical_and( + torch.logical_and( + coords[:, 1] >= bounds[0], + coords[:, 1] < bounds[1] + ), + torch.logical_and( + coords[:, 2] >= bounds[2], + coords[:, 2] < bounds[3] + ) + ), + torch.logical_and( + coords[:, 3] >= bounds[4], + coords[:, 3] < bounds[5] + ) + ) + + if within_bounds.any(): + all_coords.append(coords[within_bounds]) + all_feats.append(chunk_result.feats[mask][within_bounds]) + + if not self.training: + torch.cuda.empty_cache() + + final_coords = torch.cat(all_coords) + final_feats = torch.cat(all_feats) + + return sp.SparseTensor(final_feats, final_coords) + \ No newline at end of file diff --git a/direct3d_s2/models/autoencoders/dense_vae.py b/direct3d_s2/models/autoencoders/dense_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..f7aa0cfdc8a307b65c3ed2746ee2e4a5f4537c45 --- /dev/null +++ b/direct3d_s2/models/autoencoders/dense_vae.py @@ -0,0 +1,401 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import trimesh +from skimage import measure +from ...modules.norm import GroupNorm32, ChannelLayerNorm32 +from ...modules.spatial import pixel_shuffle_3d +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from .distributions import DiagonalGaussianDistribution + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + use_checkpoint: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.use_checkpoint = use_checkpoint + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + DownsampleBlock3d(ch, channels[i+1]) + ) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ]) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = self.out_layer(h) + + return h + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + use_checkpoint: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.use_checkpoint = use_checkpoint + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + # self.blocks.apply(convert_module_to_f16) + # self.middle_block.apply(convert_module_to_f16) + self.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = self.out_layer(h) + return h + + +class DenseShapeVAE(nn.Module): + def __init__(self, + embed_dim: int = 0, + model_channels_encoder: list = [32, 128, 512], + model_channels_decoder: list = [512, 128, 32], + num_res_blocks_encoder: int = 2, + num_res_blocks_middle_encoder: int = 2, + num_res_blocks_decoder: int = 2, + num_res_blocks_middle_decoder: int=2, + in_channels: int = 1, + out_channels: int = 1, + use_fp16: bool = False, + use_checkpoint: bool = False, + latents_scale: float = 1.0, + latents_shift: float = 0.0): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.latents_scale = latents_scale + self.latents_shift = latents_shift + + self.encoder = SparseStructureEncoder( + in_channels=in_channels, + latent_channels=embed_dim, + num_res_blocks=num_res_blocks_encoder, + channels=model_channels_encoder, + num_res_blocks_middle=num_res_blocks_middle_encoder, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + ) + + self.decoder = SparseStructureDecoder( + num_res_blocks=num_res_blocks_decoder, + num_res_blocks_middle=num_res_blocks_middle_decoder, + channels=model_channels_decoder, + latent_channels=embed_dim, + out_channels=out_channels, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + ) + + self.embed_dim = embed_dim + + def encode(self, batch, sample_posterior: bool = True): + + x = batch['dense_index'] * 2.0 - 1.0 + h = self.encoder(x) + posterior = DiagonalGaussianDistribution(h, feat_dim=1) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + + return z, posterior + + def forward(self, batch): + + z, posterior = self.encode(batch) + reconst_x = self.decoder(z) + outputs = {'reconst_x': reconst_x, 'posterior': posterior} + + return outputs + + def decode_mesh(self, + latents, + voxel_resolution: int = 64, + mc_threshold: float = 0.5, + return_index: bool = False): + x = self.decoder(latents) + if return_index: + outputs = [] + for i in range(len(x)): + occ = x[i].sigmoid() + occ = (occ >= mc_threshold).float().squeeze(0) + index = occ.unsqueeze(0).nonzero() + outputs.append(index) + else: + outputs = self.dense2mesh(x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold) + + return outputs + + def dense2mesh(self, + x: torch.FloatTensor, + voxel_resolution: int = 64, + mc_threshold: float = 0.5): + + meshes = [] + for i in range(len(x)): + occ = x[i].sigmoid() + occ = (occ >= 0.1).float().squeeze(0).cpu().detach().numpy() + vertices, faces, _, _ = measure.marching_cubes( + occ, + mc_threshold, + method="lewiner", + ) + vertices = vertices / voxel_resolution * 2 - 1 + meshes.append(trimesh.Trimesh(vertices, faces)) + + return meshes diff --git a/direct3d_s2/models/autoencoders/distributions.py b/direct3d_s2/models/autoencoders/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..a702359a8d2dedcb28c2e654bb60221af7f72c8f --- /dev/null +++ b/direct3d_s2/models/autoencoders/distributions.py @@ -0,0 +1,51 @@ +import torch +import numpy as np +from typing import Union, List + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): + self.feat_dim = feat_dim + self.parameters = parameters + + if isinstance(parameters, list): + self.mean = parameters[0] + self.logvar = parameters[1] + else: + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean) + + def sample(self): + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.mean(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=dims) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=dims) + + def nll(self, sample, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean diff --git a/direct3d_s2/models/autoencoders/encoder.py b/direct3d_s2/models/autoencoders/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..49fe567ff4dd1b62bae84c8627b27502ec4cdd09 --- /dev/null +++ b/direct3d_s2/models/autoencoders/encoder.py @@ -0,0 +1,133 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .base import SparseTransformerBase + + +class SparseDownBlock3d(nn.Module): + + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + num_groups: int = 32, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseGroupNorm32(num_groups, channels), + sp.SparseSiLU() + ) + + self.down = sp.SparseDownsample(2) + self.out_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, padding=1), + sp.SparseGroupNorm32(num_groups, self.out_channels), + sp.SparseSiLU(), + sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1) + + self.use_checkpoint = use_checkpoint + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.act_layers(x) + h = self.down(h) + x = self.down(x) + h = self.out_layers(h) + h = h + self.skip_connection(x) + return h + + def forward(self, x: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseSDFEncoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__( + in_channels=in_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + + self.input_layer1 = sp.SparseLinear(1, model_channels // 16) + + self.downsample = nn.ModuleList([ + SparseDownBlock3d( + channels=model_channels//16, + out_channels=model_channels // 8, + use_checkpoint=use_checkpoint, + ), + SparseDownBlock3d( + channels=model_channels // 8, + out_channels=model_channels // 4, + use_checkpoint=use_checkpoint, + ), + SparseDownBlock3d( + channels=model_channels // 4, + out_channels=model_channels, + use_checkpoint=use_checkpoint, + ) + ]) + + self.resolution = resolution + self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, factor: float = None): + + x = self.input_layer1(x) + for block in self.downsample: + x = block(x) + h = super().forward(x, factor) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + + return h \ No newline at end of file diff --git a/direct3d_s2/models/autoencoders/ss_vae.py b/direct3d_s2/models/autoencoders/ss_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..b32fd20767df4f90d7d5bc55b2b393676e8e47a3 --- /dev/null +++ b/direct3d_s2/models/autoencoders/ss_vae.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F + +import trimesh +from skimage import measure + +from ...modules import sparse as sp +from .encoder import SparseSDFEncoder +from .decoder import SparseSDFDecoder +from .distributions import DiagonalGaussianDistribution + + +class SparseSDFVAE(nn.Module): + def __init__(self, *, + embed_dim: int = 0, + resolution: int = 64, + model_channels_encoder: int = 512, + num_blocks_encoder: int = 4, + num_heads_encoder: int = 8, + num_head_channels_encoder: int = 64, + model_channels_decoder: int = 512, + num_blocks_decoder: int = 4, + num_heads_decoder: int = 8, + num_head_channels_decoder: int = 64, + out_channels: int = 1, + use_fp16: bool = False, + use_checkpoint: bool = False, + chunk_size: int = 1, + latents_scale: float = 1.0, + latents_shift: float = 0.0): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.resolution = resolution + self.latents_scale = latents_scale + self.latents_shift = latents_shift + + self.encoder = SparseSDFEncoder( + resolution=resolution, + in_channels=model_channels_encoder, + model_channels=model_channels_encoder, + latent_channels=embed_dim, + num_blocks=num_blocks_encoder, + num_heads=num_heads_encoder, + num_head_channels=num_head_channels_encoder, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + ) + + self.decoder = SparseSDFDecoder( + resolution=resolution, + model_channels=model_channels_decoder, + latent_channels=embed_dim, + num_blocks=num_blocks_decoder, + num_heads=num_heads_decoder, + num_head_channels=num_head_channels_decoder, + out_channels=out_channels, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + chunk_size=chunk_size, + ) + self.embed_dim = embed_dim + + def forward(self, batch): + + z, posterior = self.encode(batch) + + reconst_x = self.decoder(z) + outputs = {'reconst_x': reconst_x, 'posterior': posterior} + return outputs + + def encode(self, batch, sample_posterior: bool = True): + + feat, xyz, batch_idx = batch['sparse_sdf'], batch['sparse_index'], batch['batch_idx'] + if feat.ndim == 1: + feat = feat.unsqueeze(-1) + coords = torch.cat([batch_idx.unsqueeze(-1), xyz], dim=-1).int() + + x = sp.SparseTensor(feat, coords) + h = self.encoder(x, batch.get('factor', None)) + posterior = DiagonalGaussianDistribution(h.feats, feat_dim=1) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + z = h.replace(z) + + return z, posterior + + def decode_mesh(self, + latents, + voxel_resolution: int = 512, + mc_threshold: float = 0.2, + return_feat: bool = False, + factor: float = 1.0): + voxel_resolution = int(voxel_resolution / factor) + reconst_x = self.decoder(latents, factor=factor, return_feat=return_feat) + if return_feat: + return reconst_x + outputs = self.sparse2mesh(reconst_x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold) + + return outputs + + def sparse2mesh(self, + reconst_x: torch.FloatTensor, + voxel_resolution: int = 512, + mc_threshold: float = 0.0): + + sparse_sdf, sparse_index = reconst_x.feats.float(), reconst_x.coords + batch_size = int(sparse_index[..., 0].max().cpu().numpy() + 1) + + meshes = [] + for i in range(batch_size): + idx = sparse_index[..., 0] == i + sparse_sdf_i, sparse_index_i = sparse_sdf[idx].squeeze(-1).cpu(), sparse_index[idx][..., 1:].detach().cpu() + sdf = torch.ones((voxel_resolution, voxel_resolution, voxel_resolution)) + sdf[sparse_index_i[..., 0], sparse_index_i[..., 1], sparse_index_i[..., 2]] = sparse_sdf_i + vertices, faces, _, _ = measure.marching_cubes( + sdf.numpy(), + mc_threshold, + method="lewiner", + ) + vertices = vertices / voxel_resolution * 2 - 1 + meshes.append(trimesh.Trimesh(vertices, faces)) + + return meshes diff --git a/direct3d_s2/models/conditioner.py b/direct3d_s2/models/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..55bcee7c5a737c8a5bfc54b34eb84b0da83b54b0 --- /dev/null +++ b/direct3d_s2/models/conditioner.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms + + +class DinoEncoder(nn.Module): + + def __init__( + self, + model="facebookresearch/dinov2", + version="dinov2_vitl14_reg", + size=518, + ): + super().__init__() + + dino_model = torch.hub.load(model, version, pretrained=True) + dino_model = dino_model.eval() + self.encoder = dino_model + self.transform = transforms.Compose( + [ + transforms.Resize(size, transforms.InterpolationMode.BILINEAR, antialias=True), + transforms.CenterCrop(size), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] + ) + + def forward(self, image, image_mask=None): + + z = self.encoder(self.transform(image), is_training=True)['x_prenorm'] + z = F.layer_norm(z, z.shape[-1:]) + + if image_mask is not None: + image_mask_patch = F.max_pool2d(image_mask, kernel_size=14, stride=14).squeeze(1) > 0 + return z, image_mask_patch + + return z diff --git a/direct3d_s2/models/refiner/__pycache__/unet3d.cpython-310.pyc b/direct3d_s2/models/refiner/__pycache__/unet3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e2fe6c6d663e72c4d7838040802d2f39100969b Binary files /dev/null and b/direct3d_s2/models/refiner/__pycache__/unet3d.cpython-310.pyc differ diff --git a/direct3d_s2/models/refiner/__pycache__/unet_refiner.cpython-310.pyc b/direct3d_s2/models/refiner/__pycache__/unet_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7f2a773caa62efda28f97306284bb4536aa4bbb Binary files /dev/null and b/direct3d_s2/models/refiner/__pycache__/unet_refiner.cpython-310.pyc differ diff --git a/direct3d_s2/models/refiner/unet3d.py b/direct3d_s2/models/refiner/unet3d.py new file mode 100644 index 0000000000000000000000000000000000000000..42c7228576fec4fc1cf8f0688b44cef22ab16fb4 --- /dev/null +++ b/direct3d_s2/models/refiner/unet3d.py @@ -0,0 +1,640 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), +} + +def get_activation(act_fn: str) -> nn.Module: + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: Optional[int] = None, + downsample_padding: Optional[int] = None, + dropout: float = 0.0, +) -> Union[ + "DownBlock3D", +]: + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + dropout=dropout, + ) + + raise ValueError(f"{down_block_type} does not exist.") + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: Optional[int] = None, + dropout: float = 0.0, +) -> Union[ + "UpBlock3D", +]: + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + dropout=dropout, + ) + raise ValueError(f"{up_block_type} does not exist.") + +class Downsample3D(nn.Module): + + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + kernel_size=2, + bias=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + stride = 2 + + self.conv = nn.Conv3d( + self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias + ) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + assert hidden_states.shape[1] == self.channels + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states) + + return hidden_states + +class Upsample3D(nn.Module): + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = True, + out_channels: Optional[int] = None, + name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + bias=True, + interpolate=False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.interpolate = interpolate + + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose3d( + channels, self.out_channels, kernel_size=2, stride=2, padding=0, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + conv = nn.Conv3d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None) -> torch.Tensor: + + assert hidden_states.shape[1] == self.channels + + + if self.use_conv_transpose: + return self.conv(hidden_states) + + if hidden_states.shape[0] >= 64 or hidden_states.shape[-1] >= 64: + hidden_states = hidden_states.contiguous() + + if self.interpolate: + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + +class ResnetBlock3D(nn.Module): + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = nn.Conv3d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + self.upsample = Upsample3D(in_channels, use_conv=False) + elif self.down: + self.downsample = Downsample3D(in_channels) + + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv3d( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + + hidden_states = input_tensor + dtype = hidden_states.dtype + hidden_states = self.norm1(hidden_states.float()).to(dtype) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states.float()).to(dtype) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, + out_channels=out_channels + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + upsample_size: Optional[int] = None, + ) -> torch.Tensor: + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 2, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + use_linear_projection: bool = True, + ): + super().__init__() + + self.has_cross_attention = True + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + + for _ in range(num_layers): + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states + +def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + +class UNet3DModel(nn.Module): + + def __init__( + self, + in_channels: int = 4, + out_channels: int = 4, + use_conv_out: bool=True, + down_block_types: Tuple[str, ...] = ( + "DownBlock3D", + "DownBlock3D", + "DownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "UpBlock3D", + "UpBlock3D", + "UpBlock3D", + "UpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (4, 16, 64,256), + layers_per_block: int = 4, + layers_mid_block: int=4, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 4, + norm_eps: float = 1e-5, + use_checkpoint: bool = True, + ): + super().__init__() + + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv3d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample_padding=downsample_padding, + ) + self.down_blocks.append(down_block) + + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + num_layers=layers_mid_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + ) + + self.num_upsamplers = 0 + + reversed_block_out_channels = list(reversed(block_out_channels)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = get_activation("silu") + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + if use_conv_out: + self.conv_out = nn.Conv3d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + else: + self.conv_out = None + + self.use_checkpoint = use_checkpoint + + def forward( + self, + sample: torch.Tensor, + ) : + + default_overall_up_factor = 2**self.num_upsamplers + + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + forward_upsample_size = True + + sample = self.conv_in(sample) + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if self.use_checkpoint: + sample, res_samples = torch.utils.checkpoint.checkpoint(downsample_block, sample, use_reentrant=False) + else: + sample, res_samples = downsample_block(hidden_states=sample) + + down_block_res_samples += res_samples + + if self.mid_block is not None: + if self.use_checkpoint: + sample = torch.utils.checkpoint.checkpoint(self.mid_block, sample, use_reentrant=False) + else: + sample = self.mid_block(sample) + + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if self.use_checkpoint: + sample = torch.utils.checkpoint.checkpoint(upsample_block, (sample, res_samples, upsample_size), use_reentrant=False) + else: + sample = upsample_block( + hidden_states=sample, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + if self.conv_norm_out: + dtype = sample.dtype + sample = self.conv_norm_out(sample.float()).to(dtype) + sample = self.conv_act(sample) + if self.conv_out!=None: + sample = self.conv_out(sample) + + return F.tanh(sample)*2 + else: + return sample diff --git a/direct3d_s2/models/refiner/unet_refiner.py b/direct3d_s2/models/refiner/unet_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..056ddfc7d4a987e02160a73060f60d27893a1cd3 --- /dev/null +++ b/direct3d_s2/models/refiner/unet_refiner.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +import itertools +import torch +import torch.nn as nn +import torch.nn.functional as F +from .unet3d import UNet3DModel +import trimesh +from tqdm import tqdm +from skimage import measure +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 + + +def adaptive_conv(inputs,weights): + padding = (1, 1, 1, 1, 1, 1) + padded_input = F.pad(inputs, padding, mode="constant", value=0) + output = torch.zeros_like(inputs) + size=inputs.shape[-1] + for i in range(3): + for j in range(3): + for k in range(3): + output=output+padded_input[:,:,i:i+size,j:j+size,k:k+size]*weights[:,i*9+j*3+k:i*9+j*3+k+1] + return output + +def adaptive_block(inputs,conv,weights_=None): + if weights_ != None: + weights = conv(weights_) + else: + weights = conv(inputs) + weights = F.normalize(weights, dim=1, p=1) + for i in range(3): + inputs = adaptive_conv(inputs, weights) + return inputs + +class GeoDecoder(nn.Module): + + def __init__(self, + n_features: int, + hidden_dim: int = 32, + num_layers: int = 4, + use_sdf: bool = False, + activation: nn.Module = nn.ReLU): + super().__init__() + self.use_sdf=use_sdf + self.net = nn.Sequential( + nn.Linear(n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 8), + ) + + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + x = self.net(x) + return x + + +class Voxel_RefinerXL(nn.Module): + def __init__(self, + in_channels: int = 1, + out_channels: int = 1, + layers_per_block: int = 2, + layers_mid_block: int = 2, + patch_size: int = 192, + res: int = 512, + use_checkpoint: bool=False, + use_fp16: bool = False): + + super().__init__() + + self.unet3d1 = UNet3DModel(in_channels=16, out_channels=8, use_conv_out=False, + layers_per_block=layers_per_block, layers_mid_block=layers_mid_block, + block_out_channels=(8, 32, 128,512), norm_num_groups=4, use_checkpoint=use_checkpoint) + self.conv_in = nn.Conv3d(in_channels, 8, kernel_size=3, padding=1) + self.latent_mlp = GeoDecoder(32) + self.adaptive_conv1 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False)) + self.adaptive_conv2 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False)) + self.adaptive_conv3 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False)) + self.mid_conv = nn.Conv3d(8, 8, kernel_size=3, padding=1) + self.conv_out = nn.Conv3d(8, out_channels, kernel_size=3, padding=1) + self.patch_size = patch_size + self.res = res + + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + if use_fp16: + self.convert_to_fp16() + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + # self.blocks.apply(convert_module_to_f16) + self.apply(convert_module_to_f16) + + def run(self, + reconst_x, + feat, + mc_threshold=0, + ): + batch_size = int(reconst_x.coords[..., 0].max()) + 1 + sparse_sdf, sparse_index = reconst_x.feats, reconst_x.coords + sparse_feat = feat.feats + device = sparse_sdf.device + dtype = sparse_sdf.dtype + res = self.res + + sdfs = [] + for i in range(batch_size): + idx = sparse_index[..., 0] == i + sparse_sdf_i, sparse_index_i = sparse_sdf[idx].squeeze(-1), sparse_index[idx][..., 1:] + sdf = torch.ones((res, res, res)).to(device).to(dtype) + sdf[sparse_index_i[..., 0], sparse_index_i[..., 1], sparse_index_i[..., 2]] = sparse_sdf_i + sdfs.append(sdf.unsqueeze(0)) + + sdfs = torch.stack(sdfs, dim=0) + feats = torch.zeros((batch_size, sparse_feat.shape[-1], res, res, res), + device=device, dtype=dtype) + feats[sparse_index[...,0],:,sparse_index[...,1],sparse_index[...,2],sparse_index[...,3]] = sparse_feat + + N = sdfs.shape[0] + outputs = torch.ones([N,1,res,res,res], dtype=dtype, device=device) + stride = 160 + patch_size = self.patch_size + step = 3 + sdfs = sdfs.to(dtype) + feats = feats.to(dtype) + patchs=[] + for i in range(step): + for j in range(step): + for k in tqdm(range(step)): + sdf = sdfs[:, :, stride * i: stride * i + patch_size, + stride * j: stride * j + patch_size, + stride * k: stride * k + patch_size] + crop_feats = feats[:, :, stride * i: stride * i + patch_size, + stride * j: stride * j + patch_size, + stride * k: stride * k + patch_size] + inputs = self.conv_in(sdf) + crop_feats = self.latent_mlp(crop_feats.permute(0,2,3,4,1)).permute(0,4,1,2,3) + inputs = torch.cat([inputs, crop_feats],dim=1) + mid_feat = self.unet3d1(inputs) + mid_feat = adaptive_block(mid_feat, self.adaptive_conv1) + mid_feat = self.mid_conv(mid_feat) + mid_feat = adaptive_block(mid_feat, self.adaptive_conv2) + final_feat = self.conv_out(mid_feat) + final_feat = adaptive_block(final_feat, self.adaptive_conv3, weights_=mid_feat) + output = F.tanh(final_feat) + patchs.append(output) + weights = torch.linspace(0, 1, steps=32, device=device, dtype=dtype) + lines=[] + for i in range(9): + out1 = patchs[i * 3] + out2 = patchs[i * 3 + 1] + out3 = patchs[i * 3 + 2] + line = torch.ones([N, 1, 192, 192,res], dtype=dtype, device=device) * 2 + line[:, :, :, :, :160] = out1[:, :, :, :, :160] + line[:, :, :, :, 192:320] = out2[:, :, :, :, 32:160] + line[:, :, :, :, 352:] = out3[:, :, :, :, 32:] + + line[:,:,:,:,160:192] = out1[:,:,:,:,160:] * (1-weights.reshape(1,1,1,1,-1)) + out2[:,:,:,:,:32] * weights.reshape(1,1,1,1,-1) + line[:,:,:,:,320:352] = out2[:,:,:,:,160:] * (1-weights.reshape(1,1,1,1,-1)) + out3[:,:,:,:,:32] * weights.reshape(1,1,1,1,-1) + lines.append(line) + layers=[] + for i in range(3): + line1 = lines[i*3] + line2 = lines[i*3+1] + line3 = lines[i*3+2] + layer = torch.ones([N,1,192,res,res], device=device, dtype=dtype) * 2 + layer[:,:,:,:160] = line1[:,:,:,:160] + layer[:,:,:,192:320] = line2[:,:,:,32:160] + layer[:,:,:,352:] = line3[:,:,:,32:] + layer[:,:,:,160:192] = line1[:,:,:,160:]*(1-weights.reshape(1,1,1,-1,1))+line2[:,:,:,:32]*weights.reshape(1,1,1,-1,1) + layer[:,:,:,320:352] = line2[:,:,:,160:]*(1-weights.reshape(1,1,1,-1,1))+line3[:,:,:,:32]*weights.reshape(1,1,1,-1,1) + layers.append(layer) + outputs[:,:,:160] = layers[0][:,:,:160] + outputs[:,:,192:320] = layers[1][:,:,32:160] + outputs[:,:,352:] = layers[2][:,:,32:] + outputs[:,:,160:192] = layers[0][:,:,160:]*(1-weights.reshape(1,1,-1,1,1))+layers[1][:,:,:32]*weights.reshape(1,1,-1,1,1) + outputs[:,:,320:352] = layers[1][:,:,160:]*(1-weights.reshape(1,1,-1,1,1))+layers[2][:,:,:32]*weights.reshape(1,1,-1,1,1) + # outputs = -outputs + + meshes = [] + for i in range(outputs.shape[0]): + vertices, faces, _, _ = measure.marching_cubes(outputs[i, 0].cpu().numpy(), level=mc_threshold, method='lewiner') + vertices = vertices / res * 2 - 1 + meshes.append(trimesh.Trimesh(vertices, faces)) + + return meshes + + diff --git a/direct3d_s2/models/transformers/__pycache__/dense_dit.cpython-310.pyc b/direct3d_s2/models/transformers/__pycache__/dense_dit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0299464b4e9143f416b13a762160da2a37543b00 Binary files /dev/null and b/direct3d_s2/models/transformers/__pycache__/dense_dit.cpython-310.pyc differ diff --git a/direct3d_s2/models/transformers/__pycache__/sparse_dit.cpython-310.pyc b/direct3d_s2/models/transformers/__pycache__/sparse_dit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d2eb4b0aab5a6dc2b6e5ba116fe4f0a5810d771 Binary files /dev/null and b/direct3d_s2/models/transformers/__pycache__/sparse_dit.cpython-310.pyc differ diff --git a/direct3d_s2/models/transformers/dense_dit.py b/direct3d_s2/models/transformers/dense_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..f19ef8511dc0d3fe5c561cffe01376e9eed713fc --- /dev/null +++ b/direct3d_s2/models/transformers/dense_dit.py @@ -0,0 +1,203 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 +from ...modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock +from ...modules.spatial import patchify, unpatchify + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_freq = t_freq.to(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class DenseDiT(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + latent_shape: list = [8, 16, 16, 16], + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.latent_shape = latent_shape + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + self.register_buffer("pos_emb", pos_emb) + + self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + # self.blocks.apply(convert_module_to_f16) + self.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = patchify(x, self.patch_size) + h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous() + h = self.input_layer(h) + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + h = h.type(self.dtype) + cond = cond.type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + h = h.type(x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3) + h = unpatchify(h, self.patch_size).contiguous() + + return h \ No newline at end of file diff --git a/direct3d_s2/models/transformers/sparse_dit.py b/direct3d_s2/models/transformers/sparse_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb6c2e32b54214b621581a268273c0960a9c940 --- /dev/null +++ b/direct3d_s2/models/transformers/sparse_dit.py @@ -0,0 +1,171 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ...modules.transformer import AbsolutePositionEmbedder +from ...modules import sparse as sp +from ...modules.sparse.transformer.modulated import ModulatedSparseTransformerCrossBlock +from .dense_dit import TimestepEmbedder + + +class SparseDiT(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + num_kv_heads: Optional[int] = 2, + compression_block_size: int = 4, + selection_block_size: int = 8, + topk: int = 8, + compression_version: str = 'v2', + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + sparse_conditions: bool = False, + factor: float = 1.0, + window_size: Optional[int] = 8, + use_shift: bool = True, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.sparse_conditions = sparse_conditions + self.factor = factor + self.compression_block_size = compression_block_size + self.selection_block_size = selection_block_size + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if sparse_conditions: + self.cond_proj = sp.SparseLinear(cond_channels, cond_channels) + self.pos_embedder_cond = AbsolutePositionEmbedder(model_channels, in_channels=3) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + num_kv_heads=num_kv_heads, + compression_block_size=compression_block_size, + selection_block_size=selection_block_size, + topk=topk, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + compression_version=compression_version, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + resolution=resolution, + window_size=window_size, + shift_window=window_size // 2 * (_ % 2) if use_shift else window_size // 2, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = sp.SparseLinear(model_channels, out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + # self.blocks.apply(convert_module_to_f16) + self.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: Union[torch.Tensor, sp.SparseTensor]) -> sp.SparseTensor: + h = self.input_layer(x).type(self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + cond = cond.type(self.dtype) + + if self.sparse_conditions: + cond = self.cond_proj(cond) + cond = cond + self.pos_embedder_cond(cond.coords[:, 1:]).type(self.dtype) + if self.pe_mode == "ape": + h = h + self.pos_embedder(h.coords[:, 1:], factor=self.factor).type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h.type(x.dtype)) + return h diff --git a/direct3d_s2/modules/__pycache__/norm.cpython-310.pyc b/direct3d_s2/modules/__pycache__/norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7499560d85c1c87edbe8fd1ebe8a2c81dfbb17f7 Binary files /dev/null and b/direct3d_s2/modules/__pycache__/norm.cpython-310.pyc differ diff --git a/direct3d_s2/modules/__pycache__/spatial.cpython-310.pyc b/direct3d_s2/modules/__pycache__/spatial.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49064a2098551353971c619198fedf94458ad625 Binary files /dev/null and b/direct3d_s2/modules/__pycache__/spatial.cpython-310.pyc differ diff --git a/direct3d_s2/modules/__pycache__/utils.cpython-310.pyc b/direct3d_s2/modules/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2961e1cbf78cb5c2511fd69340e6be1928ef3cd6 Binary files /dev/null and b/direct3d_s2/modules/__pycache__/utils.cpython-310.pyc differ diff --git a/direct3d_s2/modules/attention/__init__.py b/direct3d_s2/modules/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d500ecd8ce72fd4e072ecdd9c008d2ae030e0629 --- /dev/null +++ b/direct3d_s2/modules/attention/__init__.py @@ -0,0 +1,35 @@ +from typing import * +BACKEND = 'flash_attn' +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_sttn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_sttn_debug is not None: + DEBUG = env_sttn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + + +from .full_attn import * +from .modules import * diff --git a/direct3d_s2/modules/attention/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b78deb562dc65c50d91cbd6d09e74f5a4f390dac Binary files /dev/null and b/direct3d_s2/modules/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/modules/attention/__pycache__/full_attn.cpython-310.pyc b/direct3d_s2/modules/attention/__pycache__/full_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a0450d41a7c226b03c5e7c9c476bea7c2b47b28 Binary files /dev/null and b/direct3d_s2/modules/attention/__pycache__/full_attn.cpython-310.pyc differ diff --git a/direct3d_s2/modules/attention/__pycache__/modules.cpython-310.pyc b/direct3d_s2/modules/attention/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82529b982c3e709e0822fe0f2b592edae451777d Binary files /dev/null and b/direct3d_s2/modules/attention/__pycache__/modules.cpython-310.pyc differ diff --git a/direct3d_s2/modules/attention/full_attn.py b/direct3d_s2/modules/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..f94cf46843412c4d349d2d5dcd7277fac938e507 --- /dev/null +++ b/direct3d_s2/modules/attention/full_attn.py @@ -0,0 +1,140 @@ +from typing import * +import torch +import math +from . import DEBUG, BACKEND + +if BACKEND == 'xformers': + import xformers.ops as xops +elif BACKEND == 'flash_attn': + import flash_attn +elif BACKEND == 'sdpa': + from torch.nn.functional import scaled_dot_product_attention as sdpa +elif BACKEND == 'naive': + pass +else: + raise ValueError(f"Unknown attention backend: {BACKEND}") + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if BACKEND == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif BACKEND == 'flash_attn': + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif BACKEND == 'sdpa': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {BACKEND}") + + return out diff --git a/direct3d_s2/modules/attention/modules.py b/direct3d_s2/modules/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe6235c27134f0477e48d3e12de3068c6a500ef --- /dev/null +++ b/direct3d_s2/modules/attention/modules.py @@ -0,0 +1,146 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class RotaryPositionEmbedder(nn.Module): + def __init__(self, hidden_size: int, in_channels: int = 3): + super().__init__() + assert hidden_size % 2 == 0, "Hidden size must be divisible by 2" + self.hidden_size = hidden_size + self.in_channels = in_channels + self.freq_dim = hidden_size // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (sp.SparseTensor): [..., N, D] tensor of queries + k (sp.SparseTensor): [..., N, D] tensor of keys + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + if indices is None: + indices = torch.arange(q.shape[-2], device=q.device) + if len(q.shape) > 2: + indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) + + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[1] < self.hidden_size // 2: + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device), + torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device) + )], dim=-1) + q_embed = self._rotary_embedding(q, phases) + k_embed = self._rotary_embedding(k, phases) + return q_embed, k_embed + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + if self.use_rope: + q, k, v = qkv.unbind(dim=2) + q, k = self.rope(q, k, indices) + qkv = torch.stack([q, k, v], dim=2) + if self.attn_mode == "full": + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=2) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h diff --git a/direct3d_s2/modules/norm.py b/direct3d_s2/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..09035726081fb7afda2c62504d5474cfa483c58f --- /dev/null +++ b/direct3d_s2/modules/norm.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/direct3d_s2/modules/sparse/.ipynb_checkpoints/basic-checkpoint.py b/direct3d_s2/modules/sparse/.ipynb_checkpoints/basic-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8837f44052f6d573d09e3bfb897e659e10516bb5 --- /dev/null +++ b/direct3d_s2/modules/sparse/.ipynb_checkpoints/basic-checkpoint.py @@ -0,0 +1,459 @@ +from typing import * +import torch +import torch.nn as nn +from . import BACKEND, DEBUG +SparseTensorData = None # Lazy import + + +__all__ = [ + 'SparseTensor', + 'sparse_batch_broadcast', + 'sparse_batch_op', + 'sparse_cat', + 'sparse_unbind', +] + + +class SparseTensor: + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + global SparseTensorData + if SparseTensorData is None: + import importlib + if BACKEND == 'torchsparse': + SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif BACKEND == 'spconv': + SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape, layout = args + (None,) * (4 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + if shape is None: + shape = self.__cal_shape(feats, coords) + if layout is None: + layout = self.__cal_layout(coords, shape[0]) + if BACKEND == 'torchsparse': + self.data = SparseTensorData(feats, coords, **kwargs) + elif BACKEND == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1)[1:] + self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs) + self.data._features = feats + elif method_id == 1: + data, shape, layout = args + (None,) * (3 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + self.data = data + if shape is None: + shape = self.__cal_shape(self.feats, self.coords) + if layout is None: + layout = self.__cal_layout(self.coords, shape[0]) + + self._shape = shape + self._layout = layout + self._scale = kwargs.get('scale', (1, 1, 1)) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + @property + def shape(self) -> torch.Size: + return self._shape + + def dim(self) -> int: + return len(self.shape) + + @property + def layout(self) -> List[slice]: + return self._layout + + @property + def feats(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.F + elif BACKEND == 'spconv': + return self.data.features + + @feats.setter + def feats(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.F = value + elif BACKEND == 'spconv': + self.data.features = value + + @property + def coords(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.C + elif BACKEND == 'spconv': + return self.data.indices + + @coords.setter + def coords(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.C = value + elif BACKEND == 'spconv': + self.data.indices = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @overload + def to(self, dtype: torch.dtype) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + + new_feats = self.feats.to(device=device, dtype=dtype) + new_coords = self.coords.to(device=device) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def dense(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.dense() + elif BACKEND == 'spconv': + return self.data.dense() + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + new_shape = [self.shape[0]] + new_shape.extend(feats.shape[1:]) + if BACKEND == 'torchsparse': + new_data = SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif BACKEND == 'spconv': + new_data = SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) + return new_tensor + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __neg__(self) -> 'SparseTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = sparse_batch_broadcast(self, other) + except: + pass + if isinstance(other, SparseTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + coords = [] + feats = [] + for new_idx, old_idx in enumerate(idx): + coords.append(self.coords[self.layout[old_idx]].clone()) + coords[-1][:, 0] = new_idx + feats.append(self.feats[self.layout[old_idx]]) + coords = torch.cat(coords, dim=0).contiguous() + feats = torch.cat(feats, dim=0).contiguous() + return SparseTensor(feats=feats, coords=coords) + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + +def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + coords, feats = input.coords, input.feats + broadcasted = torch.zeros_like(feats) + for k in range(input.shape[0]): + broadcasted[input.layout[k]] = other[k] + return broadcasted + + +def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) + + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/direct3d_s2/modules/sparse/__init__.py b/direct3d_s2/modules/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..036ff1998f29424abb720642cd83a83c6abf750f --- /dev/null +++ b/direct3d_s2/modules/sparse/__init__.py @@ -0,0 +1,105 @@ +from typing import * + +BACKEND = 'torchsparse' +DEBUG = False +ATTN = 'flash_attn' + +def __from_env(): + import os + + global BACKEND + global DEBUG + global ATTN + + env_sparse_backend = os.environ.get('SPARSE_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn is None: + env_sparse_attn = os.environ.get('ATTN_BACKEND') + + if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: + BACKEND = env_sparse_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: + ATTN = env_sparse_attn + + print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") + + +__from_env() + + +def set_backend(backend: Literal['spconv', 'torchsparse']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn(attn: Literal['xformers', 'flash_attn']): + global ATTN + ATTN = attn + + +import importlib + +__attributes = { + 'SparseTensor': 'basic', + 'sparse_batch_broadcast': 'basic', + 'sparse_batch_op': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseSigmoid': 'nonlinearity', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseTanh': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'sparseconv3d_func': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide' : 'spatial' +} + +__submodules = ['transformer'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + import transformer diff --git a/direct3d_s2/modules/sparse/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8f04ec54c69aa7df0aeeaaef71775535bb39ec3 Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/__pycache__/basic.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/basic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f0ddea519a69f53a825d46e59a1160f63c83ff9 Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/basic.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/__pycache__/linear.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/linear.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..907fa3835a412ca530196c7fd18fa983f120a391 Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/linear.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b74cae998582e393318ab29540a9978aa1db82a Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/__pycache__/norm.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9be3e9bb12a48fcb2ed9fc1bd077ae091d22815a Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/norm.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/__pycache__/spatial.cpython-310.pyc b/direct3d_s2/modules/sparse/__pycache__/spatial.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e962d7fe5d46ec6b52ef1b440b257e0e09bbd7de Binary files /dev/null and b/direct3d_s2/modules/sparse/__pycache__/spatial.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/__init__.py b/direct3d_s2/modules/sparse/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2732f12d139e579fb27f224c523e27f1e8cefb --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/__init__.py @@ -0,0 +1,5 @@ +from .full_attn import * +from .serialized_attn import * +from .windowed_attn import * +from .modules import * +from .spatial_sparse_attention.module.spatial_sparse_attention import SpatialSparseAttention diff --git a/direct3d_s2/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bb9bf49687ac68fa570e049d22f87b21a5d0b9c Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..636e11586f86da80ac873a5538be2bb474098cdb Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/__pycache__/modules.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c6687e1746733deb74c38271108c42655893d5c Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/__pycache__/modules.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d102abe0ef2316b2492f990a9c20195485042eb5 Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f697615079d97b18827305e1d37421078e963bd0 Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/full_attn.py b/direct3d_s2/modules/sparse/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e27aeb98419621f3f9999fd3b11eebf2b90a40 --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/full_attn.py @@ -0,0 +1,215 @@ +from typing import * +import torch +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, SparseTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, SparseTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if DEBUG: + if s is not None: + for i in range(s.shape[0]): + assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" + if num_all_args in [2, 3]: + assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" + if num_all_args == 3: + assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" + assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" + + if ATTN == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif ATTN == 'flash_attn': + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + else: + raise ValueError(f"Unknown attention module: {ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/direct3d_s2/modules/sparse/attention/modules.py b/direct3d_s2/modules/sparse/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2fe782b0947700e308e9ec0325e7e91c84e3c2 --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/modules.py @@ -0,0 +1,139 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from ...attention import RotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, SparseTensor): + x = x.replace(F.normalize(x.feats, dim=-1)) + else: + x = F.normalize(x, dim=-1) + return (x * self.gamma * self.scale).to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "serialized", "windowed"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + self.channels = channels + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_sequence = shift_sequence + self.shift_window = shift_window + self.serialize_mode = serialize_mode + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + @staticmethod + def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats + + def _rope(self, qkv: SparseTensor) -> SparseTensor: + q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] + q, k = self.rope(q, k, qkv.coords[:, 1:]) + qkv = qkv.replace(torch.stack([q, k, v], dim=1)) + return qkv + + def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.use_rope: + qkv = self._rope(qkv) + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=1) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "serialized": + h = sparse_serialized_scaled_dot_product_self_attention( + qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window + ) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=1) + k = self.k_rms_norm(k) + kv = kv.replace(torch.stack([k.feats, v.feats], dim=1)) + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h diff --git a/direct3d_s2/modules/sparse/attention/serialized_attn.py b/direct3d_s2/modules/sparse/attention/serialized_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..5950b75b2f5a6d6e79ab6d472b8501aaa5ec4a26 --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/serialized_attn.py @@ -0,0 +1,193 @@ +from typing import * +from enum import Enum +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_serialized_scaled_dot_product_self_attention', +] + + +class SerializeMode(Enum): + Z_ORDER = 0 + Z_ORDER_TRANSPOSED = 1 + HILBERT = 2 + HILBERT_TRANSPOSED = 3 + + +SerializeModes = [ + SerializeMode.Z_ORDER, + SerializeMode.Z_ORDER_TRANSPOSED, + SerializeMode.HILBERT, + SerializeMode.HILBERT_TRANSPOSED +] + + +def calc_serialization( + tensor: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (torch.Tensor, torch.Tensor): Forwards and backwards indices. + """ + fwd_indices = [] + bwd_indices = [] + seq_lens = [] + seq_batch_indices = [] + offsets = [0] + + if 'vox2seq' not in globals(): + import vox2seq + + # Serialize the input + serialize_coords = tensor.coords[:, 1:].clone() + serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) + if serialize_mode == SerializeMode.Z_ORDER: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) + elif serialize_mode == SerializeMode.HILBERT: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) + else: + raise ValueError(f"Unknown serialize mode: {serialize_mode}") + + for bi, s in enumerate(tensor.layout): + num_points = s.stop - s.start + num_windows = (num_points + window_size - 1) // window_size + valid_window_size = num_points / num_windows + to_ordered = torch.argsort(code[s.start:s.stop]) + if num_windows == 1: + fwd_indices.append(to_ordered) + bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) + fwd_indices[-1] += s.start + bwd_indices[-1] += offsets[-1] + seq_lens.append(num_points) + seq_batch_indices.append(bi) + offsets.append(offsets[-1] + seq_lens[-1]) + else: + # Partition the input + offset = 0 + mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] + split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] + bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) + for i in range(num_windows): + mid = mids[i] + valid_start = split[i] + valid_end = split[i + 1] + padded_start = math.floor(mid - 0.5 * window_size) + padded_end = padded_start + window_size + fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) + offset += valid_start - padded_start + bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) + offset += padded_end - valid_start + fwd_indices[-1] += s.start + seq_lens.extend([window_size] * num_windows) + seq_batch_indices.extend([bi] * num_windows) + bwd_indices.append(bwd_index + offsets[-1]) + offsets.append(offsets[-1] + num_windows * window_size) + + fwd_indices = torch.cat(fwd_indices) + bwd_indices = torch.cat(bwd_indices) + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_serialized_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply serialized scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/__init__.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53c0cd4d3e95a487a67f2c4a79227c5d2a852078 Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__init__.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__init__.py @@ -0,0 +1 @@ + diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a157b61e6024f1d194e5f5eea8deb290f86dbc8d Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/compression_block.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/compression_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a742c89f7169dfd9747021ca3ceceda3ef353862 Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/compression_block.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/spatial_sparse_attention.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/spatial_sparse_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56b02b9d9dbfb5ed384762e4a59ffd509e1318d4 Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/spatial_sparse_attention.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/compression_block.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/compression_block.py new file mode 100644 index 0000000000000000000000000000000000000000..579a611862995abfdd3478dad22ba912fdb8c1bb --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/compression_block.py @@ -0,0 +1,65 @@ +import torch.nn as nn +import direct3d_s2.modules.sparse as sp + + +class SparseDownBlock3d_v1(nn.Module): + + def __init__( + self, + channels: int, + out_channels: int = None, + factor: int = 2, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseConv3d(self.out_channels, self.out_channels, 1, padding=0), + sp.SparseSiLU() + ) + self.down = sp.SparseDownsample(factor) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.act_layers(x) + h = self.down(h) + return h + +class SparseDownBlock3d_v2(nn.Module): + + def __init__( + self, + channels: int, + out_channels: int = None, + num_groups: int = 32, + factor: int = 2, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseGroupNorm32(num_groups, channels), + sp.SparseSiLU() + ) + + self.down = sp.SparseDownsample(factor) + self.out_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, padding=1), + sp.SparseGroupNorm32(num_groups, self.out_channels), + sp.SparseSiLU(), + sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.act_layers(x) + h = self.down(h) + x = self.down(x) + h = self.out_layers(h) + h = h + self.skip_connection(x) + return h \ No newline at end of file diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/spatial_sparse_attention.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/spatial_sparse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8cb8977d4b7c4256d6e95cfb3d308449badfa8b9 --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/module/spatial_sparse_attention.py @@ -0,0 +1,155 @@ +import torch +import torch.nn as nn +from einops import rearrange +from flash_attn import flash_attn_varlen_func +from ..ops import ( + spatial_selection_attention, + get_block_score, + sparse_window_attention, +) +from .compression_block import SparseDownBlock3d_v1, SparseDownBlock3d_v2 +import direct3d_s2.modules.sparse as sp + + +class SpatialSparseAttention(torch.nn.Module): + def __init__( + self, + hidden_size: int, + num_q_heads: int, + num_kv_heads: int, + head_dim: int, + compression_block_size: int, + selection_block_size: int, + topk: int, + window_size: int, + shift_window: int, + resolution: int = 64, + compression_version: str = 'v2', + ): + super().__init__() + self.hidden_size = hidden_size + self.num_q_heads = num_q_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.compression_block_size = compression_block_size + self.selection_block_size = selection_block_size + self.topk = topk + self.window_size = window_size + self.shift_window = shift_window + self.resolution = resolution + + # qkv proj and o proj + self.proj_q = sp.SparseLinear( + hidden_size, num_q_heads * head_dim, bias=False + ) + self.proj_k = sp.SparseLinear( + hidden_size, num_kv_heads * head_dim, bias=False + ) + self.proj_v = sp.SparseLinear( + hidden_size, num_kv_heads * head_dim, bias=False + ) + self.proj_o = torch.nn.Linear( + num_q_heads * head_dim, hidden_size, bias=False + ) + + # ssa parameteres + if compression_version == 'v1': + compression_block = SparseDownBlock3d_v1 + elif compression_version == 'v2': + compression_block = SparseDownBlock3d_v2 + else: + raise NotImplementedError('only support v1 or v2 compression block') + self.compression_key = compression_block( + num_kv_heads * head_dim, num_kv_heads * head_dim, factor=compression_block_size + ) + self.compression_value = compression_block( + num_kv_heads * head_dim, num_kv_heads * head_dim, factor=compression_block_size + ) + self.intra_block_pe = torch.nn.Parameter( + torch.zeros(compression_block_size, + compression_block_size, + compression_block_size, + num_kv_heads * head_dim, + ) + ) + + # gate function + self.gate = torch.nn.Sequential( + sp.SparseLinear(hidden_size, 3, bias=False), sp.SparseSigmoid(), + ) + + def sparse3d_compression(self, x, key=True): + _, num_heads, num_dim = x.feats.shape + x = x.replace(x.feats.view(-1, num_heads * num_dim)) + if key: + coords = x.coords + intra_block_coords = coords[..., 1:] % self.compression_block_size + intra_block_pos = self.intra_block_pe[intra_block_coords[:, 0], intra_block_coords[:, 1], intra_block_coords[:, 2]].to(x.dtype) + x = x.replace(x.feats + intra_block_pos) + y = self.compression_key(x) + else: + y = self.compression_value(x) + y = y.replace(y.feats.view(-1, num_heads, num_dim)) + return y + + def forward(self, x: sp.SparseTensor): + # dtype and shape check + assert x.shape[-1] == self.hidden_size + assert self.selection_block_size % self.compression_block_size == 0 + # qkv proj + q = x.replace(self.proj_q(x).feats.view(-1, self.num_q_heads, self.head_dim)) + k = x.replace(self.proj_k(x).feats.view(-1, self.num_kv_heads, self.head_dim)) + v = x.replace(self.proj_v(x).feats.view(-1, self.num_kv_heads, self.head_dim)) + + # compression attention + compressed_k = self.sparse3d_compression(k, key=True) + compressed_v = self.sparse3d_compression(v, key=False) + + compressed_cu_seqlens = torch.tensor([s.start for s in compressed_v.layout] + [s.stop for s in compressed_v.layout if s.stop not in [s.start for s in compressed_v.layout]]).to(compressed_v.device).to(torch.int32) + compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1] + + cu_seqlens = torch.tensor([s.start for s in x.layout] + [s.stop for s in x.layout if s.stop not in [s.start for s in x.layout]]).to(x.device).to(torch.int32) + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + compressed_attn_output, lse, _ = flash_attn_varlen_func( + q.feats, + compressed_k.feats, + compressed_v.feats, + cu_seqlens, + compressed_cu_seqlens, + seqlens.max().item(), + compressed_seqlens.max().item(), + causal=False, + return_attn_probs=True, + ) + + with torch.no_grad(): + block_topk, cu_seqblocks, cu_block_include_tokens = get_block_score( + q, compressed_k, lse, self.resolution, self.compression_block_size, + self.selection_block_size, self.topk, cu_seqlens, compressed_cu_seqlens, + seqlens, compressed_seqlens, None) + + # spatial selection attention + selection_attn_output = spatial_selection_attention( + q.feats, k.feats, v.feats, block_topk, cu_seqblocks, + cu_block_include_tokens, self.selection_block_size, cu_seqlens, None, + ) + + # window attention + window_attn_output = sparse_window_attention( + q, k, v, window_size=self.window_size, shift_window=self.shift_window, + ).feats + + # gate average + gate = self.gate(x).feats + attn_output = ( + gate[:, 0:1, None] * compressed_attn_output + + gate[:, 1:2, None] * selection_attn_output + + gate[:, 2:3, None] * window_attn_output + ) + + # rearrange and output proj + attn_output = rearrange(attn_output, "n h d -> n (h d)") + attn_output = self.proj_o(attn_output) + + return x.replace(attn_output) diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__init__.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3bfe015052a8b72b5c9d1d7337621656bc7ce7 --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__init__.py @@ -0,0 +1,3 @@ +from .compressed_attention import get_block_score +from .selection_attention import spatial_selection_attention +from .window_attention import sparse_window_attention diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbef7c9a876454135c3ef88993cca9f0f93dac51 Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/compressed_attention.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/compressed_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f93816c494979e04f17d8e3fb8423d82891122bd Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/compressed_attention.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/selection_attention.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/selection_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bff8490d9552b66ad7d745146a9ec17ac594b44 Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/selection_attention.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/window_attention.cpython-310.pyc b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/window_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df310ed5db6d8d209b9b77b0d3464ed484881e44 Binary files /dev/null and b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/window_attention.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/compressed_attention.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/compressed_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..5edc6bfb55511f4f9376a27aac91d2eb6d04f7dd --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/compressed_attention.py @@ -0,0 +1,275 @@ +# Copyright 2025 Xunhao Lai. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------- +# Copyright 2025 Shuang Wu +# adapted from https://github.com/XunhaoLai/native-sparse-attention-triton/blob/main/native_sparse_attention/ops/triton/compressed_attention.py + +import math +import torch +from copy import deepcopy +import triton +import triton.language as tl +import direct3d_s2.modules.sparse as sp + + +@triton.jit +def score_kernel( + q_ptr, + k_ptr, + lse_ptr, + s_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_lh, + stride_ln, + stride_sh, + stride_sq, + stride_sk, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_bkh = tl.program_id(0) + pid_b = pid_bkh // NUM_KV_HEADS + pid_kh = pid_bkh % NUM_KV_HEADS + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: + return + # init k pointer and load k + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + # init score + s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + # loop over gqa heads + for h in range(NUM_SHARE_Q_HEADS): + pid_h = pid_kh * NUM_SHARE_Q_HEADS + h + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # load q and lse + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k) * qk_scale + # compute score + s += tl.exp2(qk - lse) + # save output + s_ptrs = tl.make_block_ptr( + base=s_ptr + pid_kh * stride_sh + q_start * stride_sq, + shape=(q_len, k_len), + strides=(stride_sq, stride_sk), + offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K), + order=(1, 0), + ) + tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_attention_score( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +) -> torch.Tensor: + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert lse.dtype == torch.float32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len > k_len + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + # gqa + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # init score + score = torch.zeros( + num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device + ) + # launch kernel + grid = lambda META: ( + batch_size * num_k_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + num_warps = 4 if head_dim <= 64 else 8 + num_stages = 3 + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + score_kernel[grid]( + q, + k, + lse, + score, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + lse.stride(0), + lse.stride(1), + score.stride(0), + score.stride(1), + score.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return score + + +def get_block_score( + q: sp.SparseTensor, + compressed_k: sp.SparseTensor, + lse: sp.SparseTensor, + resolution: int, + kernel_stride: int, + block_size: int, + topk: int, + cu_seqlens: torch.Tensor, + compressed_cu_seqlens: torch.Tensor, + seqlens: torch.Tensor, + compressed_seqlens: torch.Tensor, + sm_scale: float = None, +) -> torch.Tensor: + attn_score = _get_attention_score( + q.feats, + compressed_k.feats, + lse.exp().log2(), + cu_seqlens, + compressed_cu_seqlens, + seqlens.max().item(), + compressed_seqlens.max().item(), + sm_scale, + ) + + batch_size = len(cu_seqlens) - 1 + num_kv_head = attn_score.shape[0] + block_res = resolution // block_size + seqblocks, block_include_tokens = [], [] + block_topk = torch.ones((num_kv_head, cu_seqlens[-1], topk), device=q.device, dtype=torch.int32) * -1 + + q_coords = deepcopy(q.coords) + for b in range(batch_size): + q_start, q_end, q_len = cu_seqlens[b], cu_seqlens[b + 1], seqlens[b] + + compressed_k_start, compressed_k_end = compressed_cu_seqlens[b], compressed_cu_seqlens[b + 1] + attn_score_b = attn_score[:, q_start: q_end, :(compressed_k_end-compressed_k_start)] + compressed_block_coords_b = deepcopy(compressed_k.coords[compressed_k_start: compressed_k_end]) + if block_size == kernel_stride: + score_block_b = attn_score_b + real_topk = min(topk, compressed_k_end - compressed_k_start) + block_topk_b = score_block_b.topk(real_topk, dim=-1).indices.sort(-1).values + block_topk[:, q_start: q_end, :real_topk] = block_topk_b + else: + compressed_block_coords_b[:, 1:] = compressed_block_coords_b[:, 1:] // (block_size//kernel_stride) + compressed_block_coords_flatten_b = compressed_block_coords_b[:, 1] * block_res**2 + compressed_block_coords_b[:, 2] * block_res + compressed_block_coords_b[:, 3] + score_block_b = torch.scatter_reduce( + torch.zeros((num_kv_head, q_len, block_res**3), device=attn_score_b.device, dtype=attn_score_b.dtype), + index=compressed_block_coords_flatten_b.long().unsqueeze(0).unsqueeze(0).expand_as(attn_score_b), + src=attn_score_b, + reduce="sum", + dim=2, + ) + compressed_block_coords_flatten_unique_b = compressed_block_coords_flatten_b.unique() + score_block_b = score_block_b[..., compressed_block_coords_flatten_unique_b] + real_topk = min(topk, len(compressed_block_coords_flatten_unique_b)) + block_topk_b = score_block_b.topk(real_topk, dim=-1).indices.sort(-1).values + block_topk[:, q_start: q_end, :real_topk] = block_topk_b + + block_coords_b = q_coords[q_start: q_end] + block_coords_b[:, 1:] = block_coords_b[:, 1:] // block_size + block_coords_flatten_b = block_coords_b[:, 1] * block_res**2 + block_coords_b[:, 2] * block_res + block_coords_b[:, 3] + block_bins_b = torch.histc(block_coords_flatten_b, bins=block_res**3, min=0, max=block_res**3-1) + block_include_tokens.append(block_bins_b[block_bins_b > 0]) + seqblocks.append(len(block_include_tokens[-1])) + seqblocks = torch.Tensor(seqblocks).to(attn_score.device) + cu_seqblocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(seqblocks, dim=0), + ], + dim=0, + ).to(torch.int32) + block_include_tokens = torch.cat(block_include_tokens) + cu_block_include_tokens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(block_include_tokens, dim=0), + ], + dim=0, + ).to(torch.int32) + return block_topk.to(torch.int32), cu_seqblocks, cu_block_include_tokens diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/selection_attention.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/selection_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..62607d6f5a0a1c2b18180b5495a6413d9867402f --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/selection_attention.py @@ -0,0 +1,1256 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------- +# Copyright 2025 Shuang Wu +# adapted from https://github.com/XunhaoLai/native-sparse-attention-triton/blob/main/native_sparse_attention/ops/triton/topk_sparse_attention.py + +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + + +@triton.jit +def forward_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # seqlens + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_block_include_tokens, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # q loop num + num_q_loop, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_kh = tl.program_id(1) + pid_h = pid_kh * NUM_SHARE_Q_HEADS + pid_q = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + block_start = tl.load(cu_seqblocks + pid_b) + block_len = tl.load(cu_seqblocks + pid_b + 1) - block_start + if pid_q * num_q_loop >= q_len: + return + num_q_loop_ = min(num_q_loop, q_len - pid_q * num_q_loop) + for j in range(num_q_loop_): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx < block_len), 1, 0), + axis=0, + ) + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_h = tl.arange(0, BLOCK_SIZE_H) + off_k = tl.arange(0, BLOCK_SIZE_K) + m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32) + # sparse attention + for i in range(real_topk): + # get current block start index + cur_block_idx = tl.load(t_ptr_j).to(tl.int32) + cur_token_start = tl.load(cu_block_include_tokens + block_start + cur_block_idx).to(tl.int32) + cur_block_size = tl.load(cu_block_include_tokens + block_start + cur_block_idx + 1).to(tl.int32) - cur_token_start + c = cur_token_start - k_start + t_ptr_j = t_ptr_j + stride_tk + for b_j in range(0, cur_block_size, BLOCK_SIZE_K): + # load k + k = tl.load( + tl.advance(k_ptrs, (0, c + b_j)), + boundary_check=(1, 0), padding_option="zero", + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(((c + b_j + off_k < k_len) & (b_j + off_k < cur_block_size))[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load( + tl.advance(v_ptrs, (c + b_j, 0)), + boundary_check=(0, 1), padding_option="zero" + ) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) + # final scale + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_oh, stride_od), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + lse_ptrs = ( + lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh + ) + tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + + off_o[:, None] * stride_on + + pid_h * stride_oh + + off_d[None, :] * stride_od, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + + off_o[:, None] * stride_don + + pid_h * stride_doh + + off_d[None, :] * stride_dod, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store( + delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len + ) + + +@triton.jit +def count_kernel( + x_ptr, # [num_kv_heads, total_len, topk] + y_ptr, # [num_kv_heads, total_blocks] + cu_seqlens, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + topk, + stride_xh, + stride_xn, + stride_xk, + stride_yh, + stride_yn, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_R: tl.constexpr, +): + pid_h = tl.program_id(0) + pid_b = tl.program_id(1) + # get start and len after rmpad + seq_start = tl.load(cu_seqlens + pid_b) + seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start + blocks_start = tl.load(cu_seqblocks + pid_b) + num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start + # load x + off_k = tl.arange(0, BLOCK_SIZE_K) + off_n = tl.arange(0, BLOCK_SIZE_N) + x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn + x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk + # init y + y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32) + # loop + for i in range(0, seq_len, BLOCK_SIZE_N): + x = tl.load( + x_ptrs, + mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :], + other=-1, + ) + x = tl.ravel(x) + y += tl.histogram(x, BLOCK_SIZE_R) + x_ptrs += BLOCK_SIZE_N * stride_xn + # store result + off_r = tl.arange(0, BLOCK_SIZE_R) + y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn + y_ptrs = y_ptr + off_r * stride_yn + tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks) + + +def count_query( + topk_idx: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1] + batch_size = seqlens.shape[0] + BLOCK_SIZE_K = triton.next_power_of_2(topk) + BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K) + BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2) + active_query_count = torch.zeros( + num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device + ) + grid = (num_kv_heads, batch_size) + count_kernel[grid]( + topk_idx, + active_query_count, + cu_seqlens, + cu_seqblocks, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + active_query_count.stride(0), + active_query_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_R=BLOCK_SIZE_R, + num_warps=4, + num_stages=3, + ) + return active_query_count + + +@triton.jit +def pad_topk_idx_kernel( + t_ptr, + p_ptr, + cu_seqlens, + topk, + stride_th, + stride_tn, + stride_tk, + stride_pb, + stride_ph, + stride_pn, + stride_pk, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_start = tl.load(cu_seqlens + pid_b) + q_len = tl.load(cu_seqlens + pid_b + 1) - q_start + if BLOCK_SIZE_N * pid_n >= q_len: + return + # init prts + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + q_start * stride_tn, + shape=(q_len, topk), + strides=(stride_tn, stride_tk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + p_ptrs = tl.make_block_ptr( + base=p_ptr + pid_b * stride_pb + pid_h * stride_ph, + shape=(q_len, topk), + strides=(stride_pn, stride_pk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + # load and save + idxs = tl.load(t_ptrs, boundary_check=(0, 1)) + tl.store(p_ptrs, idxs, boundary_check=(0, 1)) + + +@triton.jit +def save_topk_idx_kernel( + p_ptr, + t_ptr, + cu_seqblocks, + cu_topk_q_count, + n_len, + stride_pb, + stride_ph, + stride_pn, + stride_th, + stride_tn, + stride_ch, + stride_cn, + BLOCK_SIZE_N: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_block_start = tl.load(cu_seqblocks + pid_b) + q_block_end = tl.load(cu_seqblocks + pid_b + 1) + c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn) + c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn) + c_len = c_end - c_start + if c_len <= 0: + return + if pid_n * BLOCK_SIZE_N >= c_len: + return + # init ptrs + p_ptrs = tl.make_block_ptr( + base=p_ptr + + pid_b * stride_pb + + pid_h * stride_ph + + (n_len - c_len) * stride_pn, + shape=(c_len,), + strides=(stride_pn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + c_start * stride_tn, + shape=(c_len,), + strides=(stride_tn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + # load and save + idxs = tl.load(p_ptrs, boundary_check=(0,)) + tl.store(t_ptrs, idxs, boundary_check=(0,)) + + +def reorder_topk_idx( + topk_idx: torch.Tensor, + cu_topk_q_count: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + batch_size = cu_seqlens.shape[0] - 1 + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk] + pad_topk_idx = torch.full( + (batch_size, num_kv_heads, max_seqlen, topk), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + BLOCK_SIZE_N = min( + triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T) + ) + grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N)) + pad_topk_idx_kernel[grid]( + topk_idx, + pad_topk_idx, + cu_seqlens, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + pad_topk_idx.stride(0), + pad_topk_idx.stride(1), + pad_topk_idx.stride(2), + pad_topk_idx.stride(3), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_T=BLOCK_SIZE_T, + ) + # argsort + pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk + pad_topk_q_idx = pad_topk_q_idx.to(torch.int32) + # save as remove pad version + topk_q_idx = torch.full( + (num_kv_heads, cu_topk_q_count[:, -1].max().item()), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + max_len = ( + ( + cu_topk_q_count[:, cu_seqblocks][:, 1:] + - cu_topk_q_count[:, cu_seqblocks][:, :-1] + ) + .max() + .item() + ) + + BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192) + grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N)) + save_topk_idx_kernel[grid]( + pad_topk_q_idx, + topk_q_idx, + cu_seqblocks, + cu_topk_q_count, + pad_topk_q_idx.shape[-1], + pad_topk_q_idx.stride(0), + pad_topk_q_idx.stride(1), + pad_topk_q_idx.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return topk_q_idx + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + tq_ptr, # topk_q_idx: kh x N + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DK: sh x n x kh x d + # seqlens + cu_seqlens_q, # [batch_size + 1] + cu_seqlens_k, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + cu_topk_q_count, # [kh, total_blocks] + cu_block_include_tokens, # [total_blocks + 1] + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + max_seqblocks, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tqh, + stride_tqn, + stride_ctqh, + stride_ctqn, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_kb = tl.program_id(2) + pid_k = pid_kb % max_seqblocks + pid_block = pid_kb // max_seqblocks + + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + # get topk_q_idx + b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + b_len = tl.load(cu_seqblocks + pid_b + 1) - b_start + + if pid_k >= b_len: + return + + cur_token_start = tl.load(cu_block_include_tokens + b_start + pid_k).to(tl.int32) + cur_block_size = tl.load(cu_block_include_tokens + b_start + pid_k + 1).to(tl.int32) - cur_token_start + cur_token_start_in_seq = cur_token_start - k_start + + if pid_block * BLOCK_SIZE_K >= cur_block_size: + return + + act_q_start = tl.load( + cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn + ) + act_q_end = tl.load( + cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn + ) + act_q_len = act_q_end - act_q_start + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) #+ pid_k * BLOCK_SIZE_K + off_d = tl.arange(0, BLOCK_SIZE_D) + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(cur_token_start_in_seq + pid_block * BLOCK_SIZE_K, 0), #(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = dk_ptr + (cur_token_start + pid_block * BLOCK_SIZE_K + off_k[:, None]) * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks + off_d[None, :] * stride_dkd + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(cur_token_start_in_seq + pid_block * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = dv_ptr + (cur_token_start + pid_block * BLOCK_SIZE_K + off_k[:, None]) * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs + off_d[None, :] * stride_dvd + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # init ptrs + q_ptrs = ( + q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd + ) + do_ptrs = ( + do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod + ) + d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh + lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + # loop for q blocks + for i in range(0, act_q_len, BLOCK_SIZE_Q): + # load + idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to( + tl.int32 + ) + q = tl.load( + q_ptrs + idx_q[:, None] * stride_qn, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + do = tl.load( + do_ptrs + idx_q[:, None] * stride_don, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + lse = tl.load( + lse_ptrs + idx_q[:, None] * stride_ln, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + d = tl.load( + d_ptrs + idx_q[:, None] * stride_dn, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(((pid_block * BLOCK_SIZE_K + off_k < cur_block_size)[None, :] & (off_q < act_q_len - i)[:, None]), float(0.0), float("-inf")) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), mask=(pid_block * BLOCK_SIZE_K + off_k < cur_block_size)[:, None]) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), mask=(pid_block * BLOCK_SIZE_K + off_k < cur_block_size)[:, None]) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_block_include_tokens, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # q loop num + num_q_loop, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_kh = tl.program_id(1) + pid_q = tl.program_id(2) + pid_h = pid_kh * NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + block_start = tl.load(cu_seqblocks + pid_b) + block_len = tl.load(cu_seqblocks + pid_b + 1) - block_start + if pid_q * num_q_loop >= q_len: + return + num_q_loop_ = min(num_q_loop, q_len - pid_q * num_q_loop) + for j in range(num_q_loop_): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx < block_len), 1, 0), + axis=0, + ) + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_dqh, stride_dqd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_doh, stride_dod), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_dh, stride_dn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_lh, stride_ln), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + # offsets + off_k = tl.arange(0, BLOCK_SIZE_K) + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32) + # sparse + for i in range(real_topk): + # get current block start index + cur_block_idx = tl.load(t_ptr_j).to(tl.int32) + cur_token_start = tl.load(cu_block_include_tokens + block_start + cur_block_idx).to(tl.int32) + cur_block_size = tl.load(cu_block_include_tokens + block_start + cur_block_idx + 1).to(tl.int32) - cur_token_start + c = cur_token_start - k_start + t_ptr_j = t_ptr_j + stride_tk + + for b_j in range(0, cur_block_size, BLOCK_SIZE_K): + # load kv + k = tl.load( + tl.advance(k_ptrs, (c + b_j, 0)), boundary_check=(1, 0), padding_option="zero" + ) + v = tl.load( + tl.advance(v_ptrs, (c + b_j, 0)),boundary_check=(0, 1),padding_option="zero" + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((off_k + b_j < cur_block_size)[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, tl.trans(k)) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, tl.trans(v)) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _spatial_selection_attention_fwd( + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + cu_seqblocks: torch.Tensor, + cu_block_include_tokens: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len == k_len and k_len == v_len + topk = topk_idx.shape[-1] + assert topk_idx.shape[0] == num_k_heads + assert topk_idx.shape[1] == q_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.zeros_like(q) + lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device) + # launch kernel + num_q_loop = ( + cu_seqlens_q[-1].item() // 32768 + 1 + ) # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + num_warps = 4 if head_dim <= 64 else 8 + num_stages = 3 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + forward_kernel[grid]( + q, + k, + v, + topk_idx, + o, + lse, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_block_include_tokens, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _spatial_selection_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + cu_seqblocks: torch.Tensor, + cu_block_include_tokens: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + topk = topk_idx.shape[-1] + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps = 4 if head_dim <= 64 else 8 + num_stages = 3 + grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # count active querys for each key block, shape: (num_k_heads, total_k_blocks) + topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size) # [num_kv_head, total_block] + cu_topk_q_count = torch.cat( + [ + torch.zeros( + topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device + ), + torch.cumsum(topk_q_count, dim=-1), + ], + dim=-1, + ).to(torch.int32) # [num_kv_head, cu_total_block + 1] + # active query idx for each key block + # how to get active query idx for sequence b, head h, kv block i? + # topk_q_idx[h, cu_topk_q_count[h, cu_seqblocks[b] + i]:cu_topk_q_count[h, cu_seqblocks[b] + i + 1]] + topk_q_idx = reorder_topk_idx( + topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size + ) + # compute dk dv + dk = torch.zeros( + num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype + ) + dv = torch.zeros( + num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype + ) + batch_size = cu_seqlens_q.shape[0] - 1 + num_warps = 4 if head_dim <= 64 else 8 + num_stages = 3 + max_include_block = (cu_block_include_tokens[..., 1:] - cu_block_include_tokens[..., :-1]).max().item() + BLOCK_SIZE_K = 64 + BLOCK_SIZE_Q = 128 if BLOCK_SIZE_K <= 64 else 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + max_seqblocks = (cu_seqblocks[1:] - cu_seqblocks[:-1]).max().item() + grid = (batch_size, num_q_heads, max_seqblocks * triton.cdiv(max_include_block, BLOCK_SIZE_K)) + + backward_dkdv[grid]( + q, + k, + v, + topk_q_idx, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_topk_q_count, + cu_block_include_tokens, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + max_seqblocks, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + num_q_loop = ( + cu_seqlens_q[-1].item() // 32768 + 1 + ) # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + num_warps = 4 if head_dim <= 64 else 8 + num_stages = 3 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + backward_dq[grid]( + q, + k, + v, + topk_idx, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_block_include_tokens, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class SpatialSelectionAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + cu_seqblocks: torch.Tensor, # [batch_size + 1] + cu_block_include_tokens: torch.Tensor, # [total_block_len] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert topk_idx.dtype == torch.int32 + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + o, lse = _spatial_selection_attention_fwd( + q, + k, + v, + topk_idx, + cu_seqblocks, + cu_block_include_tokens, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx, cu_seqblocks, cu_block_include_tokens) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.block_size = block_size + # return + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx, cu_seqblocks, cu_block_include_tokens = ctx.saved_tensors + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + block_size = ctx.block_size + + dq, dk, dv = _spatial_selection_attention_bwd( + o, + do, + lse, + q, + k, + v, + topk_idx, + cu_seqblocks, + cu_block_include_tokens, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +def spatial_selection_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_topk: torch.Tensor, + cu_seqblocks: torch.Tensor, + cu_block_include_tokens: torch.Tensor, + block_size: int, + cu_seqlens: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Spatial selection attention implemented in triton. + + Args: + q (torch.Tensor): shape [total_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + block_topk (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. + cu_block_include_tokens (torch.Tensor) shape [total_block_len]: number of tokens within each block + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] + """ + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return SpatialSelectionAttention.apply( + q, + k, + v, + block_topk, + cu_seqblocks, + cu_block_include_tokens, + block_size, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + softmax_scale, + ) diff --git a/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/window_attention.py b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/window_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8f8b420d6f3d533d1f00bcacebf37805089826 --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/spatial_sparse_attention/ops/window_attention.py @@ -0,0 +1,59 @@ +from typing import * +import torch +import flash_attn +from direct3d_s2.modules.sparse import SparseTensor +from direct3d_s2.modules.sparse.attention.windowed_attn import calc_window_partition + + +def sparse_window_attention( + q: SparseTensor, + k: SparseTensor, + v: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + q (SparseTensor): [N, *, H_q, C] sparse tensor containing query. + k (SparseTensor): [N, *, H_kv, C] sparse tensor containing key. + v (SparseTensor): [N, *, H_kv, C] sparse tensor containing value. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + + serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' + serialization_spatial_cache = q.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(q, window_size, shift_window) + q.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = q.feats.shape[0] + H = q.feats.shape[1] + H_kv = k.feats.shape[1] + C = q.feats.shape[2] + q_feats = q.feats[fwd_indices] # [M, H, C] + k_feats = k.feats[fwd_indices] + v_feats = v.feats[fwd_indices] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + q_feats = q_feats.reshape(B, N, H, C) + k_feats = k_feats.reshape(B, N, H_kv, C) + v_feats = v_feats.reshape(B, N, H_kv, C) + out = flash_attn.flash_attn_func(q_feats, k_feats, v_feats) + out = out.reshape(B * N, H, C) # [M, H, C] + else: + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(q.device).int() + out = flash_attn.flash_attn_varlen_func(q_feats, k_feats, v_feats, cu_seqlens, cu_seqlens, max(seq_lens), max(seq_lens)) + + out = out[bwd_indices] # [T, H, C] + + return q.replace(out) \ No newline at end of file diff --git a/direct3d_s2/modules/sparse/attention/windowed_attn.py b/direct3d_s2/modules/sparse/attention/windowed_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..2c450398c634ba314c0f2100bf4949207a40847b --- /dev/null +++ b/direct3d_s2/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,133 @@ +from typing import * +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_windowed_scaled_dot_product_self_attention', +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0 +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (List[int]): Sequence lengths. + (List[int]): Sequence batch indices. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0] + mask = seq_lens != 0 + seq_lens = seq_lens[mask].tolist() + seq_batch_indices = seq_batch_indices[mask].tolist() + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start:start+seq_lens[i]] + assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/direct3d_s2/modules/sparse/basic.py b/direct3d_s2/modules/sparse/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..8837f44052f6d573d09e3bfb897e659e10516bb5 --- /dev/null +++ b/direct3d_s2/modules/sparse/basic.py @@ -0,0 +1,459 @@ +from typing import * +import torch +import torch.nn as nn +from . import BACKEND, DEBUG +SparseTensorData = None # Lazy import + + +__all__ = [ + 'SparseTensor', + 'sparse_batch_broadcast', + 'sparse_batch_op', + 'sparse_cat', + 'sparse_unbind', +] + + +class SparseTensor: + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + global SparseTensorData + if SparseTensorData is None: + import importlib + if BACKEND == 'torchsparse': + SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif BACKEND == 'spconv': + SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape, layout = args + (None,) * (4 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + if shape is None: + shape = self.__cal_shape(feats, coords) + if layout is None: + layout = self.__cal_layout(coords, shape[0]) + if BACKEND == 'torchsparse': + self.data = SparseTensorData(feats, coords, **kwargs) + elif BACKEND == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1)[1:] + self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs) + self.data._features = feats + elif method_id == 1: + data, shape, layout = args + (None,) * (3 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + self.data = data + if shape is None: + shape = self.__cal_shape(self.feats, self.coords) + if layout is None: + layout = self.__cal_layout(self.coords, shape[0]) + + self._shape = shape + self._layout = layout + self._scale = kwargs.get('scale', (1, 1, 1)) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + @property + def shape(self) -> torch.Size: + return self._shape + + def dim(self) -> int: + return len(self.shape) + + @property + def layout(self) -> List[slice]: + return self._layout + + @property + def feats(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.F + elif BACKEND == 'spconv': + return self.data.features + + @feats.setter + def feats(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.F = value + elif BACKEND == 'spconv': + self.data.features = value + + @property + def coords(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.C + elif BACKEND == 'spconv': + return self.data.indices + + @coords.setter + def coords(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.C = value + elif BACKEND == 'spconv': + self.data.indices = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @overload + def to(self, dtype: torch.dtype) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + + new_feats = self.feats.to(device=device, dtype=dtype) + new_coords = self.coords.to(device=device) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def dense(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.dense() + elif BACKEND == 'spconv': + return self.data.dense() + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + new_shape = [self.shape[0]] + new_shape.extend(feats.shape[1:]) + if BACKEND == 'torchsparse': + new_data = SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif BACKEND == 'spconv': + new_data = SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) + return new_tensor + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __neg__(self) -> 'SparseTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = sparse_batch_broadcast(self, other) + except: + pass + if isinstance(other, SparseTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + coords = [] + feats = [] + for new_idx, old_idx in enumerate(idx): + coords.append(self.coords[self.layout[old_idx]].clone()) + coords[-1][:, 0] = new_idx + feats.append(self.feats[self.layout[old_idx]]) + coords = torch.cat(coords, dim=0).contiguous() + feats = torch.cat(feats, dim=0).contiguous() + return SparseTensor(feats=feats, coords=coords) + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + +def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + coords, feats = input.coords, input.feats + broadcasted = torch.zeros_like(feats) + for k in range(input.shape[0]): + broadcasted[input.layout[k]] = other[k] + return broadcasted + + +def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) + + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/direct3d_s2/modules/sparse/conv/__init__.py b/direct3d_s2/modules/sparse/conv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..340a87126a8de574ee0276feb96b49824a2ce234 --- /dev/null +++ b/direct3d_s2/modules/sparse/conv/__init__.py @@ -0,0 +1,21 @@ +from .. import BACKEND + + +SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' + +def __from_env(): + import os + + global SPCONV_ALGO + env_spconv_algo = os.environ.get('SPCONV_ALGO') + if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: + SPCONV_ALGO = env_spconv_algo + print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") + + +__from_env() + +if BACKEND == 'torchsparse': + from .conv_torchsparse import * +elif BACKEND == 'spconv': + from .conv_spconv import * diff --git a/direct3d_s2/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70d75bb2732eb260f2c78ebe5d62e70335c2cfbc Binary files /dev/null and b/direct3d_s2/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc b/direct3d_s2/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbeac5a85eecb6501557b46c516dfba80fb6d4ae Binary files /dev/null and b/direct3d_s2/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/conv/__pycache__/conv_torchsparse.cpython-310.pyc b/direct3d_s2/modules/sparse/conv/__pycache__/conv_torchsparse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae86bfa9fa6dfe94672df2191848c09d88d17c5c Binary files /dev/null and b/direct3d_s2/modules/sparse/conv/__pycache__/conv_torchsparse.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/conv/conv_spconv.py b/direct3d_s2/modules/sparse/conv/conv_spconv.py new file mode 100644 index 0000000000000000000000000000000000000000..ff058302f033f8f340c9f75efc869006cbf5b993 --- /dev/null +++ b/direct3d_s2/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from .. import DEBUG +from . import SPCONV_ALGO + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + algo = None + if SPCONV_ALGO == 'native': + algo = spconv.ConvAlgo.Native + elif SPCONV_ALGO == 'implicit_gemm': + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + def forward(self, x: SparseTensor) -> SparseTensor: + + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features #[fwd] + sorted_coords = new_data.indices #[fwd] + unsorted_data = new_data + + indice_dict = new_data.indice_dict + + if 'spconv' not in globals(): + import spconv.pytorch as spconv + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size, indice_dict=indice_dict) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + if DEBUG: + assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/direct3d_s2/modules/sparse/conv/conv_torchsparse.py b/direct3d_s2/modules/sparse/conv/conv_torchsparse.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc6561a2734d5c71794e24575741209b279de89 --- /dev/null +++ b/direct3d_s2/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from torchsparse.utils import make_ntuple + + +def sparseconv3d_func(input: SparseTensor, weight: torch.Tensor, kernel_size: int, stride: int = 1, dilation: int = 1, padding: int = 0, bias: torch.Tensor = None, training: bool = True): + if 'torchsparse' not in globals(): + import torchsparse + stride = make_ntuple(stride, ndim=3) + kernel_size = make_ntuple(kernel_size, ndim=3) + _padding = make_ntuple(padding, 3) + padding = () + for i in range(3): + if kernel_size[i] % 2 == 1 and stride[i] == 1: + padding += ((kernel_size[i] - 1) // 2,) + else: + padding += (_padding[i],) + out = torchsparse.nn.functional.conv3d(input.data, weight, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, training=training) + spatial_range = out.spatial_range + new_shape = [input.shape[0], weight.shape[1]] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=input.layout if all(s == 1 for s in stride) else None) + out._spatial_cache = input._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(input._scale, stride)]) + out.data.spatial_range = spatial_range + return out + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + + spatial_range = out.spatial_range + + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + + out.data.spatial_range = spatial_range + + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) + + return out + + + diff --git a/direct3d_s2/modules/sparse/linear.py b/direct3d_s2/modules/sparse/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..a854e77ce87d1a190b9730d91f363a821ff250bd --- /dev/null +++ b/direct3d_s2/modules/sparse/linear.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseLinear' +] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) diff --git a/direct3d_s2/modules/sparse/nonlinearity.py b/direct3d_s2/modules/sparse/nonlinearity.py new file mode 100644 index 0000000000000000000000000000000000000000..2e6bfd855271d238fe34ab8bec2744bf9db58b94 --- /dev/null +++ b/direct3d_s2/modules/sparse/nonlinearity.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' + 'SparseTanh', + 'SparseSigmoid', +] + +class SparseSigmoid(nn.Sigmoid): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + +class SparseReLU(nn.ReLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + +class SparseGELU(nn.GELU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + +class SparseTanh(nn.Tanh): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(self.activation(input.feats)) + diff --git a/direct3d_s2/modules/sparse/norm.py b/direct3d_s2/modules/sparse/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..6b38a36682c098210000dc31d68ddc31ccd2929d --- /dev/null +++ b/direct3d_s2/modules/sparse/norm.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from . import SparseTensor +from . import DEBUG + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + if DEBUG: + assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) diff --git a/direct3d_s2/modules/sparse/spatial.py b/direct3d_s2/modules/sparse/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..557438b50d9aaa99dee5d36c5f8a8a042ac70017 --- /dev/null +++ b/direct3d_s2/modules/sparse/spatial.py @@ -0,0 +1,115 @@ +from typing import * +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseDownsample', + 'SparseUpsample', + 'SparseSubdivide' +] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], mode="mean"): + super(SparseDownsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + self.mode = mode + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' + + coord = list(input.coords.unbind(dim=-1)) + for i, f in enumerate(factor): + coord[i+1] = coord[i+1] // f + + MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + #### using fp16 could cause overflow when factor is large ###### + dtype = input.feats.dtype + new_feats = torch.scatter_reduce( + torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=torch.float64), + dim=0, + index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), + src=input.feats.double(), + reduce=self.mode, + ) + new_feats = new_feats.to(dtype) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + out = SparseTensor(new_feats, new_coords, input.shape,) + out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + + out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) + out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) + out.register_spatial_cache(f'upsample_{factor}_idx', idx) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): + super(SparseUpsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' + + new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') + new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') + idx = input.get_spatial_cache(f'upsample_{factor}_idx') + if any([x is None for x in [new_coords, new_layout, idx]]): + raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') + new_feats = input.feats[idx] + out = SparseTensor(new_feats, new_coords, input.shape, new_layout) + out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + return out + +class SparseSubdivide(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self): + super(SparseSubdivide, self).__init__() + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + # upsample scale=2^DIM + n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) + n_coords = torch.nonzero(n_cube) + n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) + factor = n_coords.shape[0] + assert factor == 2 ** DIM + # print(n_coords.shape) + new_coords = input.coords.clone() + new_coords[:, 1:] *= 2 + new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) + + new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) + out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) + out._scale = input._scale * 2 + out._spatial_cache = input._spatial_cache + return out + diff --git a/direct3d_s2/modules/sparse/transformer/__init__.py b/direct3d_s2/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/direct3d_s2/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/direct3d_s2/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..969a1cb28304a96ecc20b610829b2e20e79d992b Binary files /dev/null and b/direct3d_s2/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc b/direct3d_s2/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ee096ac0e59a2a4530cfc7dd8014fc678962cf8 Binary files /dev/null and b/direct3d_s2/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc b/direct3d_s2/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b14f225257c4670a232ee5f135b4598065182d63 Binary files /dev/null and b/direct3d_s2/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc differ diff --git a/direct3d_s2/modules/sparse/transformer/blocks.py b/direct3d_s2/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..9d037a49bf83e1c2dfb2f8c4b23d2e9d6c51e9f0 --- /dev/null +++ b/direct3d_s2/modules/sparse/transformer/blocks.py @@ -0,0 +1,151 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: SparseTensor) -> SparseTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/direct3d_s2/modules/sparse/transformer/modulated.py b/direct3d_s2/modules/sparse/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..617af7338d7c5414a0666e9cf91f63cd71cc3447 --- /dev/null +++ b/direct3d_s2/modules/sparse/transformer/modulated.py @@ -0,0 +1,213 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..attention import SparseMultiHeadAttention, SerializeMode, SpatialSparseAttention +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + compression_version: str = "v2", + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + use_ssa: bool = True, + num_kv_heads: int = 2, + compression_block_size: int = 4, + selection_block_size: int = 8, + topk: int = 8, + resolution: int = 64, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + if use_ssa: + self.self_attn = SpatialSparseAttention( + channels, + num_q_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=channels//num_heads, + compression_block_size=compression_block_size, + compression_version=compression_version, + selection_block_size=selection_block_size, + topk=topk, + window_size=window_size, + shift_window=shift_window, + resolution=resolution, + ) + else: + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + + feats_h = h.feats + layouts = h.layout + ada_r1 = [] + for i in range(len(layouts)): + ada_r1.append(feats_h[layouts[i]] * (1 + scale_msa[i:i+1]) + shift_msa[i:i+1]) + h = h.replace(torch.cat(ada_r1, dim=0)) + h = self.self_attn(h) + + feats_h = h.feats + layouts = h.layout + ada_r2 = [] + for i in range(len(layouts)): + ada_r2.append(feats_h[layouts[i]] * gate_msa[i:i+1]) + h = h.replace(torch.cat(ada_r2, dim=0)) + + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + + feats_h = h.feats + layouts = h.layout + ada_r3 = [] + for i in range(len(layouts)): + ada_r3.append(feats_h[layouts[i]] * (1 + scale_mlp[i:i+1]) + shift_mlp[i:i+1]) + h = h.replace(torch.cat(ada_r3, dim=0)) + h = self.mlp(h) + + feats_h = h.feats + layouts = h.layout + ada_r4 = [] + for i in range(len(layouts)): + ada_r4.append(feats_h[layouts[i]] * gate_mlp[i:i+1]) + h = h.replace(torch.cat(ada_r4, dim=0)) + + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) diff --git a/direct3d_s2/modules/spatial.py b/direct3d_s2/modules/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..79e268d36c2ba49b0275744022a1a1e19983dae3 --- /dev/null +++ b/direct3d_s2/modules/spatial.py @@ -0,0 +1,48 @@ +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) + x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + return x diff --git a/direct3d_s2/modules/transformer/__init__.py b/direct3d_s2/modules/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/direct3d_s2/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/direct3d_s2/modules/transformer/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/modules/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8f2a4e028c3083fbd5787b393d7f1027d27b9f2 Binary files /dev/null and b/direct3d_s2/modules/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/modules/transformer/__pycache__/blocks.cpython-310.pyc b/direct3d_s2/modules/transformer/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38ad84b709f92193cecb88fba9300b54d6d2848c Binary files /dev/null and b/direct3d_s2/modules/transformer/__pycache__/blocks.cpython-310.pyc differ diff --git a/direct3d_s2/modules/transformer/__pycache__/modulated.cpython-310.pyc b/direct3d_s2/modules/transformer/__pycache__/modulated.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38d12b1952b57c97234e214a5385a947f7ed0a3d Binary files /dev/null and b/direct3d_s2/modules/transformer/__pycache__/modulated.cpython-310.pyc differ diff --git a/direct3d_s2/modules/transformer/blocks.py b/direct3d_s2/modules/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..605ae33bc44276f45f73789f547dc756ac3999da --- /dev/null +++ b/direct3d_s2/modules/transformer/blocks.py @@ -0,0 +1,184 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor, factor: float = None) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + if factor is not None: + x = x * factor + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor): + h = self.norm1(x) + h = self.self_attn(h) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) + \ No newline at end of file diff --git a/direct3d_s2/modules/transformer/modulated.py b/direct3d_s2/modules/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..d4aeca0689e68f656b08f7aa822b7be839aa727d --- /dev/null +++ b/direct3d_s2/modules/transformer/modulated.py @@ -0,0 +1,157 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) + \ No newline at end of file diff --git a/direct3d_s2/modules/utils.py b/direct3d_s2/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f0afb1b6c767aa2ad00bad96649fb30315e696ea --- /dev/null +++ b/direct3d_s2/modules/utils.py @@ -0,0 +1,54 @@ +import torch.nn as nn +from ..modules import sparse as sp + +FP16_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) diff --git a/direct3d_s2/pipeline.py b/direct3d_s2/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..ce0e9e7c64329b4092654254235f38f011b05c29 --- /dev/null +++ b/direct3d_s2/pipeline.py @@ -0,0 +1,357 @@ +import os +import torch +import numpy as np +from typing import Any +from PIL import Image +from tqdm import tqdm +from omegaconf import OmegaConf +from huggingface_hub import hf_hub_download +from typing import Union, List, Optional +from direct3d_s2.modules import sparse as sp +from direct3d_s2.utils import ( + instantiate_from_config, + preprocess_image, + sort_block, + extract_tokens_and_coords, + normalize_mesh, + mesh2index, +) + + +class Direct3DS2Pipeline(object): + + def __init__(self, + dense_vae, + dense_dit, + sparse_vae_512, + sparse_dit_512, + sparse_vae_1024, + sparse_dit_1024, + refiner, + dense_image_encoder, + sparse_image_encoder, + dense_scheduler, + sparse_scheduler_512, + sparse_scheduler_1024, + dtype=torch.float16, + ): + self.dense_vae = dense_vae + self.dense_dit = dense_dit + self.sparse_vae_512 = sparse_vae_512 + self.sparse_dit_512 = sparse_dit_512 + self.sparse_vae_1024 = sparse_vae_1024 + self.sparse_dit_1024 = sparse_dit_1024 + self.refiner = refiner + self.dense_image_encoder = dense_image_encoder + self.sparse_image_encoder = sparse_image_encoder + self.dense_scheduler = dense_scheduler + self.sparse_scheduler_512 = sparse_scheduler_512 + self.sparse_scheduler_1024 = sparse_scheduler_1024 + self.dtype = dtype + + def to(self, device): + self.device = torch.device(device) + self.dense_vae.to(device) + self.dense_dit.to(device) + self.sparse_vae_512.to(device) + self.sparse_dit_512.to(device) + self.sparse_vae_1024.to(device) + self.sparse_dit_1024.to(device) + self.refiner.to(device) + self.dense_image_encoder.to(device) + self.sparse_image_encoder.to(device) + + @classmethod + def from_pretrained(cls, pipeline_path, subfolder="direct3d-s2-v-1-1"): + + if os.path.isdir(pipeline_path): + config_path = os.path.join(pipeline_path, 'config.yaml') + model_dense_path = os.path.join(pipeline_path, 'model_dense.ckpt') + model_sparse_512_path = os.path.join(pipeline_path, 'model_sparse_512.ckpt') + model_sparse_1024_path = os.path.join(pipeline_path, 'model_sparse_1024.ckpt') + model_refiner_path = os.path.join(pipeline_path, 'model_refiner.ckpt') + else: + config_path = hf_hub_download( + repo_id=pipeline_path, + subfolder=subfolder, + filename="config.yaml", + repo_type="model" + ) + model_dense_path = hf_hub_download( + repo_id=pipeline_path, + subfolder=subfolder, + filename="model_dense.ckpt", + repo_type="model" + ) + model_sparse_512_path = hf_hub_download( + repo_id=pipeline_path, + subfolder=subfolder, + filename="model_sparse_512.ckpt", + repo_type="model" + ) + model_sparse_1024_path = hf_hub_download( + repo_id=pipeline_path, + subfolder=subfolder, + filename="model_sparse_1024.ckpt", + repo_type="model" + ) + model_refiner_path = hf_hub_download( + repo_id=pipeline_path, + subfolder=subfolder, + filename="model_refiner.ckpt", + repo_type="model" + ) + + cfg = OmegaConf.load(config_path) + + state_dict_dense = torch.load(model_dense_path, map_location='cpu', weights_only=True) + dense_vae = instantiate_from_config(cfg.dense_vae) + dense_vae.load_state_dict(state_dict_dense["vae"], strict=True) + dense_vae.eval() + dense_dit = instantiate_from_config(cfg.dense_dit) + dense_dit.load_state_dict(state_dict_dense["dit"], strict=True) + dense_dit.eval() + + state_dict_sparse_512 = torch.load(model_sparse_512_path, map_location='cpu', weights_only=True) + sparse_vae_512 = instantiate_from_config(cfg.sparse_vae_512) + sparse_vae_512.load_state_dict(state_dict_sparse_512["vae"], strict=True) + sparse_vae_512.eval() + sparse_dit_512 = instantiate_from_config(cfg.sparse_dit_512) + sparse_dit_512.load_state_dict(state_dict_sparse_512["dit"], strict=True) + sparse_dit_512.eval() + + state_dict_sparse_1024 = torch.load(model_sparse_1024_path, map_location='cpu', weights_only=True) + sparse_vae_1024 = instantiate_from_config(cfg.sparse_vae_1024) + sparse_vae_1024.load_state_dict(state_dict_sparse_1024["vae"], strict=True) + sparse_vae_1024.eval() + sparse_dit_1024 = instantiate_from_config(cfg.sparse_dit_1024) + sparse_dit_1024.load_state_dict(state_dict_sparse_1024["dit"], strict=True) + sparse_dit_1024.eval() + + state_dict_refiner = torch.load(model_refiner_path, map_location='cpu', weights_only=True) + refiner = instantiate_from_config(cfg.refiner) + refiner.load_state_dict(state_dict_refiner["refiner"], strict=True) + refiner.eval() + + dense_image_encoder = instantiate_from_config(cfg.dense_image_encoder) + sparse_image_encoder = instantiate_from_config(cfg.sparse_image_encoder) + + dense_scheduler = instantiate_from_config(cfg.dense_scheduler) + sparse_scheduler_512 = instantiate_from_config(cfg.sparse_scheduler_512) + sparse_scheduler_1024 = instantiate_from_config(cfg.sparse_scheduler_1024) + + return cls( + dense_vae=dense_vae, + dense_dit=dense_dit, + sparse_vae_512=sparse_vae_512, + sparse_dit_512=sparse_dit_512, + sparse_vae_1024=sparse_vae_1024, + sparse_dit_1024=sparse_dit_1024, + dense_image_encoder=dense_image_encoder, + sparse_image_encoder=sparse_image_encoder, + dense_scheduler=dense_scheduler, + sparse_scheduler_512=sparse_scheduler_512, + sparse_scheduler_1024=sparse_scheduler_1024, + refiner=refiner, + ) + + def preprocess(self, image): + if image.mode == 'RGBA': + image = np.array(image) + else: + if getattr(self, 'birefnet_model', None) is None: + from direct3d_s2.utils import BiRefNet + self.birefnet_model = BiRefNet(self.device) + image = self.birefnet_model.run(image) + image = preprocess_image(image) + return image + + def prepare_image(self, image: Union[str, List[str], Image.Image, List[Image.Image]]): + if not isinstance(image, list): + image = [image] + if isinstance(image[0], str): + image = [Image.open(img) for img in image] + image = [self.preprocess(img) for img in image] + image = torch.stack([img for img in image]).to(self.device) + return image + + def encode_image(self, image: torch.Tensor, conditioner: Any, + do_classifier_free_guidance: bool = True, use_mask: bool = False): + if use_mask: + cond = conditioner(image[:, :3], image[:, 3:]) + else: + cond = conditioner(image[:, :3]) + + if isinstance(cond, tuple): + cond, cond_mask = cond + cond, cond_coords = extract_tokens_and_coords(cond, cond_mask) + else: + cond_mask, cond_coords = None, None + + if do_classifier_free_guidance: + uncond = torch.zeros_like(cond) + else: + uncond = None + + if cond_coords is not None: + cond = sp.SparseTensor(cond, cond_coords.int()) + if uncond is not None: + uncond = sp.SparseTensor(uncond, cond_coords.int()) + + return cond, uncond + + def inference( + self, + image, + vae, + dit, + conditioner, + scheduler, + num_inference_steps: int = 30, + guidance_scale: int = 7.0, + generator: Optional[torch.Generator] = None, + latent_index: torch.Tensor = None, + mode: str = 'dense', # 'dense', 'sparse512' or 'sparse1024 + remove_interior: bool = False, + mc_threshold: float = 0.02): + + do_classifier_free_guidance = guidance_scale > 0 + if mode == 'dense': + sparse_conditions = False + else: + sparse_conditions = dit.sparse_conditions + cond, uncond = self.encode_image(image, conditioner, + do_classifier_free_guidance, sparse_conditions) + batch_size = cond.shape[0] + + if mode == 'dense': + latent_shape = (batch_size, *dit.latent_shape) + else: + latent_shape = (len(latent_index), dit.out_channels) + latents = torch.randn(latent_shape, dtype=self.dtype, device=self.device, generator=generator) + + scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = scheduler.timesteps + + extra_step_kwargs = { + "generator": generator + } + + for i, t in enumerate(tqdm(timesteps, desc=f"{mode} Sampling:")): + latent_model_input = latents + timestep_tensor = torch.tensor([t], dtype=latent_model_input.dtype, device=self.device) + + if mode == 'dense': + x_input = latent_model_input + elif mode in ['sparse512', 'sparse1024']: + x_input = sp.SparseTensor(latent_model_input, latent_index.int()) + + diffusion_inputs = { + "x": x_input, + "t": timestep_tensor, + "cond": cond, + } + + noise_pred_cond = dit(**diffusion_inputs) + if mode != 'dense': + noise_pred_cond = noise_pred_cond.feats + + if do_classifier_free_guidance: + diffusion_inputs["cond"] = uncond + noise_pred_uncond = dit(**diffusion_inputs) + if mode != 'dense': + noise_pred_uncond = noise_pred_uncond.feats + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + latents = 1. / vae.latents_scale * latents + vae.latents_shift + + if mode != 'dense': + latents = sp.SparseTensor(latents, latent_index.int()) + + decoder_inputs = { + "latents": latents, + "mc_threshold": mc_threshold, + } + if mode == 'dense': + decoder_inputs['return_index'] = True + elif remove_interior: + decoder_inputs['return_feat'] = True + if mode == 'sparse1024': + decoder_inputs['voxel_resolution'] = 1024 + + outputs = vae.decode_mesh(**decoder_inputs) + + if remove_interior: + del latents, noise_pred, noise_pred_cond, noise_pred_uncond, x_input, cond, uncond + torch.cuda.empty_cache() + outputs = self.refiner.run(*outputs, mc_threshold=mc_threshold*2.0) + + return outputs + + @torch.no_grad() + def __call__( + self, + image: Union[str, List[str], Image.Image, List[Image.Image]] = None, + sdf_resolution: int = 1024, + dense_sampler_params: dict = {'num_inference_steps': 50, 'guidance_scale': 7.0}, + sparse_512_sampler_params: dict = {'num_inference_steps': 30, 'guidance_scale': 7.0}, + sparse_1024_sampler_params: dict = {'num_inference_steps': 15, 'guidance_scale': 7.0}, + generator: Optional[torch.Generator] = None, + remesh: bool = False, + simplify_ratio: float = 0.95, + mc_threshold: float = 0.2): + + image = self.prepare_image(image) + + latent_index = self.inference(image, self.dense_vae, self.dense_dit, self.dense_image_encoder, + self.dense_scheduler, generator=generator, mode='dense', mc_threshold=0.1, **dense_sampler_params)[0] + + latent_index = sort_block(latent_index, self.sparse_dit_512.selection_block_size) + + torch.cuda.empty_cache() + + if sdf_resolution == 512: + remove_interior = False + else: + remove_interior = True + + mesh = self.inference(image, self.sparse_vae_512, self.sparse_dit_512, + self.sparse_image_encoder, self.sparse_scheduler_512, + generator=generator, mode='sparse512', + mc_threshold=mc_threshold, latent_index=latent_index, + remove_interior=remove_interior, **sparse_512_sampler_params)[0] + + if sdf_resolution == 1024: + del latent_index + torch.cuda.empty_cache() + mesh = normalize_mesh(mesh) + latent_index = mesh2index(mesh, size=1024, factor=8) + latent_index = sort_block(latent_index, self.sparse_dit_1024.selection_block_size) + print(f"number of latent tokens: {len(latent_index)}") + + mesh = self.inference(image, self.sparse_vae_1024, self.sparse_dit_1024, + self.sparse_image_encoder, self.sparse_scheduler_1024, + generator=generator, mode='sparse1024', + mc_threshold=mc_threshold, latent_index=latent_index, + **sparse_1024_sampler_params)[0] + + if remesh: + import trimesh + from direct3d_s2.utils import postprocess_mesh + filled_mesh = postprocess_mesh( + vertices=mesh.vertices, + faces=mesh.faces, + simplify=True, + simplify_ratio=simplify_ratio, + verbose=True, + ) + mesh = trimesh.Trimesh(filled_mesh[0], filled_mesh[1]) + + outputs = {"mesh": mesh} + + return outputs + \ No newline at end of file diff --git a/direct3d_s2/utils/__init__.py b/direct3d_s2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e46f0ec24c6f4da95eb3a9b8df5a28ae1377c036 --- /dev/null +++ b/direct3d_s2/utils/__init__.py @@ -0,0 +1,6 @@ +from .util import instantiate_from_config, get_obj_from_str +from .image import preprocess_image +from .rembg import BiRefNet +from .sparse import sort_block, extract_tokens_and_coords +from .mesh import mesh2index, normalize_mesh +from .fill_hole import postprocess_mesh \ No newline at end of file diff --git a/direct3d_s2/utils/__pycache__/__init__.cpython-310.pyc b/direct3d_s2/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ce8ae171e4f700ec661f50952122d1fc6241a85 Binary files /dev/null and b/direct3d_s2/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/direct3d_s2/utils/__pycache__/fill_hole.cpython-310.pyc b/direct3d_s2/utils/__pycache__/fill_hole.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf3af72c632aac5aa27083aba8fbd70d58482cd6 Binary files /dev/null and b/direct3d_s2/utils/__pycache__/fill_hole.cpython-310.pyc differ diff --git a/direct3d_s2/utils/__pycache__/fix_hole.cpython-310.pyc b/direct3d_s2/utils/__pycache__/fix_hole.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e96698412e82122026ca3c9a9d9a0c935ed3b30c Binary files /dev/null and b/direct3d_s2/utils/__pycache__/fix_hole.cpython-310.pyc differ diff --git a/direct3d_s2/utils/__pycache__/image.cpython-310.pyc b/direct3d_s2/utils/__pycache__/image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..053a83a090808081b97dd683ac8a7a3be2766a8c Binary files /dev/null and b/direct3d_s2/utils/__pycache__/image.cpython-310.pyc differ diff --git a/direct3d_s2/utils/__pycache__/mesh.cpython-310.pyc b/direct3d_s2/utils/__pycache__/mesh.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3f4aa988d21ca631a4b0587892bd50afae7c448 Binary files /dev/null and b/direct3d_s2/utils/__pycache__/mesh.cpython-310.pyc differ diff --git a/direct3d_s2/utils/__pycache__/rembg.cpython-310.pyc b/direct3d_s2/utils/__pycache__/rembg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c89319caf9b74ccacd6ec6cdd352b85efc07ab8 Binary files /dev/null and b/direct3d_s2/utils/__pycache__/rembg.cpython-310.pyc differ diff --git a/direct3d_s2/utils/__pycache__/sparse.cpython-310.pyc b/direct3d_s2/utils/__pycache__/sparse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..390eca94ae38c2c4a39b84bb1b2d977b55eb8a50 Binary files /dev/null and b/direct3d_s2/utils/__pycache__/sparse.cpython-310.pyc differ diff --git a/direct3d_s2/utils/__pycache__/util.cpython-310.pyc b/direct3d_s2/utils/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66a7bb9d4838016184c6fe6b7aef44615d442415 Binary files /dev/null and b/direct3d_s2/utils/__pycache__/util.cpython-310.pyc differ diff --git a/direct3d_s2/utils/fill_hole.py b/direct3d_s2/utils/fill_hole.py new file mode 100644 index 0000000000000000000000000000000000000000..48a56d979a4b367ba1cddc834c7214e947b98adc --- /dev/null +++ b/direct3d_s2/utils/fill_hole.py @@ -0,0 +1,272 @@ +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from pymeshfix import _meshfix +import igraph +import pyvista as pv + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] + +@torch.no_grad() +def _fill_holes( + verts, + faces, + max_hole_size=0.04, + max_hole_nbe=32, + resolution=128, + num_views=500, + debug=False, + verbose=False +): + """ + Rasterize a mesh from multiple views and remove invisible faces. + Also includes postprocessing to: + 1. Remove connected components that are have low visibility. + 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. + + Args: + verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). + faces (torch.Tensor): Faces of the mesh. Shape (F, 3). + max_hole_size (float): Maximum area of a hole to fill. + resolution (int): Resolution of the rasterization. + num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + # Construct cameras + yaws = [] + pitchs = [] + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views) + yaws.append(y) + pitchs.append(p) + yaws = torch.tensor(yaws).cuda() + pitchs = torch.tensor(pitchs).cuda() + radius = 2.0 + fov = torch.deg2rad(torch.tensor(40)).cuda() + projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) + views = [] + for (yaw, pitch) in zip(yaws, pitchs): + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda().float() * radius + view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + views.append(view) + views = torch.stack(views, dim=0) + + # Rasterize + visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) + rastctx = utils3d.torch.RastContext(backend='cuda') + for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'): + view = views[i] + buffers = utils3d.torch.rasterize_triangle_faces( + rastctx, verts[None].float(), faces, resolution, resolution, view=view, projection=projection + ) + face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1 + face_id = torch.unique(face_id).long() + visblity[face_id] += 1 + visblity = visblity.float() / num_views + + # Mincut + ## construct outer faces + edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) + boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) + connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge) + outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device) + for i in range(len(connected_components)): + outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) + outer_face_indices = outer_face_indices.nonzero().reshape(-1) + + ## construct inner faces + inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) + if verbose: + tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces') + if inner_face_indices.shape[0] == 0: + return verts, faces + + ## Construct dual graph (faces as nodes, edges as edges) + dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) + dual_edge2edge = edges[dual_edge2edge] + dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1) + if verbose: + tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges') + + ## solve mincut problem + ### construct main graph + g = igraph.Graph() + g.add_vertices(faces.shape[0]) + g.add_edges(dual_edges.cpu().numpy()) + g.es['weight'] = dual_edges_weights.cpu().numpy() + + ### source and target + g.add_vertex('s') + g.add_vertex('t') + + ### connect invisible faces to source + g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### connect outer faces to target + g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### solve mincut + cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist()) + remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device) + if verbose: + tqdm.write(f'Mincut solved, start checking the cut') + + ### check if the cut is valid with each connected component + to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices]) + if debug: + tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}') + valid_remove_cc = [] + cutting_edges = [] + for cc in to_remove_cc: + #### check if the connected component has low visibility + visblity_median = visblity[remove_face_indices[cc]].median() + if debug: + tqdm.write(f'visblity_median: {visblity_median}') + if visblity_median > 0.25: + continue + + #### check if the cuting loop is small enough + cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True) + cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] + cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)] + if len(cc_new_boundary_edge_indices) > 0: + cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices]) + cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc] + cc_new_boundary_edges_cc_area = [] + for i, edge_cc in enumerate(cc_new_boundary_edge_cc): + _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] + _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] + cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5) + if debug: + cutting_edges.append(cc_new_boundary_edge_indices) + tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}') + if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): + continue + + valid_remove_cc.append(cc) + + if debug: + face_v = verts[faces].mean(dim=1).cpu().numpy() + vis_dual_edges = dual_edges.cpu().numpy() + vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) + vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] + vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] + vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] + if len(valid_remove_cc) > 0: + vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0] + utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors) + + vis_verts = verts.cpu().numpy() + vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() + utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges) + + + if len(valid_remove_cc) > 0: + remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] + mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) + mask[remove_face_indices] = 0 + faces = faces[mask] + faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) + if verbose: + tqdm.write(f'Removed {(~mask).sum()} faces by mincut') + else: + if verbose: + tqdm.write(f'Removed 0 faces by mincut') + + mesh = _meshfix.PyTMesh() + mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) + mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) + verts, faces = mesh.return_arrays() + verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32) + + return verts, faces + +def postprocess_mesh( + vertices: np.array, + faces: np.array, + simplify: bool = False, + simplify_ratio: float = 0.9, + fill_holes: bool = False, + fill_holes_max_hole_size: float = 0.04, + fill_holes_max_hole_nbe: int = 32, + fill_holes_resolution: int = 1024, + fill_holes_num_views: int = 1000, + debug: bool = False, + verbose: bool = False, +): + """ + Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + simplify (bool): Whether to simplify the mesh, using quadric edge collapse. + simplify_ratio (float): Ratio of faces to keep after simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_hole_size (float): Maximum area of a hole to fill. + fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. + fill_holes_resolution (int): Resolution of the rasterization. + fill_holes_num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + + if verbose: + tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Simplify + if simplify and simplify_ratio > 0: + mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) + mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) + vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] + if verbose: + tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Remove invisible faces + if fill_holes: + vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda() + vertices, faces = _fill_holes( + vertices, faces, + max_hole_size=fill_holes_max_hole_size, + max_hole_nbe=fill_holes_max_hole_nbe, + resolution=fill_holes_resolution, + num_views=fill_holes_num_views, + debug=debug, + verbose=verbose, + ) + vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() + if verbose: + tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + return vertices, faces \ No newline at end of file diff --git a/direct3d_s2/utils/image.py b/direct3d_s2/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..a2094ffa0f71d3c8d44df26cfde91371bffbad74 --- /dev/null +++ b/direct3d_s2/utils/image.py @@ -0,0 +1,109 @@ +import numpy as np +from PIL import Image +import torch +import random +from torchvision import transforms +import torchvision.transforms.functional as TF + + +def apply_joint_transforms(rgb, mask, img_size, img_aug=True, test=True): + if test: + extra_pad = 16 + else: + extra_pad = random.randint(0, 32) + W_img, H_img = rgb.size[:2] + max_HW = max(H_img, W_img) + top_pad = (max_HW - H_img) // 2 + bottom_pad = max_HW - H_img - top_pad + left_pad = (max_HW - W_img) // 2 + right_pad = max_HW - W_img - left_pad + + # 1. padding + rgb = TF.pad(rgb, (left_pad, top_pad, right_pad, bottom_pad), fill=255) + mask = TF.pad(mask, (left_pad, top_pad, right_pad, bottom_pad), fill=0) + + if img_aug and (not test): + # 2. random rotate + if random.random() < 0.1: + angle = random.uniform(-10, 10) + rgb = TF.rotate(rgb, angle, fill=255) + mask = TF.rotate(mask, angle, fill=0) + + # 3. random crop + if random.random() < 0.1: + crop_ratio = random.uniform(0.9, 1.0) + crop_size = int(max_HW * crop_ratio) + i, j, h, w = transforms.RandomCrop.get_params(rgb, (crop_size, crop_size)) + rgb = TF.crop(rgb, i, j, h, w) + mask = TF.crop(mask, i, j, h, w) + + # 4. resize + target_size = (img_size, img_size) + rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR) + mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST) + + # 5. extra padding + rgb = TF.pad(rgb, extra_pad, fill=255) + mask = TF.pad(mask, extra_pad, fill=0) + rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR) + mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST) + + # to tensor + rgb_tensor = TF.to_tensor(rgb) + mask_tensor = TF.to_tensor(mask) + + return rgb_tensor, mask_tensor + +def crop_recenter(image_no_bg, thereshold=100): + image_no_bg_np = np.array(image_no_bg) + mask = (image_no_bg_np[..., -1]).astype(np.uint8) + mask_bin = mask > thereshold + + H, W = image_no_bg_np.shape[:2] + + valid_pixels = mask_bin.astype(np.float32).nonzero() # [N, 2] + if np.sum(mask_bin) < (H*W) * 0.001: + min_h =0 + max_h = H - 1 + min_w = 0 + max_w = W -1 + else: + min_h, max_h = valid_pixels[0].min(), valid_pixels[0].max() + min_w, max_w = valid_pixels[1].min(), valid_pixels[1].max() + + if min_h < 0: + min_h = 0 + if min_w < 0: + min_w = 0 + if max_h > H: + max_h = H + if max_w > W: + max_w = W + + image_no_bg_np = image_no_bg_np[min_h:max_h+1, min_w:max_w+1] + return image_no_bg_np + +def preprocess_image(img): + + if isinstance(img, str): + img = Image.open(img) + img = np.array(img) + elif isinstance(img, Image.Image): + img = np.array(img) + + if img.shape[-1] == 3: + mask = np.ones_like(img[..., 0:1]) + img = np.concatenate([img, mask], axis=-1) + + img = crop_recenter(img, thereshold=0) / 255. + + mask = img[..., 3] + img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:]) + img = Image.fromarray((img * 255).astype(np.uint8)) + mask = Image.fromarray((mask * 255).astype(np.uint8)) + + img, mask = apply_joint_transforms(img, mask, img_size=518, + img_aug=False, test=True) + img = torch.cat([img, mask], dim=0) + return img + \ No newline at end of file diff --git a/direct3d_s2/utils/mesh.py b/direct3d_s2/utils/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..321172edfb7814e63222d2b44ff9f9ba2365323f --- /dev/null +++ b/direct3d_s2/utils/mesh.py @@ -0,0 +1,34 @@ +import torch +import numpy as np +import udf_ext + + +def compute_valid_udf(vertices, faces, dim=512, threshold=8.0): + if not faces.is_cuda or not vertices.is_cuda: + raise ValueError("Both maze and visited tensors must be CUDA tensors") + udf = torch.zeros(dim**3,device=vertices.device).int() + 10000000 + n_faces = faces.shape[0] + udf_ext.compute_valid_udf(vertices, faces, udf, n_faces, dim, threshold) + return udf.float()/10000000. + +def normalize_mesh(mesh, scale=0.95): + vertices = mesh.vertices + min_coords, max_coords = vertices.min(axis=0), vertices.max(axis=0) + dxyz = max_coords - min_coords + dist = max(dxyz) + mesh_scale = 2.0 * scale / dist + mesh_offset = -(min_coords + max_coords) / 2 + vertices = (vertices + mesh_offset) * mesh_scale + mesh.vertices = vertices + return mesh + +def mesh2index(mesh, size=1024, factor=8): + vertices = torch.Tensor(mesh.vertices).float().cuda() * 0.5 + faces = torch.Tensor(mesh.faces).int().cuda() + sdf = compute_valid_udf(vertices, faces, dim=size, threshold=4.0) + sdf = sdf.reshape(size, size, size).unsqueeze(0) + + sparse_index = (sdf < 4/size).nonzero() + sparse_index[..., 1:] = sparse_index[..., 1:] // factor + latent_index = torch.unique(sparse_index, dim=0) + return latent_index \ No newline at end of file diff --git a/direct3d_s2/utils/rembg.py b/direct3d_s2/utils/rembg.py new file mode 100644 index 0000000000000000000000000000000000000000..af64dccbf186afd4bfdf51e3a56cbcbae6ceef21 --- /dev/null +++ b/direct3d_s2/utils/rembg.py @@ -0,0 +1,35 @@ +import numpy as np +import torch +from torchvision import transforms + + +class BiRefNet(object): + def __init__(self, device): + from transformers import AutoModelForImageSegmentation + self.birefnet_model = AutoModelForImageSegmentation.from_pretrained( + 'ZhengPeng7/BiRefNet', + trust_remote_code=True, + ).to(device) + self.birefnet_model.eval() + self.device = device + + def run(self, image): + image = image.convert('RGB') + image_size = (1024, 1024) + transform_image = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + input_images = transform_image(image).unsqueeze(0).to(self.device) + + with torch.no_grad(): + preds = self.birefnet_model(input_images)[-1].sigmoid().cpu() + + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image.size) + mask = np.array(mask) + image = np.concatenate([np.array(image), mask[..., None]], axis=-1) + return image \ No newline at end of file diff --git a/direct3d_s2/utils/sparse.py b/direct3d_s2/utils/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..9682ed73c7f281c6717185a1067ec0e05daafdaa --- /dev/null +++ b/direct3d_s2/utils/sparse.py @@ -0,0 +1,68 @@ +import torch +import numpy as np + +def sort_block(latent_index, block_size): + device = latent_index.device + latent_index_block = latent_index.cpu().numpy() + latent_index_block[..., 1:] = latent_index_block[..., 1:] // block_size + latent_index_inblock = latent_index.cpu().numpy() + latent_index_inblock[..., 1:] = latent_index_inblock[..., 1:] % block_size + sort_index = np.lexsort(( + latent_index_inblock[..., 3], + latent_index_inblock[..., 2], + latent_index_inblock[..., 1], + latent_index_block[..., 3], + latent_index_block[..., 2], + latent_index_block[..., 1]) + ) + sort_index = torch.from_numpy(sort_index).to(device) + return latent_index[sort_index] + +def extract_tokens_and_coords(conditions, token_mask, num_cls=1, num_reg=4): + device = conditions.device + B = conditions.size(0) + patch_size = token_mask.size(1) + + class_tokens = conditions[:, 0:num_cls, :] # [B, 1, 1024] + register_tokens = conditions[:, num_cls:num_cls+num_reg, :] # [B, 4, 1024] + patch_tokens = conditions[:, num_cls+num_reg:, :] # [B, 1369, 1024] + + selected_tokens_list = [] + coords_list = [] + + for batch_idx in range(B): + cls_tokens = class_tokens[batch_idx] # [1, 1024] + reg_tokens = register_tokens[batch_idx] # [4, 1024] + cls_reg_tokens = torch.cat([cls_tokens, reg_tokens], dim=0) # [5, 1024] + + cls_coord = torch.tensor([[batch_idx, 0, 0, 1]] * num_cls, device=device) + reg_coords = torch.tensor([[batch_idx, 0, 0, 1]] * num_reg, device=device) + cls_reg_coords = torch.cat([cls_coord, reg_coords], dim=0) + + mask = token_mask[batch_idx] + pos = mask.nonzero(as_tuple=False) + K = pos.size(0) + + if K > 0: + h, w = pos[:, 0], pos[:, 1] + indices = h * patch_size + w # + patches = patch_tokens[batch_idx][indices] + + batch_ids = torch.full((K, 1), batch_idx, device=device) + x = w.unsqueeze(1) + y = h.unsqueeze(1) + patch_coords = torch.cat([batch_ids, x, y, torch.zeros((K, 1), device=device)], dim=1) + + combined_tokens = torch.cat([cls_reg_tokens, patches], dim=0) + combined_coords = torch.cat([cls_reg_coords, patch_coords], dim=0) + else: + combined_tokens = cls_reg_tokens + combined_coords = cls_reg_coords + + selected_tokens_list.append(combined_tokens) + coords_list.append(combined_coords) + + selected_tokens = torch.cat(selected_tokens_list, dim=0) + coords = torch.cat(coords_list, dim=0) + + return selected_tokens, coords \ No newline at end of file diff --git a/direct3d_s2/utils/util.py b/direct3d_s2/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..51bd0eaa45994f0a3afe18040854288064535e7a --- /dev/null +++ b/direct3d_s2/utils/util.py @@ -0,0 +1,19 @@ +import importlib + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3ad4f7d80bfe6190ac69ec3585246a9360587c04 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +scikit-image +trimesh +omegaconf +tqdm +huggingface_hub +einops +numpy +transformers==4.40.2 +diffusers +triton==3.1.0 +flash-attn --no-build-isolation +pymeshfix +git+https://github.com/EasternJournalist/utils3d.git#egg=utils3d +pyvista +igraph +git+https://github.com/mit-han-lab/torchsparse.git +third_party/voxelize/ \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c8db6251e3ed2b257ea34ef3c4ad056e4b14e515 --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ +from setuptools import setup, find_packages + + +setup( + name="direct3d_s2", + version="1.0.0", + description="Direct3D-S2: Gigascale 3D Generation Made Easy with Spatial Sparse Attention", + packages=find_packages(), + python_requires=">=3.10", + install_requires=[ + "torch", + "numpy", + "cython", + "trimesh", + "diffusers", + "triton", + ], +) \ No newline at end of file diff --git a/third_party/voxelize/build/lib.linux-x86_64-3.10/udf_ext.cpython-310-x86_64-linux-gnu.so b/third_party/voxelize/build/lib.linux-x86_64-3.10/udf_ext.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..02b864c83d160429ce5c8caab893df304a2a7316 --- /dev/null +++ b/third_party/voxelize/build/lib.linux-x86_64-3.10/udf_ext.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dba5cb992519f31d01d20161f44fed88ec5bb7f4000a0efdf2f344222e54095d +size 9334072 diff --git a/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_deps b/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_deps new file mode 100644 index 0000000000000000000000000000000000000000..e7ebba6788c4ab705a00bb27ea9c9b5c4707ad2c --- /dev/null +++ b/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_deps @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bfaad9b34926f6b0320aa5208865a236d2fabab01194c00662fa1f8c5c123474 +size 557356 diff --git a/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_log b/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_log new file mode 100644 index 0000000000000000000000000000000000000000..5d2d9d37878743c6735597258366c8cedd790ec2 --- /dev/null +++ b/third_party/voxelize/build/temp.linux-x86_64-3.10/.ninja_log @@ -0,0 +1,7 @@ +# ninja log v5 +1 10231 1748533567318943487 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_part/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o 974b8952d6695070 +1 23455 1748533580544026989 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_part/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o 451cfe46e7f34448 +94 9916 1748614488097090921 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_part/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o 974b8952d6695070 +94 22534 1748614500786335832 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_part/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o 451cfe46e7f34448 +8 9872 1748636870244409619 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o 52db95a09a5c3658 +8 21856 1748636882230425289 /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o e9d1e10a200931c1 diff --git a/third_party/voxelize/build/temp.linux-x86_64-3.10/build.ninja b/third_party/voxelize/build/temp.linux-x86_64-3.10/build.ninja new file mode 100644 index 0000000000000000000000000000000000000000..64c09e77d7ad90229414c2242f6813329c9f3e56 --- /dev/null +++ b/third_party/voxelize/build/temp.linux-x86_64-3.10/build.ninja @@ -0,0 +1,33 @@ +ninja_required_version = 1.3 +cxx = c++ +nvcc = /usr/local/cuda/bin/nvcc + +cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.10 -c +post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=udf_ext -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17 +cuda_cflags = -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.10 -c +cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=udf_ext -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_52,code=sm_52 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_90,code=sm_90 -std=c++17 +cuda_dlink_post_cflags = +ldflags = + +rule compile + command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags + depfile = $out.d + deps = gcc + +rule cuda_compile + depfile = $out.d + deps = gcc + command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags + + + + + +build /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o: compile /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/src/udf_cuda.cpp +build /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o: cuda_compile /opt/ml/input/data/CFS/wushuang/research/TOG2025/open_source/Direct3D-S2/third_party/voxelize/src/udf_kernel.cu + + + + + + diff --git a/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o b/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..e233bb97c7ecbccb7f9a9b86e1afd267089e6e51 --- /dev/null +++ b/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_cuda.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36e81102ff162bc0cbbb09814def88484f0d4ff68fd6aec52c5e5ee10895a3c3 +size 13639720 diff --git a/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o b/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o new file mode 100644 index 0000000000000000000000000000000000000000..46bfc650dc8c076f035aed2a36d4901b555954fe --- /dev/null +++ b/third_party/voxelize/build/temp.linux-x86_64-3.10/src/udf_kernel.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d90cd327d8506bd0db2654081633bb519f8367a80c8841a685e79f127018a2ca +size 234112 diff --git a/third_party/voxelize/setup.py b/third_party/voxelize/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..2c68d1aa3b6feb7dae5b468b690c2d2f5a6a9a20 --- /dev/null +++ b/third_party/voxelize/setup.py @@ -0,0 +1,16 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='udf_ext', + ext_modules=[ + CUDAExtension('udf_ext', [ + 'src/udf_kernel.cu', + 'src/udf_cuda.cpp' + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + } +) + diff --git a/third_party/voxelize/src/udf_cuda.cpp b/third_party/voxelize/src/udf_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..877772fd6372be13e04c7ddf757175c0213e56d4 --- /dev/null +++ b/third_party/voxelize/src/udf_cuda.cpp @@ -0,0 +1,13 @@ +#include + +void compute_valid_udf_cuda(float* vertices, int* faces, int* udf, const int numTriangles, const int DIM=512, const float threshold=8); + +extern "C" +void compute_valid_udf_wrapper(torch::Tensor vertices, torch::Tensor faces, torch::Tensor udf, const int numTriangles, const int DIM=512, const float threshold=8.0) { + compute_valid_udf_cuda(vertices.data_ptr(), faces.data_ptr(), udf.data_ptr(), numTriangles, DIM, threshold); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("compute_valid_udf", &compute_valid_udf_wrapper, "Compute UDF using CUDA"); +} + diff --git a/third_party/voxelize/src/udf_kernel.cu b/third_party/voxelize/src/udf_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2488606605b8185ec2ffdacbc59152d32767aa66 --- /dev/null +++ b/third_party/voxelize/src/udf_kernel.cu @@ -0,0 +1,214 @@ +#include +#include +#include + + +struct Point3D { + float x, y, z; +}; + +struct Triangle { + Point3D v0, v1, v2; +}; +__device__ Point3D cross(const Point3D& v1, const Point3D& v2) { + Point3D result; + result.x = v1.y * v2.z - v1.z * v2.y; + result.y = v1.z * v2.x - v1.x * v2.z; + result.z = v1.x * v2.y - v1.y * v2.x; + return result; +} + +// Compute the dot product of two vectors +__device__ float dot(const Point3D& v1, const Point3D& v2) { + return v1.x * v2.x + v1.y * v2.y + v1.z * v2.z; +} + +// Subtract two 3D points (vector subtraction) +__device__ Point3D subtract(const Point3D& p1, const Point3D& p2) { + Point3D result = {p1.x - p2.x, p1.y - p2.y, p1.z - p2.z}; + return result; +} +__device__ float magnitude(const Point3D &v) { + return sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); +} +__device__ bool is_identical(const Point3D & p1, const Point3D & p2){ + Point3D check = subtract(p1, p2); + if(check.x==0 && check.y == 0 && check.z == 0) + return true; + return false; +} + +// Compute the squared distance between two points +__device__ float squaredDistance(const Point3D& p1, const Point3D& p2) { + return (p1.x - p2.x) * (p1.x - p2.x) + + (p1.y - p2.y) * (p1.y - p2.y) + + (p1.z - p2.z) * (p1.z - p2.z); +} +__device__ Point3D normalize(Point3D v){ + float len = sqrtf(dot(v, v)); + if (len ==0) + return v; + float scale = 1 / len; + Point3D result = {v.x * scale, v.y * scale, v.z * scale}; + return result; +} + +__device__ float point_to_line_distance(const Point3D &p, const Point3D &v0, const Point3D &v1) { + // Direction vector of the line + Point3D d = subtract(v1, v0); + + // Vector from v0 to point p + Point3D v0_to_p = subtract(p, v0); + + // Scalar projection of v0_to_p onto the direction vector d + float t = dot(v0_to_p, d) / dot(d, d); + + Point3D closest_point; + + // Check where the projection falls + if (t < 0) { + // Projection falls before v0, so the closest point is v0 + closest_point = v0; + } else if (t > 1) { + // Projection falls beyond v1, so the closest point is v1 + closest_point = v1; + } else { + // Projection falls within the segment, compute the projection point + closest_point.x = v0.x + t * d.x; + closest_point.y = v0.y + t * d.y; + closest_point.z = v0.z + t * d.z; + } + + // Calculate the distance between p and the closest point + Point3D closest_to_p = subtract(p, closest_point); + return magnitude(closest_to_p); +} + +// Compute the distance between a point and a triangle face +__device__ float pointToTriangleDistance(const Point3D& queryPoint, const Point3D& v0, const Point3D& v1, const Point3D& v2, bool inverse=false) { + // Edge vectors + Point3D edge0 = subtract(v1, v0); + Point3D edge1 = subtract(v2, v0); + if (is_identical(v0, v1) && is_identical(v0, v2)) + return sqrtf(squaredDistance(queryPoint, v0)); + if (is_identical(v0, v1)) + return point_to_line_distance(queryPoint, v0, v2); + if (is_identical(v0, v2)) + return point_to_line_distance(queryPoint, v0, v1); + if (is_identical(v1, v2)) + return point_to_line_distance(queryPoint, v0, v1); + // Normal vector to the triangle plane + Point3D normal = cross(edge0, edge1); + if (inverse) + normal = cross(edge1, edge0); + + // Vector from v0 to queryPoint + Point3D queryVec = subtract(queryPoint, v0); + if (dot(normal, normal)==0) + return sqrtf(dot(queryVec, queryVec)); + normal = normalize(normal); + //return 1.0; + + // Project the query point onto the triangle's plane + float distanceToPlane = dot(normal, queryVec); // / sqrtf(dot(normal, normal)); + +// return fabsf(distanceToPlane); + Point3D projectionPoint = { + queryPoint.x - distanceToPlane * normal.x, + queryPoint.y - distanceToPlane * normal.y, + queryPoint.z - distanceToPlane * normal.z + }; + // Check if the projection point is inside the triangle using barycentric coordinates + edge0 = subtract(v0, v1); + edge1 = subtract(v1, v2); + Point3D edge2 = subtract(v2, v0); + Point3D projVec0 = subtract(v0, projectionPoint); + Point3D projVec1 = subtract(v1, projectionPoint); + Point3D projVec2 = subtract(v2, projectionPoint); + Point3D c0 = cross(edge0, projVec0); + Point3D c1 = cross(edge1, projVec1); + Point3D c2 = cross(edge2, projVec2); + if (dot(c0, c1) > 0 && dot(c1, c2) > 0 && dot(c0, c2) > 0) + return fabsf(distanceToPlane); + + // Otherwise, return the minimum distance to the triangle's edges + float minEdgeDistance = 1e6f; + minEdgeDistance = fmin(minEdgeDistance, point_to_line_distance(queryPoint, v0, v1)); + minEdgeDistance = fmin(minEdgeDistance, point_to_line_distance(queryPoint, v0, v2)); + minEdgeDistance = fmin(minEdgeDistance, point_to_line_distance(queryPoint, v1, v2)); + + + return minEdgeDistance; +} + + +__device__ void updateUDF(Triangle t, int* udf, const int DIM, const float threshold) { + // Compute the bounding box of the triangle + float minX = fminf(fminf(t.v0.x, t.v1.x), t.v2.x); + float minY = fminf(fminf(t.v0.y, t.v1.y), t.v2.y); + float minZ = fminf(fminf(t.v0.z, t.v1.z), t.v2.z); + float maxX = fmaxf(fmaxf(t.v0.x, t.v1.x), t.v2.x); + float maxY = fmaxf(fmaxf(t.v0.y, t.v1.y), t.v2.y); + float maxZ = fmaxf(fmaxf(t.v0.z, t.v1.z), t.v2.z); + + // Convert bounding box to grid coordinates + int iMin = max(0, (int)floorf((minX + 0.5) * (DIM-1))); + int jMin = max(0, (int)floorf((minY + 0.5) * (DIM-1))); + int kMin = max(0, (int)floorf((minZ + 0.5) * (DIM-1))); + int iMax = min(DIM - 1, (int)floorf((maxX + 0.5) * (DIM-1))); + int jMax = min(DIM - 1, (int)floorf((maxY + 0.5) * (DIM-1))); + int kMax = min(DIM - 1, (int)floorf((maxZ + 0.5) * (DIM-1))); + + int range = (int)(threshold + 1); + + // Make the bounding box larger than the original + iMax = min(DIM - 1, iMax + range); + iMin = max(0, iMin - range); + jMax = min(DIM - 1, jMax + range); + jMin = max(0, jMin - range); + kMax = min(DIM - 1, kMax + range); + kMin = max(0, kMin - range); + + // Update the valid grids within the bounding box + for (int i = iMin; i <= iMax; ++i) { + for (int j = jMin; j <= jMax; ++j) { + for (int k = kMin; k <= kMax; ++k) { + int idx = i * DIM * DIM + j * DIM + k; + + // Compute the distance from the query point to the triangle + Point3D queryPoint = {(float)i/(DIM-1) - 0.5, (float)j/(DIM-1) - 0.5, (float)k/(DIM-1) -0.5}; + float distance = pointToTriangleDistance(queryPoint, t.v0, t.v1, t.v2); + float distance2 = pointToTriangleDistance(queryPoint, t.v0, t.v1, t.v2, true); + if (distance < threshold / DIM or distance2 < threshold / DIM){ + //distance = distance2; + int int_dist = (int)(distance * 10000000); + atomicMin(&udf[idx], int_dist); + } + } + + } + } +} + +__global__ void compute_udf_kernel(float* vertices, int* faces, int * udf, int numTriangles, const int DIM, const float threshold) { + int t = blockIdx.x * blockDim.x + threadIdx.x; + if (t < numTriangles) { + int f0 = faces[t * 3 + 0]; + int f1 = faces[t * 3 + 1]; + int f2 = faces[t * 3 + 2]; + Point3D v0 = {vertices[f0 * 3 + 0], vertices[f0 * 3 + 1], vertices[f0 * 3 + 2]}; + Point3D v1 = {vertices[f1 * 3 + 0], vertices[f1 * 3 + 1], vertices[f1 * 3 + 2]}; + Point3D v2 = {vertices[f2 * 3 + 0], vertices[f2 * 3 + 1], vertices[f2 * 3 + 2]}; + Triangle triangle = {v0, v1, v2}; + updateUDF(triangle, udf, DIM, threshold); + } +} + +void compute_valid_udf_cuda(float* vertices, int* faces, int* udf, int numTriangles, const int DIM=512, const float threshold=8) { + int blockSize = 256; + int gridSize = (numTriangles + blockSize - 1) / blockSize; + + // Launch the kernel + compute_udf_kernel<<>>(vertices, faces, udf, numTriangles, DIM, threshold); +} + diff --git a/third_party/voxelize/udf_ext.cpython-310-x86_64-linux-gnu.so b/third_party/voxelize/udf_ext.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..02b864c83d160429ce5c8caab893df304a2a7316 --- /dev/null +++ b/third_party/voxelize/udf_ext.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dba5cb992519f31d01d20161f44fed88ec5bb7f4000a0efdf2f344222e54095d +size 9334072 diff --git a/third_party/voxelize/udf_ext.egg-info/PKG-INFO b/third_party/voxelize/udf_ext.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..21cdbe2d582ff6f92783c01ac46ffd62f8c1a480 --- /dev/null +++ b/third_party/voxelize/udf_ext.egg-info/PKG-INFO @@ -0,0 +1,8 @@ +Metadata-Version: 2.1 +Name: udf-ext +Version: 0.0.0 +Summary: UNKNOWN +License: UNKNOWN +Platform: UNKNOWN + +UNKNOWN diff --git a/third_party/voxelize/udf_ext.egg-info/SOURCES.txt b/third_party/voxelize/udf_ext.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..c8f3f0e01b4cc56421b16b394c3ad796ef613412 --- /dev/null +++ b/third_party/voxelize/udf_ext.egg-info/SOURCES.txt @@ -0,0 +1,7 @@ +setup.py +src/udf_cuda.cpp +src/udf_kernel.cu +udf_ext.egg-info/PKG-INFO +udf_ext.egg-info/SOURCES.txt +udf_ext.egg-info/dependency_links.txt +udf_ext.egg-info/top_level.txt \ No newline at end of file diff --git a/third_party/voxelize/udf_ext.egg-info/dependency_links.txt b/third_party/voxelize/udf_ext.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/third_party/voxelize/udf_ext.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/third_party/voxelize/udf_ext.egg-info/top_level.txt b/third_party/voxelize/udf_ext.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..07a52f59f2f35df1f9db761ec17412a7406b9802 --- /dev/null +++ b/third_party/voxelize/udf_ext.egg-info/top_level.txt @@ -0,0 +1 @@ +udf_ext