diff --git a/UniRig/.gitattributes b/UniRig/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..520ec4e3434c3d3ba2d43125fe86933ec680c3ba
--- /dev/null
+++ b/UniRig/.gitattributes
@@ -0,0 +1,5 @@
+examples/*.glb filter=lfs diff=lfs merge=lfs -text
+examples/*.fbx filter=lfs diff=lfs merge=lfs -text
+examples/*.vrm filter=lfs diff=lfs merge=lfs -text
+examples/*.FBX filter=lfs diff=lfs merge=lfs -text
+examples/*.obj filter=lfs diff=lfs merge=lfs -text
diff --git a/UniRig/.gitignore b/UniRig/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..474041aa2dbf654480297bd340b6e5fb3a57d24b
--- /dev/null
+++ b/UniRig/.gitignore
@@ -0,0 +1,59 @@
+# igonore all pychace
+**/__pycache__/
+*.py[cod]
+*$py.class
+
+# ignore tmp & output files
+_data/
+tmp/
+*.npz
+*.blend
+*.blend1
+*.blend2
+
+# ignore logs
+wandb/
+lightning_logs/
+*.log
+
+# ignore experiments
+experiments/
+results/
+dataset_clean/
+logs/
+datalist/
+dataset_inference/
+dataset_inference_clean/
+feature_viz/
+
+# Distribution / packaging
+dist/
+build/
+*.egg-info/
+*.egg
+*.whl
+
+# Virtual environments
+venv/
+env/
+.env/
+.venv/
+
+# IDE specific files
+.idea/
+.vscode/
+*.swp
+*.swo
+.DS_Store
+
+# Jupyter Notebook
+.ipynb_checkpoints
+*.ipynb
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+coverage.xml
+*.cover
\ No newline at end of file
diff --git a/UniRig/LICENSE b/UniRig/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..a20d482774136ee1858d92964e896031598902b4
--- /dev/null
+++ b/UniRig/LICENSE
@@ -0,0 +1,21 @@
+# MIT License
+
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE
diff --git a/UniRig/README.md b/UniRig/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b222e6161e9bf181d27cad9041c9d27c640224a6
--- /dev/null
+++ b/UniRig/README.md
@@ -0,0 +1,164 @@
+# UniRig: One Model to Rig Them All
+
+
+
+[](https://zjp-shadow.github.io/works/UniRig/)
+[](https://arxiv.org/abs/2504.12451)
+[](https://huggingface.co/VAST-AI/UniRig)
+
+
+
+
+
+This repository contains the official implementation for the **SIGGRAPH'25 (TOG) UniRig** framework, a unified solution for automatic 3D model rigging, developed by Tsinghua University and [Tripo](https://www.tripo3d.ai).
+
+**Paper:** [One Model to Rig Them All: Diverse Skeleton Rigging with UniRig](https://arxiv.org/abs/2504.12451)
+
+## Overview
+
+Rigging 3D models – creating a skeleton and assigning skinning weights – is a crucial but often complex and time-consuming step in 3D animation. UniRig tackles this challenge by introducing a novel, unified framework leveraging large autoregressive models to automate the process for a diverse range of 3D assets.
+
+Combining UniRig with keyframe animation produces these following results:
+
+|  |  |  |
+|:-----------------------------:|:-------------------------------:|:-------------------------------:|
+
+The full UniRig system consists of two main stages:
+1. **Skeleton Prediction:** An GPT-like transformer autoregressively predicts a topologically valid skeleton hierarchy using a novel **Skeleton Tree Tokenization** scheme.
+2. **Skinning Weight & Attribute Prediction:** A **Bone-Point Cross Attention** mechanism predicts per-vertex skinning weights and relevant bone attributes (e.g., for physics simulation) based on the predicted skeleton and input mesh geometry.
+
+This repository provides the code implementation for the entire framework vision, with components being released progressively.
+
+## Key Features (Full UniRig Framework)
+
+* **Unified Model:** Aims to handle diverse model categories (humans, animals, objects) with a single framework.
+* **Automated Skeleton Generation:** Predicts topologically valid skeleton structures. **(✅ Available in current release)**
+* **Automated Skinning Prediction:** Predicts per-vertex skinning weights. **(✅ Available in current release)**
+* **Bone Attribute Prediction:** Predicts attributes like stiffness for physics-based secondary motion. **(⏳ Coming Soon)**
+* **High Accuracy & Robustness:** Achieves state-of-the-art results on challenging datasets (as shown in the paper with Rig-XL/VRoid training).
+* **Efficient Tokenization:** Uses Skeleton Tree Tokenization for compact representation and efficient processing.
+* **Human-in-the-Loop Ready:** Designed to potentially support iterative refinement workflows.
+
+## 🚨 Current Release Status & Roadmap 🚨
+
+We are open-sourcing UniRig progressively. Please note the current status:
+
+**Available Now (Initial Release):**
+* ✅ **Code:** Implementation for skeleton and skinning prediction.
+* ✅ **Model:** Skeleton & Skinning Prediction checkpoint trained on [**Articulation-XL2.0**](https://huggingface.co/datasets/Seed3D/Articulation-XL2.0). Available on [Hugging Face](https://huggingface.co/VAST-AI/UniRig).
+
+**Planned Future Releases:**
+* ⏳ Release of the **Rig-XL** and **VRoid** datasets used in the paper.
+* ⏳ Full UniRig model checkpoints (Skeleton + Skinning) trained on Rig-XL/VRoid, replicating the paper's main results.
+
+We appreciate your patience as we prepare these components for release. Follow [VAST-AI-Research](https://github.com/orgs/VAST-AI-Research) announcements for updates!
+
+## Installation
+
+1. **Prerequisites:**
+ * Python 3.11
+ * PyTorch (tested with version >=2.3.1)
+
+2. **Clone the repository:**
+ ```bash
+ git clone https://github.com/VAST-AI-Research/UniRig
+ cd UniRig
+ ```
+
+3. **Set up a virtual environment (recommended):**
+ ```bash
+ conda create -n UniRig python=3.11
+ conda activate UniRig
+ ```
+
+4. **Install dependencies:**
+ ```bash
+ python -m pip install torch torchvision
+ python -m pip install -r requirements.txt
+ python -m pip install spconv-{you-cuda-version}
+ python -m pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{your-torch-version}+{your-cuda-version}.html --no-cache-dir
+ python -m pip install numpy==1.26.4
+ ```
+
+5. **Download Model Checkpoint:**
+ The currently available skeleton prediction model checkpoint is hosted on Hugging Face and will typically be downloaded automatically by the provided scripts/functions.
+
+6. **(Optional, for importing/exporting .vrm) Install the blender addon:**
+ The blender addon is modifed from [VRM-Addon-for-Blender](https://github.com/saturday06/VRM-Addon-for-Blender).
+
+ Make sure you are in the root directory of the project, then:
+ ```bash
+ python -c "import bpy, os; bpy.ops.preferences.addon_install(filepath=os.path.abspath('blender/add-on-vrm-v2.20.77_modified.zip'))"
+ ```
+
+## Usage
+
+### Skeleton Prediction (Available Now)
+
+Generate a skeleton for your 3D model using our pre-trained model. The process automatically analyzes the geometry and predicts an appropriate skeletal structure.
+
+```bash
+# Process a single file
+bash launch/inference/generate_skeleton.sh --input examples/giraffe.glb --output results/giraffe_skeleton.fbx
+
+# Process multiple files in a directory
+bash launch/inference/generate_skeleton.sh --input_dir --output_dir
+
+# Try different skeleton variations by changing the random seed
+bash launch/inference/generate_skeleton.sh --input examples/giraffe.glb --output results/giraffe_skeleton.fbx --seed 42
+```
+
+Supported input formats: `.obj`, `.fbx`, `.glb`, and `.vrm`
+
+### Skinning Weight Prediction (Available Now)
+```bash
+# Skin a single file
+bash launch/inference/generate_skin.sh --input examples/skeleton/giraffe.fbx --output results/giraffe_skin.fbx
+
+# Process multiple files in a directory
+bash launch/inference/generate_skin.sh --input_dir --output_dir
+```
+
+Note that the command above uses an **edited-version** from skeleton phase. The results may degrade significantly if the skeleton is inaccurate — for example, if tail bones or wing bones are missing. Therefore, it is recommended to refine the skeleton before performing skinning in order to achieve better results.
+
+### Merge the Predicted Results
+
+Combine the predicted skeleton with your original 3D model to create a fully rigged asset:
+
+```bash
+# Merge skeleton from skeleton prediction
+bash launch/inference/merge.sh --source results/giraffe_skeleton.fbx --target examples/giraffe.glb --output results/giraffe_rigged.glb
+
+# Or merge skin from skin prediction
+bash launch/inference/merge.sh --source results/giraffe_skin.fbx --target examples/giraffe.glb --output results/giraffe_rigged.glb
+```
+
+## Models
+
+Available models are hosted on the: https://huggingface.co/VAST-AI/UniRig
+
+## System Requirements
+
+- CUDA-enabled GPU with at least 8GB VRAM
+
+## Citation
+
+```
+@article{zhang2025unirig,
+ title={One Model to Rig Them All: Diverse Skeleton Rigging with UniRig},
+ author={Zhang, Jia-Peng and Pu, Cheng-Feng and Guo, Meng-Hao and Cao, Yan-Pei and Hu, Shi-Min},
+ journal={arXiv preprint arXiv:2504.12451},
+ year={2025}
+}
+```
+
+## Acknowledgements
+
+We would like to thank the following open-source projects and research works:
+
+- [OPT](https://huggingface.co/facebook/opt-350m) for model architecture
+- [3DShape2VecSet](https://github.com/1zb/3DShape2VecSet) for 3D shape representation
+- [SAMPart3D](https://github.com/Pointcept/SAMPart3D) and [Michelangelo](https://github.com/NeuralCarver/Michelangelo/) for shape encoder implementation
+- [Articulation-XL2.0](https://huggingface.co/datasets/Seed3D/Articulation-XL2.0) for a curated dataset
+
+We are grateful to the broader research community for their open exploration and contributions to the field of 3D generation.
diff --git a/UniRig/assets/doc/devil.gif b/UniRig/assets/doc/devil.gif
new file mode 100644
index 0000000000000000000000000000000000000000..181a003e5b4a3b61e88bdb9dd672d8e332e20a79
--- /dev/null
+++ b/UniRig/assets/doc/devil.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1a6d00b42bd25d6708df9ca5b69900865eb50cf2bfe3cc72d64d87873cd99749
+size 1831961
diff --git a/UniRig/assets/doc/dragon.gif b/UniRig/assets/doc/dragon.gif
new file mode 100644
index 0000000000000000000000000000000000000000..4dc0ac2fe070a3e8d9b3e139dffcbb52233300c7
--- /dev/null
+++ b/UniRig/assets/doc/dragon.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1459fb925b79b710d9496be5105eed48539942a47763029ae0cc855684c9a0e9
+size 1922869
diff --git a/UniRig/assets/doc/rabbit.gif b/UniRig/assets/doc/rabbit.gif
new file mode 100644
index 0000000000000000000000000000000000000000..1322b653d3a1ddbaea58d3bea9621eef486512f2
--- /dev/null
+++ b/UniRig/assets/doc/rabbit.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:124cad359767a48a4f5cf84103c5400ceb2815728c7f4c7bab7d57d0944d93a3
+size 732336
diff --git a/UniRig/assets/doc/unirig_teaser.png b/UniRig/assets/doc/unirig_teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..7cb8af8d8b071643c625f54f0a799d6834a4d0cf
--- /dev/null
+++ b/UniRig/assets/doc/unirig_teaser.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44b056c9355386b872de584b734c57c823c7e73be7d20257e01e15861204247f
+size 12417252
diff --git a/UniRig/blender/add-on-vrm-v2.20.77_modified.zip b/UniRig/blender/add-on-vrm-v2.20.77_modified.zip
new file mode 100644
index 0000000000000000000000000000000000000000..76de53479b744097e69c3a0aa2070821f6c43ee5
--- /dev/null
+++ b/UniRig/blender/add-on-vrm-v2.20.77_modified.zip
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8fe1e8b7e31ec602d9b32db33c533b03d68df747837bfad3479e1057bc9937c5
+size 1331571
diff --git a/UniRig/configs/data/quick_inference.yaml b/UniRig/configs/data/quick_inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..25f988646a99d0cdf97a72cb1ae0a7e66d98caa6
--- /dev/null
+++ b/UniRig/configs/data/quick_inference.yaml
@@ -0,0 +1,16 @@
+input_dataset_dir: &input_dataset_dir ./dataset_inference
+output_dataset_dir: &output_dataset_dir ./dataset_inference_clean
+
+predict_dataset_config:
+ shuffle: False
+ batch_size: 1
+ num_workers: 1
+ pin_memory: False
+ persistent_workers: False
+ datapath_config:
+ input_dataset_dir: *output_dataset_dir
+ use_prob: False
+ data_path:
+ inference: [
+ [./dataset_inference_clean/inference_datalist.txt, 1.0],
+ ]
\ No newline at end of file
diff --git a/UniRig/configs/model/unirig_ar_350m_1024_81920_float32.yaml b/UniRig/configs/model/unirig_ar_350m_1024_81920_float32.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7b5a83e17433bd6c9b655cd72d3c05f935a3c48a
--- /dev/null
+++ b/UniRig/configs/model/unirig_ar_350m_1024_81920_float32.yaml
@@ -0,0 +1,32 @@
+__target__: unirig_ar
+llm:
+ pretrained_model_name_or_path: facebook/opt-350m
+ n_positions: 3076
+ max_position_embeddings: 3076
+ hidden_size: 1024
+ word_embed_proj_dim: 1024
+ do_layer_norm_before: True
+ _attn_implementation: flash_attention_2
+
+mesh_encoder:
+ __target__: michelangelo_encoder
+ pretrained_path: ~
+ freeze_encoder: False
+ device: cpu
+ dtype: float32
+ num_latents: 512
+ embed_dim: 64
+ point_feats: 3
+ num_freqs: 8
+ include_pi: False
+ heads: 8
+ width: 512
+ num_encoder_layers: 16
+ use_ln_post: True
+ init_scale: 0.25
+ qkv_bias: False
+ use_checkpoint: False
+ flash: True
+ supervision_type: sdf
+ query_method: False
+ token_num: 1024
\ No newline at end of file
diff --git a/UniRig/configs/model/unirig_skin.yaml b/UniRig/configs/model/unirig_skin.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e13e1bb6e7388538af0e9b0e975ee4ad72f3ca3e
--- /dev/null
+++ b/UniRig/configs/model/unirig_skin.yaml
@@ -0,0 +1,52 @@
+__target__: unirig_skin
+
+num_train_vertex: 512 # increase this for faster speed at the cost of memory
+num_heads: 16
+feat_dim: 768
+grid_size: 0.005
+mlp_dim: 512
+num_bone_attn: 8
+num_mesh_bone_attn: 16
+bone_embed_dim: 1024
+voxel_mask: 3.0
+
+mesh_encoder:
+ # vertex groups are handled in model
+ __target__: ptv3obj
+ pretrained_path: ~
+ freeze_encoder: False
+ in_channels: 9
+ cls_mode: False
+ shuffle_orders: True
+ drop_path: 0.0
+ upcast_attention: False
+ upcast_softmax: False
+ enc_depths: [3, 3, 3, 6, 16]
+ enc_channels: [32, 64, 128, 256, 384]
+ enc_num_head: [2, 4, 8, 16, 24]
+ enable_qknorm: True
+ layer_norm: False
+ res_linear: True
+
+global_encoder:
+ __target__: michelangelo_encoder
+ pretrained_path: ~
+ freeze_encoder: False
+ device: cpu
+ dtype: float32
+ num_latents: 512
+ embed_dim: 64
+ point_feats: 3
+ num_freqs: 8
+ include_pi: False
+ heads: 8
+ width: 512
+ num_encoder_layers: 16
+ use_ln_post: True
+ init_scale: 0.25
+ qkv_bias: False
+ use_checkpoint: False
+ flash: True
+ supervision_type: sdf
+ query_method: False
+ token_num: 1024
\ No newline at end of file
diff --git a/UniRig/configs/skeleton/mixamo.yaml b/UniRig/configs/skeleton/mixamo.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6271e7b9f59eec59ae08ce77138d1a054be382c8
--- /dev/null
+++ b/UniRig/configs/skeleton/mixamo.yaml
@@ -0,0 +1,59 @@
+parts_order: [body, hand]
+
+parts:
+ body: [
+ mixamorig:Hips,
+ mixamorig:Spine,
+ mixamorig:Spine1,
+ mixamorig:Spine2,
+ mixamorig:Neck,
+ mixamorig:Head,
+ mixamorig:LeftShoulder,
+ mixamorig:LeftArm,
+ mixamorig:LeftForeArm,
+ mixamorig:LeftHand,
+ mixamorig:RightShoulder,
+ mixamorig:RightArm,
+ mixamorig:RightForeArm,
+ mixamorig:RightHand,
+ mixamorig:LeftUpLeg,
+ mixamorig:LeftLeg,
+ mixamorig:LeftFoot,
+ mixamorig:LeftToeBase,
+ mixamorig:RightUpLeg,
+ mixamorig:RightLeg,
+ mixamorig:RightFoot,
+ mixamorig:RightToeBase,
+ ]
+ hand: [
+ mixamorig:LeftHandThumb1,
+ mixamorig:LeftHandThumb2,
+ mixamorig:LeftHandThumb3,
+ mixamorig:LeftHandIndex1,
+ mixamorig:LeftHandIndex2,
+ mixamorig:LeftHandIndex3,
+ mixamorig:LeftHandMiddle1,
+ mixamorig:LeftHandMiddle2,
+ mixamorig:LeftHandMiddle3,
+ mixamorig:LeftHandRing1,
+ mixamorig:LeftHandRing2,
+ mixamorig:LeftHandRing3,
+ mixamorig:LeftHandPinky1,
+ mixamorig:LeftHandPinky2,
+ mixamorig:LeftHandPinky3,
+ mixamorig:RightHandIndex1,
+ mixamorig:RightHandIndex2,
+ mixamorig:RightHandIndex3,
+ mixamorig:RightHandThumb1,
+ mixamorig:RightHandThumb2,
+ mixamorig:RightHandThumb3,
+ mixamorig:RightHandMiddle1,
+ mixamorig:RightHandMiddle2,
+ mixamorig:RightHandMiddle3,
+ mixamorig:RightHandRing1,
+ mixamorig:RightHandRing2,
+ mixamorig:RightHandRing3,
+ mixamorig:RightHandPinky1,
+ mixamorig:RightHandPinky2,
+ mixamorig:RightHandPinky3,
+ ]
\ No newline at end of file
diff --git a/UniRig/configs/skeleton/vroid.yaml b/UniRig/configs/skeleton/vroid.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1d6f066e78686d4f3e7d9f50dae8b9ab73922aa9
--- /dev/null
+++ b/UniRig/configs/skeleton/vroid.yaml
@@ -0,0 +1,59 @@
+parts_order: [body, hand]
+
+parts:
+ body: [
+ J_Bip_C_Hips,
+ J_Bip_C_Spine,
+ J_Bip_C_Chest,
+ J_Bip_C_UpperChest,
+ J_Bip_C_Neck,
+ J_Bip_C_Head,
+ J_Bip_L_Shoulder,
+ J_Bip_L_UpperArm,
+ J_Bip_L_LowerArm,
+ J_Bip_L_Hand,
+ J_Bip_R_Shoulder,
+ J_Bip_R_UpperArm,
+ J_Bip_R_LowerArm,
+ J_Bip_R_Hand,
+ J_Bip_L_UpperLeg,
+ J_Bip_L_LowerLeg,
+ J_Bip_L_Foot,
+ J_Bip_L_ToeBase,
+ J_Bip_R_UpperLeg,
+ J_Bip_R_LowerLeg,
+ J_Bip_R_Foot,
+ J_Bip_R_ToeBase,
+ ]
+ hand: [
+ J_Bip_L_Thumb1,
+ J_Bip_L_Thumb2,
+ J_Bip_L_Thumb3,
+ J_Bip_L_Index1,
+ J_Bip_L_Index2,
+ J_Bip_L_Index3,
+ J_Bip_L_Middle1,
+ J_Bip_L_Middle2,
+ J_Bip_L_Middle3,
+ J_Bip_L_Ring1,
+ J_Bip_L_Ring2,
+ J_Bip_L_Ring3,
+ J_Bip_L_Little1,
+ J_Bip_L_Little2,
+ J_Bip_L_Little3,
+ J_Bip_R_Index1,
+ J_Bip_R_Index2,
+ J_Bip_R_Index3,
+ J_Bip_R_Thumb1,
+ J_Bip_R_Thumb2,
+ J_Bip_R_Thumb3,
+ J_Bip_R_Middle1,
+ J_Bip_R_Middle2,
+ J_Bip_R_Middle3,
+ J_Bip_R_Ring1,
+ J_Bip_R_Ring2,
+ J_Bip_R_Ring3,
+ J_Bip_R_Little1,
+ J_Bip_R_Little2,
+ J_Bip_R_Little3,
+ ]
\ No newline at end of file
diff --git a/UniRig/configs/system/ar_inference_articulationxl.yaml b/UniRig/configs/system/ar_inference_articulationxl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..87fab782cf9dd5a6970e579d1f170c0bcc67b9ba
--- /dev/null
+++ b/UniRig/configs/system/ar_inference_articulationxl.yaml
@@ -0,0 +1,14 @@
+__target__: ar
+val_interval: 1
+generate_kwargs:
+ max_new_tokens: 2048
+ num_return_sequences: 1
+ num_beams: 15
+ do_sample: True
+ top_k: 5
+ top_p: 0.95
+ repetition_penalty: 3.0
+ temperature: 1.5 # must be a float
+ no_cls: False
+ assign_cls: articulationxl
+ use_dir_cls: False
\ No newline at end of file
diff --git a/UniRig/configs/system/skin.yaml b/UniRig/configs/system/skin.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..782de0e970ece5cfb4fbdaf3078eb2af60ab70cb
--- /dev/null
+++ b/UniRig/configs/system/skin.yaml
@@ -0,0 +1,5 @@
+__target__: skin
+val_interval: 1
+val_start_from: 1
+output_path: tmp_skin
+record_res: True
\ No newline at end of file
diff --git a/UniRig/configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml b/UniRig/configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7477b9c107cb95070510745a8544e894809b57e7
--- /dev/null
+++ b/UniRig/configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml
@@ -0,0 +1,30 @@
+mode: predict
+debug: False
+experiment_name: quick_inference_skeleton_articulationxl_ar_256
+resume_from_checkpoint: experiments/skeleton/articulation-xl_quantization_256/model.ckpt
+
+components:
+ data: quick_inference
+ tokenizer: tokenizer_parts_articulationxl_256
+ transform: inference_ar_transform
+ model: unirig_ar_350m_1024_81920_float32
+ system: ar_inference_articulationxl
+ data_name: raw_data.npz
+
+writer:
+ __target__: ar
+ output_dir: ~ # export results into the same input folder
+ add_num: False
+ repeat: 1
+ export_npz: predict_skeleton
+ export_obj: skeleton
+ export_fbx: skeleton
+ # export_pc: pc
+
+trainer:
+ max_epochs: 1
+ num_nodes: 1
+ devices: 1
+ precision: bf16-mixed
+ accelerator: gpu
+ strategy: auto
\ No newline at end of file
diff --git a/UniRig/configs/task/quick_inference_unirig_skin.yaml b/UniRig/configs/task/quick_inference_unirig_skin.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c0797ef545edf5995610f2b8698340222b709f10
--- /dev/null
+++ b/UniRig/configs/task/quick_inference_unirig_skin.yaml
@@ -0,0 +1,28 @@
+mode: predict
+debug: False
+experiment_name: quick_inference_skin
+resume_from_checkpoint: experiments/skin/articulation-xl/model.ckpt
+
+components:
+ data: quick_inference
+ transform: inference_skin_transform
+ model: unirig_skin
+ system: skin
+ data_name: predict_skeleton.npz # capture data from ar phase
+
+writer:
+ __target__: skin
+ output_dir: results
+ add_num: False
+ repeat: 1
+ save_name: predict
+ export_npz: predict_skin # this must be specified if textured results are required
+ export_fbx: result_fbx
+
+trainer:
+ num_nodes: 1
+ devices: 1
+ precision: bf16-mixed
+ accelerator: gpu
+ strategy: auto
+ inference_mode: True
diff --git a/UniRig/configs/tokenizer/tokenizer_parts_articulationxl_256.yaml b/UniRig/configs/tokenizer/tokenizer_parts_articulationxl_256.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..570b810058053fdb9212977518310e1fae4719e0
--- /dev/null
+++ b/UniRig/configs/tokenizer/tokenizer_parts_articulationxl_256.yaml
@@ -0,0 +1,14 @@
+method: tokenizer_part
+num_discrete: 256
+continuous_range: [-1, 1]
+cls_token_id:
+ vroid: 0
+ mixamo: 1 # this is currently untrained, do not use it
+ articulationxl: 2
+parts_token_id:
+ body: 0
+ hand: 1
+order_config:
+ skeleton_path:
+ vroid: ./configs/skeleton/vroid.yaml
+ mixamo: ./configs/skeleton/mixamo.yaml
\ No newline at end of file
diff --git a/UniRig/configs/transform/inference_ar_transform.yaml b/UniRig/configs/transform/inference_ar_transform.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e2ca90f588c5325c5e8a2c42868d7980910131be
--- /dev/null
+++ b/UniRig/configs/transform/inference_ar_transform.yaml
@@ -0,0 +1,30 @@
+sampler_config: &sampler_config
+ method: mix
+ num_samples: 65536
+ vertex_samples: 8192
+
+tail_config: &tail_config
+ copy_joint_to_tail: False # Be careful ! If tail is important, keep it False !!!
+ connect_tail_to_unique_son: True
+
+order_config: &order_config
+ skeleton_path:
+ vroid: ./configs/skeleton/vroid.yaml
+ mixamo: ./configs/skeleton/mixamo.yaml
+
+vertex_group_config: &vertex_group_config
+
+validate_transform_config: &validate_transform_config
+ augment_config:
+ augment_affine_config:
+ normalize_into: [-1.0, 1.0]
+ random_scale_p: 0.0
+ random_scale: [1.0, 1.0]
+ random_shift_p: 0.0
+ random_shift: [0.0, 0.0]
+ tail_config: *tail_config
+ order_config: *order_config
+ vertex_group_config: *vertex_group_config
+ sampler_config: *sampler_config
+
+predict_transform_config: *validate_transform_config
\ No newline at end of file
diff --git a/UniRig/configs/transform/inference_skin_transform.yaml b/UniRig/configs/transform/inference_skin_transform.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..04284fd6d7fe312a76a0a2f55e31cc5b52e6a899
--- /dev/null
+++ b/UniRig/configs/transform/inference_skin_transform.yaml
@@ -0,0 +1,32 @@
+sampler_config: &sampler_config
+ method: mix
+ num_samples: 32768
+ vertex_samples: 8192
+
+tail_config: &tail_config
+ copy_joint_to_tail: False # Be careful ! If tail is important, keep it False !!!
+ connect_tail_to_unique_son: True
+
+order_config: &order_config
+ skeleton_path:
+ vroid: ./configs/skeleton/vroid.yaml
+ mixamo: ./configs/skeleton/mixamo.yaml
+
+predict_transform_config:
+ augment_config:
+ augment_affine_config:
+ normalize_into: [-1.0, 1.0]
+ tail_config: *tail_config
+ order_config: *order_config
+ vertex_group_config:
+ names: ['voxel_skin']
+ kwargs:
+ voxel_skin:
+ grid: 196 # increase this for better results
+ alpha: 0.5
+ link_dis: 0.00001
+ grid_query: 7
+ vertex_query: 1
+ grid_weight: 3.0
+ # mode: exp
+ sampler_config: *sampler_config
\ No newline at end of file
diff --git a/UniRig/examples/bird.glb b/UniRig/examples/bird.glb
new file mode 100644
index 0000000000000000000000000000000000000000..c82417b3c66e411ac5b7b26926b5a2329167df60
--- /dev/null
+++ b/UniRig/examples/bird.glb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb59726f598ab4a4e4c431b9789317f5b2f4252a6fb57364d929f5e1ddd7b5bb
+size 8032388
diff --git a/UniRig/examples/giraffe.glb b/UniRig/examples/giraffe.glb
new file mode 100644
index 0000000000000000000000000000000000000000..9cec09a7ea93692b61033b9de5f3f7bcab4c790c
--- /dev/null
+++ b/UniRig/examples/giraffe.glb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a947cae00b169345802c08885c1e54313c6ce93c885bacf8e37e8a1f18f9e3b
+size 6310044
diff --git a/UniRig/examples/skeleton/bird.fbx b/UniRig/examples/skeleton/bird.fbx
new file mode 100644
index 0000000000000000000000000000000000000000..7d075b3cd1093150cdf13a0b8871390e36d6f203
--- /dev/null
+++ b/UniRig/examples/skeleton/bird.fbx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:885d432850506ab673d509ae2481067cc623452d5910090e1e15f323b9f83fa2
+size 401084
diff --git a/UniRig/examples/skeleton/giraffe.fbx b/UniRig/examples/skeleton/giraffe.fbx
new file mode 100644
index 0000000000000000000000000000000000000000..805b4bc0267f94425f2d8c472ce74f09cf44e048
--- /dev/null
+++ b/UniRig/examples/skeleton/giraffe.fbx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73673a15c8103fbf9a6b39762768ae48e7c9404eab4850a60e4863ee400336fd
+size 759180
diff --git a/UniRig/examples/skeleton/tira.fbx b/UniRig/examples/skeleton/tira.fbx
new file mode 100644
index 0000000000000000000000000000000000000000..61ad187c7ed805ee6ac521bdf2ec6c53cbe94235
--- /dev/null
+++ b/UniRig/examples/skeleton/tira.fbx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57ed6969266d642da2d8a579059462c7cd4a36c9bd7a8415236dcbea36607fee
+size 1694668
diff --git a/UniRig/examples/skeleton/tripo_carrot.fbx b/UniRig/examples/skeleton/tripo_carrot.fbx
new file mode 100644
index 0000000000000000000000000000000000000000..66e22239154887c8a59dfa07d5ef7708b467f5cb
--- /dev/null
+++ b/UniRig/examples/skeleton/tripo_carrot.fbx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50e08d70c7bdfa96eb842aed18564d8486a4622a474dbc0e0ef9304af1d4c6d3
+size 1879420
diff --git a/UniRig/examples/tira.glb b/UniRig/examples/tira.glb
new file mode 100644
index 0000000000000000000000000000000000000000..961966ed74868fcaa26aa397cf31dfc82efa7af0
--- /dev/null
+++ b/UniRig/examples/tira.glb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8e5a282d0a99c61a8d2439496057b15b0c6ea02643c16da2dcbec85571157799
+size 32346060
diff --git a/UniRig/examples/tripo_carrot.glb b/UniRig/examples/tripo_carrot.glb
new file mode 100644
index 0000000000000000000000000000000000000000..9e334e2549a3576d2996fc468d7a8a337a1eca03
--- /dev/null
+++ b/UniRig/examples/tripo_carrot.glb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3c00b7e6ff1e71a019a02128a5b37e8d2db259a402313ad9b4eaab4f183ae40b
+size 9511824
diff --git a/UniRig/launch/inference/extract.sh b/UniRig/launch/inference/extract.sh
new file mode 100644
index 0000000000000000000000000000000000000000..42bd86d77504943b42acc2d4db5be30dd6979149
--- /dev/null
+++ b/UniRig/launch/inference/extract.sh
@@ -0,0 +1,60 @@
+# extract mesh
+config="configs/data/quick_inference.yaml"
+require_suffix="obj,fbx,FBX,dae,glb,gltf,vrm"
+num_runs=1
+force_override="false"
+faces_target_count=50000
+
+while [[ "$#" -gt 0 ]]; do
+ case $1 in
+ --config) config="$2"; shift ;;
+ --require_suffix) require_suffix="$2"; shift ;;
+ --num_runs) num_runs="$2"; shift ;;
+ --force_override) force_override="$2"; shift ;;
+ --faces_target_count) faces_target_count="$2"; shift ;;
+ --time) time="$2"; shift ;;
+ --input) input="$2"; shift ;;
+ --input_dir) input_dir="$2"; shift ;;
+ --output_dir) output_dir="$2"; shift ;;
+ *) echo "Unknown parameter: $1"; exit 1 ;;
+ esac
+ shift
+done
+
+# ensure psutil is installed for memory management
+pip install psutil --quiet
+if [ $? -ne 0 ]; then
+ echo "Warning: Failed to install psutil. Memory management may not work properly."
+fi
+
+# set the time for all processes to use
+time=$(date "+%Y_%m_%d_%H_%M_%S")
+
+for (( i=0; i Box:
+ if path.endswith('.yaml'):
+ path = path.removesuffix('.yaml')
+ path += '.yaml'
+ print(f"\033[92mload {task} config: {path}\033[0m")
+ return Box(yaml.safe_load(open(path, 'r')))
+
+def nullable_string(val):
+ if not val:
+ return None
+ return val
+
+if __name__ == "__main__":
+ torch.set_float32_matmul_precision('high')
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--task", type=str, required=True)
+ parser.add_argument("--seed", type=int, required=False, default=123,
+ help="random seed")
+ parser.add_argument("--input", type=nullable_string, required=False, default=None,
+ help="a single input file or files splited by comma")
+ parser.add_argument("--input_dir", type=nullable_string, required=False, default=None,
+ help="input directory")
+ parser.add_argument("--output", type=nullable_string, required=False, default=None,
+ help="filename for a single output")
+ parser.add_argument("--output_dir", type=nullable_string, required=False, default=None,
+ help="output directory")
+ parser.add_argument("--npz_dir", type=nullable_string, required=False, default='tmp',
+ help="intermediate npz directory")
+ parser.add_argument("--cls", type=nullable_string, required=False, default=None,
+ help="class name")
+ parser.add_argument("--data_name", type=nullable_string, required=False, default=None,
+ help="npz filename from skeleton phase")
+ args = parser.parse_args()
+
+ L.seed_everything(args.seed, workers=True)
+
+ task = load('task', args.task)
+ mode = task.mode
+ assert mode in ['predict']
+
+ if args.input is not None or args.input_dir is not None:
+ assert args.output_dir is not None or args.output is not None, 'output or output_dir must be specified'
+ assert args.npz_dir is not None, 'npz_dir must be specified'
+ files = get_files(
+ data_name=task.components.data_name,
+ inputs=args.input,
+ input_dataset_dir=args.input_dir,
+ output_dataset_dir=args.npz_dir,
+ force_override=True,
+ warning=False,
+ )
+ files = [f[1] for f in files]
+ if len(files) > 1 and args.output is not None:
+ print("\033[92mwarning: output is specified, but multiple files are detected. Output will be written.\033[0m")
+ datapath = Datapath(files=files, cls=args.cls)
+ else:
+ datapath = None
+
+ data_config = load('data', os.path.join('configs/data', task.components.data))
+ transform_config = load('transform', os.path.join('configs/transform', task.components.transform))
+
+ # get tokenizer
+ tokenizer_config = task.components.get('tokenizer', None)
+ if tokenizer_config is not None:
+ tokenizer_config = load('tokenizer', os.path.join('configs/tokenizer', task.components.tokenizer))
+ tokenizer_config = TokenizerConfig.parse(config=tokenizer_config)
+
+ # get data name
+ data_name = task.components.get('data_name', 'raw_data.npz')
+ if args.data_name is not None:
+ data_name = args.data_name
+
+ # get predict dataset
+ predict_dataset_config = data_config.get('predict_dataset_config', None)
+ if predict_dataset_config is not None:
+ predict_dataset_config = DatasetConfig.parse(config=predict_dataset_config).split_by_cls()
+
+ # get predict transform
+ predict_transform_config = transform_config.get('predict_transform_config', None)
+ if predict_transform_config is not None:
+ predict_transform_config = TransformConfig.parse(config=predict_transform_config)
+
+ # get model
+ model_config = task.components.get('model', None)
+ if model_config is not None:
+ model_config = load('model', os.path.join('configs/model', model_config))
+ if tokenizer_config is not None:
+ tokenizer = get_tokenizer(config=tokenizer_config)
+ else:
+ tokenizer = None
+ model = get_model(tokenizer=tokenizer, **model_config)
+ else:
+ model = None
+
+ # set data
+ data = UniRigDatasetModule(
+ process_fn=None if model is None else model._process_fn,
+ predict_dataset_config=predict_dataset_config,
+ predict_transform_config=predict_transform_config,
+ tokenizer_config=tokenizer_config,
+ debug=False,
+ data_name=data_name,
+ datapath=datapath,
+ cls=args.cls,
+ )
+
+ # add call backs
+ callbacks = []
+
+ ## get checkpoint callback
+ checkpoint_config = task.get('checkpoint', None)
+ if checkpoint_config is not None:
+ checkpoint_config['dirpath'] = os.path.join('experiments', task.experiment_name)
+ callbacks.append(ModelCheckpoint(**checkpoint_config))
+
+ ## get writer callback
+ writer_config = task.get('writer', None)
+ if writer_config is not None:
+ assert predict_transform_config is not None, 'missing predict_transform_config in transform'
+ if args.output_dir is not None or args.output is not None:
+ if args.output is not None:
+ assert args.output.endswith('.fbx'), 'output must be .fbx'
+ writer_config['npz_dir'] = args.npz_dir
+ writer_config['output_dir'] = args.output_dir
+ writer_config['output_name'] = args.output
+ writer_config['user_mode'] = True
+ callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config))
+
+ # get trainer
+ trainer_config = task.get('trainer', {})
+
+ # get system
+ system_config = task.components.get('system', None)
+ if system_config is not None:
+ system_config = load('system', os.path.join('configs/system', system_config))
+ system = get_system(
+ **system_config,
+ model=model,
+ steps_per_epoch=1,
+ )
+ else:
+ system = None
+
+ logger = None
+
+ # set ckpt path
+ resume_from_checkpoint = task.get('resume_from_checkpoint', None)
+ resume_from_checkpoint = download(resume_from_checkpoint)
+ trainer = L.Trainer(
+ callbacks=callbacks,
+ logger=logger,
+ **trainer_config,
+ )
+
+ if mode == 'predict':
+ assert resume_from_checkpoint is not None, 'expect resume_from_checkpoint in task'
+ trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False)
+ else:
+ assert 0
\ No newline at end of file
diff --git a/UniRig/src/data/__init__.py b/UniRig/src/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/UniRig/src/data/asset.py b/UniRig/src/data/asset.py
new file mode 100644
index 0000000000000000000000000000000000000000..24f56e470d6f7b58e6d42f4a55dd0b93a5cdf63f
--- /dev/null
+++ b/UniRig/src/data/asset.py
@@ -0,0 +1,433 @@
+from collections import defaultdict
+from dataclasses import dataclass
+import numpy as np
+from numpy import ndarray
+
+from typing import Dict, Union, List, Tuple
+
+from .order import Order
+from .raw_data import RawData
+from .exporter import Exporter
+
+from ..tokenizer.spec import TokenizeInput
+from .utils import linear_blend_skinning
+
+import trimesh
+
+
+@dataclass
+class Asset(Exporter):
+ '''
+ Dataclass to handle data parsed from raw data.
+ '''
+
+ # data class
+ cls: str
+
+ # where is this asset from
+ path: str
+
+ # data file name
+ data_name: str
+
+ # vertices of the mesh, shape (N, 3), float32
+ vertices: ndarray
+
+ # normals of vertices, shape (N, 3), float32
+ vertex_normals: ndarray
+
+ # faces of mesh, shape (F, 3), face id starts from 0 to F-1, int64
+ faces: ndarray
+
+ # face normal of mesh, shape (F, 3), float32
+ face_normals: ndarray
+
+ # joints of bones, shape (J, 3), float32
+ joints: Union[ndarray, None]=None
+
+ # tails of joints, shape (J, 3), float32
+ tails: Union[ndarray, None]=None
+
+ # skinning of joints, shape (N, J), float32
+ skin: Union[ndarray, None]=None
+
+ # whether the joint has skin, bool
+ no_skin: Union[ndarray, None]=None
+
+ # vertex groups
+ vertex_groups: Union[Dict[str, ndarray], None]=None
+
+ # parents of joints, None represents no parent(a root joint)
+ # make sure parent[k] < k
+ parents: Union[List[Union[int, None]], None]=None
+
+ # names of joints
+ names: Union[List[str], None]=None
+
+ # sampled vertices, shape (N, 3)
+ sampled_vertices: Union[ndarray, None]=None
+
+ # sampled normals, shape (N, 3)
+ sampled_normals: Union[ndarray, None]=None
+
+ # sampled vertex groups, every vertex group should be (N, J)
+ sampled_vertex_groups: Union[Dict[str, ndarray], None]=None
+
+ # {id: part}, part==None -> a spring token
+ parts_bias: Union[Dict[int, Union[str, None]], None]=None
+
+ # local coordinate, shape (J, 4, 4)
+ matrix_local: Union[ndarray, None]=None
+
+ # pose matrix for skinning loss calculation, shape (J, 4, 4)
+ pose_matrix: Union[ndarray, None]=None
+
+ meta: Union[Dict[str, ...], None]=None
+
+ @property
+ def N(self):
+ '''
+ number of vertices
+ '''
+ return self.vertices.shape[0]
+
+ @property
+ def F(self):
+ '''
+ number of faces
+ '''
+ return self.faces.shape[0]
+
+ @property
+ def J(self):
+ '''
+ number of joints
+ '''
+ return self.joints.shape[0]
+
+ def get_matrix(self, matrix_basis: ndarray, matrix_local: Union[ndarray, None]=None):
+ '''
+ get matrix
+
+ matrix_basis: (J, 4, 4)
+ '''
+ if matrix_local is None:
+ assert self.joints is not None
+ matrix_local = self.matrix_local
+ if matrix_local is None:
+ matrix_local = np.zeros((self.J, 4, 4))
+ matrix_local[:, 0, 0] = 1.
+ matrix_local[:, 1, 1] = 1.
+ matrix_local[:, 2, 2] = 1.
+ matrix_local[:, 3, 3] = 1.
+ for i in range(self.J):
+ matrix_local[i, :3, 3] = self.joints[i]
+
+ matrix = np.zeros((self.J, 4, 4))
+ for i in range(self.J):
+ if i==0:
+ matrix[i] = matrix_local[i] @ matrix_basis[i]
+ else:
+ pid = self.parents[i]
+ matrix_parent = matrix[pid]
+ matrix_local_parent = matrix_local[pid]
+
+ matrix[i] = (
+ matrix_parent @
+ (np.linalg.inv(matrix_local_parent) @ matrix_local[i]) @
+ matrix_basis[i]
+ )
+ return matrix
+
+ def apply_matrix_basis(self, matrix_basis: ndarray):
+ '''
+ apply a pose to armature
+
+ matrix_basis: (J, 4, 4)
+ '''
+ matrix_local = self.matrix_local
+ if matrix_local is None:
+ matrix_local = np.zeros((self.J, 4, 4))
+ matrix_local[:, 0, 0] = 1.
+ matrix_local[:, 1, 1] = 1.
+ matrix_local[:, 2, 2] = 1.
+ matrix_local[:, 3, 3] = 1.
+ for i in range(self.J):
+ matrix_local[i, :3, 3] = self.joints[i].copy()
+
+ matrix = self.get_matrix(matrix_basis=matrix_basis, matrix_local=matrix_local)
+ self.joints = matrix[:, :3, 3].copy()
+ vertices = linear_blend_skinning(self.vertices, matrix_local, matrix, self.skin, pad=1, value=1.)
+ # update matrix_local
+ self.matrix_local = matrix.copy()
+
+ # change tails
+ if self.tails is not None:
+ t_skin = np.eye(self.J)
+ self.tails = linear_blend_skinning(self.tails, matrix_local, matrix, t_skin, pad=1, value=1.)
+ # in accordance with trimesh's normals
+ mesh = trimesh.Trimesh(vertices=vertices, faces=self.faces, process=False)
+ self.vertices = vertices
+ self.vertex_normals = mesh.vertex_normals.copy()
+ self.face_normals = mesh.face_normals.copy()
+
+ def set_order_by_names(self, new_names: List[str]):
+ assert len(new_names) == len(self.names)
+ name_to_id = {name: id for (id, name) in enumerate(self.names)}
+ new_name_to_id = {name: id for (id, name) in enumerate(new_names)}
+ perm = []
+ new_parents = []
+ for (new_id, name) in enumerate(new_names):
+ perm.append(name_to_id[name])
+ pid = self.parents[name_to_id[name]]
+ if new_id == 0:
+ assert pid is None, 'first bone is not root bone'
+ else:
+ pname = self.names[pid]
+ pid = new_name_to_id[pname]
+ assert pid < new_id, 'new order does not form a tree'
+ new_parents.append(pid)
+
+ if self.joints is not None:
+ self.joints = self.joints[perm]
+ self.parents = new_parents
+ if self.tails is not None:
+ self.tails = self.tails[perm]
+ if self.skin is not None:
+ self.skin = self.skin[:, perm]
+ if self.no_skin is not None:
+ self.no_skin = self.no_skin[perm]
+ if self.matrix_local is not None:
+ self.matrix_local = self.matrix_local[perm]
+ self.names = new_names
+
+ def set_order(self, order: Order):
+ if self.names is None or self.parents is None:
+ return
+ new_names, self.parts_bias = order.arrange_names(cls=self.cls, names=self.names, parents=self.parents)
+ self.set_order_by_names(new_names=new_names)
+
+ def collapse(self, keep: List[str]):
+ dsu = [i for i in range(self.J)]
+
+ def find(x: int) -> int:
+ if dsu[x] == x:
+ return x
+ y = find(dsu[x])
+ dsu[x] = y
+ return y
+
+ def merge(x: int, y: int):
+ dsu[find(x)] = find(y)
+
+ if self.tails is not None:
+ new_tails = self.tails.copy()
+ else:
+ new_tails = None
+ if self.skin is not None:
+ new_skin = self.skin.copy()
+ else:
+ new_skin = None
+
+ if self.no_skin is not None:
+ new_no_skin = self.no_skin.copy()
+ else:
+ new_no_skin = None
+
+ if self.matrix_local is not None:
+ matrix_local = self.matrix_local.copy()
+ else:
+ matrix_local = None
+ new_names = []
+ new_parents = []
+ perm = []
+ new_name_to_id = {}
+ tot = 0
+ for (i, name) in enumerate(self.names):
+ if name in keep:
+ new_names.append(name)
+ new_name_to_id[name] = tot
+ tot += 1
+ perm.append(i)
+ pid = self.parents[i]
+ if pid is None:
+ new_parents.append(None)
+ else:
+ pid = find(pid)
+ new_parents.append(new_name_to_id[self.names[pid]])
+ continue
+ assert i != 0, 'cannot remove root'
+ id = find(i)
+ pid = find(self.parents[id])
+ # be careful !
+ # do not copy tail here because you dont know which child to inherit from
+ if new_skin is not None:
+ new_skin[:, pid] += new_skin[:, id]
+ if new_no_skin is not None:
+ new_no_skin[pid] &= new_no_skin[id]
+ merge(id, pid)
+
+ if new_tails is not None:
+ new_tails = new_tails[perm]
+ if new_skin is not None:
+ new_skin = new_skin[:, perm]
+ if new_no_skin is not None:
+ new_no_skin = new_no_skin[perm]
+ if matrix_local is not None:
+ matrix_local = matrix_local[perm]
+
+ if self.joints is not None:
+ self.joints = self.joints[perm]
+ self.parents = new_parents
+ self.tails = new_tails
+ self.skin = new_skin
+ self.no_skin = new_no_skin
+ self.names = new_names
+ self.matrix_local = matrix_local
+
+ @staticmethod
+ def from_raw_data(
+ raw_data: RawData,
+ cls: str,
+ path: str,
+ data_name: str,
+ ) -> 'Asset':
+ '''
+ Return an asset initialized from raw data and do transform.
+ '''
+ return Asset(
+ cls=cls,
+ path=path,
+ data_name=data_name,
+ vertices=raw_data.vertices,
+ vertex_normals=raw_data.vertex_normals,
+ faces=raw_data.faces,
+ face_normals=raw_data.face_normals,
+ joints=raw_data.joints,
+ tails=raw_data.tails,
+ skin=raw_data.skin,
+ no_skin=raw_data.no_skin,
+ parents=raw_data.parents,
+ names=raw_data.names,
+ matrix_local=raw_data.matrix_local,
+ meta={},
+ )
+
+ def get_tokenize_input(self) -> TokenizeInput:
+ children = defaultdict(list)
+
+ for (id, p) in enumerate(self.parents):
+ if p is not None:
+ children[p].append(id)
+ bones = []
+ branch = []
+ is_leaf = []
+ last = None
+ for i in range(self.J):
+ is_leaf.append(len(children[i])==0)
+ if i == 0:
+ bones.append(np.concatenate([self.joints[i], self.joints[i]]))
+ branch.append(False)
+ else:
+ pid = self.parents[i]
+ bones.append(np.concatenate([self.joints[pid], self.joints[i]]))
+ branch.append(pid!=last)
+ last = i
+ bones = np.stack(bones)
+ branch = np.array(branch, dtype=bool)
+ is_leaf = np.array(is_leaf, dtype=bool)
+ return TokenizeInput(
+ bones=bones,
+ tails=self.tails,
+ branch=branch,
+ is_leaf=is_leaf,
+ no_skin=self.no_skin,
+ cls=self.cls,
+ parts_bias=self.parts_bias,
+ )
+
+ def export_pc(self, path: str, with_normal: bool=True, normal_size=0.01):
+ '''
+ export point cloud
+ '''
+ vertices = self.vertices
+ normals = self.vertex_normals
+ if self.sampled_vertices is not None:
+ vertices = self.sampled_vertices
+ normals = self.sampled_normals
+ if with_normal == False:
+ normals = None
+ self._export_pc(vertices=vertices, path=path, vertex_normals=normals, normal_size=normal_size)
+
+ def export_mesh(self, path: str):
+ '''
+ export mesh
+ '''
+ self._export_mesh(vertices=self.vertices, faces=self.faces, path=path)
+
+ def export_skeleton(self, path: str):
+ '''
+ export spring
+ '''
+ self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
+
+ def export_skeleton_sequence(self, path: str):
+ '''
+ export spring
+ '''
+ self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
+
+ def export_fbx(
+ self,
+ path: str,
+ vertex_group_name: str,
+ extrude_size: float=0.03,
+ group_per_vertex: int=-1,
+ add_root: bool=False,
+ do_not_normalize: bool=False,
+ use_extrude_bone: bool=True,
+ use_connect_unique_child: bool=True,
+ extrude_from_parent: bool=True,
+ use_tail: bool=False,
+ use_origin: bool=False,
+ ):
+ '''
+ export the whole model with skining
+ '''
+ self._export_fbx(
+ path=path,
+ vertices=self.vertices if use_origin else self.sampled_vertices,
+ joints=self.joints,
+ skin=self.sampled_vertex_groups[vertex_group_name],
+ parents=self.parents,
+ names=self.names,
+ faces=self.faces if use_origin else None,
+ extrude_size=extrude_size,
+ group_per_vertex=group_per_vertex,
+ add_root=add_root,
+ do_not_normalize=do_not_normalize,
+ use_extrude_bone=use_extrude_bone,
+ use_connect_unique_child=use_connect_unique_child,
+ extrude_from_parent=extrude_from_parent,
+ tails=self.tails if use_tail else None,
+ )
+
+ def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256], use_tail: bool=False):
+ if use_tail:
+ assert self.tails is not None
+ self._export_render(
+ path=path,
+ vertices=self.vertices,
+ faces=self.faces,
+ bones=np.concatenate([self.joints, self.tails], axis=-1),
+ resolution=resolution,
+ )
+ else:
+ pjoints = self.joints[self.parents[1:]]
+ self._export_render(
+ path=path,
+ vertices=self.vertices,
+ faces=self.faces,
+ bones=np.concatenate([pjoints, self.joints[1:]], axis=-1),
+ resolution=resolution,
+ )
\ No newline at end of file
diff --git a/UniRig/src/data/augment.py b/UniRig/src/data/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..80e2e9810384b898ae0762d506974d9c4be0eced
--- /dev/null
+++ b/UniRig/src/data/augment.py
@@ -0,0 +1,152 @@
+from dataclasses import dataclass
+from typing import Tuple, Union, List, Dict
+from numpy import ndarray
+import numpy as np
+from abc import ABC, abstractmethod
+from scipy.spatial.transform import Rotation as R
+
+from .spec import ConfigSpec
+from .asset import Asset
+from .utils import axis_angle_to_matrix
+
+@dataclass(frozen=True)
+class AugmentAffineConfig(ConfigSpec):
+ # final normalization cube
+ normalize_into: Tuple[float, float]
+
+ # randomly scale coordinates with probability p
+ random_scale_p: float
+
+ # scale range (lower, upper)
+ random_scale: Tuple[float, float]
+
+ # randomly shift coordinates with probability p
+ random_shift_p: float
+
+ # shift range (lower, upper)
+ random_shift: Tuple[float, float]
+
+ @classmethod
+ def parse(cls, config) -> Union['AugmentAffineConfig', None]:
+ if config is None:
+ return None
+ cls.check_keys(config)
+ return AugmentAffineConfig(
+ normalize_into=config.normalize_into,
+ random_scale_p=config.get('random_scale_p', 0.),
+ random_scale=config.get('random_scale', [1., 1.]),
+ random_shift_p=config.get('random_shift_p', 0.),
+ random_shift=config.get('random_shift', [0., 0.]),
+ )
+
+@dataclass(frozen=True)
+class AugmentConfig(ConfigSpec):
+ '''
+ Config to handle final easy augmentation of vertices, normals and bones before sampling.
+ '''
+ augment_affine_config: Union[AugmentAffineConfig, None]
+
+ @classmethod
+ def parse(cls, config) -> 'AugmentConfig':
+ cls.check_keys(config)
+ return AugmentConfig(
+ augment_affine_config=AugmentAffineConfig.parse(config.get('augment_affine_config', None)),
+ )
+
+class Augment(ABC):
+ '''
+ Abstract class for augmentation
+ '''
+ def __init__(self):
+ pass
+
+ @abstractmethod
+ def transform(self, asset: Asset, **kwargs):
+ pass
+
+ @abstractmethod
+ def inverse(self, asset: Asset):
+ pass
+
+class AugmentAffine(Augment):
+
+ def __init__(self, config: AugmentAffineConfig):
+ super().__init__()
+ self.config = config
+
+ def _apply(self, v: ndarray, trans: ndarray) -> ndarray:
+ return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
+
+ def transform(self, asset: Asset, **kwargs):
+ bound_min = asset.vertices.min(axis=0)
+ bound_max = asset.vertices.max(axis=0)
+ if asset.joints is not None:
+ joints_bound_min = asset.joints.min(axis=0)
+ joints_bound_max = asset.joints.max(axis=0)
+ bound_min = np.minimum(bound_min, joints_bound_min)
+ bound_max = np.maximum(bound_max, joints_bound_max)
+
+ trans_vertex = np.eye(4, dtype=np.float32)
+
+ trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex
+
+ # scale into the cube
+ normalize_into = self.config.normalize_into
+ scale = np.max((bound_max - bound_min) / (normalize_into[1] - normalize_into[0]))
+ trans_vertex = _scale_to_m(1. / scale) @ trans_vertex
+
+ bias = (normalize_into[0] + normalize_into[1]) / 2
+ trans_vertex = _trans_to_m(np.array([bias, bias, bias], dtype=np.float32)) @ trans_vertex
+
+ if np.random.rand() < self.config.random_scale_p:
+ scale = _scale_to_m(np.random.uniform(self.config.random_scale[0], self.config.random_scale[1]))
+ trans_vertex = scale @ trans_vertex
+
+ if np.random.rand() < self.config.random_shift_p:
+ l, r = self.config.random_shift
+ shift = _trans_to_m(np.array([np.random.uniform(l, r), np.random.uniform(l, r), np.random.uniform(l, r)]), dtype=np.float32)
+ trans_vertex = shift @ trans_vertex
+
+ asset.vertices = self._apply(asset.vertices, trans_vertex)
+ # do not affect scale in matrix
+ if asset.matrix_local is not None:
+ asset.matrix_local[:, :, 3:4] = trans_vertex @ asset.matrix_local[:, :, 3:4]
+ if asset.pose_matrix is not None:
+ asset.pose_matrix[:, :, 3:4] = trans_vertex @ asset.pose_matrix[:, :, 3:4]
+ # do not affect normal here
+ if asset.joints is not None:
+ asset.joints = self._apply(asset.joints, trans_vertex)
+ if asset.tails is not None:
+ asset.tails = self._apply(asset.tails, trans_vertex)
+
+ self.trans_vertex = trans_vertex
+
+ def inverse(self, asset: Asset):
+ m = np.linalg.inv(self.trans_vertex)
+ asset.vertices = self._apply(asset.vertices, m)
+ if asset.joints is not None:
+ asset.joints = self._apply(asset.joints, m)
+ if asset.tails is not None:
+ asset.tails = self._apply(asset.tails, m)
+
+def _trans_to_m(v: ndarray):
+ m = np.eye(4, dtype=np.float32)
+ m[0:3, 3] = v
+ return m
+
+def _scale_to_m(r: ndarray):
+ m = np.zeros((4, 4), dtype=np.float32)
+ m[0, 0] = r
+ m[1, 1] = r
+ m[2, 2] = r
+ m[3, 3] = 1.
+ return m
+
+def get_augments(config: AugmentConfig) -> Tuple[List[Augment], List[Augment]]:
+ first_augments = [] # augments before sample
+ second_augments = [] # augments after sample
+ augment_affine_config = config.augment_affine_config
+
+ if augment_affine_config is not None:
+ second_augments.append(AugmentAffine(config=augment_affine_config))
+ return first_augments, second_augments
\ No newline at end of file
diff --git a/UniRig/src/data/datapath.py b/UniRig/src/data/datapath.py
new file mode 100644
index 0000000000000000000000000000000000000000..18c9e6ebc1a23569a6ded4462b8b3d2d812b7a88
--- /dev/null
+++ b/UniRig/src/data/datapath.py
@@ -0,0 +1,149 @@
+from copy import deepcopy
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Dict, Union, Tuple, List
+import numpy as np
+from numpy import ndarray
+import os
+from random import shuffle
+from box import Box
+from torch.onnx.symbolic_opset11 import index_copy
+
+from .spec import ConfigSpec
+
+@dataclass
+class DatapathConfig(ConfigSpec):
+ '''
+ Config to handle input data paths.
+ '''
+ # root
+ input_dataset_dir: str
+
+ # use proportion data sampling
+ use_prob: bool
+
+ # cls: [(path_1, p_1), ...]
+ data_path: Dict[str, List[Tuple[str, float]]]
+
+ # how many files to return when using data sampling
+ num_files: Union[int, None]
+
+ @classmethod
+ def from_args(cls, **kwargs) -> 'DatapathConfig':
+ '''
+ Make a temporary datapath from user inputs.
+ '''
+ input = kwargs.get('input', None)
+ output = kwargs.get('output', None)
+ recursive = kwargs.get('recursive', False)
+
+
+ @classmethod
+ def parse(cls, config) -> 'DatapathConfig':
+ cls.check_keys(config)
+ return DatapathConfig(
+ input_dataset_dir=config.input_dataset_dir,
+ use_prob=config.get('use_prob', True),
+ data_path=config.data_path,
+ num_files=config.get('num_files', None),
+ )
+
+ def split_by_cls(self) -> Dict[str, 'DatapathConfig']:
+ res: Dict[str, DatapathConfig] = {}
+ for cls in self.data_path:
+ res[cls] = deepcopy(self)
+ res[cls].data_path = {cls: self.data_path[cls]}
+ return res
+
+class Datapath():
+ def __init__(
+ self,
+ config: Union[DatapathConfig, None]=None,
+ files: Union[List[str], None]=None,
+ cls: Union[str, None]=None,
+ ):
+ if config is not None:
+ self.config = config
+ self.file_list = []
+ cls_probs_first = []
+ cls_first = []
+
+ self.files_by_class: Dict[str, List[Dict]] = defaultdict(list)
+ self.class_positions: Dict[str, List[int]] = defaultdict(list)
+ self.cls_probs_second: Dict[str, ndarray] = defaultdict(List)
+
+ for cls in self.config.data_path:
+ prob = 0.
+ probs_second = []
+ for (path, p) in self.config.data_path[cls]:
+ prob += p
+ probs_second.append(p)
+ with open(path, 'r') as f:
+ file_items = []
+ missing = 0
+ for l in f.readlines():
+ raw_data_path = os.path.join(self.config.input_dataset_dir, l.strip(), 'raw_data.npz')
+ if not os.path.exists(raw_data_path):
+ missing += 1
+ continue
+ file_items.append({
+ 'cls': cls,
+ 'path': os.path.join(self.config.input_dataset_dir, l.strip()),
+ 'prob': p
+ })
+ assert len(file_items) > 0, f"files in {path} are all missing! root: {self.config.input_dataset_dir}"
+ if missing > 0:
+ print(f"\033[31m{cls}: {missing} missing files\033[0m")
+ self.files_by_class[cls].append(file_items)
+ self.class_positions[cls].append(0)
+ self.file_list.extend(file_items)
+ probs_second = np.array(probs_second)
+ self.cls_probs_second[cls] = probs_second / probs_second.sum()
+ cls_first.append(cls)
+ cls_probs_first.append(prob)
+ cls_probs_first = np.array(cls_probs_first)
+ self.cls_first: List[str] = cls_first
+ self.cls_probs_first: Dict[str, List[float]] = cls_probs_first / cls_probs_first.sum()
+ elif files is not None:
+ if cls is None:
+ cls = 'inference'
+ self.file_list = [{'cls': cls, 'path': file} for file in files]
+ cls_probs_first = np.array([1.])
+ cls_first = []
+
+ self.files_by_class: Dict[str, List[Dict]] = {cls: self.file_list.copy()}
+ self.class_positions: Dict[str, List[int]] = {cls: [0]}
+ self.cls_probs_second: Dict[str, ndarray] = {cls: np.array([1.])}
+ self.config = Box({'use_prob': False})
+ else:
+ assert(0)
+
+ def __len__(self):
+ if self.config.use_prob:
+ assert self.config.num_files is not None, 'num_files is not specified'
+ return self.config.num_files
+ return len(self.file_list)
+
+ def __getitem__(self, index) -> Tuple[str, str]:
+ if self.config.use_prob:
+ # first sample a class
+ cls = np.random.choice(self.cls_first, p=self.cls_probs_first)
+
+ # second sample in this class
+ idx = np.random.choice(len(self.files_by_class[cls]), p=self.cls_probs_second[cls])
+
+ # get the current position
+ pos = self.class_positions[cls][idx]
+ files = self.files_by_class[cls][idx]
+
+ # get the item andd update position
+ item = files[pos]
+ self.class_positions[cls][idx] = (pos + 1) % len(files)
+ if (pos + 1) % len(files) == 0:
+ shuffle(self.files_by_class[cls][idx])
+ else:
+ item = self.file_list[index]
+ return (item['cls'], item['path'])
+
+ def get_data(self) -> List[Tuple[str, str]]:
+ return [self[i] for i in range(len(self))]
\ No newline at end of file
diff --git a/UniRig/src/data/dataset.py b/UniRig/src/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dce4c7a8b8b8239300a4b757f291df8bd0dd0fa
--- /dev/null
+++ b/UniRig/src/data/dataset.py
@@ -0,0 +1,231 @@
+from copy import deepcopy
+from dataclasses import dataclass
+import lightning.pytorch as pl
+# from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
+import torch
+from torch import LongTensor
+from torch.utils import data
+from torch.utils.data import DataLoader, Dataset
+from typing import Dict, List, Tuple, Union, Callable
+import os
+import numpy as np
+
+from .raw_data import RawData
+from .asset import Asset
+from .transform import TransformConfig, transform_asset
+from .datapath import DatapathConfig, Datapath
+from .spec import ConfigSpec
+
+from ..tokenizer.spec import TokenizerSpec, TokenizerConfig
+from ..tokenizer.parse import get_tokenizer
+from ..model.spec import ModelInput
+
+@dataclass
+class DatasetConfig(ConfigSpec):
+ '''
+ Config to handle dataset format.
+ '''
+ # shuffle dataset
+ shuffle: bool
+
+ # batch size
+ batch_size: int
+
+ # number of workers
+ num_workers: int
+
+ # datapath
+ datapath_config: DatapathConfig
+
+ # use pin memory
+ pin_memory: bool = True
+
+ # use persistent workers
+ persistent_workers: bool = True
+
+ @classmethod
+ def parse(cls, config) -> 'DatapathConfig':
+ cls.check_keys(config)
+ return DatasetConfig(
+ shuffle=config.shuffle,
+ batch_size=config.batch_size,
+ num_workers=config.num_workers,
+ pin_memory=config.pin_memory,
+ persistent_workers=config.persistent_workers,
+ datapath_config=DatapathConfig.parse(config.datapath_config),
+ )
+
+ def split_by_cls(self) -> Dict[str, 'DatasetConfig']:
+ res: Dict[str, DatasetConfig] = {}
+ datapath_config_dict = self.datapath_config.split_by_cls()
+ for cls in self.datapath_config.data_path:
+ res[cls] = deepcopy(self)
+ res[cls].datapath_config = datapath_config_dict[cls]
+ return res
+
+class UniRigDatasetModule(pl.LightningDataModule):
+ def __init__(
+ self,
+ process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
+ predict_dataset_config: Union[Dict[str, DatasetConfig], None]=None,
+ predict_transform_config: Union[TransformConfig, None]=None,
+ tokenizer_config: Union[TokenizerConfig, None]=None,
+ debug: bool=False,
+ data_name: str='raw_data.npz',
+ datapath: Union[Datapath, None]=None,
+ cls: Union[str, None]=None,
+ ):
+ super().__init__()
+ self.process_fn = process_fn
+ self.predict_dataset_config = predict_dataset_config
+ self.predict_transform_config = predict_transform_config
+ self.tokenizer_config = tokenizer_config
+ self.debug = debug
+ self.data_name = data_name
+
+ if debug:
+ print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m")
+
+ if datapath is not None:
+ self.train_datapath = None
+ self.validate_datapath = None
+ self.predict_datapath = {
+ cls: deepcopy(datapath),
+ }
+ self.predict_dataset_config = {
+ cls: DatasetConfig(
+ shuffle=False,
+ batch_size=1,
+ num_workers=0,
+ datapath_config=deepcopy(datapath),
+ pin_memory=False,
+ persistent_workers=False,
+ )
+ }
+ else:
+ # build predict datapath
+ if self.predict_dataset_config is not None:
+ self.predict_datapath = {
+ cls: Datapath(self.predict_dataset_config[cls].datapath_config)
+ for cls in self.predict_dataset_config
+ }
+ else:
+ self.predict_datapath = None
+
+ # get tokenizer
+ if tokenizer_config is None:
+ self.tokenizer = None
+ else:
+ self.tokenizer = get_tokenizer(config=tokenizer_config)
+
+ def prepare_data(self):
+ pass
+
+ def setup(self, stage=None):
+ if self.predict_datapath is not None:
+ self._predict_ds = {}
+ for cls in self.predict_datapath:
+ self._predict_ds[cls] = UniRigDataset(
+ process_fn=self.process_fn,
+ data=self.predict_datapath[cls].get_data(),
+ name=f"predict-{cls}",
+ tokenizer=self.tokenizer,
+ transform_config=self.predict_transform_config,
+ debug=self.debug,
+ data_name=self.data_name,
+ )
+
+ def predict_dataloader(self):
+ if not hasattr(self, "_predict_ds"):
+ self.setup()
+ return self._create_dataloader(
+ dataset=self._predict_ds,
+ config=self.predict_dataset_config,
+ is_train=False,
+ drop_last=False,
+ )
+
+ def _create_dataloader(
+ self,
+ dataset: Union[Dataset, Dict[str, Dataset]],
+ config: DatasetConfig,
+ is_train: bool,
+ **kwargs,
+ ) -> Union[DataLoader, Dict[str, DataLoader]]:
+ def create_single_dataloader(dataset, config: Union[DatasetConfig, Dict[str, DatasetConfig]], **kwargs):
+ return DataLoader(
+ dataset,
+ batch_size=config.batch_size,
+ shuffle=config.shuffle,
+ num_workers=config.num_workers,
+ pin_memory=config.pin_memory,
+ persistent_workers=config.persistent_workers,
+ collate_fn=dataset.collate_fn,
+ **kwargs,
+ )
+ if isinstance(dataset, Dict):
+ return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()}
+ else:
+ return create_single_dataloader(dataset, config, **kwargs)
+
+class UniRigDataset(Dataset):
+ def __init__(
+ self,
+ data: List[Tuple[str, str]], # (cls, part)
+ name: str,
+ process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
+ tokenizer: Union[TokenizerSpec, None]=None,
+ transform_config: Union[TransformConfig, None]=None,
+ debug: bool=False,
+ data_name: str='raw_data.npz',
+ ) -> None:
+ super().__init__()
+
+ self.data = data
+ self.name = name
+ self.process_fn = process_fn
+ self.tokenizer = tokenizer
+ self.transform_config = transform_config
+ self.debug = debug
+ self.data_name = data_name
+
+ if not debug:
+ assert self.process_fn is not None, 'missing data processing function'
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def __getitem__(self, idx) -> ModelInput:
+ cls, dir_path = self.data[idx]
+ raw_data = RawData.load(path=os.path.join(dir_path, self.data_name))
+ asset = Asset.from_raw_data(raw_data=raw_data, cls=cls, path=dir_path, data_name=self.data_name)
+
+ first_augments, second_augments = transform_asset(
+ asset=asset,
+ transform_config=self.transform_config,
+ )
+ if self.tokenizer is not None and asset.parents is not None:
+ tokens = self.tokenizer.tokenize(input=asset.get_tokenize_input())
+ else:
+ tokens = None
+ return ModelInput(
+ tokens=tokens,
+ pad=None if self.tokenizer is None else self.tokenizer.pad,
+ vertices=asset.sampled_vertices.astype(np.float32),
+ normals=asset.sampled_normals.astype(np.float32),
+ joints=None if asset.joints is None else asset.joints.astype(np.float32),
+ tails=None if asset.tails is None else asset.tails.astype(np.float32),
+ asset=asset,
+ augments=None,
+ )
+
+ def _collate_fn_debug(self, batch):
+ return batch
+
+ def _collate_fn(self, batch):
+ return data.dataloader.default_collate(self.process_fn(batch))
+
+ def collate_fn(self, batch):
+ if self.debug:
+ return self._collate_fn_debug(batch)
+ return self._collate_fn(batch)
\ No newline at end of file
diff --git a/UniRig/src/data/exporter.py b/UniRig/src/data/exporter.py
new file mode 100644
index 0000000000000000000000000000000000000000..c84ff5f2ce1e76fcceab813552a523b41f6207a6
--- /dev/null
+++ b/UniRig/src/data/exporter.py
@@ -0,0 +1,486 @@
+import numpy as np
+from numpy import ndarray
+from typing import List, Union, Tuple
+from collections import defaultdict
+import os
+
+try:
+ import open3d as o3d
+ OPEN3D_EQUIPPED = True
+except:
+ print("do not have open3d")
+ OPEN3D_EQUIPPED = False
+
+class Exporter():
+
+ def _safe_make_dir(self, path):
+ if os.path.dirname(path) == '':
+ return
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+
+ def _export_skeleton(self, joints: ndarray, parents: List[Union[int, None]], path: str):
+ format = path.split('.')[-1]
+ assert format in ['obj']
+ name = path.removesuffix('.obj')
+ path = name + ".obj"
+ self._safe_make_dir(path)
+ J = joints.shape[0]
+ with open(path, 'w') as file:
+ file.write("o spring_joint\n")
+ _joints = []
+ for id in range(J):
+ pid = parents[id]
+ if pid is None or pid == -1:
+ continue
+ bx, by, bz = joints[id]
+ ex, ey, ez = joints[pid]
+ _joints.extend([
+ f"v {bx} {bz} {-by}\n",
+ f"v {ex} {ez} {-ey}\n",
+ f"v {ex} {ez} {-ey + 0.00001}\n"
+ ])
+ file.writelines(_joints)
+
+ _faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)]
+ file.writelines(_faces)
+
+ def _export_bones(self, bones: ndarray, path: str):
+ format = path.split('.')[-1]
+ assert format in ['obj']
+ name = path.removesuffix('.obj')
+ path = name + ".obj"
+ self._safe_make_dir(path)
+ J = bones.shape[0]
+ with open(path, 'w') as file:
+ file.write("o bones\n")
+ _joints = []
+ for bone in bones:
+ bx, by, bz = bone[:3]
+ ex, ey, ez = bone[3:]
+ _joints.extend([
+ f"v {bx} {bz} {-by}\n",
+ f"v {ex} {ez} {-ey}\n",
+ f"v {ex} {ez} {-ey + 0.00001}\n"
+ ])
+ file.writelines(_joints)
+
+ _faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)]
+ file.writelines(_faces)
+
+ def _export_skeleton_sequence(self, joints: ndarray, parents: List[Union[int, None]], path: str):
+ format = path.split('.')[-1]
+ assert format in ['obj']
+ name = path.removesuffix('.obj')
+ path = name + ".obj"
+ self._safe_make_dir(path)
+ J = joints.shape[0]
+ for i in range(J):
+ file = open(name + f"_{i}.obj", 'w')
+ file.write("o spring_joint\n")
+ _joints = []
+ for id in range(i + 1):
+ pid = parents[id]
+ if pid is None:
+ continue
+ bx, by, bz = joints[id]
+ ex, ey, ez = joints[pid]
+ _joints.extend([
+ f"v {bx} {bz} {-by}\n",
+ f"v {ex} {ez} {-ey}\n",
+ f"v {ex} {ez} {-ey + 0.00001}\n"
+ ])
+ file.writelines(_joints)
+
+ _faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)]
+ file.writelines(_faces)
+ file.close()
+
+ def _export_mesh(self, vertices: ndarray, faces: ndarray, path: str):
+ format = path.split('.')[-1]
+ assert format in ['obj', 'ply']
+ if path.endswith('ply'):
+ if not OPEN3D_EQUIPPED:
+ raise RuntimeError("open3d is not available")
+ mesh = o3d.geometry.TriangleMesh()
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
+ self._safe_make_dir(path)
+ o3d.io.write_triangle_mesh(path, mesh)
+ return
+ name = path.removesuffix('.obj')
+ path = name + ".obj"
+ self._safe_make_dir(path)
+ with open(path, 'w') as file:
+ file.write("o mesh\n")
+ _vertices = []
+ for co in vertices:
+ _vertices.append(f"v {co[0]} {co[2]} {-co[1]}\n")
+ file.writelines(_vertices)
+ _faces = []
+ for face in faces:
+ _faces.append(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
+ file.writelines(_faces)
+
+ def _export_pc(self, vertices: ndarray, path: str, vertex_normals: Union[ndarray, None]=None, normal_size: float=0.01):
+ if path.endswith('.ply'):
+ if vertex_normals is not None:
+ print("normal result will not be displayed in .ply format")
+ name = path.removesuffix('.ply')
+ path = name + ".ply"
+ pc = o3d.geometry.PointCloud()
+ pc.points = o3d.utility.Vector3dVector(vertices)
+ # segment fault when numpy >= 2.0 !! use torch environment
+ self._safe_make_dir(path)
+ o3d.io.write_point_cloud(path, pc)
+ return
+ name = path.removesuffix('.obj')
+ path = name + ".obj"
+ self._safe_make_dir(path)
+ with open(path, 'w') as file:
+ file.write("o pc\n")
+ _vertex = []
+ for co in vertices:
+ _vertex.append(f"v {co[0]} {co[2]} {-co[1]}\n")
+ file.writelines(_vertex)
+ if vertex_normals is not None:
+ new_path = path.replace('.obj', '_normal.obj')
+ nfile = open(new_path, 'w')
+ nfile.write("o normal\n")
+ _normal = []
+ for i in range(vertices.shape[0]):
+ co = vertices[i]
+ x = vertex_normals[i, 0]
+ y = vertex_normals[i, 1]
+ z = vertex_normals[i, 2]
+ _normal.extend([
+ f"v {co[0]} {co[2]} {-co[1]}\n",
+ f"v {co[0]+0.0001} {co[2]} {-co[1]}\n",
+ f"v {co[0]+x*normal_size} {co[2]+z*normal_size} {-(co[1]+y*normal_size)}\n",
+ f"f {i*3+1} {i*3+2} {i*3+3}\n",
+ ])
+ nfile.writelines(_normal)
+
+ def _make_armature(
+ self,
+ vertices: Union[ndarray, None],
+ joints: ndarray,
+ skin: Union[ndarray, None],
+ parents: List[Union[int, None]],
+ names: List[str],
+ faces: Union[ndarray, None]=None,
+ extrude_size: float=0.03,
+ group_per_vertex: int=-1,
+ add_root: bool=False,
+ do_not_normalize: bool=False,
+ use_extrude_bone: bool=True,
+ use_connect_unique_child: bool=True,
+ extrude_from_parent: bool=True,
+ tails: Union[ndarray, None]=None,
+ ):
+ import bpy # type: ignore
+ from mathutils import Vector # type: ignore
+
+ # make collection
+ collection = bpy.data.collections.new('new_collection')
+ bpy.context.scene.collection.children.link(collection)
+
+ # make mesh
+ if vertices is not None:
+ mesh = bpy.data.meshes.new('mesh')
+ if faces is None:
+ faces = []
+ mesh.from_pydata(vertices, [], faces)
+ mesh.update()
+
+ # make object from mesh
+ object = bpy.data.objects.new('character', mesh)
+
+ # add object to scene collection
+ collection.objects.link(object)
+
+ # deselect mesh
+ bpy.ops.object.armature_add(enter_editmode=True)
+ armature = bpy.data.armatures.get('Armature')
+ edit_bones = armature.edit_bones
+
+ J = joints.shape[0]
+ if tails is None:
+ tails = joints.copy()
+ tails[:, 2] += extrude_size
+ connects = [False for _ in range(J)]
+ children = defaultdict(list)
+ for i in range(1, J):
+ children[parents[i]].append(i)
+ if tails is not None:
+ if use_extrude_bone:
+ for i in range(J):
+ if len(children[i]) != 1 and extrude_from_parent and i != 0:
+ pjoint = joints[parents[i]]
+ joint = joints[i]
+ d = joint - pjoint
+ if np.linalg.norm(d) < 0.000001:
+ d = np.array([0., 0., 1.]) # in case son.head == parent.head
+ else:
+ d = d / np.linalg.norm(d)
+ tails[i] = joint + d * extrude_size
+ if use_connect_unique_child:
+ for i in range(J):
+ if len(children[i]) == 1:
+ child = children[i][0]
+ tails[i] = joints[child]
+ if parents[i] is not None and len(children[parents[i]]) == 1:
+ connects[i] = True
+
+ if add_root:
+ bone_root = edit_bones.get('Bone')
+ bone_root.name = 'Root'
+ bone_root.tail = Vector((joints[0, 0], joints[0, 1], joints[0, 2]))
+ else:
+ bone_root = edit_bones.get('Bone')
+ bone_root.name = names[0]
+ bone_root.head = Vector((joints[0, 0], joints[0, 1], joints[0, 2]))
+ bone_root.tail = Vector((joints[0, 0], joints[0, 1], joints[0, 2] + extrude_size))
+
+ def extrude_bone(
+ edit_bones,
+ name: str,
+ parent_name: str,
+ head: Tuple[float, float, float],
+ tail: Tuple[float, float, float],
+ connect: bool
+ ):
+ bone = edit_bones.new(name)
+ bone.head = Vector((head[0], head[1], head[2]))
+ bone.tail = Vector((tail[0], tail[1], tail[2]))
+ bone.name = name
+ parent_bone = edit_bones.get(parent_name)
+ bone.parent = parent_bone
+ bone.use_connect = connect
+ assert not np.isnan(head).any(), f"nan found in head of bone {name}"
+ assert not np.isnan(tail).any(), f"nan found in tail of bone {name}"
+
+ for i in range(J):
+ if add_root is False and i==0:
+ continue
+ edit_bones = armature.edit_bones
+ pname = 'Root' if parents[i] is None else names[parents[i]]
+ extrude_bone(edit_bones, names[i], pname, joints[i], tails[i], connects[i])
+ for i in range(J):
+ bone = edit_bones.get(names[i])
+ bone.head = Vector((joints[i, 0], joints[i, 1], joints[i, 2]))
+ bone.tail = Vector((tails[i, 0], tails[i, 1], tails[i, 2]))
+
+ if vertices is None or skin is None:
+ return
+ # must set to object mode to enable parent_set
+ bpy.ops.object.mode_set(mode='OBJECT')
+ objects = bpy.data.objects
+ for o in bpy.context.selected_objects:
+ o.select_set(False)
+ ob = objects['character']
+ arm = bpy.data.objects['Armature']
+ ob.select_set(True)
+ arm.select_set(True)
+ bpy.ops.object.parent_set(type='ARMATURE_NAME')
+ vis = []
+ for x in ob.vertex_groups:
+ vis.append(x.name)
+ #sparsify
+ argsorted = np.argsort(-skin, axis=1)
+ vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted]
+ if group_per_vertex == -1:
+ group_per_vertex = vertex_group_reweight.shape[-1]
+ if not do_not_normalize:
+ vertex_group_reweight = vertex_group_reweight / vertex_group_reweight[..., :group_per_vertex].sum(axis=1)[...,None]
+
+ for v, w in enumerate(skin):
+ for ii in range(group_per_vertex):
+ i = argsorted[v, ii]
+ if i >= J:
+ continue
+ n = names[i]
+ if n not in vis:
+ continue
+ ob.vertex_groups[n].add([v], vertex_group_reweight[v, ii], 'REPLACE')
+
+ def _clean_bpy(self):
+ import bpy # type: ignore
+ for c in bpy.data.actions:
+ bpy.data.actions.remove(c)
+ for c in bpy.data.armatures:
+ bpy.data.armatures.remove(c)
+ for c in bpy.data.cameras:
+ bpy.data.cameras.remove(c)
+ for c in bpy.data.collections:
+ bpy.data.collections.remove(c)
+ for c in bpy.data.images:
+ bpy.data.images.remove(c)
+ for c in bpy.data.materials:
+ bpy.data.materials.remove(c)
+ for c in bpy.data.meshes:
+ bpy.data.meshes.remove(c)
+ for c in bpy.data.objects:
+ bpy.data.objects.remove(c)
+ for c in bpy.data.textures:
+ bpy.data.textures.remove(c)
+
+ def _export_fbx(
+ self,
+ path: str,
+ vertices: Union[ndarray, None],
+ joints: ndarray,
+ skin: Union[ndarray, None],
+ parents: List[Union[int, None]],
+ names: List[str],
+ faces: Union[ndarray, None]=None,
+ extrude_size: float=0.03,
+ group_per_vertex: int=-1,
+ add_root: bool=False,
+ do_not_normalize: bool=False,
+ use_extrude_bone: bool=True,
+ use_connect_unique_child: bool=True,
+ extrude_from_parent: bool=True,
+ tails: Union[ndarray, None]=None,
+ ):
+ '''
+ Requires bpy installed
+ '''
+ import bpy # type: ignore
+ self._safe_make_dir(path)
+ self._clean_bpy()
+ self._make_armature(
+ vertices=vertices,
+ joints=joints,
+ skin=skin,
+ parents=parents,
+ names=names,
+ faces=faces,
+ extrude_size=extrude_size,
+ group_per_vertex=group_per_vertex,
+ add_root=add_root,
+ do_not_normalize=do_not_normalize,
+ use_extrude_bone=use_extrude_bone,
+ use_connect_unique_child=use_connect_unique_child,
+ extrude_from_parent=extrude_from_parent,
+ tails=tails,
+ )
+
+ # always enable add_leaf_bones to keep leaf bones
+ bpy.ops.export_scene.fbx(filepath=path, check_existing=False, add_leaf_bones=False)
+
+ def _export_render(
+ self,
+ path: str,
+ vertices: Union[ndarray, None],
+ faces: Union[ndarray, None],
+ bones: Union[ndarray, None],
+ resolution: Tuple[float, float]=[256, 256],
+ ):
+ import bpy # type: ignore
+ import bpy_extras # type: ignore
+ from mathutils import Vector # type: ignore
+
+ self._safe_make_dir(path)
+ # normalize into [-1, 1]^3
+ # copied from augment
+ assert (vertices is not None) or (bones is not None)
+ bounds = []
+ if vertices is not None:
+ bounds.append(vertices)
+ if bones is not None:
+ bounds.append(bones[:, :3])
+ bounds.append(bones[:, 3:])
+ bounds = np.concatenate(bounds, axis=0)
+ bound_min = bounds.min(axis=0)
+ bound_max = bounds.max(axis=0)
+
+ trans_vertex = np.eye(4)
+
+ trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex
+
+ # scale into the cube [-1, 1]
+ scale = np.max((bound_max - bound_min) / 2)
+ trans_vertex = _scale_to_m(1. / scale) @ trans_vertex
+
+ def _apply(v: ndarray, trans: ndarray) -> ndarray:
+ return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
+
+ if vertices is not None:
+ vertices = _apply(vertices, trans_vertex)
+ if bones is not None:
+ bones[:, :3] = _apply(bones[:, :3], trans_vertex)
+ bones[:, 3:] = _apply(bones[:, 3:], trans_vertex)
+
+ # bpy api calls
+ self._clean_bpy()
+ bpy.context.scene.render.engine = 'BLENDER_WORKBENCH'
+ bpy.context.scene.render.film_transparent = True
+ bpy.context.scene.display.shading.background_type = 'VIEWPORT'
+
+ collection = bpy.data.collections.new('new_collection')
+ bpy.context.scene.collection.children.link(collection)
+
+ if vertices is not None:
+ mesh_data = bpy.data.meshes.new(name="MeshData")
+ mesh_obj = bpy.data.objects.new(name="MeshObject", object_data=mesh_data)
+ collection.objects.link(mesh_obj)
+
+ mesh_data.from_pydata((vertices).tolist(), [], faces.tolist())
+ mesh_data.update()
+
+ def look_at(camera, point):
+ direction = point - camera.location
+ rot_quat = direction.to_track_quat('-Z', 'Y')
+ camera.rotation_euler = rot_quat.to_euler()
+
+ bpy.ops.object.camera_add(location=(4, -4, 2.5))
+ camera = bpy.context.object
+ camera.data.angle = np.radians(25.0)
+ look_at(camera, Vector((0, 0, -0.2)))
+ bpy.context.scene.camera = camera
+
+ bpy.context.scene.render.resolution_x = resolution[0]
+ bpy.context.scene.render.resolution_y = resolution[1]
+ bpy.context.scene.render.image_settings.file_format = 'PNG'
+ bpy.context.scene.render.filepath = path
+
+ bpy.ops.render.render(write_still=True)
+ # some AI generated code to draw bones over mesh
+ if bones is not None:
+ # TODO: do not save image after rendering
+ from PIL import Image, ImageDraw
+ img_pil = Image.open(path).convert("RGBA")
+ draw = ImageDraw.Draw(img_pil)
+
+ from bpy_extras.image_utils import load_image # type: ignore
+ bpy.context.scene.use_nodes = True
+ nodes = bpy.context.scene.node_tree.nodes
+ # nodes.clear()
+
+ img = load_image(path)
+ image_node = nodes.new(type='CompositorNodeImage')
+ image_node.image = img
+
+ for i, bone in enumerate(bones):
+ head, tail = bone[:3], bone[3:]
+ head_2d = bpy_extras.object_utils.world_to_camera_view(bpy.context.scene, camera, Vector(head))
+ tail_2d = bpy_extras.object_utils.world_to_camera_view(bpy.context.scene, camera, Vector(tail))
+
+ res_x, res_y = resolution
+ head_pix = (head_2d.x * res_x, (1 - head_2d.y) * res_y)
+ tail_pix = (tail_2d.x * res_x, (1 - tail_2d.y) * res_y)
+ draw.line([head_pix, tail_pix], fill=(255, 0, 0, 255), width=1)
+ img_pil.save(path)
+
+def _trans_to_m(v: ndarray):
+ m = np.eye(4)
+ m[0:3, 3] = v
+ return m
+
+def _scale_to_m(r: ndarray):
+ m = np.zeros((4, 4))
+ m[0, 0] = r
+ m[1, 1] = r
+ m[2, 2] = r
+ m[3, 3] = 1.
+ return m
\ No newline at end of file
diff --git a/UniRig/src/data/extract.py b/UniRig/src/data/extract.py
new file mode 100644
index 0000000000000000000000000000000000000000..828915d52c6c08fc1b5b72c2c71a465007f61b5f
--- /dev/null
+++ b/UniRig/src/data/extract.py
@@ -0,0 +1,523 @@
+import bpy, os
+from collections import defaultdict
+from tqdm import tqdm
+import numpy as np
+from numpy import ndarray
+from typing import Dict, Tuple, List, Optional, Union
+import trimesh
+import fast_simplification
+from scipy.spatial import KDTree
+
+import argparse
+import yaml
+from box import Box
+import os
+
+from .log import new_entry, add_error, add_warning, new_log, end_log
+from .raw_data import RawData
+
+def load(filepath: str):
+ old_objs = set(bpy.context.scene.objects)
+
+ if not os.path.exists(filepath):
+ raise ValueError(f'File {filepath} does not exist !')
+
+ try:
+ if filepath.endswith(".vrm"):
+ # enable vrm addon and load vrm model
+ bpy.ops.preferences.addon_enable(module='vrm')
+
+ bpy.ops.import_scene.vrm(
+ filepath=filepath,
+ use_addon_preferences=True,
+ extract_textures_into_folder=False,
+ make_new_texture_folder=False,
+ set_shading_type_to_material_on_import=False,
+ set_view_transform_to_standard_on_import=True,
+ set_armature_display_to_wire=True,
+ set_armature_display_to_show_in_front=True,
+ set_armature_bone_shape_to_default=True,
+ disable_bake=True, # customized option for better performance
+ )
+ elif filepath.endswith(".obj"):
+ bpy.ops.wm.obj_import(filepath=filepath)
+ elif filepath.endswith(".fbx") or filepath.endswith(".FBX"):
+ # end bone is removed using remove_dummy_bone
+ bpy.ops.import_scene.fbx(filepath=filepath, ignore_leaf_bones=False, use_image_search=False)
+ elif filepath.endswith(".glb") or filepath.endswith(".gltf"):
+ bpy.ops.import_scene.gltf(filepath=filepath, import_pack_images=False)
+ elif filepath.endswith(".dae"):
+ bpy.ops.wm.collada_import(filepath=filepath)
+ elif filepath.endswith(".blend"):
+ with bpy.data.libraries.load(filepath) as (data_from, data_to):
+ data_to.objects = data_from.objects
+ for obj in data_to.objects:
+ if obj is not None:
+ bpy.context.collection.objects.link(obj)
+ else:
+ raise ValueError(f"not suported type {filepath}")
+ except:
+ raise ValueError(f"failed to load {filepath}")
+
+ armature = [x for x in set(bpy.context.scene.objects)-old_objs if x.type=="ARMATURE"]
+ if len(armature)==0:
+ return None
+ if len(armature)>1:
+ raise ValueError(f"multiple armatures found")
+ armature = armature[0]
+
+ armature.select_set(True)
+ bpy.context.view_layer.objects.active = armature
+ bpy.ops.object.mode_set(mode='EDIT')
+ for bone in bpy.data.armatures[0].edit_bones:
+ bone.roll = 0. # change all roll to 0. to prevent weird behaviour
+
+ bpy.ops.object.mode_set(mode='OBJECT')
+ armature.select_set(False)
+
+ bpy.ops.object.select_all(action='DESELECT')
+ return armature
+
+# remove all data in bpy
+def clean_bpy():
+ # First try to purge orphan data
+ try:
+ bpy.ops.outliner.orphans_purge(do_local_ids=True, do_linked_ids=True, do_recursive=True)
+ except Exception as e:
+ print(f"Warning: Could not purge orphans: {e}")
+
+ # Then remove all data by type
+ data_types = [
+ bpy.data.actions,
+ bpy.data.armatures,
+ bpy.data.cameras,
+ bpy.data.collections,
+ bpy.data.curves,
+ bpy.data.images,
+ bpy.data.lights,
+ bpy.data.materials,
+ bpy.data.meshes,
+ bpy.data.objects,
+ bpy.data.textures,
+ bpy.data.worlds,
+ bpy.data.node_groups
+ ]
+
+ for data_collection in data_types:
+ try:
+ for item in data_collection:
+ try:
+ data_collection.remove(item)
+ except Exception as e:
+ print(f"Warning: Could not remove {item.name} from {data_collection}: {e}")
+ except Exception as e:
+ print(f"Warning: Error processing {data_collection}: {e}")
+
+ # Force garbage collection to free memory
+ import gc
+ gc.collect()
+
+def get_arranged_bones(armature):
+ matrix_world = armature.matrix_world
+ arranged_bones = []
+ root = armature.pose.bones[0]
+ while root.parent is not None:
+ root = root.parent
+ Q = [root]
+ rot = np.array(matrix_world)[:3, :3]
+
+ # dfs and sort
+ while len(Q) != 0:
+ b = Q.pop(0)
+ arranged_bones.append(b)
+ children = []
+ for cb in b.children:
+ head = rot @ np.array(b.head)
+ children.append((cb, head[0], head[1], head[2]))
+ children = sorted(children, key=lambda x: (x[3], x[1], x[2]))
+ _c = [x[0] for x in children]
+ Q = _c + Q
+ return arranged_bones
+
+def process_mesh():
+ meshes = []
+ for v in bpy.data.objects:
+ if v.type == 'MESH':
+ meshes.append(v)
+
+ _dict_mesh = {}
+ for obj in meshes:
+ m = np.array(obj.matrix_world)
+ matrix_world_rot = m[:3, :3]
+ matrix_world_bias = m[:3, 3]
+ rot = matrix_world_rot
+ total_vertices = len(obj.data.vertices)
+ vertex = np.zeros((4, total_vertices))
+ vertex_normal = np.zeros((total_vertices, 3))
+ obj_verts = obj.data.vertices
+ faces = []
+ normals = []
+
+ for v in obj_verts:
+ vertex_normal[v.index] = rot @ np.array(v.normal) # be careful !
+ vv = rot @ v.co
+ vv = np.array(vv) + matrix_world_bias
+ vertex[0:3, v.index] = vv
+ vertex[3][v.index] = 1 # affine coordinate
+
+ for polygon in obj.data.polygons:
+ edges = polygon.edge_keys
+ nodes = []
+ adj = {}
+ for edge in edges:
+ if adj.get(edge[0]) is None:
+ adj[edge[0]] = []
+ adj[edge[0]].append(edge[1])
+ if adj.get(edge[1]) is None:
+ adj[edge[1]] = []
+ adj[edge[1]].append(edge[0])
+ nodes.append(edge[0])
+ nodes.append(edge[1])
+ normal = polygon.normal
+ nodes = list(set(sorted(nodes)))
+ first = nodes[0]
+ loop = []
+ now = first
+ vis = {}
+ while True:
+ loop.append(now)
+ vis[now] = True
+ if vis.get(adj[now][0]) is None:
+ now = adj[now][0]
+ elif vis.get(adj[now][1]) is None:
+ now = adj[now][1]
+ else:
+ break
+ for (second, third) in zip(loop[1:], loop[2:]):
+ faces.append((first + 1, second + 1, third + 1)) # the cursed +1
+ normals.append(rot @ normal) # and the cursed normal of BLENDER
+
+ correct_faces = []
+ for (i, face) in enumerate(faces):
+ normal = normals[i]
+ v0 = face[0] - 1
+ v1 = face[1] - 1
+ v2 = face[2] - 1
+ v = np.cross(
+ vertex[:3, v1] - vertex[:3, v0],
+ vertex[:3, v2] - vertex[:3, v0],
+ )
+ if (v*normal).sum() > 0:
+ correct_faces.append(face)
+ else:
+ correct_faces.append((face[0], face[2], face[1]))
+ if len(correct_faces) > 0:
+ _dict_mesh[obj.name] = {
+ 'vertex': vertex,
+ 'face': correct_faces,
+ }
+
+ vertex = np.concatenate([_dict_mesh[name]['vertex'] for name in _dict_mesh], axis=1)[:3, :].transpose()
+
+ total_faces = 0
+ now_bias = 0
+ for name in _dict_mesh:
+ total_faces += len(_dict_mesh[name]['face'])
+ faces = np.zeros((total_faces, 3), dtype=np.int64)
+ tot = 0
+ for name in _dict_mesh:
+ f = np.array(_dict_mesh[name]['face'], dtype=np.int64)
+ faces[tot:tot+f.shape[0]] = f + now_bias
+ now_bias += _dict_mesh[name]['vertex'].shape[1]
+ tot += f.shape[0]
+
+ return vertex, faces
+
+def process_armature(
+ armature,
+ arranged_bones,
+) -> Tuple[np.ndarray, np.ndarray]:
+ matrix_world = armature.matrix_world
+ index = {}
+
+ for (id, pbone) in enumerate(arranged_bones):
+ index[pbone.name] = id
+
+ root = armature.pose.bones[0]
+ while root.parent is not None:
+ root = root.parent
+ m = np.array(matrix_world.to_4x4())
+ scale_inv = np.linalg.inv(np.diag(matrix_world.to_scale()))
+ rot = m[:3, :3]
+ bias = m[:3, 3]
+
+ s = []
+ bpy.ops.object.editmode_toggle()
+ edit_bones = armature.data.edit_bones
+
+ J = len(arranged_bones)
+ joints = np.zeros((J, 3), dtype=np.float32)
+ tails = np.zeros((J, 3), dtype=np.float32)
+ parents = []
+ name_to_id = {}
+ names = []
+ matrix_local_stack = np.zeros((J, 4, 4), dtype=np.float32)
+ for (id, pbone) in enumerate(arranged_bones):
+ name = pbone.name
+ names.append(name)
+ matrix_local = np.array(pbone.bone.matrix_local)
+ use_inherit_rotation = pbone.bone.use_inherit_rotation
+ if use_inherit_rotation == False:
+ add_warning(f"use_inherit_rotation of bone {name} is False !")
+ head = rot @ matrix_local[0:3, 3] + bias
+ s.append(head)
+ edit_bone = edit_bones.get(name)
+ tail = rot @ np.array(edit_bone.tail) + bias
+
+ name_to_id[name] = id
+ joints[id] = head
+ tails[id] = tail
+ parents.append(None if pbone.parent not in arranged_bones else name_to_id[pbone.parent.name])
+ # remove scale part
+ matrix_local[:, 3:4] = m @ matrix_local[:, 3:4]
+ matrix_local[:3, :3] = scale_inv @ matrix_local[:3, :3]
+ matrix_local_stack[id] = matrix_local
+ bpy.ops.object.editmode_toggle()
+
+ return joints, tails, parents, names, matrix_local_stack
+
+def save_raw_data(
+ path: str,
+ vertices: ndarray,
+ faces: ndarray,
+ joints: Union[ndarray, None],
+ tails: Union[ndarray, None],
+ parents: Union[List[Union[int, None]], None],
+ names: Union[List[str], None],
+ matrix_local: Union[ndarray, None],
+ target_count: int,
+):
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
+ vertices = np.array(mesh.vertices, dtype=np.float32)
+ faces = np.array(mesh.faces, dtype=np.int64)
+ if faces.shape[0] > target_count:
+ vertices, faces = fast_simplification.simplify(vertices, faces, target_count=target_count)
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
+
+ new_vertices = np.array(mesh.vertices, dtype=np.float32)
+ new_vertex_normals = np.array(mesh.vertex_normals, dtype=np.float32)
+ new_faces = np.array(mesh.faces, dtype=np.int64)
+ new_face_normals = np.array(mesh.face_normals, dtype=np.float32)
+ if joints is not None:
+ new_joints = np.array(joints, dtype=np.float32)
+ else:
+ new_joints = None
+ raw_data = RawData(
+ vertices=new_vertices,
+ vertex_normals=new_vertex_normals,
+ faces=new_faces,
+ face_normals=new_face_normals,
+ joints=new_joints,
+ tails=tails,
+ skin=None,
+ no_skin=None,
+ parents=parents,
+ names=names,
+ matrix_local=matrix_local,
+ )
+ raw_data.check()
+ raw_data.save(path=path)
+
+def extract_builtin(
+ output_folder: str,
+ target_count: int,
+ num_runs: int,
+ id: int,
+ time: str,
+ files: List[Union[str, str]],
+):
+ log_path = "./logs"
+ log_path = os.path.join(log_path, time)
+
+ num_files = len(files)
+ gap = num_files // num_runs
+ start = gap * id
+ end = gap * (id + 1)
+ if id+1==num_runs:
+ end = num_files
+
+ files = sorted(files)
+ if end!=-1:
+ files = files[:end]
+ new_log(log_path, f"extract_builtin_{start}_{end}")
+ tot = 0
+ for file in tqdm(files[start:]):
+ input_file = file[0]
+ output_dir = file[1]
+ clean_bpy()
+ new_entry(input_file)
+ try:
+ print(f"Now processing {input_file}...")
+
+ armature = load(input_file)
+
+ print('save to:', output_dir)
+ os.makedirs(output_dir, exist_ok=True)
+
+ vertices, faces = process_mesh()
+ if armature is not None:
+ arranged_bones = get_arranged_bones(armature)
+ joints, tails, parents, names, matrix_local = process_armature(armature, arranged_bones)
+
+ else:
+ joints = None
+ tails = None
+ parents = None
+ names = None
+ matrix_local = None
+
+ save_file = os.path.join(output_dir, 'raw_data.npz')
+ save_raw_data(
+ path=save_file,
+ vertices=vertices,
+ faces=faces-1,
+ joints=joints,
+ tails=tails,
+ parents=parents,
+ names=names,
+ matrix_local=matrix_local,
+ target_count=target_count,
+ )
+
+ tot += 1
+
+ except ValueError as e:
+ add_error(str(e))
+ print(f"ValueError: {str(e)}")
+ except RuntimeError as e:
+ add_error(str(e))
+ print(f"RuntimeError: {str(e)}")
+ except TimeoutError as e:
+ add_error("time out")
+ print("TimeoutError: Processing timed out")
+ except Exception as e:
+ add_error(f"Unexpected error: {str(e)}")
+ print(f"Unexpected error: {str(e)}")
+ end_log()
+ print(f"{tot} models processed")
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+def nullable_string(val):
+ if not val:
+ return None
+ return val
+
+def get_files(
+ data_name: str,
+ input_dataset_dir: str,
+ output_dataset_dir: str,
+ inputs: Union[str, None]=None,
+ require_suffix: List[str]=['obj','fbx','FBX','dae','glb','gltf','vrm'],
+ force_override: bool=False,
+ warning: bool=True,
+) -> List[Tuple[str, str]]:
+
+ files = [] # (input_file, output_dir)
+ if inputs is not None: # specified input file(s)
+ vis = {}
+ inputs = inputs.split(',')
+ for file in inputs:
+ file_name = file.removeprefix("./")
+ # remove suffix
+ file_name = '.'.join(file_name.split('.')[:-1])
+ output_dir = os.path.join(output_dataset_dir, file_name)
+ raw_data_npz = os.path.join(output_dir, data_name)
+ if not force_override and os.path.exists(raw_data_npz):
+ continue
+ if warning and output_dir in vis:
+ print(f"\033[33mWARNING: duplicate output directory: {output_dir}, you need to rename prefix of files to avoid ambiguity\033[0m")
+ vis[output_dir] = True
+ files.append((file, output_dir))
+ else:
+ vis = {}
+ for root, dirs, f in os.walk(input_dataset_dir):
+ for file in f:
+ if file.split('.')[-1] in require_suffix:
+ file_name = file.removeprefix("./")
+ # remove suffix
+ file_name = '.'.join(file_name.split('.')[:-1])
+
+ output_dir = os.path.join(output_dataset_dir, os.path.relpath(root, input_dataset_dir), file_name)
+ raw_data_npz = os.path.join(output_dir, data_name)
+
+ # Check if all required files exist
+ if not force_override and os.path.exists(raw_data_npz):
+ continue
+ if warning and output_dir in vis:
+ print(f"\033[33mWARNING: duplicate output directory: {output_dir}, you need to rename prefix of files to avoid ambiguity\033[0m")
+ vis[output_dir] = True
+ files.append((os.path.join(root, file), output_dir))
+
+ return files
+
+def parse():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True)
+ parser.add_argument('--require_suffix', type=str, required=True)
+ parser.add_argument('--faces_target_count', type=int, required=True)
+ parser.add_argument('--num_runs', type=int, required=True)
+ parser.add_argument('--force_override', type=str2bool, required=True)
+ parser.add_argument('--id', type=int, required=True)
+ parser.add_argument('--time', type=str, required=True)
+
+ parser.add_argument('--input', type=nullable_string, required=False, default=None)
+ parser.add_argument('--input_dir', type=nullable_string, required=False, default=None)
+ parser.add_argument('--output_dir', type=nullable_string, required=False, default=None)
+ return parser.parse_args()
+
+if __name__ == "__main__":
+ args = parse()
+
+ config = Box(yaml.safe_load(open(args.config, "r")))
+
+ num_runs = args.num_runs
+ id = args.id
+ timestamp = args.time
+ require_suffix = args.require_suffix.split(',')
+ force_override = args.force_override
+ target_count = args.faces_target_count
+
+ if args.input_dir:
+ config.input_dataset_dir = args.input_dir
+ if args.output_dir:
+ config.output_dataset_dir = args.output_dir
+
+ assert config.input_dataset_dir is not None or args.input is None, 'you cannot specify both input and input_dir'
+
+ files = get_files(
+ data_name='raw_data.npz',
+ inputs=args.input,
+ input_dataset_dir=config.input_dataset_dir,
+ output_dataset_dir=config.output_dataset_dir,
+ require_suffix=require_suffix,
+ force_override=force_override,
+ warning=True,
+ )
+
+ extract_builtin(
+ output_folder=config.output_dataset_dir,
+ target_count=target_count,
+ num_runs=num_runs,
+ id=id,
+ time=timestamp,
+ files=files,
+ )
\ No newline at end of file
diff --git a/UniRig/src/data/log.py b/UniRig/src/data/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcf3d211728aa5548de7eb240cdd914ed01a593c
--- /dev/null
+++ b/UniRig/src/data/log.py
@@ -0,0 +1,50 @@
+import os
+from typing import List
+
+login_time = ''
+log_filepath = ''
+
+class Entry:
+ def __init__(self, entry_name):
+ self.entry = entry_name
+ self.error = None
+ self.warning = []
+
+ def have_error(self):
+ return self.error != None
+
+ def have_warning(self):
+ return len(self.warning) != 0
+
+logs: List[Entry] = []
+
+def new_log(path, log_name):
+ global login_time, log_filepath
+ log_filepath = os.path.join(path, f"{log_name}.txt")
+ os.makedirs(path, exist_ok=True)
+ with open(log_filepath, 'a') as file:
+ file.write(f"Log: {log_name}\n")
+
+def end_log():
+ global log_filepath
+ with open(log_filepath, 'a') as file:
+ file.write(f"End of file\n")
+
+def new_entry(entry_name):
+ global log_filepath
+ print(f"\033[32mNow processing {entry_name}...\033[0m")
+ logs.append(Entry(entry_name))
+
+def add_error(error):
+ global log_filepath
+ print(f"\033[31mError found when processing {logs[-1].entry}: {error}\033[0m")
+ logs[-1].error = error
+ with open(log_filepath, 'a') as file:
+ file.write(f"Entry: {logs[-1].entry}, Error: {error}\n")
+
+def add_warning(warning):
+ global log_filepath
+ print(f"\033[33mWarning found when processing {logs[-1].entry}: {warning}\033[0m")
+ logs[-1].warning.append(warning)
+ with open(log_filepath, 'a') as file:
+ file.write(f"Entry: {logs[-1].entry}, Warning: {warning}\n")
\ No newline at end of file
diff --git a/UniRig/src/data/order.py b/UniRig/src/data/order.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33deeed9ae177d6991e7444f31650e4902c7ab5
--- /dev/null
+++ b/UniRig/src/data/order.py
@@ -0,0 +1,112 @@
+from typing import Dict, List, Tuple, Union
+from collections import defaultdict
+from dataclasses import dataclass
+import yaml
+from box import Box
+
+from .spec import ConfigSpec
+
+@dataclass
+class OrderConfig(ConfigSpec):
+ '''
+ Config to handle bones re-ordering.
+ '''
+
+ # {skeleton_name: path}
+ skeleton_path: Dict[str, str]
+
+ # {cls: {part_name: [bone_name_1, bone_name_2, ...]}}
+ parts: Dict[str, Dict[str, List[str]]]
+
+ # {cls: parts of bones to be arranged in [part_name_1, part_name_2, ...]}
+ parts_order: Dict[str, List[str]]
+
+ @classmethod
+ def parse(cls, config):
+ cls.check_keys(config)
+ skeleton_path = config.skeleton_path
+ parts = {}
+ parts_order = {}
+ for (cls, path) in skeleton_path.items():
+ assert cls not in parts, 'cls conflicts'
+ d = Box(yaml.safe_load(open(path, 'r')))
+ parts[cls] = d.parts
+ parts_order[cls] = d.parts_order
+ return OrderConfig(
+ skeleton_path=skeleton_path,
+ parts=parts,
+ parts_order=parts_order,
+ )
+
+class Order():
+
+ # {part_name: [bone_name_1, bone_name_2, ...]}
+ parts: Dict[str, Dict[str, List[str]]]
+
+ # parts of bones to be arranged in [part_name_1, part_name_2, ...]
+ parts_order: Dict[str, List[str]]
+
+ def __init__(self, config: OrderConfig):
+ self.parts = config.parts
+ self.parts_order = config.parts_order
+
+ def part_exists(self, cls: str, part: str, names: List[str]) -> bool:
+ '''
+ Check if part exists.
+ '''
+ if part not in self.parts[cls]:
+ return False
+ for name in self.parts[cls][part]:
+ if name not in names:
+ return False
+ return True
+
+ def make_names(self, cls: Union[str, None], parts: List[Union[str, None]], num_bones: int) -> List[str]:
+ '''
+ Get names for specified cls.
+ '''
+ names = []
+ for part in parts:
+ if part is None: # spring
+ continue
+ if cls in self.parts and part in self.parts[cls]:
+ names.extend(self.parts[cls][part])
+ assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones"
+ for i in range(len(names), num_bones):
+ names.append(f"bone_{i}")
+ return names
+
+ def arrange_names(self, cls: str, names: List[str], parents: List[Union[int, None]]) -> Tuple[List[str], Dict[int, Union[str]]]:
+ '''
+ Arrange names according to required parts order.
+ '''
+ if cls not in self.parts_order:
+ return names, {0: None} # add a spring token
+ vis = defaultdict(bool)
+ name_to_id = {name: i for (i, name) in enumerate(names)}
+ new_names = []
+ parts_bias = {}
+ for part in self.parts_order[cls]:
+ if self.part_exists(cls=cls, part=part, names=names):
+ for name in self.parts[cls][part]:
+ vis[name] = True
+ flag = False
+ for name in self.parts[cls][part]:
+ pid = parents[name_to_id[name]]
+ if pid is None:
+ continue
+ if not vis[names[pid]]:
+ flag = True
+ break
+ if flag: # incorrect parts order and should immediately add a spring token
+ break
+ parts_bias[len(new_names)] = part
+ new_names.extend(self.parts[cls][part])
+ parts_bias[len(new_names)] = None # add a spring token
+ for name in names:
+ if name not in new_names:
+ new_names.append(name)
+ return new_names, parts_bias
+
+def get_order(config: OrderConfig) -> Order:
+ return Order(config=config)
\ No newline at end of file
diff --git a/UniRig/src/data/raw_data.py b/UniRig/src/data/raw_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..286342f6f7d4b47aa31f62908ee6d92b90d13131
--- /dev/null
+++ b/UniRig/src/data/raw_data.py
@@ -0,0 +1,307 @@
+from dataclasses import dataclass
+import numpy as np
+from numpy import ndarray
+
+import os
+from typing import Union, List, Tuple
+
+from .exporter import Exporter
+
+from ..tokenizer.spec import DetokenzeOutput
+from .order import Order
+
+@dataclass(frozen=True)
+class RawData(Exporter):
+ '''
+ Dataclass to handle data from processed model files.
+ '''
+
+ # vertices of the mesh, shape (N, 3), float32
+ vertices: Union[ndarray, None]
+
+ # normals of vertices, shape (N, 3), float32
+ vertex_normals: Union[ndarray, None]
+
+ # faces of mesh, shape (F, 3), face id starts from 0 to F-1, int64
+ faces: Union[ndarray, None]
+
+ # face normal of mesh, shape (F, 3), float32
+ face_normals: Union[ndarray, None]
+
+ # joints of bones, shape (J, 3), float32
+ joints: Union[ndarray, None]
+
+ # tails of joints, shape (J, 3), float32
+ tails: Union[ndarray, None]
+
+ # skinning of joints, shape (N, J), float32
+ skin: Union[ndarray, None]
+
+ # whether the joint has skin, bool
+ no_skin: Union[ndarray, None]
+
+ # parents of joints, None represents no parent(a root joint)
+ # make sure parent[k] < k
+ parents: Union[List[Union[int, None]], None]
+
+ # names of joints
+ names: Union[List[str], None]
+
+ # local coordinate
+ matrix_local: Union[ndarray, None]
+
+ # path to data
+ path: Union[str, None]=None
+
+ # data cls
+ cls: Union[str, None]=None
+
+ @staticmethod
+ def load(path: str) -> 'RawData':
+ data = np.load(path, allow_pickle=True)
+ d = {name: data[name][()] for name in data}
+ d['path'] = path
+ return RawData(**d)
+
+ def save(self, path: str):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ np.savez(file=path, **self.__dict__)
+
+ @property
+ def N(self):
+ '''
+ number of vertices
+ '''
+ return self.vertices.shape[0]
+
+ @property
+ def F(self):
+ '''
+ number of faces
+ '''
+ return self.faces.shape[0]
+
+ @property
+ def J(self):
+ '''
+ number of joints
+ '''
+ return self.joints.shape[0]
+
+ def check(self):
+ if self.names is not None and self.joints is not None:
+ assert len(self.names) == self.J
+ if self.names is not None and self.parents is not None:
+ assert len(self.names) == len(self.parents)
+ if self.parents is not None:
+ for (i, pid) in enumerate(self.parents):
+ if i==0:
+ assert pid is None
+ else:
+ assert pid is not None
+ assert pid < i
+
+ def export_pc(self, path: str, with_normal: bool=True, normal_size=0.01):
+ '''
+ export point cloud
+ '''
+ if with_normal:
+ self._export_pc(vertices=self.vertices, path=path, vertex_normals=self.vertex_normals, normal_size=normal_size)
+ else:
+ self._export_pc(vertices=self.vertices, path=path, vertex_normals=None, normal_size=normal_size)
+
+ def export_mesh(self, path: str):
+ '''
+ export mesh
+ '''
+ self._export_mesh(vertices=self.vertices, faces=self.faces, path=path)
+
+ def export_skeleton(self, path: str):
+ '''
+ export spring
+ '''
+ self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
+
+ def export_skeleton_sequence(self, path: str):
+ '''
+ export spring
+ '''
+ self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
+
+ def export_fbx(
+ self,
+ path: str,
+ extrude_size: float=0.03,
+ group_per_vertex: int=-1,
+ add_root: bool=False,
+ do_not_normalize: bool=False,
+ use_extrude_bone: bool=True,
+ use_connect_unique_child: bool=True,
+ extrude_from_parent: bool=True,
+ use_tail: bool=False,
+ custom_vertex_group: Union[ndarray, None]=None,
+ ):
+ '''
+ export the whole model with skining
+ '''
+ self._export_fbx(
+ path=path,
+ vertices=self.vertices,
+ joints=self.joints,
+ skin=self.skin if custom_vertex_group is None else custom_vertex_group,
+ parents=self.parents,
+ names=self.names,
+ faces=self.faces,
+ extrude_size=extrude_size,
+ group_per_vertex=group_per_vertex,
+ add_root=add_root,
+ do_not_normalize=do_not_normalize,
+ use_extrude_bone=use_extrude_bone,
+ use_connect_unique_child=use_connect_unique_child,
+ extrude_from_parent=extrude_from_parent,
+ tails=self.tails if use_tail else None,
+ )
+
+ def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256]):
+ self._export_render(
+ path=path,
+ vertices=self.vertices,
+ faces=self.faces,
+ bones=np.concatenate([self.joints, self.tails], axis=-1),
+ resolution=resolution,
+ )
+
+@dataclass(frozen=True)
+class RawSkeleton(Exporter):
+ '''
+ Dataclass to handle skeleton from AR.
+ '''
+ # joints of bones, shape (J, 3), float32
+ joints: Union[ndarray, None]
+
+ # tails of joints, shape (J, 3), float32
+ tails: Union[ndarray, None]
+
+ # whether the joint has skin, bool
+ no_skin: Union[ndarray, None]
+
+ # parents of joints, None represents no parent(a root joint)
+ # make sure parent[k] < k
+ parents: Union[List[Union[int, None]], None]
+
+ # names of joints
+ names: Union[List[str], None]
+
+ @staticmethod
+ def load(path: str) -> 'RawSkeleton':
+ data = np.load(path, allow_pickle=True)
+ return RawSkeleton(**{name: data[name][()] for name in data})
+
+ def save(self, path: str):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ np.savez(file=path, **self.__dict__)
+
+ @staticmethod
+ def from_detokenize_output(res: DetokenzeOutput, order: Union[Order, None]) -> 'RawSkeleton':
+ J = len(res.bones)
+ names = order.make_names(cls=res.cls, parts=res.parts, num_bones=J)
+ joints = res.joints
+ p_joints = res.p_joints
+ parents = []
+ for (i, joint) in enumerate(joints):
+ if i == 0:
+ parents.append(None)
+ continue
+ p_joint = p_joints[i]
+ dis = 999999
+ pid = None
+ for j in reversed(range(i)):
+ n_dis = ((joints[j] - p_joint)**2).sum()
+ if n_dis < dis:
+ pid = j
+ dis = n_dis
+ parents.append(pid)
+ return RawSkeleton(
+ joints=joints,
+ tails=res.tails,
+ no_skin=res.no_skin,
+ parents=parents,
+ names=names,
+ )
+
+ def export_skeleton(self, path: str):
+ '''
+ export spring
+ '''
+ self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
+
+ def export_skeleton_sequence(self, path: str):
+ '''
+ export spring
+ '''
+ self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
+
+ def export_fbx(
+ self,
+ path: str,
+ extrude_size: float=0.03,
+ group_per_vertex: int=-1,
+ add_root: bool=False,
+ do_not_normalize: bool=False,
+ use_extrude_bone: bool=True,
+ use_connect_unique_child: bool=True,
+ extrude_from_parent: bool=True,
+ use_tail: bool=False,
+ ):
+ '''
+ export the whole model with skining
+ '''
+ self._export_fbx(
+ path=path,
+ vertices=None,
+ joints=self.joints,
+ skin=None,
+ parents=self.parents,
+ names=self.names,
+ faces=None,
+ extrude_size=extrude_size,
+ group_per_vertex=group_per_vertex,
+ add_root=add_root,
+ do_not_normalize=do_not_normalize,
+ use_extrude_bone=use_extrude_bone,
+ use_connect_unique_child=use_connect_unique_child,
+ extrude_from_parent=extrude_from_parent,
+ tails=self.tails if use_tail else None,
+ )
+
+ def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256]):
+ self._export_render(
+ path=path,
+ vertices=None,
+ faces=None,
+ bones=np.concatenate([self.joints, self.tails], axis=-1),
+ resolution=resolution,
+ )
+
+@dataclass
+class RawSkin(Exporter):
+ '''
+ Dataclass to handle skeleton from AR.
+ '''
+ # skin, shape (J, N)
+ skin: ndarray
+
+ # always sampled, shape (N, 3)
+ vertices: Union[ndarray, None]=None
+
+ # for future use, shape (J, 3)
+ joints: Union[ndarray, None]=None
+
+ @staticmethod
+ def load(path: str) -> 'RawSkin':
+ data = np.load(path, allow_pickle=True)
+ return RawSkin(**{name: data[name][()] for name in data})
+
+ def save(self, path: str):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ np.savez(file=path, **self.__dict__)
\ No newline at end of file
diff --git a/UniRig/src/data/sampler.py b/UniRig/src/data/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..919bb14a514747b45afd16eaa9481f7c4798dccb
--- /dev/null
+++ b/UniRig/src/data/sampler.py
@@ -0,0 +1,210 @@
+from typing import List
+from heapq import heappush, heappop, heapify
+from dataclasses import dataclass
+from abc import ABC, abstractmethod
+import numpy as np
+from numpy import ndarray
+
+from typing import Dict, Tuple
+
+from .asset import Asset
+from .spec import ConfigSpec
+
+@dataclass
+class SamplerConfig(ConfigSpec):
+ '''
+ Config to handle bones re-ordering.
+ '''
+ # which sampler to use
+ method: str
+
+ # how many samples in total
+ num_samples: int
+
+ # how many vertex samples
+ vertex_samples: int
+
+ # kwargs
+ kwargs: Dict[str, Dict]
+
+ @classmethod
+ def parse(cls, config) -> 'SamplerConfig':
+ cls.check_keys(config)
+ return SamplerConfig(
+ method=config.method,
+ num_samples=config.get('num_samples', 0),
+ vertex_samples=config.get('vertex_samples', 0),
+ kwargs=config.get('kwargs', {}),
+ )
+
+@dataclass
+class SamplerResult():
+ # sampled vertices
+ vertices: ndarray
+
+ # sampled normals
+ normals: ndarray
+
+ # sampled vertex groups
+ vertex_groups: Dict[str, ndarray]
+
+class Sampler(ABC):
+ '''
+ Abstract class for samplers.
+ '''
+
+ def _sample_barycentric(
+ self,
+ vertex_group: ndarray,
+ faces: ndarray,
+ face_index: ndarray,
+ random_lengths: ndarray,
+ ):
+ v_origins = vertex_group[faces[face_index, 0]]
+ v_vectors = vertex_group[faces[face_index, 1:]]
+ v_vectors -= v_origins[:, np.newaxis, :]
+
+ sample_vector = (v_vectors * random_lengths).sum(axis=1)
+ v_samples = sample_vector + v_origins
+ return v_samples
+
+ @abstractmethod
+ def __init__(self, config: SamplerConfig):
+ pass
+
+ @abstractmethod
+ def sample(
+ self,
+ asset: Asset,
+ ) -> SamplerResult:
+ '''
+ Return sampled vertices, sampled normals and vertex groups.
+ '''
+ pass
+
+class SamplerOrigin(Sampler):
+ def __init__(self, config: SamplerConfig):
+ super().__init__(config)
+ self.num_samples = config.num_samples
+ self.vertex_samples = config.vertex_samples
+
+ def sample(
+ self,
+ asset: Asset,
+ ) -> SamplerResult:
+ perm = np.random.permutation(asset.vertices.shape[0])
+ if asset.vertices.shape[0] < self.num_samples:
+ m = self.num_samples - asset.vertices.shape[0]
+ perm = np.concatenate([perm, np.random.randint(0, asset.vertices.shape[0], (m,))])
+ perm = perm[:self.num_samples]
+ n_v = asset.vertices[perm]
+ n_n = asset.vertex_normals[perm]
+ n_vg = {name: v[perm] for name, v in asset.vertex_groups.items()}
+ return SamplerResult(
+ vertices=n_v,
+ normals=n_n,
+ vertex_groups=n_vg,
+ )
+
+class SamplerMix(Sampler):
+ def __init__(self, config: SamplerConfig):
+ super().__init__(config)
+ self.num_samples = config.num_samples
+ self.vertex_samples = config.vertex_samples
+ assert self.num_samples >= self.vertex_samples, 'num_samples should >= vertex_samples'
+
+ @property
+ def mesh_preserve(self):
+ return self.num_samples==-1
+
+ def sample(
+ self,
+ asset: Asset,
+ ) -> SamplerResult:
+ # 1. sample vertices
+ num_samples = self.num_samples
+ perm = np.random.permutation(asset.vertices.shape[0])
+ vertex_samples = min(self.vertex_samples, asset.vertices.shape[0])
+ num_samples -= vertex_samples
+ perm = perm[:vertex_samples]
+ n_vertex = asset.vertices[perm]
+ n_normal = asset.vertex_normals[perm]
+ n_v = {name: v[perm] for name, v in asset.vertex_groups.items()}
+
+ # 2. sample surface
+ perm = np.random.permutation(num_samples)
+ vertex_samples, face_index, random_lengths = sample_surface(
+ num_samples=num_samples,
+ vertices=asset.vertices,
+ faces=asset.faces,
+ return_weight=True,
+ )
+ vertex_samples = np.concatenate([n_vertex, vertex_samples], axis=0)
+ normal_samples = np.concatenate([n_normal, asset.face_normals[face_index]], axis=0)
+ vertex_group_samples = {}
+ for n, v in asset.vertex_groups.items():
+ g = self._sample_barycentric(
+ vertex_group=v,
+ faces=asset.faces,
+ face_index=face_index,
+ random_lengths=random_lengths,
+ )
+ vertex_group_samples[n] = np.concatenate([n_v[n], g], axis=0)
+ return SamplerResult(
+ vertices=vertex_samples,
+ normals=normal_samples,
+ vertex_groups=vertex_group_samples,
+ )
+
+def sample_surface(
+ num_samples: int,
+ vertices: ndarray,
+ faces: ndarray,
+ return_weight: bool=False,
+):
+ '''
+ Randomly pick samples according to face area.
+
+ See sample_surface: https://github.com/mikedh/trimesh/blob/main/trimesh/sample.py
+ '''
+ # get face area
+ offset_0 = vertices[faces[:, 1]] - vertices[faces[:, 0]]
+ offset_1 = vertices[faces[:, 2]] - vertices[faces[:, 0]]
+ face_weight = np.cross(offset_0, offset_1, axis=-1)
+ face_weight = (face_weight * face_weight).sum(axis=1)
+
+ weight_cum = np.cumsum(face_weight, axis=0)
+ face_pick = np.random.rand(num_samples) * weight_cum[-1]
+ face_index = np.searchsorted(weight_cum, face_pick)
+
+ # pull triangles into the form of an origin + 2 vectors
+ tri_origins = vertices[faces[:, 0]]
+ tri_vectors = vertices[faces[:, 1:]]
+ tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))
+
+ # pull the vectors for the faces we are going to sample from
+ tri_origins = tri_origins[face_index]
+ tri_vectors = tri_vectors[face_index]
+
+ # randomly generate two 0-1 scalar components to multiply edge vectors b
+ random_lengths = np.random.rand(len(tri_vectors), 2, 1)
+
+ random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0
+ random_lengths[random_test] -= 1.0
+ random_lengths = np.abs(random_lengths)
+
+ sample_vector = (tri_vectors * random_lengths).sum(axis=1)
+ vertex_samples = sample_vector + tri_origins
+ if not return_weight:
+ return vertex_samples
+ return vertex_samples, face_index, random_lengths
+
+def get_sampler(config: SamplerConfig) -> Sampler:
+ method = config.method
+ if method=='origin':
+ sampler = SamplerOrigin(config)
+ elif method=='mix':
+ sampler = SamplerMix(config)
+ else:
+ raise ValueError(f"sampler method {method} not supported")
+ return sampler
\ No newline at end of file
diff --git a/UniRig/src/data/spec.py b/UniRig/src/data/spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..9331427e1edd1ef69ad87f9b65a29ca0bb71f9f8
--- /dev/null
+++ b/UniRig/src/data/spec.py
@@ -0,0 +1,15 @@
+from abc import ABC, abstractmethod
+from dataclasses import fields
+
+class ConfigSpec(ABC):
+ @classmethod
+ def check_keys(cls, config):
+ expect = [field.name for field in fields(cls)]
+ for key in config.keys():
+ if key not in expect:
+ raise ValueError(f"expect names {expect} in {cls.__name__}, found {key}")
+
+ @classmethod
+ @abstractmethod
+ def parse(cls, config):
+ pass
diff --git a/UniRig/src/data/tail.py b/UniRig/src/data/tail.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba6307ae63c97efdd515ecf715a130319fff1542
--- /dev/null
+++ b/UniRig/src/data/tail.py
@@ -0,0 +1,50 @@
+from collections import defaultdict
+from dataclasses import dataclass
+import numpy as np
+from numpy import ndarray
+
+from typing import Tuple
+
+from .asset import Asset
+from .spec import ConfigSpec
+
+@dataclass
+class TailConfig(ConfigSpec):
+ '''
+ Config to handle tails.
+ '''
+
+ # copy joints to tails
+ copy_joint_to_tail: bool
+
+ # if the joint has only one son, then connect tail to son's joint
+ connect_tail_to_unique_son: bool
+
+ @classmethod
+ def parse(cls, config) -> 'TailConfig':
+ cls.check_keys(config)
+ return TailConfig(
+ copy_joint_to_tail=config.copy_joint_to_tail,
+ connect_tail_to_unique_son=config.connect_tail_to_unique_son,
+ )
+
+class Tail():
+
+ def __init__(self, config: TailConfig):
+ self.config = config
+
+ def process_tail(self, asset: Asset):
+ if self.config.copy_joint_to_tail:
+ assert asset.tails is None, 'copying joints to existing tails is not permitted, please change copy_joint_to_tail to False in transform config'
+ asset.tails = asset.joints.copy()
+ if self.config.connect_tail_to_unique_son and asset.tails is not None:
+ children = defaultdict(list)
+ for (id, p) in enumerate(asset.parents):
+ if p is not None:
+ children[p].append(id)
+ for i in range(asset.J):
+ if len(children[i]) == 1:
+ asset.tails[i] = asset.joints[children[i][0]]
+
+def get_tail(config: TailConfig) -> Tail:
+ return Tail(config=config)
\ No newline at end of file
diff --git a/UniRig/src/data/transform.py b/UniRig/src/data/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..018124f17e1758d41062200ecaccf02a5fbbb6ae
--- /dev/null
+++ b/UniRig/src/data/transform.py
@@ -0,0 +1,107 @@
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Union, List, Tuple
+from copy import deepcopy
+
+from .asset import Asset
+from .augment import AugmentConfig, Augment, get_augments
+from .order import OrderConfig, Order, get_order
+from .sampler import SamplerConfig, get_sampler
+from .vertex_group import VertexGroupConfig, get_vertex_groups
+from .tail import TailConfig, get_tail
+from .spec import ConfigSpec
+
+@dataclass
+class TransformConfig(ConfigSpec):
+
+ tail_config: Union[TailConfig, None]=None,
+
+ order_config: Union[OrderConfig, None]=None,
+
+ vertex_group_config: Union[VertexGroupConfig, None]=None,
+
+ augment_config: Union[AugmentConfig, None]=None,
+
+ sampler_config: Union[SamplerConfig, None]=None,
+
+ @classmethod
+ def parse(cls, config) -> 'TransformConfig':
+ cls.check_keys(config)
+ tail_config = config.get('tail_config', None)
+ order_config = config.get('order_config', None)
+ vertex_group_config = config.get('vertex_group_config', None)
+ augment_config = config.get('augment_config', None)
+ sampler_config = config.get('sampler_config', None)
+
+ if tail_config is not None:
+ tail_config = TailConfig.parse(config=tail_config)
+ if order_config is not None:
+ order_config = OrderConfig.parse(config=order_config)
+ if vertex_group_config is not None:
+ vertex_group_config = VertexGroupConfig.parse(config=vertex_group_config)
+ if augment_config is not None:
+ augment_config = AugmentConfig.parse(config=augment_config)
+ if sampler_config is not None:
+ sampler_config = SamplerConfig.parse(config=sampler_config)
+
+ return TransformConfig(
+ tail_config=tail_config,
+ order_config=order_config,
+ vertex_group_config=vertex_group_config,
+ augment_config=augment_config,
+ sampler_config=sampler_config,
+ )
+
+def transform_asset(
+ asset: Asset,
+ transform_config: TransformConfig,
+) -> Tuple[List[Augment], List[Augment]]:
+ assert isinstance(transform_config, TransformConfig), f"found {type(transform_config)}"
+ # 1. try processing tails
+ # TODO: use a better method
+ if transform_config.tail_config is not None:
+ tail = get_tail(config=transform_config.tail_config)
+ tail.process_tail(asset=asset)
+
+ # 2. arrange bones
+ if transform_config.order_config is not None:
+ order = get_order(config=transform_config.order_config)
+ asset.set_order(order=order)
+
+ # 3. collapse must perform first
+ if transform_config.augment_config:
+ first_augments, second_augments = get_augments(config=transform_config.augment_config)
+ else:
+ first_augments = []
+ second_augments = []
+
+ kwargs = {}
+ for augment in first_augments:
+ augment.transform(asset=asset, **kwargs)
+
+ # 4. get vertex groups
+ if transform_config.vertex_group_config is not None:
+ vertex_groups = get_vertex_groups(config=transform_config.vertex_group_config)
+ d = {}
+ for v in vertex_groups:
+ d.update(v.get_vertex_group(asset=asset))
+ asset.vertex_groups = d
+ else:
+ asset.vertex_groups = {}
+
+ # 5. regular augments
+ for augment in second_augments:
+ augment.transform(asset=asset, **kwargs)
+
+ # 6. sample
+ if transform_config.sampler_config is not None:
+ sampler = get_sampler(config=transform_config.sampler_config)
+ res = sampler.sample(asset=asset)
+ asset.sampled_vertices = res.vertices
+ asset.sampled_normals = res.normals
+ asset.sampled_vertex_groups = res.vertex_groups
+ else:
+ asset.sampled_vertices = asset.vertices.copy()
+ asset.sampled_normals = asset.vertex_normals.copy()
+ asset.sampled_vertex_groups = deepcopy(asset.vertex_groups)
+ return first_augments, second_augments
\ No newline at end of file
diff --git a/UniRig/src/data/utils.py b/UniRig/src/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aec48872ce4332d02017f4ffaf3ee8e7aad6348e
--- /dev/null
+++ b/UniRig/src/data/utils.py
@@ -0,0 +1,258 @@
+import torch
+import numpy as np
+from numpy import ndarray
+from torch import Tensor, FloatTensor
+from typing import Tuple, Union
+
+from scipy.spatial.transform import Rotation as R
+from scipy.sparse import csc_matrix
+import numpy as np
+
+def quaternion_to_matrix(x, use_4x4=True) -> FloatTensor:
+ """
+ Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#quaternion_to_matrix
+ """
+ if not isinstance(x, Tensor):
+ quaternions = torch.tensor(x, dtype=torch.float32)
+ else:
+ quaternions = x
+ r, i, j, k = torch.unbind(quaternions, -1)
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+ device = quaternions.device
+
+ if use_4x4:
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
+ torch.ones(quaternions.shape[:-1], device=device, dtype=torch.float32),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (4, 4))
+ else:
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+def axis_angle_to_quaternion(axis_angle: FloatTensor) -> FloatTensor:
+ """
+ Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#axis_angle_to_quaternion
+ """
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
+ half_angles = angles * 0.5
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ quaternions = torch.cat(
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
+ )
+ return quaternions
+
+def axis_angle_to_matrix(axis_angle: Union[FloatTensor, ndarray]) -> Union[FloatTensor, ndarray]:
+ """
+ Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#axis_angle_to_matrix
+ """
+ if isinstance(axis_angle, FloatTensor):
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
+ else:
+ res = np.pad(R.from_rotvec(axis_angle).as_matrix(), ((0, 0), (0, 1), (0, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0)))
+ assert res.ndim == 3
+ res[:, -1, -1] = 1
+ return res
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
+
+def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[
+ torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,))
+ return standardize_quaternion(out)
+
+def linear_blend_skinning(
+ vertex: Union[FloatTensor, ndarray],
+ matrix_local: Union[FloatTensor, ndarray],
+ matrix: Union[FloatTensor, ndarray],
+ skin: Union[FloatTensor, ndarray],
+ pad: int=0,
+ value: float=0.,
+) -> Union[FloatTensor, ndarray]:
+ '''
+ Args:
+ vertex: (B, N, 4-pad) or (N, 4-pad)
+ matrix_local: (B, J, 4, 4) or (J, 4, 4)
+ matrix: (B, J, 4, 4) or (J, 4, 4)
+ skin: (B, N, J) or (N, J), value of pseudo bones should be 0
+ Returns:
+ (B, N, 3) or (N, 3)
+ '''
+ assert vertex.shape[-1] + pad == 4
+ if isinstance(vertex, Tensor):
+ dims = vertex.dim()
+ elif isinstance(vertex, ndarray):
+ dims = vertex.ndim
+ else:
+ raise NotImplementedError()
+ if dims == 3: # Case: (B, N, 3+pad)
+ assert isinstance(vertex, Tensor)
+ J = matrix_local.shape[1]
+ # (B, J, 3+pad, N)
+ offset = (
+ matrix_local.inverse() @
+ torch.nn.functional.pad(vertex, (0, pad, 0, 0, 0, 0), value=value).unsqueeze(1).transpose(2, 3).repeat(1, J, 1, 1)
+ )
+ # (B, J, 4, N)
+ per_bone_matrix = matrix @ offset
+ # (B, J, 4, N)
+ weighted_per_bone_matrix = skin.transpose(1, 2).unsqueeze(2) * per_bone_matrix
+ # (B, 3, N)
+ g = weighted_per_bone_matrix.sum(dim=1)
+ # (B, 3, N)
+ final = g[:, 0:3, :] / (skin.transpose(1, 2).sum(dim=1) + 1e-8).unsqueeze(1)
+ return final.permute(0, 2, 1)
+
+ elif dims == 2: # Case: (N, 3+pad)
+ if isinstance(vertex, Tensor):
+ J = matrix_local.shape[0]
+ offset = (
+ matrix_local.inverse() @
+ torch.nn.functional.pad(vertex, (0, pad, 0, 0), value=value).unsqueeze(0).transpose(1, 2).repeat(J, 1, 1)
+ )
+ per_bone_matrix = matrix @ offset
+ weighted_per_bone_matrix = skin.transpose(0, 1).unsqueeze(1) * per_bone_matrix
+ g = weighted_per_bone_matrix.sum(dim=0)
+ final = g[0:3, :] / (skin.transpose(0, 1).sum(dim=0) + 1e-8).unsqueeze(0)
+ return final.permute(1, 0) # Output shape (N, 3)
+ else:
+ J = matrix_local.shape[0]
+ N = vertex.shape[0]
+ # (4, N)
+ padded = np.pad(vertex, ((0, 0), (0, pad)), 'constant', constant_values=(0, value)).T
+ # (J, 4, 4)
+ trans = matrix @ np.linalg.inv(matrix_local)
+ weighted_per_bone_matrix = []
+ # (J, N)
+ mask = (skin > 0).T
+ for i in range(J):
+ offset = np.zeros((4, N), dtype=np.float32)
+ offset[:, mask[i]] = (trans[i] @ padded[:, mask[i]]) * skin.T[i, mask[i]]
+ weighted_per_bone_matrix.append(offset)
+ weighted_per_bone_matrix = np.stack(weighted_per_bone_matrix)
+ g = np.sum(weighted_per_bone_matrix, axis=0)
+ final = g[:3, :] / (np.sum(skin, axis=1) + 1e-8)
+ return final.T
+ else:
+ assert 0, f'unsupported shape: {vertex.shape}'
diff --git a/UniRig/src/data/vertex_group.py b/UniRig/src/data/vertex_group.py
new file mode 100644
index 0000000000000000000000000000000000000000..645e065712cd825452c359f201893ad28656d916
--- /dev/null
+++ b/UniRig/src/data/vertex_group.py
@@ -0,0 +1,577 @@
+import platform
+import os
+if platform.system() == "Linux":
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
+
+from typing import Dict, List, Tuple
+from dataclasses import dataclass
+from collections import defaultdict
+from abc import ABC, abstractmethod
+import numpy as np
+from numpy import ndarray
+
+from scipy.spatial import cKDTree
+from scipy.sparse import csr_matrix
+from scipy.sparse.csgraph import shortest_path, connected_components
+
+from .asset import Asset
+from .spec import ConfigSpec
+
+@dataclass
+class VertexGroupConfig(ConfigSpec):
+ '''
+ Config to sample vertex group.
+ '''
+
+ # names
+ names: List[str]
+
+ # kwargs
+ kwargs: Dict[str, Dict]
+
+ @classmethod
+ def parse(cls, config) -> 'VertexGroupConfig':
+ cls.check_keys(config)
+ return VertexGroupConfig(
+ names=config.get('names', []),
+ kwargs=config.get('kwargs', {}),
+ )
+
+class VertexGroup(ABC):
+
+ @abstractmethod
+ def __init__(self, **kwargs):
+ pass
+
+ @abstractmethod
+ def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
+ pass
+
+class VertexGroupSkin(VertexGroup):
+ '''
+ Capture skin.
+ '''
+
+ def __init__(self, **kwargs):
+ pass
+
+ def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
+ return {
+ 'skin': asset.skin / (asset.skin.sum(axis=-1, keepdims=True) + 1e-6),
+ }
+
+class VertexGroupGeodesicDistance(VertexGroup):
+ '''
+ Calculate geodesic distance.
+ '''
+ def __init__(self, **kwargs):
+ self.deterministic = kwargs.get('deterministic', False)
+ self.soft_mask = kwargs.get('soft_mask', False)
+
+ def _prepare(
+ self,
+ joints: ndarray, # (J, 3)
+ edges: List[Tuple[int, int]],
+ ) -> Tuple[ndarray, ndarray]:
+ J = joints.shape[0]
+ dis_matrix = np.ones((J, J)) * 100.0
+ step_matrix = np.ones((J, J)) * 100.0
+ def dis(x: ndarray, y: ndarray):
+ return np.linalg.norm(x-y)
+ for i in range(J):
+ dis_matrix[i, i] = 0.
+ step_matrix[i, i] = 0.
+ for edge in edges:
+ dis_matrix[edge[0], edge[1]] = dis(joints[edge[0]], joints[edge[1]])
+ dis_matrix[edge[1], edge[0]] = dis(joints[edge[0]], joints[edge[1]])
+ step_matrix[edge[0], edge[1]] = 1
+ step_matrix[edge[1], edge[0]] = 1
+ # floyd
+ for k in range(J):
+ dis_matrix = np.minimum(dis_matrix, dis_matrix[:, k][:, np.newaxis] + dis_matrix[k, :][np.newaxis, :])
+ step_matrix = np.minimum(step_matrix, step_matrix[:, k][:, np.newaxis] + step_matrix[k, :][np.newaxis, :])
+ return dis_matrix, step_matrix
+
+ def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
+ children = defaultdict(list)
+ edges = []
+ for (id, p) in enumerate(asset.parents):
+ if p is not None:
+ edges.append((id, p))
+ children[p].append(id)
+ child = []
+ tails = asset.tails.copy()
+ for id in range(asset.J):
+ if len(children[id]) == 1:
+ child.append(children[id][0])
+ else:
+ child.append(id)
+ if self.deterministic:
+ tails[id] = asset.joints[id]
+ child = np.array(child)
+ dis_matrix, step_matrix = self._prepare(
+ joints=asset.joints,
+ edges=edges,
+ )
+ geo_dis, geo_mask = get_geodesic_distance(
+ vertices=asset.vertices,
+ joints=asset.joints,
+ tails=tails,
+ dis_matrix=dis_matrix,
+ step_matrix=step_matrix,
+ child=child,
+ soft_mask=self.soft_mask,
+ )
+ return {
+ 'geodesic_distance': geo_dis,
+ 'geodesic_mask': geo_mask,
+ }
+
+class VertexGroupVoxelSkin(VertexGroup):
+ '''
+ Capture voxel skin.
+ '''
+
+ def __init__(self, **kwargs):
+ self.grid = kwargs.get('grid', 64)
+ self.alpha = kwargs.get('alpha', 0.5)
+ self.link_dis = kwargs.get('link_dis', 0.00001)
+ self.grid_query = kwargs.get('grid_query', 27)
+ self.vertex_query = kwargs.get('vertex_query', 27)
+ self.grid_weight = kwargs.get('grid_weight', 3.0)
+ self.mode = kwargs.get('mode', 'square')
+
+ def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
+
+ # normalize into [-1, 1] first
+ min_vals = np.min(asset.vertices, axis=0)
+ max_vals = np.max(asset.vertices, axis=0)
+
+ center = (min_vals + max_vals) / 2
+
+ scale = np.max(max_vals - min_vals) / 2
+
+ normalized_vertices = (asset.vertices - center) / scale
+ normalized_joints = (asset.joints - center) / scale
+
+ grid_indices, grid_coords = voxelization(
+ vertices=normalized_vertices,
+ faces=asset.faces,
+ grid=self.grid,
+ )
+ skin = voxel_skin(
+ grid=self.grid,
+ grid_coords=grid_coords,
+ joints=normalized_joints,
+ vertices=normalized_vertices,
+ faces=asset.faces,
+ alpha=self.alpha,
+ link_dis=self.link_dis,
+ grid_query=self.grid_query,
+ vertex_query=self.vertex_query,
+ grid_weight=self.grid_weight,
+ mode=self.mode,
+ )
+ skin = np.nan_to_num(skin, nan=0., posinf=0., neginf=0.)
+ return {
+ 'voxel_skin': skin,
+ }
+
+class VertexGroupMeshPartDistance(VertexGroup):
+ def __init__(self, **kwargs):
+ self.part_dim = kwargs['part_dim']
+ self.dis_dim = kwargs['dis_dim']
+
+ def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
+ tot, vertex_labels, face_labels = find_connected_components(asset.vertices, asset.faces)
+ # (N, dis_dim)
+ part_distances = compute_distances_in_components(asset.vertices, asset.faces, vertex_labels, tot, self.dis_dim)
+ # (tot, part_dim)
+ part_vectors = generate_spread_vectors(tot, self.part_dim)
+ # (N, part_dim)
+ part_vectors = np.zeros((asset.vertices.shape[0], self.part_dim))
+ for i in range(tot):
+ part_vectors[labels == i] = part_vectors[i]
+ return {
+ 'num_parts': tot,
+ 'part_vectors': part_vectors,
+ 'part_distances': part_distances,
+ }
+
+# TODO: move this into a new file
+class VertexGroupMeshParts(VertexGroup):
+ def __init__(self, **kwargs):
+ pass
+
+ def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
+ tot, vertex_labels, face_labels = find_connected_components(asset.vertices, asset.faces)
+ asset.meta['num_parts'] = tot
+ asset.meta['vertex_labels'] = vertex_labels
+ asset.meta['face_labels'] = face_labels
+ return {}
+
+def get_geodesic_distance(
+ vertices: ndarray, # (N, 3)
+ joints: ndarray, # (J, 3)
+ tails: ndarray, # (J, 3)
+ dis_matrix: ndarray, # (J, J)
+ step_matrix: ndarray, # (J, J)
+ child: ndarray,
+ eps: float=1e-4,
+ soft_mask: bool=False,
+) -> Tuple[ndarray, ndarray]:
+ # (J, 3)
+ offset = tails - joints
+ inv = (1./(offset * offset + eps).sum(axis=-1))[np.newaxis, ...]
+ # head
+ g0 = tails[np.newaxis, ...] - vertices[:, np.newaxis, :]
+ c0 = (g0 * offset[np.newaxis, ...]).sum(axis=-1) * inv
+ # tail
+ g1 = vertices[:, np.newaxis, :] - joints[np.newaxis, ...]
+ c1 = (g1 * offset[np.newaxis, ...]).sum(axis=-1) * inv
+ # (N, J)
+ scale0 = (np.clip(c0, 0., 1.) + eps) / (np.clip(c0, 0., 1.) + np.clip(c1, 0., 1.) + eps * 2)
+ scale1 = -scale0 + 1
+ # (N, J, 3)
+ nearest = scale0[..., np.newaxis] * joints[np.newaxis, ...] + scale1[..., np.newaxis] * tails[np.newaxis, ...]
+ # (N, J)
+ dis = np.linalg.norm(vertices[:, np.newaxis, :] - nearest, axis=-1)
+ # (N)
+ index = np.argmin(dis, axis=1)
+ # (N)
+ r = np.arange(dis.shape[0])
+ # (N, J)
+ res = (
+ dis_matrix[index] * scale0[r[:, np.newaxis], index[:, np.newaxis]] +
+ dis_matrix[child[index]] * scale1[r[:, np.newaxis], index[:, np.newaxis]]
+ )
+ if soft_mask:
+ mask = (1.0 - (
+ step_matrix[index] * scale0[r[:, np.newaxis], index[:, np.newaxis]] +
+ step_matrix[child[index]] * scale1[r[:, np.newaxis], index[:, np.newaxis]]
+ )).clip(0., 1.).astype(np.float32)
+ else:
+ mask = ((
+ step_matrix[index] * scale0[r[:, np.newaxis], index[:, np.newaxis]] +
+ step_matrix[child[index]] * scale1[r[:, np.newaxis], index[:, np.newaxis]]
+ ) <= 1.).astype(np.float32)
+
+ # normalize geo dis
+ row_min = np.min(res, axis=0, keepdims=True)
+ row_max = np.max(res, axis=0, keepdims=True)
+ res = (res - row_min) / (row_max - row_min)
+ res = np.nan_to_num(res, nan=0., posinf=0., neginf=0.)
+ return res, mask
+
+def get_vertex_groups(config: VertexGroupConfig) -> List[VertexGroup]:
+ vertex_groups = []
+ MAP = {
+ 'geodesic_distance': VertexGroupGeodesicDistance,
+ 'skin': VertexGroupSkin,
+ 'voxel_skin': VertexGroupVoxelSkin,
+ 'mesh_part_distance': VertexGroupMeshPartDistance,
+ 'mesh_parts': VertexGroupMeshParts,
+ }
+ for name in config.names:
+ assert name in MAP, f"expect: [{','.join(MAP.keys())}], found: {name}"
+ vertex_groups.append(MAP[name](**config.kwargs.get(name, {})))
+ return vertex_groups
+
+def voxelization(
+ vertices: ndarray,
+ faces: ndarray,
+ grid: int=256,
+ scale: float=1.0,
+):
+ import pyrender
+ znear = 0.05
+ zfar = 4.0
+ eye_dis = 2.0 # distance from eye to origin
+ r_faces = np.stack([faces[:, 0], faces[:, 2], faces[:, 1]], axis=-1)
+ # get zbuffers
+ mesh = pyrender.Mesh(
+ primitives=[
+ pyrender.Primitive(
+ positions=vertices,
+ indices=np.concatenate([faces, r_faces]), # double sided
+ mode=pyrender.GLTF.TRIANGLES,
+ )
+ ]
+ )
+ scene = pyrender.Scene(bg_color=[0, 0, 0, 0])
+ scene.add(mesh)
+
+ camera = pyrender.OrthographicCamera(xmag=scale, ymag=scale, znear=znear, zfar=zfar)
+ camera_poses = {}
+ # coordinate:
+ # see https://pyrender.readthedocs.io/en/latest/examples/cameras.html
+ camera_poses['+z'] = np.array([
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, eye_dis],
+ [0, 0, 0, 1],
+ ], dtype=np.float32) # look at +z (bottom to top)
+ camera_poses['-z'] = np.array([
+ [-1, 0, 0, 0],
+ [ 0, 1, 0, 0],
+ [ 0, 0,-1, -eye_dis],
+ [ 0, 0, 0, 1],
+ ], dtype=np.float32) # look at -z (top to bottom)
+ camera_poses['+y'] = np.array([
+ [1, 0, 0, 0],
+ [0, 0,-1, -eye_dis],
+ [0, 1, 0, 0],
+ [0, 0, 0, 1],
+ ], dtype=np.float32) # look at +y (because model is looking at -y)(front to back)
+ camera_poses['-y'] = np.array([
+ [1, 0, 0, 0],
+ [0, 0, 1, eye_dis],
+ [0,-1, 0, 0],
+ [0, 0, 0, 1],
+ ], dtype=np.float32) # look at -y (back to front)
+ camera_poses['+x'] = np.array([
+ [0, 0,-1, -eye_dis],
+ [0, 1, 0, 0],
+ [1, 0, 0, 0],
+ [0, 0, 0, 1],
+ ], dtype=np.float32) # look at +x (left to right)
+ camera_poses['-x'] = np.array([
+ [ 0, 0, 1, eye_dis],
+ [ 0, 1, 0, 0],
+ [-1, 0, 0, 0],
+ [ 0, 0, 0, 1],
+ ], dtype=np.float32) # look at -x (righy to left)
+ for name, pose in camera_poses.items():
+ scene.add(camera, name=name, pose=pose)
+ camera_nodes = [node for node in scene.get_nodes() if isinstance(node, pyrender.Node) and node.camera is not None]
+ renderer = pyrender.OffscreenRenderer(viewport_width=grid, viewport_height=grid)
+
+ i, j, k = np.indices((grid, grid, grid))
+ grid_indices = np.stack((i.ravel(), j.ravel(), k.ravel()), axis=1, dtype=np.int64)
+ grid_coords = np.stack((i.ravel(), j.ravel(), grid-1-k.ravel()), axis=1, dtype=np.float32) * 2 / grid - 1.0 + 1.0 / grid # every position is in the middle of the grid
+ depths = {}
+ for cam_node in camera_nodes:
+ # a = time.time()
+ scene.main_camera_node = cam_node
+ name = cam_node.name
+ proj_depth = renderer.render(scene, flags=pyrender.constants.RenderFlags.DEPTH_ONLY | pyrender.constants.RenderFlags.OFFSCREEN)
+ proj_depth[proj_depth Tuple[int, ndarray]:
+ '''
+ Find connected components of a mesh.
+
+ Returns:
+ int: number of connected components
+ ndarray: labels of connected components
+ '''
+ N = vertices.shape[0]
+ edges = []
+ for face in faces:
+ v0, v1, v2 = face
+ edges.append([v0, v1])
+ edges.append([v1, v2])
+ edges.append([v2, v0])
+
+ edges = np.array(edges)
+ row = edges[:, 0]
+ col = edges[:, 1]
+ data = np.ones(len(edges), dtype=int)
+ adj_matrix = csr_matrix((data, (row, col)), shape=(N, N))
+ adj_matrix = adj_matrix + adj_matrix.T
+
+ tot, vertex_labels = connected_components(adj_matrix, directed=False, return_labels=True)
+ face_labels = vertex_labels[faces[:, 0]]
+ return tot, vertex_labels, face_labels
+
+def compute_distances_in_components(vertices: ndarray, faces: ndarray, vertex_labels: ndarray, tot: int, k: int) -> ndarray:
+ N = vertices.shape[0]
+ edges = []
+ weights = []
+ for face in faces:
+ v0, v1, v2 = face
+ w01 = np.linalg.norm(vertices[v0] - vertices[v1])
+ w12 = np.linalg.norm(vertices[v1] - vertices[v2])
+ w20 = np.linalg.norm(vertices[v2] - vertices[v0])
+ edges.extend([[v0, v1], [v1, v2], [v2, v0]])
+ weights.extend([w01, w12, w20])
+
+ edges = np.array(edges)
+ weights = np.array(weights)
+ row = edges[:, 0]
+ col = edges[:, 1]
+ adj_matrix = csr_matrix((weights, (row, col)), shape=(N, N))
+ adj_matrix = adj_matrix + adj_matrix.T
+
+ distance_matrix = np.full((N, k), np.inf) # (N, k)
+
+ for component_id in range(tot):
+ component_mask = (vertex_labels == component_id)
+ component_vertices_idx = np.where(component_mask)[0]
+ n_component = len(component_vertices_idx)
+
+ if n_component == 0:
+ continue
+
+ if n_component >= k:
+ sampled_indices = np.random.permutation(n_component)[:k]
+ else:
+ sampled_indices = np.concatenate([
+ np.random.permutation(n_component),
+ np.random.randint(0, n_component, k - n_component)
+ ])
+ sampled_vertices = component_vertices_idx[sampled_indices]
+
+ dist_matrix = shortest_path(adj_matrix, indices=sampled_vertices, directed=False)
+ dist_matrix = dist_matrix[:, component_mask].T
+ # normalize into [0, 1]
+ max_value = dist_matrix.max()
+ min_value = dist_matrix.min()
+ if max_value < min_value + 1e-6:
+ dist_matrix[...] = 0.
+ else:
+ dist_matrix = (dist_matrix - min_value) / (max_value - min_value)
+
+ distance_matrix[component_mask, :] = dist_matrix
+
+ return distance_matrix
+
+def generate_spread_vectors(tot: int, dim: int, iterations: int=100, lr: float=1.0) -> ndarray:
+ if tot <= 0:
+ return None
+
+ vectors = np.random.randn(tot, dim)
+ vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
+ vectors = np.nan_to_num(vectors, nan=1.0, posinf=1.0, neginf=1.0)
+
+ for _ in range(iterations):
+ diff = vectors[np.newaxis, :, :] - vectors[:, np.newaxis, :]
+ norm_sq = np.sum(diff ** 2, axis=2)
+ weight = 1. / (norm_sq + 1.)
+ vectors += np.sum(diff * weight[:, :, np.newaxis] * lr, axis=1)
+ vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
+
+ return vectors
diff --git a/UniRig/src/inference/download.py b/UniRig/src/inference/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..febd2bbe62edad62de66f09686605574623fb623
--- /dev/null
+++ b/UniRig/src/inference/download.py
@@ -0,0 +1,19 @@
+from huggingface_hub import hf_hub_download
+
+def download(ckpt_name: str) -> str:
+ MAP = {
+ 'experiments/skeleton/articulation-xl_quantization_256/model.ckpt': 'skeleton/articulation-xl_quantization_256/model.ckpt',
+ 'experiments/skin/articulation-xl/model.ckpt': 'skin/articulation-xl/model.ckpt',
+ }
+
+ try:
+ if ckpt_name not in MAP:
+ print(f"not found: {ckpt_name}")
+ return ckpt_name
+ return hf_hub_download(
+ repo_id='VAST-AI/UniRig',
+ filename=MAP[ckpt_name],
+ )
+ except Exception as e:
+ print(f"Failed to download {ckpt_name}: {e}")
+ return ckpt_name
\ No newline at end of file
diff --git a/UniRig/src/inference/get_list.py b/UniRig/src/inference/get_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..3286e91a6ae6da60085e4326e1313ef6bd09c37d
--- /dev/null
+++ b/UniRig/src/inference/get_list.py
@@ -0,0 +1,25 @@
+import os
+import argparse
+from tqdm import tqdm
+from box import Box
+import yaml
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str)
+ args = parser.parse_args()
+
+ config = Box(yaml.safe_load(open(args.config, "r")))
+
+ dataset = config.output_dataset_dir
+
+ paths = []
+ for root, dirs, files in tqdm(os.walk(dataset)):
+ for file in files:
+ if file == 'raw_data.npz':
+ paths.append(os.path.relpath(root, dataset))
+ tot = len(paths)
+ os.makedirs(dataset, exist_ok=True)
+ f = open(os.path.join(dataset, f"inference_datalist.txt"), 'w')
+ f.writelines('\n'.join(paths))
+ f.close()
\ No newline at end of file
diff --git a/UniRig/src/inference/merge.py b/UniRig/src/inference/merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b3613cc29e4e246563e9408cc6ce2d1fed614a5
--- /dev/null
+++ b/UniRig/src/inference/merge.py
@@ -0,0 +1,527 @@
+'''
+inject the result in res.npz into model.vrm and exports as res_textured.vrm
+'''
+import argparse
+import yaml
+import os
+import numpy as np
+from numpy import ndarray
+
+from typing import Tuple, Union, List
+
+import argparse
+from tqdm import tqdm
+from box import Box
+
+from scipy.spatial import cKDTree
+
+import open3d as o3d
+import itertools
+
+import bpy
+from mathutils import Vector
+
+from ..data.raw_data import RawData, RawSkin
+from ..data.extract import process_mesh, process_armature, get_arranged_bones
+
+def parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str)
+ parser.add_argument('--num_runs', type=int)
+ parser.add_argument('--id', type=int)
+ return parser.parse_args()
+
+def clean_bpy():
+ for c in bpy.data.actions:
+ bpy.data.actions.remove(c)
+ for c in bpy.data.armatures:
+ bpy.data.armatures.remove(c)
+ for c in bpy.data.cameras:
+ bpy.data.cameras.remove(c)
+ for c in bpy.data.collections:
+ bpy.data.collections.remove(c)
+ for c in bpy.data.images:
+ bpy.data.images.remove(c)
+ for c in bpy.data.materials:
+ bpy.data.materials.remove(c)
+ for c in bpy.data.meshes:
+ bpy.data.meshes.remove(c)
+ for c in bpy.data.objects:
+ bpy.data.objects.remove(c)
+ for c in bpy.data.textures:
+ bpy.data.textures.remove(c)
+
+def load(filepath: str, return_armature: bool=False):
+ if return_armature:
+ old_objs = set(bpy.context.scene.objects)
+
+ if not os.path.exists(filepath):
+ raise ValueError(f'File {filepath} does not exist !')
+ try:
+ if filepath.endswith(".vrm"):
+ # enable vrm addon and load vrm model
+ bpy.ops.preferences.addon_enable(module='vrm')
+
+ bpy.ops.import_scene.vrm(
+ filepath=filepath,
+ use_addon_preferences=True,
+ extract_textures_into_folder=False,
+ make_new_texture_folder=False,
+ set_shading_type_to_material_on_import=False,
+ set_view_transform_to_standard_on_import=True,
+ set_armature_display_to_wire=True,
+ set_armature_display_to_show_in_front=True,
+ set_armature_bone_shape_to_default=True,
+ disable_bake=True, # customized option for better performance
+ )
+ elif filepath.endswith(".fbx") or filepath.endswith(".FBX"):
+ bpy.ops.import_scene.fbx(filepath=filepath, ignore_leaf_bones=False, use_image_search=False)
+ elif filepath.endswith(".glb") or filepath.endswith(".gltf"):
+ bpy.ops.import_scene.gltf(filepath=filepath, import_pack_images=False)
+ elif filepath.endswith(".dae"):
+ bpy.ops.wm.collada_import(filepath=filepath)
+ elif filepath.endswith(".blend"):
+ with bpy.data.libraries.load(filepath) as (data_from, data_to):
+ data_to.objects = data_from.objects
+ for obj in data_to.objects:
+ if obj is not None:
+ bpy.context.collection.objects.link(obj)
+ else:
+ raise ValueError(f"not suported type {filepath}")
+ except:
+ raise ValueError(f"failed to load {filepath}")
+ if return_armature:
+ armature = [x for x in set(bpy.context.scene.objects)-old_objs if x.type=="ARMATURE"]
+ if len(armature)==0:
+ return None
+ if len(armature)>1:
+ raise ValueError(f"multiple armatures found")
+ armature = armature[0]
+
+ armature.select_set(True)
+ bpy.context.view_layer.objects.active = armature
+ bpy.ops.object.mode_set(mode='EDIT')
+ for bone in bpy.data.armatures[0].edit_bones:
+ bone.roll = 0. # change all roll to 0. to prevent weird behaviour
+
+ bpy.ops.object.mode_set(mode='OBJECT')
+ armature.select_set(False)
+
+ bpy.ops.object.select_all(action='DESELECT')
+ return armature
+
+def get_skin(arranged_bones):
+ meshes = []
+ for v in bpy.data.objects:
+ if v.type == 'MESH':
+ meshes.append(v)
+ index = {}
+ for (id, pbone) in enumerate(arranged_bones):
+ index[pbone.name] = id
+ _dict_skin = {}
+ total_bones = len(arranged_bones)
+ for obj in meshes:
+ total_vertices = len(obj.data.vertices)
+ skin_weight = np.zeros((total_vertices, total_bones))
+ obj_group_names = [g.name for g in obj.vertex_groups]
+ obj_verts = obj.data.vertices
+ for bone in arranged_bones:
+ if bone.name not in obj_group_names:
+ continue
+
+ gidx = obj.vertex_groups[bone.name].index
+ bone_verts = [v for v in obj_verts if gidx in [g.group for g in v.groups]]
+ for v in bone_verts:
+ which = [id for id in range(len(v.groups)) if v.groups[id].group==gidx]
+ w = v.groups[which[0]].weight
+ skin_weight[v.index, index[bone.name]] = w
+ _dict_skin[obj.name] = {
+ 'skin': skin_weight,
+ }
+
+ skin = np.concatenate([
+ _dict_skin[d]['skin'] for d in _dict_skin
+ ], axis=0)
+ return skin
+
+def axis(a: np.ndarray):
+ b = np.concatenate([-a[:, 0:1], -a[:, 1:2], a[:, 2:3]], axis=1)
+ return b
+
+def get_correct_orientation_kdtree(a: np.ndarray, b: np.ndarray, bones: np.ndarray, num: int=16384) -> np.ndarray:
+ '''
+ a: sampled_vertiecs
+ b: mesh_vertices
+ '''
+ min_loss = float('inf')
+ best_transformed = a.copy()
+ axis_permutations = list(itertools.permutations([0, 1, 2]))
+ sign_combinations = [(x, y, z) for x in [1, -1]
+ for y in [1, -1]
+ for z in [1, -1]]
+ _bones = bones.copy()
+ for perm in axis_permutations:
+ permuted_a = a[np.random.permutation(a.shape[0])[:num]][:, perm]
+ for signs in sign_combinations:
+ transformed = permuted_a * np.array(signs)
+ tree = cKDTree(transformed)
+ distances, indices = tree.query(b)
+ current_loss = distances.mean()
+ if current_loss < min_loss: # prevent from mirroring
+ min_loss = current_loss
+ best_transformed = a[:, perm] * np.array(signs)
+ bones[:, :3] = _bones[:, :3][:, perm] * np.array(signs)
+ bones[:, 3:] = _bones[:, 3:][:, perm] * np.array(signs)
+
+ return best_transformed, bones
+
+def denormalize_vertices(mesh_vertices: ndarray, vertices: ndarray, bones: ndarray) -> np.ndarray:
+ min_vals = np.min(mesh_vertices, axis=0)
+ max_vals = np.max(mesh_vertices, axis=0)
+ center = (min_vals + max_vals) / 2
+ scale = np.max(max_vals - min_vals) / 2
+ denormalized_vertices = vertices * scale + center
+ denormalized_bones = bones * scale
+ denormalized_bones[:, :3] += center
+ denormalized_bones[:, 3:] += center
+
+ return denormalized_vertices, denormalized_bones
+
+def make_armature(
+ vertices: ndarray,
+ bones: ndarray, # (joint, tail)
+ parents: list[Union[int, None]],
+ names: list[str],
+ skin: ndarray,
+ group_per_vertex: int=4,
+ add_root: bool=False,
+ is_vrm: bool=False,
+):
+ context = bpy.context
+
+ mesh_vertices = []
+ for ob in bpy.data.objects:
+ if ob.type != 'MESH':
+ continue
+ m = np.array(ob.matrix_world)
+ matrix_world_rot = m[:3, :3]
+ matrix_world_bias = m[:3, 3]
+ for v in ob.data.vertices:
+ mesh_vertices.append(matrix_world_rot @ np.array(v.co) + matrix_world_bias)
+
+ mesh_vertices = np.stack(mesh_vertices)
+ vertices, bones = denormalize_vertices(mesh_vertices, vertices, bones)
+
+ bpy.ops.object.add(type="ARMATURE", location=(0, 0, 0))
+ armature = context.object
+ if hasattr(armature.data, 'vrm_addon_extension'):
+ armature.data.vrm_addon_extension.spec_version = "1.0"
+ humanoid = armature.data.vrm_addon_extension.vrm1.humanoid
+ is_vrm = True
+ bpy.ops.object.mode_set(mode="EDIT")
+ edit_bones = armature.data.edit_bones
+ if add_root:
+ bone_root = edit_bones.new('Root')
+ bone_root.name = 'Root'
+ bone_root.head = (0., 0., 0.)
+ bone_root.tail = (bones[0, 0], bones[0, 1], bones[0, 2])
+
+ J = len(names)
+ def extrude_bone(
+ name: Union[None, str],
+ parent_name: Union[None, str],
+ head: Tuple[float, float, float],
+ tail: Tuple[float, float, float],
+ ):
+ bone = edit_bones.new(name)
+ bone.head = (head[0], head[1], head[2])
+ bone.tail = (tail[0], tail[1], tail[2])
+ bone.name = name
+ if parent_name is None:
+ return
+ parent_bone = edit_bones.get(parent_name)
+ bone.parent = parent_bone
+ bone.use_connect = False # always False currently
+
+ vertices, bones = get_correct_orientation_kdtree(vertices, mesh_vertices, bones)
+ for i in range(J):
+ if add_root:
+ pname = 'Root' if parents[i] is None else names[parents[i]]
+ else:
+ pname = None if parents[i] is None else names[parents[i]]
+ extrude_bone(names[i], pname, bones[i, :3], bones[i, 3:])
+
+ # must set to object mode to enable parent_set
+ bpy.ops.object.mode_set(mode='OBJECT')
+ objects = bpy.data.objects
+ for o in bpy.context.selected_objects:
+ o.select_set(False)
+
+ argsorted = np.argsort(-skin, axis=1)
+ vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted]
+ vertex_group_reweight = vertex_group_reweight / vertex_group_reweight[..., :group_per_vertex].sum(axis=1)[...,None]
+ vertex_group_reweight = np.nan_to_num(vertex_group_reweight)
+ tree = cKDTree(vertices)
+ for ob in objects:
+ if ob.type != 'MESH':
+ continue
+ ob.select_set(True)
+ armature.select_set(True)
+ bpy.ops.object.parent_set(type='ARMATURE_NAME')
+ vis = []
+ for x in ob.vertex_groups:
+ vis.append(x.name)
+
+ n_vertices = []
+ m = np.array(ob.matrix_world)
+ matrix_world_rot = m[:3, :3]
+ matrix_world_bias = m[:3, 3]
+ for v in ob.data.vertices:
+ n_vertices.append(matrix_world_rot @ np.array(v.co) + matrix_world_bias)
+ n_vertices = np.stack(n_vertices)
+
+ _, index = tree.query(n_vertices)
+
+ for v, co in enumerate(tqdm(n_vertices)):
+ for ii in range(group_per_vertex):
+ i = argsorted[index[v], ii]
+ if i >= len(names):
+ continue
+ n = names[i]
+ if n not in ob.vertex_groups:
+ continue
+
+ ob.vertex_groups[n].add([v], vertex_group_reweight[index[v], ii], 'REPLACE')
+ armature.select_set(False)
+ ob.select_set(False)
+
+ # set vrm bones link
+ if is_vrm:
+ armature.data.vrm_addon_extension.spec_version = "1.0"
+ humanoid.human_bones.hips.node.bone_name = "J_Bip_C_Hips"
+ humanoid.human_bones.spine.node.bone_name = "J_Bip_C_Spine"
+
+ humanoid.human_bones.chest.node.bone_name = "J_Bip_C_Chest"
+ humanoid.human_bones.neck.node.bone_name = "J_Bip_C_Neck"
+ humanoid.human_bones.head.node.bone_name = "J_Bip_C_Head"
+ humanoid.human_bones.left_upper_leg.node.bone_name = "J_Bip_L_UpperLeg"
+ humanoid.human_bones.left_lower_leg.node.bone_name = "J_Bip_L_LowerLeg"
+ humanoid.human_bones.left_foot.node.bone_name = "J_Bip_L_Foot"
+ humanoid.human_bones.right_upper_leg.node.bone_name = "J_Bip_R_UpperLeg"
+ humanoid.human_bones.right_lower_leg.node.bone_name = "J_Bip_R_LowerLeg"
+ humanoid.human_bones.right_foot.node.bone_name = "J_Bip_R_Foot"
+ humanoid.human_bones.left_upper_arm.node.bone_name = "J_Bip_L_UpperArm"
+ humanoid.human_bones.left_lower_arm.node.bone_name = "J_Bip_L_LowerArm"
+ humanoid.human_bones.left_hand.node.bone_name = "J_Bip_L_Hand"
+ humanoid.human_bones.right_upper_arm.node.bone_name = "J_Bip_R_UpperArm"
+ humanoid.human_bones.right_lower_arm.node.bone_name = "J_Bip_R_LowerArm"
+ humanoid.human_bones.right_hand.node.bone_name = "J_Bip_R_Hand"
+
+ bpy.ops.vrm.assign_vrm1_humanoid_human_bones_automatically(armature_name="Armature")
+
+def merge(
+ path: str,
+ output_path: str,
+ vertices: ndarray,
+ joints: ndarray,
+ skin: ndarray,
+ parents: List[Union[None, int]],
+ names: List[str],
+ tails: ndarray,
+ add_root: bool=False,
+ is_vrm: bool=False,
+):
+ '''
+ Merge skin and bone into original file.
+ '''
+ clean_bpy()
+ try:
+ load(path)
+ except Exception as e:
+ print(f"Failed to load {path}: {e}")
+ return
+ for c in bpy.data.armatures:
+ bpy.data.armatures.remove(c)
+
+ bones = np.concatenate([joints, tails], axis=1)
+ # if the result is weired, orientation may be wrong
+ make_armature(
+ vertices=vertices,
+ bones=bones,
+ parents=parents,
+ names=names,
+ skin=skin,
+ group_per_vertex=4,
+ add_root=add_root,
+ is_vrm=is_vrm,
+ )
+
+ dirpath = os.path.dirname(output_path)
+ if dirpath != '':
+ os.makedirs(dirpath, exist_ok=True)
+ try:
+ if is_vrm:
+ bpy.ops.export_scene.vrm(filepath=output_path)
+ elif output_path.endswith(".fbx") or output_path.endswith(".FBX"):
+ bpy.ops.export_scene.fbx(filepath=output_path, add_leaf_bones=True)
+ elif output_path.endswith(".glb") or output_path.endswith(".gltf"):
+ bpy.ops.export_scene.gltf(filepath=output_path)
+ elif output_path.endswith(".dae"):
+ bpy.ops.wm.collada_export(filepath=output_path)
+ elif output_path.endswith(".blend"):
+ with bpy.data.libraries.load(output_path) as (data_from, data_to):
+ data_to.objects = data_from.objects
+ else:
+ raise ValueError(f"not suported type {output_path}")
+ except:
+ raise ValueError(f"failed to export {output_path}")
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+def nullable_string(val):
+ if not val:
+ return None
+ return val
+
+def parse():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--require_suffix', type=str, required=True)
+ parser.add_argument('--num_runs', type=int, required=True)
+ parser.add_argument('--id', type=int, required=True)
+ parser.add_argument('--data_config', type=str, required=False)
+ parser.add_argument('--skeleton_config', type=str, required=False)
+ parser.add_argument('--skin_config', type=str, required=False)
+ parser.add_argument('--merge_dir', type=str, required=False)
+ parser.add_argument('--merge_name', type=str, required=False)
+ parser.add_argument('--add_root', type=str2bool, required=False, default=False)
+ parser.add_argument('--source', type=nullable_string, required=False, default=None)
+ parser.add_argument('--target', type=nullable_string, required=False, default=None)
+ parser.add_argument('--output', type=nullable_string, required=False, default=None)
+ return parser.parse_args()
+
+def transfer(source: str, target: str, output: str, add_root: bool=False):
+ try:
+ armature = load(filepath=source, return_armature=True)
+ assert armature is not None
+ except Exception as e:
+ print(f"failed to load {source}")
+ return
+ vertices, faces = process_mesh()
+ arranged_bones = get_arranged_bones(armature)
+ skin = get_skin(arranged_bones)
+ joints, tails, parents, names, matrix_local = process_armature(armature, arranged_bones)
+ merge(
+ path=target,
+ output_path=output,
+ vertices=vertices,
+ joints=joints,
+ skin=skin,
+ parents=parents,
+ names=names,
+ tails=tails,
+ add_root=add_root,
+ )
+
+if __name__ == "__main__":
+ args = parse()
+
+ if args.source is not None or args.target is not None:
+ assert args.source is not None and args.target is not None
+ transfer(args.source, args.target, args.output, args.add_root)
+ exit()
+
+ data_config = Box(yaml.safe_load(open(args.data_config, "r")))
+ skeleton_config = Box(yaml.safe_load(open(args.skeleton_config, "r")))
+ skin_config = Box(yaml.safe_load(open(args.skin_config, "r")))
+
+ num_runs = args.num_runs
+ id = args.id
+ require_suffix = args.require_suffix.split(',')
+ merge_dir = args.merge_dir
+ merge_name = args.merge_name
+ add_root = args.add_root
+
+ input_dataset_dir = data_config.input_dataset_dir
+ dataset_name = data_config.output_dataset_dir
+
+ skin_output_dataset_dir = skin_config.writer.output_dir
+ skin_name = skin_config.writer.export_npz
+
+ skeleton_output_dataset_dir = skeleton_config.writer.output_dir
+ skeleton_name = skeleton_config.writer.export_npz
+
+ def make_path(output_dataset_dir, dataset_name, root, file_name):
+ if output_dataset_dir is None:
+ return os.path.join(
+ dataset_name,
+ os.path.relpath(root, input_dataset_dir),
+ file_name,
+ )
+ return os.path.join(
+ output_dataset_dir,
+ dataset_name,
+ os.path.relpath(root, input_dataset_dir),
+ file_name,
+ )
+
+ files = []
+ for root, dirs, f in os.walk(input_dataset_dir):
+ for file in f:
+ if file.split('.')[-1] in require_suffix:
+ file_name = file.removeprefix("./")
+ suffix = file.split('.')[-1]
+ # remove suffix
+ file_name = '.'.join(file_name.split('.')[:-1])
+
+ skin_path = make_path(skin_output_dataset_dir, dataset_name, root, os.path.join(file_name, skin_name+'.npz'))
+ skeleton_path = make_path(skeleton_output_dataset_dir, dataset_name, root, os.path.join(file_name, skeleton_name+'.npz'))
+ merge_path = make_path(merge_dir, dataset_name, root, os.path.join(file_name, merge_name+"."+suffix))
+
+ # check if inference result exists
+ if os.path.exists(skin_path) and os.path.exists(skeleton_path):
+ files.append((os.path.join(root, file), skin_path, skeleton_path, merge_path))
+
+ num_files = len(files)
+ print("num_files", num_files)
+ gap = num_files // num_runs
+ start = gap * id
+ end = gap * (id + 1)
+ if id+1==num_runs:
+ end = num_files
+
+ files = sorted(files)
+ if end!=-1:
+ files = files[:end]
+ tot = 0
+ for file in tqdm(files[start:]):
+ origin_file = file[0]
+ skin_path = file[1]
+ skeleton_path = file[2]
+ merge_file = file[3]
+
+ raw_skin = RawSkin.load(path=skin_path)
+ raw_data = RawData.load(path=skeleton_path)
+
+ try:
+ merge(
+ path=origin_file,
+ output_path=merge_file,
+ vertices=raw_skin.vertices,
+ joints=raw_skin.joints,
+ skin=raw_skin.skin,
+ parents=raw_data.parents,
+ names=raw_data.names,
+ tails=raw_data.tails,
+ add_root=add_root,
+ is_vrm=(raw_data.cls=='vroid'),
+ )
+ except Exception as e:
+ print(f"failed to merge {origin_file}: {e}")
\ No newline at end of file
diff --git a/UniRig/src/model/__init__.py b/UniRig/src/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/UniRig/src/model/michelangelo/LICENSE b/UniRig/src/model/michelangelo/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e72bfddabc15be5718a7cc061ac10e47741d8219
--- /dev/null
+++ b/UniRig/src/model/michelangelo/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
\ No newline at end of file
diff --git a/UniRig/src/model/michelangelo/__init__.py b/UniRig/src/model/michelangelo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..16d324c6fbb109889666d97ec8aaf96dffbb802b
--- /dev/null
+++ b/UniRig/src/model/michelangelo/__init__.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of UniRig.
+#
+# This file is derived from https://github.com/NeuralCarver/Michelangelo
+#
+# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
\ No newline at end of file
diff --git a/UniRig/src/model/michelangelo/get_model.py b/UniRig/src/model/michelangelo/get_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b4c5197b2d438c32a2b487ea9b87303bc2dd0b6
--- /dev/null
+++ b/UniRig/src/model/michelangelo/get_model.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of UniRig.
+#
+# This file is derived from https://github.com/NeuralCarver/Michelangelo
+#
+# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import torch
+
+from .models.tsal.sal_perceiver import AlignedShapeLatentPerceiver, ShapeAsLatentPerceiverEncoder
+
+def get_encoder(
+ pretrained_path: str=None,
+ freeze_decoder: bool=False,
+ **kwargs
+) -> AlignedShapeLatentPerceiver:
+ model = AlignedShapeLatentPerceiver(**kwargs)
+ if pretrained_path is not None:
+ state_dict = torch.load(pretrained_path, weights_only=True)
+ model.load_state_dict(state_dict)
+ if freeze_decoder:
+ model.geo_decoder.requires_grad_(False)
+ model.encoder.query.requires_grad_(False)
+ model.pre_kl.requires_grad_(False)
+ model.post_kl.requires_grad_(False)
+ model.transformer.requires_grad_(False)
+ return model
+
+def get_encoder_simplified(
+ pretrained_path: str=None,
+ **kwargs
+) -> ShapeAsLatentPerceiverEncoder:
+ model = ShapeAsLatentPerceiverEncoder(**kwargs)
+ if pretrained_path is not None:
+ state_dict = torch.load(pretrained_path, weights_only=True)
+ model.load_state_dict(state_dict)
+ return model
\ No newline at end of file
diff --git a/UniRig/src/model/michelangelo/models/__init__.py b/UniRig/src/model/michelangelo/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..16d324c6fbb109889666d97ec8aaf96dffbb802b
--- /dev/null
+++ b/UniRig/src/model/michelangelo/models/__init__.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of UniRig.
+#
+# This file is derived from https://github.com/NeuralCarver/Michelangelo
+#
+# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
\ No newline at end of file
diff --git a/UniRig/src/model/michelangelo/models/modules/__init__.py b/UniRig/src/model/michelangelo/models/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..deb22f87c60542124ec634ddcecc36c9699e5207
--- /dev/null
+++ b/UniRig/src/model/michelangelo/models/modules/__init__.py
@@ -0,0 +1,23 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of UniRig.
+#
+# This file is derived from https://github.com/NeuralCarver/Michelangelo
+#
+# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from .checkpoint import checkpoint
diff --git a/UniRig/src/model/michelangelo/models/modules/checkpoint.py b/UniRig/src/model/michelangelo/models/modules/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab5d51f05e5c7b2818fb8cd0e207bfa2f260415f
--- /dev/null
+++ b/UniRig/src/model/michelangelo/models/modules/checkpoint.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of UniRig.
+#
+# This file is derived from https://github.com/NeuralCarver/Michelangelo
+#
+# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+"""
+Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
+"""
+
+import torch
+from typing import Callable, Iterable, Sequence, Union
+from packaging import version
+
+
+def checkpoint(
+ func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
+ inputs: Sequence[torch.Tensor],
+ params: Iterable[torch.Tensor],
+ flag: bool,
+ use_deepspeed: bool = False
+):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ :param use_deepspeed: if True, use deepspeed
+ """
+ if flag:
+ if use_deepspeed:
+ import deepspeed
+ return deepspeed.checkpointing.checkpoint(func, *inputs)
+
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def _get_fwd_decorator():
+ if version.parse(torch.__version__) >= version.parse('2.5.0'):
+ return torch.amp.custom_fwd(device_type='cuda')
+ else:
+ return torch.cuda.amp.custom_fwd()
+
+ @staticmethod
+ def _get_bwd_decorator():
+ if version.parse(torch.__version__) >= version.parse('2.5.0'):
+ return torch.amp.custom_bwd(device_type='cuda')
+ else:
+ def custom_bwd(bwd):
+ return torch.cuda.amp.custom_bwd(bwd=bwd)
+ return custom_bwd
+
+ @staticmethod
+ @_get_fwd_decorator()
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ @_get_bwd_decorator()
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
diff --git a/UniRig/src/model/michelangelo/models/modules/embedder.py b/UniRig/src/model/michelangelo/models/modules/embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fd01d55c79cb5bf3e357f2ac5bb45d465884195
--- /dev/null
+++ b/UniRig/src/model/michelangelo/models/modules/embedder.py
@@ -0,0 +1,233 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of UniRig.
+#
+# This file is derived from https://github.com/NeuralCarver/Michelangelo
+#
+# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import numpy as np
+import torch
+import torch.nn as nn
+import math
+
+VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
+
+
+class FourierEmbedder(nn.Module):
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
+ each feature dimension of `x[..., i]` into:
+ [
+ sin(x[..., i]),
+ sin(f_1*x[..., i]),
+ sin(f_2*x[..., i]),
+ ...
+ sin(f_N * x[..., i]),
+ cos(x[..., i]),
+ cos(f_1*x[..., i]),
+ cos(f_2*x[..., i]),
+ ...
+ cos(f_N * x[..., i]),
+ x[..., i] # only present if include_input is True.
+ ], here f_i is the frequency.
+
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
+
+ Args:
+ num_freqs (int): the number of frequencies, default is 6;
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
+ input_dim (int): the input dimension, default is 3;
+ include_input (bool): include the input tensor or not, default is True.
+
+ Attributes:
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
+
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
+ otherwise, it is input_dim * num_freqs * 2.
+
+ """
+
+ def __init__(self,
+ num_freqs: int = 6,
+ logspace: bool = True,
+ input_dim: int = 3,
+ include_input: bool = True,
+ include_pi: bool = True) -> None:
+
+ """The initialization"""
+
+ super().__init__()
+
+ if logspace:
+ frequencies = 2.0 ** torch.arange(
+ num_freqs,
+ dtype=torch.float32
+ )
+ else:
+ frequencies = torch.linspace(
+ 1.0,
+ 2.0 ** (num_freqs - 1),
+ num_freqs,
+ dtype=torch.float32
+ )
+
+ if include_pi:
+ frequencies *= torch.pi
+
+ self.register_buffer("frequencies", frequencies, persistent=False)
+ self.include_input = include_input
+ self.num_freqs = num_freqs
+
+ self.out_dim = self.get_dims(input_dim)
+
+ def get_dims(self, input_dim):
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
+
+ return out_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """ Forward process.
+
+ Args:
+ x: tensor of shape [..., dim]
+
+ Returns:
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
+ where temp is 1 if include_input is True and 0 otherwise.
+ """
+
+ if self.num_freqs > 0:
+ embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
+ if self.include_input:
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
+ else:
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
+ else:
+ return x
+
+
+class LearnedFourierEmbedder(nn.Module):
+ """ following @crowsonkb "s lead with learned sinusoidal pos emb """
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
+
+ def __init__(self, in_channels, dim):
+ super().__init__()
+ assert (dim % 2) == 0
+ half_dim = dim // 2
+ per_channel_dim = half_dim // in_channels
+ self.weights = nn.Parameter(torch.randn(per_channel_dim))
+
+ def forward(self, x):
+ """
+
+ Args:
+ x (torch.FloatTensor): [..., c]
+
+ Returns:
+ x (torch.FloatTensor): [..., d]
+ """
+
+ # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
+ freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
+ fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
+ return fouriered
+
+
+class TriplaneLearnedFourierEmbedder(nn.Module):
+ def __init__(self, in_channels, dim):
+ super().__init__()
+
+ self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
+ self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
+ self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
+
+ self.out_dim = in_channels + dim
+
+ def forward(self, x):
+
+ yz_embed = self.yz_plane_embedder(x)
+ xz_embed = self.xz_plane_embedder(x)
+ xy_embed = self.xy_plane_embedder(x)
+
+ embed = yz_embed + xz_embed + xy_embed
+
+ return embed
+
+
+def sequential_pos_embed(num_len, embed_dim):
+ assert embed_dim % 2 == 0
+
+ pos = torch.arange(num_len, dtype=torch.float32)
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000 ** omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+
+ return embeddings
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
+ num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
+ log2_hashmap_size=19, desired_resolution=None):
+ if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
+ return nn.Identity(), input_dim
+
+ elif embed_type == "fourier":
+ embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
+ logspace=True, include_input=True)
+ return embedder_obj, embedder_obj.out_dim
+
+ elif embed_type == "hashgrid":
+ raise NotImplementedError
+
+ elif embed_type == "sphere_harmonic":
+ raise NotImplementedError
+
+ else:
+ raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
diff --git a/UniRig/src/model/michelangelo/models/modules/transformer_blocks.py b/UniRig/src/model/michelangelo/models/modules/transformer_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..4df460c7dac0ae68dd675b5504e0bfeeee2e6026
--- /dev/null
+++ b/UniRig/src/model/michelangelo/models/modules/transformer_blocks.py
@@ -0,0 +1,320 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of UniRig.
+#
+# This file is derived from https://github.com/NeuralCarver/Michelangelo
+#
+# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional
+import os
+
+from .checkpoint import checkpoint
+
+def init_linear(l, stddev):
+ nn.init.normal_(l.weight, std=stddev)
+ if l.bias is not None:
+ nn.init.constant_(l.bias, 0.0)
+
+def flash_attention(q, k, v):
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ out = F.scaled_dot_product_attention(q, k, v)
+ out = out.transpose(1, 2)
+ # print("use flash atten 2")
+
+ return out
+
+class MultiheadAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ device: torch.device,
+ dtype: torch.dtype,
+ n_ctx: int,
+ width: int,
+ heads: int,
+ init_scale: float,
+ qkv_bias: bool,
+ flash: bool = False
+ ):
+ super().__init__()
+ self.n_ctx = n_ctx
+ self.width = width
+ self.heads = heads
+ self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
+ self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash)
+ init_linear(self.c_qkv, init_scale)
+ init_linear(self.c_proj, init_scale)
+
+ def forward(self, x):
+ x = self.c_qkv(x)
+ x = checkpoint(self.attention, (x,), (), False)
+ x = self.c_proj(x)
+ return x
+
+
+class QKVMultiheadAttention(nn.Module):
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False):
+ super().__init__()
+ self.device = device
+ self.dtype = dtype
+ self.heads = heads
+ self.n_ctx = n_ctx
+ self.flash = flash
+
+ def forward(self, qkv):
+ bs, n_ctx, width = qkv.shape
+ attn_ch = width // self.heads // 3
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
+
+ if self.flash:
+ out = flash_attention(q, k, v)
+ out = out.reshape(out.shape[0], out.shape[1], -1)
+ else:
+ weight = torch.einsum(
+ "bthc,bshc->bhts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ wdtype = weight.dtype
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
+
+ return out
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ device: torch.device,
+ dtype: torch.dtype,
+ n_ctx: int,
+ width: int,
+ heads: int,
+ init_scale: float = 1.0,
+ qkv_bias: bool = True,
+ flash: bool = False,
+ use_checkpoint: bool = False
+ ):
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+
+ self.attn = MultiheadAttention(
+ device=device,
+ dtype=dtype,
+ n_ctx=n_ctx,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash
+ )
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
+
+ def _forward(self, x: torch.Tensor):
+ x = x + self.attn(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+ def forward(self, x: torch.Tensor):
+ return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
+
+
+class MultiheadCrossAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ device: torch.device,
+ dtype: torch.dtype,
+ width: int,
+ heads: int,
+ init_scale: float,
+ qkv_bias: bool = True,
+ flash: bool = False,
+ n_data: Optional[int] = None,
+ data_width: Optional[int] = None,
+ ):
+ super().__init__()
+ self.n_data = n_data
+ self.width = width
+ self.heads = heads
+ self.data_width = width if data_width is None else data_width
+ self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
+ self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
+ self.attention = QKVMultiheadCrossAttention(
+ device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash
+ )
+ init_linear(self.c_q, init_scale)
+ init_linear(self.c_kv, init_scale)
+ init_linear(self.c_proj, init_scale)
+
+ def forward(self, x, data):
+ x = self.c_q(x)
+ data = self.c_kv(data)
+ x = checkpoint(self.attention, (x, data), (), False)
+ x = self.c_proj(x)
+ return x
+
+
+class QKVMultiheadCrossAttention(nn.Module):
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
+ flash: bool = False, n_data: Optional[int] = None):
+
+ super().__init__()
+ self.device = device
+ self.dtype = dtype
+ self.heads = heads
+ self.n_data = n_data
+ self.flash = flash
+
+ def forward(self, q, kv):
+ _, n_ctx, _ = q.shape
+ bs, n_data, width = kv.shape
+ attn_ch = width // self.heads // 2
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
+ q = q.view(bs, n_ctx, self.heads, -1)
+ kv = kv.view(bs, n_data, self.heads, -1)
+ k, v = torch.split(kv, attn_ch, dim=-1)
+
+ if self.flash:
+ out = flash_attention(q, k, v)
+ out = out.reshape(out.shape[0], out.shape[1], -1)
+ else:
+ weight = torch.einsum(
+ "bthc,bshc->bhts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ wdtype = weight.dtype
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
+
+ return out
+
+
+class ResidualCrossAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ device: Optional[torch.device],
+ dtype: Optional[torch.dtype],
+ n_data: Optional[int] = None,
+ width: int,
+ heads: int,
+ data_width: Optional[int] = None,
+ mlp_width_scale: int = 4,
+ init_scale: float = 0.25,
+ qkv_bias: bool = True,
+ flash: bool = False
+ ):
+ super().__init__()
+
+ if data_width is None:
+ data_width = width
+
+ self.attn = MultiheadCrossAttention(
+ device=device,
+ dtype=dtype,
+ n_data=n_data,
+ width=width,
+ heads=heads,
+ data_width=data_width,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ )
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
+ self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
+ self.mlp = MLP(device=device, dtype=dtype, width=width, hidden_width_scale=mlp_width_scale, init_scale=init_scale)
+ self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
+
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
+ x = x + self.mlp(self.ln_3(x))
+ return x
+
+
+class MLP(nn.Module):
+ def __init__(self, *,
+ device: Optional[torch.device],
+ dtype: Optional[torch.dtype],
+ width: int,
+ hidden_width_scale: int = 4,
+ init_scale: float):
+ super().__init__()
+ self.width = width
+ self.c_fc = nn.Linear(width, width * hidden_width_scale, device=device, dtype=dtype)
+ self.c_proj = nn.Linear(width * hidden_width_scale, width, device=device, dtype=dtype)
+ self.gelu = nn.GELU()
+ init_linear(self.c_fc, init_scale)
+ init_linear(self.c_proj, init_scale)
+
+ def forward(self, x):
+ return self.c_proj(self.gelu(self.c_fc(x)))
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ *,
+ device: Optional[torch.device],
+ dtype: Optional[torch.dtype],
+ n_ctx: int,
+ width: int,
+ layers: int,
+ heads: int,
+ init_scale: float = 0.25,
+ qkv_bias: bool = True,
+ flash: bool = False,
+ use_checkpoint: bool = False
+ ):
+ super().__init__()
+ self.n_ctx = n_ctx
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.ModuleList(
+ [
+ ResidualAttentionBlock(
+ device=device,
+ dtype=dtype,
+ n_ctx=n_ctx,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_checkpoint=use_checkpoint
+ )
+ for _ in range(layers)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor):
+ for block in self.resblocks:
+ x = block(x)
+ return x
diff --git a/UniRig/src/model/michelangelo/models/tsal/__init__.py b/UniRig/src/model/michelangelo/models/tsal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d3779a628a7c29f7c0177a252a3b41ef3d6c82c
--- /dev/null
+++ b/UniRig/src/model/michelangelo/models/tsal/__init__.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of UniRig.
+#
+# This file is derived from https://github.com/NeuralCarver/Michelangelo
+#
+# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
diff --git a/UniRig/src/model/michelangelo/models/tsal/sal_perceiver.py b/UniRig/src/model/michelangelo/models/tsal/sal_perceiver.py
new file mode 100644
index 0000000000000000000000000000000000000000..3300fd07c5207cd2fae713942ec154e5d4ff9325
--- /dev/null
+++ b/UniRig/src/model/michelangelo/models/tsal/sal_perceiver.py
@@ -0,0 +1,621 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of UniRig.
+#
+# This file is derived from https://github.com/NeuralCarver/Michelangelo
+#
+# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import torch
+import torch.nn as nn
+from typing import Optional, Union
+from einops import repeat
+import math
+from torch_cluster import fps
+import random
+import time
+import numpy as np
+
+from ..modules import checkpoint
+from ..modules.embedder import FourierEmbedder
+from ..modules.transformer_blocks import (
+ ResidualCrossAttentionBlock,
+ Transformer
+)
+
+from .tsal_base import ShapeAsLatentModule
+
+
+class CrossAttentionEncoder(nn.Module):
+
+ def __init__(self, *,
+ device: Optional[torch.device],
+ dtype: Optional[torch.dtype],
+ num_latents: int,
+ fourier_embedder: FourierEmbedder,
+ point_feats: int,
+ width: int,
+ heads: int,
+ layers: int,
+ init_scale: float = 0.25,
+ qkv_bias: bool = True,
+ flash: bool = False,
+ use_ln_post: bool = False,
+ use_checkpoint: bool = False,
+ query_method: bool = False,
+ use_full_input: bool = True,
+ token_num: int = 256,
+ no_query: bool=False):
+
+ super().__init__()
+
+ self.query_method = query_method
+ self.token_num = token_num
+ self.use_full_input = use_full_input
+
+ self.use_checkpoint = use_checkpoint
+ self.num_latents = num_latents
+
+ if no_query:
+ self.query = None
+ else:
+ self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
+
+ self.fourier_embedder = fourier_embedder
+ self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
+ self.cross_attn = ResidualCrossAttentionBlock(
+ device=device,
+ dtype=dtype,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ )
+
+ self.self_attn = Transformer(
+ device=device,
+ dtype=dtype,
+ n_ctx=num_latents,
+ width=width,
+ layers=layers,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_checkpoint=False
+ )
+
+ if use_ln_post:
+ self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
+ else:
+ self.ln_post = None
+
+ def _forward(self, pc, feats):
+ """
+
+ Args:
+ pc (torch.FloatTensor): [B, N, 3]
+ feats (torch.FloatTensor or None): [B, N, C]
+
+ Returns:
+
+ """
+ if self.query_method:
+ token_num = self.num_latents
+ bs = pc.shape[0]
+ data = self.fourier_embedder(pc)
+ if feats is not None:
+ data = torch.cat([data, feats], dim=-1)
+ data = self.input_proj(data)
+
+ query = repeat(self.query, "m c -> b m c", b=bs)
+
+ latents = self.cross_attn(query, data)
+ latents = self.self_attn(latents)
+
+ if self.ln_post is not None:
+ latents = self.ln_post(latents)
+
+ pre_pc = None
+ else:
+
+ if isinstance(self.token_num, int):
+ token_num = self.token_num
+ else:
+ token_num = random.choice(self.token_num)
+
+ if self.training:
+ rng = np.random.default_rng()
+ else:
+ rng = np.random.default_rng(seed=0)
+ ind = rng.choice(pc.shape[1], token_num * 4, replace=token_num * 4 > pc.shape[1])
+
+ pre_pc = pc[:,ind,:]
+ pre_feats = feats[:,ind,:]
+
+
+ B, N, D = pre_pc.shape
+ C = pre_feats.shape[-1]
+ ###### fps
+ pos = pre_pc.view(B*N, D)
+ pos_feats = pre_feats.view(B*N, C)
+ batch = torch.arange(B).to(pc.device)
+ batch = torch.repeat_interleave(batch, N)
+
+ idx = fps(pos, batch, ratio=1. / 4, random_start=self.training)
+
+ sampled_pc = pos[idx]
+ sampled_pc = sampled_pc.view(B, -1, 3)
+
+ sampled_feats = pos_feats[idx]
+ sampled_feats = sampled_feats.view(B, -1, C)
+
+ ######
+ if self.use_full_input:
+ data = self.fourier_embedder(pc)
+ else:
+ data = self.fourier_embedder(pre_pc)
+
+ if feats is not None:
+ if not self.use_full_input:
+ feats = pre_feats
+ data = torch.cat([data, feats], dim=-1)
+ data = self.input_proj(data)
+
+ sampled_data = self.fourier_embedder(sampled_pc)
+ if feats is not None:
+ sampled_data = torch.cat([sampled_data, sampled_feats], dim=-1)
+ sampled_data = self.input_proj(sampled_data)
+
+ latents = self.cross_attn(sampled_data, data)
+ latents = self.self_attn(latents)
+
+ if self.ln_post is not None:
+ latents = self.ln_post(latents)
+
+ pre_pc = torch.cat([pre_pc, pre_feats], dim=-1)
+
+ return latents, pc, token_num, pre_pc
+
+ def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
+ """
+
+ Args:
+ pc (torch.FloatTensor): [B, N, 3]
+ feats (torch.FloatTensor or None): [B, N, C]
+
+ Returns:
+ dict
+ """
+
+ return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
+
+
+class CrossAttentionDecoder(nn.Module):
+
+ def __init__(self, *,
+ device: Optional[torch.device],
+ dtype: Optional[torch.dtype],
+ num_latents: int,
+ out_channels: int,
+ fourier_embedder: FourierEmbedder,
+ width: int,
+ heads: int,
+ init_scale: float = 0.25,
+ qkv_bias: bool = True,
+ flash: bool = False,
+ use_checkpoint: bool = False,
+ mlp_width_scale: int = 4,
+ supervision_type: str = 'occupancy'):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+ self.fourier_embedder = fourier_embedder
+ self.supervision_type = supervision_type
+
+ self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
+
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
+ device=device,
+ dtype=dtype,
+ n_data=num_latents,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ mlp_width_scale=mlp_width_scale,
+ )
+
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
+ self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
+ if self.supervision_type == 'occupancy-sdf':
+ self.output_proj_sdf = nn.Linear(width, out_channels, device=device, dtype=dtype)
+
+
+
+ def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
+ if next(self.query_proj.parameters()).dtype == torch.float16:
+ queries = queries.half()
+ latents = latents.half()
+ # print(f"queries: {queries.dtype}, {queries.device}")
+ # print(f"latents: {latents.dtype}, {latents.device}"z)
+ queries = self.query_proj(self.fourier_embedder(queries))
+ x = self.cross_attn_decoder(queries, latents)
+ x = self.ln_post(x)
+ x_1 = self.output_proj(x)
+ if self.supervision_type == 'occupancy-sdf':
+ x_2 = self.output_proj_sdf(x)
+ return x_1, x_2
+ else:
+ return x_1
+
+ def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
+ return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
+
+
+class ShapeAsLatentPerceiver(ShapeAsLatentModule):
+ def __init__(self, *,
+ device: Optional[torch.device],
+ dtype: Optional[torch.dtype],
+ num_latents: int,
+ point_feats: int = 0,
+ embed_dim: int = 0,
+ num_freqs: int = 8,
+ include_pi: bool = True,
+ width: int,
+ heads: int,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ decoder_width: Optional[int] = None,
+ init_scale: float = 0.25,
+ qkv_bias: bool = True,
+ flash: bool = False,
+ use_ln_post: bool = False,
+ use_checkpoint: bool = False,
+ supervision_type: str = 'occupancy',
+ query_method: bool = False,
+ token_num: int = 256,
+ grad_type: str = "numerical",
+ grad_interval: float = 0.005,
+ use_full_input: bool = True,
+ freeze_encoder: bool = False,
+ decoder_mlp_width_scale: int = 4,
+ residual_kl: bool = False,
+ ):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+
+ self.num_latents = num_latents
+ assert grad_type in ["numerical", "analytical"]
+ self.grad_type = grad_type
+ self.grad_interval = grad_interval
+ self.supervision_type = supervision_type
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+
+ init_scale = init_scale * math.sqrt(1.0 / width)
+ self.encoder = CrossAttentionEncoder(
+ device=device,
+ dtype=dtype,
+ fourier_embedder=self.fourier_embedder,
+ num_latents=num_latents,
+ point_feats=point_feats,
+ width=width,
+ heads=heads,
+ layers=num_encoder_layers,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_ln_post=use_ln_post,
+ use_checkpoint=use_checkpoint,
+ query_method=query_method,
+ use_full_input=use_full_input,
+ token_num=token_num
+ )
+
+ self.embed_dim = embed_dim
+ self.residual_kl = residual_kl
+ if decoder_width is None:
+ decoder_width = width
+ if embed_dim > 0:
+ # VAE embed
+ self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
+ self.post_kl = nn.Linear(embed_dim, decoder_width, device=device, dtype=dtype)
+ self.latent_shape = (num_latents, embed_dim)
+ if self.residual_kl:
+ assert self.post_kl.out_features % self.post_kl.in_features == 0
+ assert self.pre_kl.in_features % self.pre_kl.out_features == 0
+ else:
+ self.latent_shape = (num_latents, width)
+
+ self.transformer = Transformer(
+ device=device,
+ dtype=dtype,
+ n_ctx=num_latents,
+ width=decoder_width,
+ layers=num_decoder_layers,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_checkpoint=use_checkpoint
+ )
+
+ # geometry decoder
+ self.geo_decoder = CrossAttentionDecoder(
+ device=device,
+ dtype=dtype,
+ fourier_embedder=self.fourier_embedder,
+ out_channels=1,
+ num_latents=num_latents,
+ width=decoder_width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_checkpoint=use_checkpoint,
+ supervision_type=supervision_type,
+ mlp_width_scale=decoder_mlp_width_scale
+ )
+
+ if freeze_encoder:
+ for p in self.encoder.parameters():
+ p.requires_grad = False
+ for p in self.pre_kl.parameters():
+ p.requires_grad = False
+ print("freeze encoder and pre kl")
+
+ def forward(self,
+ pc: torch.FloatTensor,
+ feats: torch.FloatTensor,
+ volume_queries: torch.FloatTensor,
+ sample_posterior: bool = True):
+ """
+
+ Args:
+ pc (torch.FloatTensor): [B, N, 3]
+ feats (torch.FloatTensor or None): [B, N, C]
+ volume_queries (torch.FloatTensor): [B, P, 3]
+ sample_posterior (bool):
+
+ Returns:
+ logits (torch.FloatTensor): [B, P]
+ center_pos (torch.FloatTensor): [B, M, 3]
+ posterior (DiagonalGaussianDistribution or None).
+
+ """
+
+ latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
+
+ latents = self.decode(latents)
+ logits = self.query_geometry(volume_queries, latents)
+
+ return logits, center_pos, posterior
+
+
+class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):
+
+ def __init__(self, *,
+ device: Optional[torch.device],
+ dtype: Optional[str],
+ num_latents: int,
+ point_feats: int = 0,
+ embed_dim: int = 0,
+ num_freqs: int = 8,
+ include_pi: bool = True,
+ width: int,
+ heads: int,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ decoder_width: Optional[int] = None,
+ init_scale: float = 0.25,
+ qkv_bias: bool = True,
+ flash: bool = False,
+ use_ln_post: bool = False,
+ use_checkpoint: bool = False,
+ supervision_type: str = 'occupancy',
+ grad_type: str = "numerical",
+ grad_interval: float = 0.005,
+ query_method: bool = False,
+ use_full_input: bool = True,
+ token_num: int = 256,
+ freeze_encoder: bool = False,
+ decoder_mlp_width_scale: int = 4,
+ residual_kl: bool = False,
+ ):
+
+ MAP_DTYPE = {
+ 'float32': torch.float32,
+ 'float16': torch.float16,
+ 'bfloat16': torch.bfloat16,
+ }
+ if dtype is not None:
+ dtype = MAP_DTYPE[dtype]
+ super().__init__(
+ device=device,
+ dtype=dtype,
+ num_latents=1 + num_latents,
+ point_feats=point_feats,
+ embed_dim=embed_dim,
+ num_freqs=num_freqs,
+ include_pi=include_pi,
+ width=width,
+ decoder_width=decoder_width,
+ heads=heads,
+ num_encoder_layers=num_encoder_layers,
+ num_decoder_layers=num_decoder_layers,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_ln_post=use_ln_post,
+ use_checkpoint=use_checkpoint,
+ supervision_type=supervision_type,
+ grad_type=grad_type,
+ grad_interval=grad_interval,
+ query_method=query_method,
+ token_num=token_num,
+ use_full_input=use_full_input,
+ freeze_encoder=freeze_encoder,
+ decoder_mlp_width_scale=decoder_mlp_width_scale,
+ residual_kl=residual_kl,
+ )
+
+ self.width = width
+
+ def encode(self,
+ pc: torch.FloatTensor,
+ feats: Optional[torch.FloatTensor] = None,
+ sample_posterior: bool = True,
+ only_shape: bool=False):
+ """
+
+ Args:
+ pc (torch.FloatTensor): [B, N, 3]
+ feats (torch.FloatTensor or None): [B, N, c]
+ sample_posterior (bool):
+
+ Returns:
+ shape_embed (torch.FloatTensor)
+ kl_embed (torch.FloatTensor):
+ posterior (DiagonalGaussianDistribution or None):
+ """
+
+ shape_embed, latents, token_num, pre_pc = self.encode_latents(pc, feats)
+ if only_shape:
+ return shape_embed
+ kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)
+
+ return shape_embed, kl_embed, posterior, token_num, pre_pc
+
+ def encode_latents(self,
+ pc: torch.FloatTensor,
+ feats: Optional[torch.FloatTensor] = None):
+
+ x, _, token_num, pre_pc = self.encoder(pc, feats)
+
+ shape_embed = x[:, 0]
+ # latents = x[:, 1:]
+ # use all tokens
+ latents = x
+
+ return shape_embed, latents, token_num, pre_pc
+
+ def forward(self,
+ pc: torch.FloatTensor,
+ feats: torch.FloatTensor,
+ volume_queries: torch.FloatTensor,
+ sample_posterior: bool = True):
+ raise NotImplementedError()
+
+#####################################################
+# a simplified verstion of perceiver encoder
+#####################################################
+
+class ShapeAsLatentPerceiverEncoder(ShapeAsLatentModule):
+ def __init__(self, *,
+ device: Optional[torch.device],
+ dtype: Optional[Union[torch.dtype, str]],
+ num_latents: int,
+ point_feats: int = 0,
+ embed_dim: int = 0,
+ num_freqs: int = 8,
+ include_pi: bool = True,
+ width: int,
+ heads: int,
+ num_encoder_layers: int,
+ init_scale: float = 0.25,
+ qkv_bias: bool = True,
+ flash: bool = False,
+ use_ln_post: bool = False,
+ use_checkpoint: bool = False,
+ supervision_type: str = 'occupancy',
+ query_method: bool = False,
+ token_num: int = 256,
+ grad_type: str = "numerical",
+ grad_interval: float = 0.005,
+ use_full_input: bool = True,
+ freeze_encoder: bool = False,
+ residual_kl: bool = False,
+ ):
+
+ super().__init__()
+
+
+ MAP_DTYPE = {
+ 'float32': torch.float32,
+ 'float16': torch.float16,
+ 'bfloat16': torch.bfloat16,
+ }
+
+ if dtype is not None and isinstance(dtype, str):
+ dtype = MAP_DTYPE[dtype]
+
+ self.use_checkpoint = use_checkpoint
+
+ self.num_latents = num_latents
+ assert grad_type in ["numerical", "analytical"]
+ self.grad_type = grad_type
+ self.grad_interval = grad_interval
+ self.supervision_type = supervision_type
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+
+ init_scale = init_scale * math.sqrt(1.0 / width)
+ self.encoder = CrossAttentionEncoder(
+ device=device,
+ dtype=dtype,
+ fourier_embedder=self.fourier_embedder,
+ num_latents=num_latents,
+ point_feats=point_feats,
+ width=width,
+ heads=heads,
+ layers=num_encoder_layers,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_ln_post=use_ln_post,
+ use_checkpoint=use_checkpoint,
+ query_method=query_method,
+ use_full_input=use_full_input,
+ token_num=token_num,
+ no_query=True,
+ )
+
+ self.embed_dim = embed_dim
+ self.residual_kl = residual_kl
+ if freeze_encoder:
+ for p in self.encoder.parameters():
+ p.requires_grad = False
+ print("freeze encoder")
+ self.width = width
+
+ def encode_latents(self,
+ pc: torch.FloatTensor,
+ feats: Optional[torch.FloatTensor] = None):
+
+ x, _, token_num, pre_pc = self.encoder(pc, feats)
+
+ shape_embed = x[:, 0]
+ latents = x
+
+ return shape_embed, latents, token_num, pre_pc
+
+ def forward(self):
+ raise NotImplementedError()
\ No newline at end of file
diff --git a/UniRig/src/model/michelangelo/models/tsal/tsal_base.py b/UniRig/src/model/michelangelo/models/tsal/tsal_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c785fe23df80a59ead9f165b7f209fe9721b95d5
--- /dev/null
+++ b/UniRig/src/model/michelangelo/models/tsal/tsal_base.py
@@ -0,0 +1,141 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of UniRig.
+#
+# This file is derived from https://github.com/NeuralCarver/Michelangelo
+#
+# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors
+# Copyright (c) 2025 VAST-AI-Research and contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import torch.nn as nn
+from typing import Tuple, List, Optional
+import lightning.pytorch as pl
+
+
+class Point2MeshOutput(object):
+ def __init__(self):
+ self.mesh_v = None
+ self.mesh_f = None
+ self.center = None
+ self.pc = None
+
+
+class Latent2MeshOutput(object):
+
+ def __init__(self):
+ self.mesh_v = None
+ self.mesh_f = None
+
+
+class AlignedMeshOutput(object):
+
+ def __init__(self):
+ self.mesh_v = None
+ self.mesh_f = None
+ self.surface = None
+ self.image = None
+ self.text: Optional[str] = None
+ self.shape_text_similarity: Optional[float] = None
+ self.shape_image_similarity: Optional[float] = None
+
+
+class ShapeAsLatentPLModule(pl.LightningModule):
+ latent_shape: Tuple[int]
+
+ def encode(self, surface, *args, **kwargs):
+ raise NotImplementedError
+
+ def decode(self, z_q, *args, **kwargs):
+ raise NotImplementedError
+
+ def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
+ raise NotImplementedError
+
+ def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
+ raise NotImplementedError
+
+
+class ShapeAsLatentModule(nn.Module):
+ latent_shape: Tuple[int, int]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def decode(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def query_geometry(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class AlignedShapeAsLatentPLModule(pl.LightningModule):
+ latent_shape: Tuple[int]
+
+ def set_shape_model_only(self):
+ raise NotImplementedError
+
+ def encode(self, surface, *args, **kwargs):
+ raise NotImplementedError
+
+ def decode(self, z_q, *args, **kwargs):
+ raise NotImplementedError
+
+ def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
+ raise NotImplementedError
+
+ def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
+ raise NotImplementedError
+
+
+class AlignedShapeAsLatentModule(nn.Module):
+ shape_model: ShapeAsLatentModule
+ latent_shape: Tuple[int, int]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def set_shape_model_only(self):
+ raise NotImplementedError
+
+ def encode_image_embed(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def encode_text_embed(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def encode_shape_embed(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class TexturedShapeAsLatentModule(nn.Module):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def decode(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def query_geometry(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def query_color(self, *args, **kwargs):
+ raise NotImplementedError
diff --git a/UniRig/src/model/parse.py b/UniRig/src/model/parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..be065d388725f5876d6c4ace8a47fe1aec379bb1
--- /dev/null
+++ b/UniRig/src/model/parse.py
@@ -0,0 +1,14 @@
+from .unirig_ar import UniRigAR
+from .unirig_skin import UniRigSkin
+
+from .spec import ModelSpec
+
+def get_model(**kwargs) -> ModelSpec:
+ MAP = {
+ 'unirig_ar': UniRigAR,
+ 'unirig_skin': UniRigSkin,
+ }
+ __target__ = kwargs['__target__']
+ del kwargs['__target__']
+ assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
+ return MAP[__target__](**kwargs)
\ No newline at end of file
diff --git a/UniRig/src/model/parse_encoder.py b/UniRig/src/model/parse_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d51f0f63ced6bc924ccf1ece871e6905d04da8b
--- /dev/null
+++ b/UniRig/src/model/parse_encoder.py
@@ -0,0 +1,28 @@
+from dataclasses import dataclass
+
+from .michelangelo.get_model import get_encoder as get_encoder_michelangelo
+from .michelangelo.get_model import AlignedShapeLatentPerceiver
+from .michelangelo.get_model import get_encoder_simplified as get_encoder_michelangelo_encoder
+from .michelangelo.get_model import ShapeAsLatentPerceiverEncoder
+from .pointcept.models.PTv3Object import get_encoder as get_encoder_ptv3obj
+from .pointcept.models.PTv3Object import PointTransformerV3Object
+
+@dataclass(frozen=True)
+class _MAP_MESH_ENCODER:
+ ptv3obj = PointTransformerV3Object
+ michelangelo = AlignedShapeLatentPerceiver
+ michelangelo_encoder = ShapeAsLatentPerceiverEncoder
+
+MAP_MESH_ENCODER = _MAP_MESH_ENCODER()
+
+
+def get_mesh_encoder(**kwargs):
+ MAP = {
+ 'ptv3obj': get_encoder_ptv3obj,
+ 'michelangelo': get_encoder_michelangelo,
+ 'michelangelo_encoder': get_encoder_michelangelo_encoder,
+ }
+ __target__ = kwargs['__target__']
+ del kwargs['__target__']
+ assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
+ return MAP[__target__](**kwargs)
\ No newline at end of file
diff --git a/UniRig/src/model/pointcept/LICENSE b/UniRig/src/model/pointcept/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..9293bb11a97a3da9232d22db6cc3005e8df219a7
--- /dev/null
+++ b/UniRig/src/model/pointcept/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Pointcept
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/UniRig/src/model/pointcept/README.md b/UniRig/src/model/pointcept/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7585d78c428e3a0b9e593dde0c2b1a958c545d9d
--- /dev/null
+++ b/UniRig/src/model/pointcept/README.md
@@ -0,0 +1 @@
+original repo: https://github.com/Pointcept/SAMPart3D/tree/main
\ No newline at end of file
diff --git a/UniRig/src/model/pointcept/__init__.py b/UniRig/src/model/pointcept/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/UniRig/src/model/pointcept/datasets/__init__.py b/UniRig/src/model/pointcept/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e6d08bd5d43f12670958df4687791b5b073d815
--- /dev/null
+++ b/UniRig/src/model/pointcept/datasets/__init__.py
@@ -0,0 +1,4 @@
+from .builder import build_dataset
+from .utils import point_collate_fn, collate_fn
+
+from .dataset_render_16views import SAMPart3DDataset16Views
\ No newline at end of file
diff --git a/UniRig/src/model/pointcept/datasets/builder.py b/UniRig/src/model/pointcept/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b01c6defba0a8afc7f835184744dc7f6356fb82
--- /dev/null
+++ b/UniRig/src/model/pointcept/datasets/builder.py
@@ -0,0 +1,8 @@
+from pointcept.utils.registry import Registry
+
+DATASETS = Registry("datasets")
+
+
+def build_dataset(cfg):
+ """Build datasets."""
+ return DATASETS.build(cfg)
diff --git a/UniRig/src/model/pointcept/datasets/dataset_render_16views.py b/UniRig/src/model/pointcept/datasets/dataset_render_16views.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ed6e7ff1c929b4ad4884a76b7b89177e321afc6
--- /dev/null
+++ b/UniRig/src/model/pointcept/datasets/dataset_render_16views.py
@@ -0,0 +1,438 @@
+import os
+os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
+from os.path import join
+import glob
+import numpy as np
+import torch
+import trimesh
+import json
+import cv2
+import pointops
+from copy import deepcopy
+from torch.utils.data import Dataset
+from collections.abc import Sequence
+from transformers import pipeline, SamModel
+from PIL import Image
+
+from pointcept.utils.logger import get_root_logger
+from pointcept.utils.cache import shared_dict
+from .builder import DATASETS
+from .transform import Compose, TRANSFORMS
+from .sampart3d_util import *
+
+
+@DATASETS.register_module()
+class SAMPart3DDataset16Views(Dataset):
+
+ def __init__(
+ self,
+ split="train",
+ data_root="data/scannet",
+ mesh_root="",
+ mesh_path_mapping=None,
+ oid="",
+ label="",
+ sample_num=15000,
+ pixels_per_image=256,
+ batch_size=90,
+ transform=None,
+ loop=1,
+ extent_scale=10.0
+ ):
+ super(SAMPart3DDataset16Views, self).__init__()
+
+ data_root = os.path.join(data_root, str(oid))
+ mesh_path = os.path.join(mesh_root, f"{oid}.glb")
+ self.data_root = data_root
+ self.split = split
+ self.pixels_per_image = pixels_per_image
+ self.batch_size = batch_size
+ self.device = 'cuda'
+ self.logger = get_root_logger()
+
+ self.extent_scale = extent_scale
+
+ self.meta_data = json.load(open(os.path.join(data_root, "meta.json")))
+
+ # Load mesh and sample pointclouds
+ self.mesh_path = mesh_path
+ transform = Compose(transform)
+ self.load_mesh(mesh_path, transform, sample_num)
+
+ # Prepare SAM masks and depth mapping
+ if self.split == "train":
+
+ self.prepare_meta_data()
+
+ self.loop = loop
+ self.data_list = self.get_data_list()
+ self.logger.info(
+ "Totally {} x {} samples in {} set.".format(
+ len(self.data_list), self.loop, split
+ )
+ )
+
+ def sample_pixel(self, masks, image_height=512, image_width=512):
+ masks = masks.to(self.device)
+ indices_batch = torch.zeros((self.batch_size*self.pixels_per_image, 3), device=self.device)
+ random_imgs = torch.randint(0, len(masks), (self.batch_size,), device=self.device)
+ for i in range(self.batch_size):
+ # Find the indices of the valid points in the mask
+ valid_indices = torch.nonzero(masks[random_imgs[i]], as_tuple=False)
+ # if len(valid_indices) == 0:
+ # continue
+ # Randomly sample from the valid indices
+ if len(valid_indices) >= self.pixels_per_image:
+ indices = valid_indices[torch.randint(0, len(valid_indices), (self.pixels_per_image,))]
+ else:
+ # Repeat the indices to fill up to pixels_per_image
+ repeat_times = self.pixels_per_image // len(valid_indices) + 1
+ indices = valid_indices.repeat(repeat_times, 1)[:self.pixels_per_image]
+
+ indices_batch[i * self.pixels_per_image : (i + 1) * self.pixels_per_image, 0] = random_imgs[i]
+ indices_batch[i * self.pixels_per_image : (i + 1) * self.pixels_per_image, 1:] = indices
+
+ return indices_batch
+
+
+ def load_mesh(self, mesh_path, transform, sample_num=15000, pcd_path=None):
+ mesh = trimesh.load(mesh_path)
+ if isinstance(mesh, trimesh.Scene):
+ mesh = mesh.dump(concatenate=True)
+ coord, face_index, color = sample_surface(mesh, count=sample_num, sample_color=True)
+ color = color[..., :3]
+ face_normals = mesh.face_normals
+ normal = face_normals[face_index]
+ # self.mesh_scale, self.mesh_center_offset = cal_scale(mesh_path)
+ mesh_scale = self.meta_data["scaling_factor"]
+ mesh_center_offset = self.meta_data["mesh_offset"]
+
+ object_org_coord = coord.copy()
+ rotation_matrix = np.array([
+ [1, 0, 0],
+ [0, 0, 1],
+ [0, -1, 0]])
+ object_org_coord = np.dot(object_org_coord, rotation_matrix)
+ object_org_coord = object_org_coord * mesh_scale + mesh_center_offset
+
+ offset = torch.tensor(coord.shape[0])
+ obj = dict(coord=coord, normal=normal, color=color, offset=offset, origin_coord=object_org_coord, face_index=face_index)
+ obj = transform(obj)
+ self.object_org_coord = obj["origin_coord"].clone()
+ self.face_index = obj["face_index"].clone().numpy()
+ self.pcd_inverse = obj["inverse"].clone().numpy()
+ # print("object_org_coord", torch.unique(self.object_org_coord, return_counts=True))
+ del obj["origin_coord"], obj["face_index"], obj["inverse"]
+ self.object = obj
+
+
+
+ def prepare_meta_data(self, data_path=None):
+ SAM_model = pipeline("mask-generation", model="facebook/sam-vit-huge", device=self.device)
+ pixel_level_keys_list = []
+ scale_list = []
+ group_cdf_list = []
+ depth_valid_list = []
+ mapping_list = []
+ mapping_valid_list = []
+ object_org_coord = self.object_org_coord.to(self.device).contiguous().float()
+ obj_offset = torch.tensor(object_org_coord.shape[0]).to(self.device)
+
+ camera_angle_x = self.meta_data['camera_angle_x']
+ for i, c2w_opengl in enumerate(self.meta_data["transforms"]):
+ # print(frame['index'])
+ c2w_opengl = np.array(c2w_opengl)
+ self.logger.info(f"Processing frame_{i}")
+ rgb_path = join(self.data_root, f"render_{i:04d}.webp")
+ img = np.array(Image.open(rgb_path))
+ if img.shape[-1] == 4:
+ mask_img = img[..., 3] == 0
+ img[mask_img] = [255, 255, 255, 255]
+ img = img[..., :3]
+ img = Image.fromarray(img.astype('uint8'))
+
+ # Calculate mapping
+ depth_path = join(self.data_root, f"depth_{i:04d}.exr")
+ depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
+ depth = depth[..., 0]
+ depth_valid = torch.tensor(depth < 65500.0)
+ org_points = gen_pcd(depth, c2w_opengl, camera_angle_x)
+ org_points = torch.from_numpy(org_points)
+ points_tensor = org_points.to(self.device).contiguous().float()
+ offset = torch.tensor(points_tensor.shape[0]).to(self.device)
+ indices, distances = pointops.knn_query(1, object_org_coord, obj_offset, points_tensor, offset)
+ mapping = torch.zeros((depth.shape[0], depth.shape[1]), dtype=torch.int) - 1
+
+ # Create a mask where distances are less than 0.03
+ mask_dis = distances[..., 0] < 0.03
+ indices[~mask_dis] = -1
+ mapping[depth_valid] = indices.cpu().flatten()
+ mapping_valid = mapping != -1
+
+ # Calculate groups
+ try:
+ masks = SAM_model(img, points_per_side=32, pred_iou_thresh=0.9, stability_score_thresh=0.9)
+ masks = masks['masks']
+ masks = sorted(masks, key=lambda x: x.sum())
+ except:
+ masks = []
+
+ # mask filter
+ masks_filtered = []
+ img_valid = ~mask_img
+ for mask in masks:
+ valid_ratio = mask[img_valid].sum() / img_valid.sum()
+ invalid_ratio = mask[mask_img].sum() / mask_img.sum()
+ if valid_ratio == 0 or invalid_ratio > 0.1:
+ continue
+ else:
+ masks_filtered.append(mask)
+ pixel_level_keys, scale, mask_cdf = self._calculate_3d_groups(torch.from_numpy(depth), mapping_valid, masks_filtered, points_tensor[mask_dis])
+
+ pixel_level_keys_list.append(pixel_level_keys)
+ scale_list.append(scale)
+ group_cdf_list.append(mask_cdf)
+ depth_valid_list.append(depth_valid)
+ mapping_list.append(mapping)
+ mapping_valid_list.append(mapping_valid)
+
+ self.pixel_level_keys = torch.nested.nested_tensor(
+ pixel_level_keys_list
+ )
+ self.scale_3d_statistics = torch.cat(scale_list)
+ self.scale_3d = torch.nested.nested_tensor(scale_list)
+ self.group_cdf = torch.nested.nested_tensor(group_cdf_list)
+ self.depth_valid = torch.stack(depth_valid_list)
+ self.mapping = torch.stack(mapping_list)
+ self.mapping_valid = torch.stack(mapping_valid_list)
+
+ def _calculate_3d_groups(
+ self,
+ depth: torch.Tensor,
+ valid: torch.Tensor,
+ masks: torch.Tensor,
+ point: torch.Tensor,
+ max_scale: float = 2.0,
+ ):
+ """
+ Calculate the set of groups and their 3D scale for each pixel, and the cdf.
+ Returns:
+ - pixel_level_keys: [H, W, max_masks]
+ - scale: [num_masks, 1]
+ - mask_cdf: [H, W, max_masks]
+ max_masks is the maximum number of masks that was assigned to a pixel in the image,
+ padded with -1s. mask_cdf does *not* include the -1s.
+ Refer to the main paper for more details.
+ """
+ image_shape = depth.shape[:2]
+ depth_valid = valid
+ point = point.to(self.device)
+
+ def helper_return_no_masks():
+ # Fail gracefully when no masks are found.
+ # Create dummy data (all -1s), which will be ignored later.
+ # See: `get_loss_dict_group` in `garfield_model.py`
+ pixel_level_keys = torch.full(
+ (image_shape[0], image_shape[1], 1), -1, dtype=torch.int
+ )
+ scale = torch.Tensor([0.0]).view(-1, 1)
+ mask_cdf = torch.full(
+ (image_shape[0], image_shape[1], 1), 1, dtype=torch.float
+ )
+ return (pixel_level_keys, scale, mask_cdf)
+
+
+ # If no masks are found, return dummy data.
+ if len(masks) == 0:
+ return helper_return_no_masks()
+
+ sam_mask = []
+ scale = []
+
+ # For all 2D groups,
+ # 1) Denoise the masks (through eroding)
+ all_masks = torch.stack(
+ # [torch.from_numpy(_["segmentation"]).to(self.device) for _ in masks]
+ [torch.from_numpy(_).to(self.device) for _ in masks]
+ )
+ # erode all masks using 3x3 kernel
+ # ignore erode
+ eroded_masks = torch.conv2d(
+ all_masks.unsqueeze(1).float(),
+ torch.full((3, 3), 1.0).view(1, 1, 3, 3).to("cuda"),
+ padding=1,
+ )
+ eroded_masks = (eroded_masks >= 5).squeeze(1) # (num_masks, H, W)
+
+ # 2) Calculate 3D scale
+ # Don't include groups with scale > max_scale (likely to be too noisy to be useful)
+ for i in range(len(masks)):
+ curr_mask_org = eroded_masks[i]
+ curr_mask = curr_mask_org[depth_valid]
+ curr_points = point[curr_mask]
+ extent = (curr_points.std(dim=0) * self.extent_scale).norm()
+ if extent.item() < max_scale:
+ sam_mask.append(curr_mask_org)
+ scale.append(extent.item())
+
+ # If no masks are found, after postprocessing, return dummy data.
+ if len(sam_mask) == 0:
+ return helper_return_no_masks()
+
+ sam_mask = torch.stack(sam_mask) # (num_masks, H, W)
+ scale = torch.Tensor(scale).view(-1, 1).to(self.device) # (num_masks, 1)
+
+ # Calculate "pixel level keys", which is a 2D array of shape (H, W, max_masks)
+ # Each pixel has a list of group indices that it belongs to, in order of increasing scale.
+ pixel_level_keys = self.create_pixel_mask_array(
+ sam_mask
+ ).long() # (H, W, max_masks)
+ depth_invalid = ~depth_valid
+ pixel_level_keys[depth_invalid, :] = -1
+
+ # Calculate group sampling CDF, to bias sampling towards smaller groups
+ # Be careful to not include -1s in the CDF (padding, or unlabeled pixels)
+ # Inversely proportional to log of mask size.
+ mask_inds, counts = torch.unique(pixel_level_keys, return_counts=True)
+ counts[0] = 0 # don't include -1
+ probs = counts / counts.sum() # [-1, 0, ...]
+
+ pixel_shape = pixel_level_keys.shape
+ if (pixel_level_keys.max()+2) != probs.shape[0]:
+ pixel_level_keys_new = pixel_level_keys.reshape(-1)
+ unique_values, inverse_indices = torch.unique(pixel_level_keys_new, return_inverse=True)
+ pixel_level_keys_new = inverse_indices.reshape(-1)
+ else:
+ pixel_level_keys_new = pixel_level_keys.reshape(-1) + 1
+
+ mask_probs = torch.gather(probs, 0, pixel_level_keys.reshape(-1) + 1).view(
+ pixel_shape
+ )
+ mask_log_probs = torch.log(mask_probs)
+ never_masked = mask_log_probs.isinf()
+ mask_log_probs[never_masked] = 0.0
+ mask_log_probs = mask_log_probs / (
+ mask_log_probs.sum(dim=-1, keepdim=True) + 1e-6
+ )
+ mask_cdf = torch.cumsum(mask_log_probs, dim=-1)
+ mask_cdf[never_masked] = 1.0
+
+ return (pixel_level_keys.cpu(), scale.cpu(), mask_cdf.cpu())
+
+ @staticmethod
+ def create_pixel_mask_array(masks: torch.Tensor):
+ """
+ Create per-pixel data structure for grouping supervision.
+ pixel_mask_array[x, y] = [m1, m2, ...] means that pixel (x, y) belongs to masks m1, m2, ...
+ where Area(m1) < Area(m2) < ... (sorted by area).
+ """
+ max_masks = masks.sum(dim=0).max().item()
+ # print(max_masks)
+ image_shape = masks.shape[1:]
+ pixel_mask_array = torch.full(
+ (max_masks, image_shape[0], image_shape[1]), -1, dtype=torch.int
+ ).to(masks.device)
+
+ for m, mask in enumerate(masks):
+ mask_clone = mask.clone()
+ for i in range(max_masks):
+ free = pixel_mask_array[i] == -1
+ masked_area = mask_clone == 1
+ right_index = free & masked_area
+ if len(pixel_mask_array[i][right_index]) != 0:
+ pixel_mask_array[i][right_index] = m
+ mask_clone[right_index] = 0
+ pixel_mask_array = pixel_mask_array.permute(1, 2, 0)
+
+ return pixel_mask_array
+
+ def get_data_list(self):
+ data_list = glob.glob(os.path.join(self.data_root, "*.exr"))
+ return data_list
+
+ def get_data(self, idx):
+ indices = self.sample_pixel(self.mapping_valid, 512, 512).long().detach().cpu()
+ npximg = self.pixels_per_image
+ img_ind = indices[:, 0]
+ x_ind = indices[:, 1]
+ y_ind = indices[:, 2]
+
+ # sampled_imgs = img_ind[::npximg]
+ mask_id = torch.zeros((indices.shape[0],), device=self.device)
+ scale = torch.zeros((indices.shape[0],), device=self.device)
+ mapping = torch.zeros((indices.shape[0],), device=self.device)
+
+ random_vec_sampling = (torch.rand((1,)) * torch.ones((npximg,))).view(-1, 1)
+ random_vec_densify = (torch.rand((1,)) * torch.ones((npximg,))).view(-1, 1)
+
+ for i in range(0, indices.shape[0], npximg):
+ img_idx = img_ind[i]
+
+ # calculate mapping
+ mapping[i : i + npximg] = self.mapping[img_idx][x_ind[i : i + npximg], y_ind[i : i + npximg]]
+
+ # Use `random_vec` to choose a group for each pixel.
+ per_pixel_index = self.pixel_level_keys[img_idx][
+ x_ind[i : i + npximg], y_ind[i : i + npximg]
+ ]
+ random_index = torch.sum(
+ random_vec_sampling.view(-1, 1)
+ > self.group_cdf[img_idx][x_ind[i : i + npximg], y_ind[i : i + npximg]],
+ dim=-1,
+ )
+
+ # `per_pixel_index` encodes the list of groups that each pixel belongs to.
+ # If there's only one group, then `per_pixel_index` is a 1D tensor
+ # -- this will mess up the future `gather` operations.
+ if per_pixel_index.shape[-1] == 1:
+ per_pixel_mask = per_pixel_index.squeeze()
+ else:
+ per_pixel_mask = torch.gather(
+ per_pixel_index, 1, random_index.unsqueeze(-1)
+ ).squeeze()
+ per_pixel_mask_ = torch.gather(
+ per_pixel_index,
+ 1,
+ torch.max(random_index.unsqueeze(-1) - 1, torch.Tensor([0]).int()),
+ ).squeeze()
+
+ mask_id[i : i + npximg] = per_pixel_mask.to(self.device)
+
+ # interval scale supervision
+ curr_scale = self.scale_3d[img_idx][per_pixel_mask]
+ curr_scale[random_index == 0] = (
+ self.scale_3d[img_idx][per_pixel_mask][random_index == 0]
+ * random_vec_densify[random_index == 0]
+ )
+ for j in range(1, self.group_cdf[img_idx].shape[-1]):
+ if (random_index == j).sum() == 0:
+ continue
+ curr_scale[random_index == j] = (
+ self.scale_3d[img_idx][per_pixel_mask_][random_index == j]
+ + (
+ self.scale_3d[img_idx][per_pixel_mask][random_index == j]
+ - self.scale_3d[img_idx][per_pixel_mask_][random_index == j]
+ )
+ * random_vec_densify[random_index == j]
+ )
+ scale[i : i + npximg] = curr_scale.squeeze().to(self.device)
+
+ batch = dict()
+ batch["mask_id"] = mask_id
+ batch["scale"] = scale
+ batch["nPxImg"] = npximg
+ batch["obj"] = self.object
+ batch["mapping"] = mapping.long()
+ return batch
+
+ def val_data(self):
+ return dict(obj=self.object)
+
+ def get_data_name(self, idx):
+ return os.path.basename(self.data_list[idx % len(self.data_list)]).split(".")[0]
+
+ def __getitem__(self, idx):
+ return self.get_data(idx % len(self.data_list))
+
+ def __len__(self):
+ return len(self.data_list) * self.loop
diff --git a/UniRig/src/model/pointcept/datasets/sampart3d_util.py b/UniRig/src/model/pointcept/datasets/sampart3d_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e11a17ca59027a3750c1ccaa2c9f07a3f73fcab
--- /dev/null
+++ b/UniRig/src/model/pointcept/datasets/sampart3d_util.py
@@ -0,0 +1,152 @@
+import numpy as np
+import trimesh
+import os
+import json
+import math
+import open3d as o3d
+import torch
+
+
+def sample_surface(mesh, count, face_weight=None, sample_color=False, seed=147):
+
+ if face_weight is None:
+ # len(mesh.faces) float, array of the areas
+ # of each face of the mesh
+ face_weight = mesh.area_faces
+
+ # cumulative sum of weights (len(mesh.faces))
+ weight_cum = np.cumsum(face_weight)
+
+ # seed the random number generator as requested
+ random = np.random.default_rng(seed).random
+
+ # last value of cumulative sum is total summed weight/area
+ face_pick = random(count) * weight_cum[-1]
+ # get the index of the selected faces
+ face_index = np.searchsorted(weight_cum, face_pick)
+
+ # pull triangles into the form of an origin + 2 vectors
+ tri_origins = mesh.vertices[mesh.faces[:, 0]]
+ tri_vectors = mesh.vertices[mesh.faces[:, 1:]].copy()
+ tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))
+
+ # pull the vectors for the faces we are going to sample from
+ tri_origins = tri_origins[face_index]
+ tri_vectors = tri_vectors[face_index]
+
+ if sample_color and hasattr(mesh.visual, "uv"):
+ uv_origins = mesh.visual.uv[mesh.faces[:, 0]]
+ uv_vectors = mesh.visual.uv[mesh.faces[:, 1:]].copy()
+ uv_origins_tile = np.tile(uv_origins, (1, 2)).reshape((-1, 2, 2))
+ uv_vectors -= uv_origins_tile
+ uv_origins = uv_origins[face_index]
+ uv_vectors = uv_vectors[face_index]
+
+ # randomly generate two 0-1 scalar components to multiply edge vectors b
+ random_lengths = random((len(tri_vectors), 2, 1))
+
+ # points will be distributed on a quadrilateral if we use 2 0-1 samples
+ # if the two scalar components sum less than 1.0 the point will be
+ # inside the triangle, so we find vectors longer than 1.0 and
+ # transform them to be inside the triangle
+ random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0
+ random_lengths[random_test] -= 1.0
+ random_lengths = np.abs(random_lengths)
+
+ # multiply triangle edge vectors by the random lengths and sum
+ sample_vector = (tri_vectors * random_lengths).sum(axis=1)
+
+ # finally, offset by the origin to generate
+ # (n,3) points in space on the triangle
+ samples = sample_vector + tri_origins
+
+ if sample_color:
+ if hasattr(mesh.visual, "uv"):
+ sample_uv_vector = (uv_vectors * random_lengths).sum(axis=1)
+ uv_samples = sample_uv_vector + uv_origins
+ try:
+ texture = mesh.visual.material.baseColorTexture
+ except:
+ texture = mesh.visual.material.image
+ colors = trimesh.visual.color.uv_to_interpolated_color(uv_samples, texture)
+ else:
+ colors = mesh.visual.face_colors[face_index]
+
+ return samples, face_index, colors
+
+ return samples, face_index
+
+
+def get_ray_directions(W, H, fx, fy, cx, cy, use_pixel_centers=True):
+ pixel_center = 0.5 if use_pixel_centers else 0
+ i, j = np.meshgrid(
+ np.arange(W, dtype=np.float32) + pixel_center,
+ np.arange(H, dtype=np.float32) + pixel_center,
+ indexing="xy",
+ )
+ directions = np.stack(
+ [(i - cx) / fx, -(j - cy) / fy, -np.ones_like(i)], -1
+ )
+
+ return directions
+
+
+def gen_pcd(depth, c2w_opengl, camera_angle_x):
+
+ h, w = depth.shape
+
+ depth_valid = depth < 65500.0
+ depth = depth[depth_valid]
+ focal = (
+ 0.5 * w / math.tan(0.5 * camera_angle_x)
+ ) # scaled focal length
+ ray_directions = get_ray_directions(w, h, focal, focal, w // 2, h // 2)
+ points_c = ray_directions[depth_valid] * depth[:, None]
+ points_c_homo = np.concatenate(
+ [points_c, np.ones_like(points_c[..., :1])], axis=-1
+ )
+ org_points = (points_c_homo @ c2w_opengl.T)[..., :3]
+
+ return org_points
+
+
+def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None):
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ coord = np.array(coord)
+ if color is not None:
+ color = np.array(color)
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(coord)
+ pcd.colors = o3d.utility.Vector3dVector(np.ones_like(coord) if color is None else color)
+ o3d.io.write_point_cloud(file_path, pcd)
+ if logger is not None:
+ logger.info(f"Save Point Cloud to: {file_path}")
+
+
+def vis_pcd_feat(coord, point_feat, save_path):
+ class TorchPCA(object):
+
+ def __init__(self, n_components):
+ self.n_components = n_components
+
+ def fit(self, X):
+ self.mean_ = X.mean(dim=0)
+ unbiased = X - self.mean_.unsqueeze(0)
+ U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4)
+ self.components_ = V.T
+ self.singular_values_ = S
+ return self
+
+ def transform(self, X):
+ t0 = X - self.mean_.unsqueeze(0)
+ projected = t0 @ self.components_.T
+ return projected
+
+ fit_pca = TorchPCA(n_components=3).fit(point_feat)
+ x_red = fit_pca.transform(point_feat)
+ if isinstance(x_red, np.ndarray):
+ x_red = torch.from_numpy(x_red)
+ x_red -= x_red.min(dim=0, keepdim=True).values
+ x_red /= x_red.max(dim=0, keepdim=True).values
+
+ save_point_cloud(coord.detach().cpu(), x_red.detach().cpu(), save_path)
\ No newline at end of file
diff --git a/UniRig/src/model/pointcept/datasets/transform.py b/UniRig/src/model/pointcept/datasets/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecb11b80fc28ad592e453cb69df23669f413aaa4
--- /dev/null
+++ b/UniRig/src/model/pointcept/datasets/transform.py
@@ -0,0 +1,1143 @@
+import random
+import numbers
+import scipy
+import scipy.ndimage
+import scipy.interpolate
+import scipy.stats
+import numpy as np
+import torch
+import copy
+from collections.abc import Sequence, Mapping
+
+from pointcept.utils.registry import Registry
+
+TRANSFORMS = Registry("transforms")
+
+
+@TRANSFORMS.register_module()
+class Collect(object):
+ def __init__(self, keys, offset_keys_dict=None, **kwargs):
+ """
+ e.g. Collect(keys=[coord], feat_keys=[coord, color])
+ """
+ if offset_keys_dict is None:
+ offset_keys_dict = dict(offset="coord")
+ self.keys = keys
+ self.offset_keys = offset_keys_dict
+ self.kwargs = kwargs
+
+ def __call__(self, data_dict):
+ data = dict()
+ if isinstance(self.keys, str):
+ self.keys = [self.keys]
+ for key in self.keys:
+ data[key] = data_dict[key]
+ for key, value in self.offset_keys.items():
+ data[key] = torch.tensor([data_dict[value].shape[0]])
+ for name, keys in self.kwargs.items():
+ name = name.replace("_keys", "")
+ assert isinstance(keys, Sequence)
+ data[name] = torch.cat([data_dict[key].float() for key in keys], dim=1)
+ return data
+
+
+@TRANSFORMS.register_module()
+class Copy(object):
+ def __init__(self, keys_dict=None):
+ if keys_dict is None:
+ keys_dict = dict(coord="origin_coord", segment="origin_segment")
+ self.keys_dict = keys_dict
+
+ def __call__(self, data_dict):
+ for key, value in self.keys_dict.items():
+ if isinstance(data_dict[key], np.ndarray):
+ data_dict[value] = data_dict[key].copy()
+ elif isinstance(data_dict[key], torch.Tensor):
+ data_dict[value] = data_dict[key].clone().detach()
+ else:
+ data_dict[value] = copy.deepcopy(data_dict[key])
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class ToTensor(object):
+ def __call__(self, data):
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, str):
+ # note that str is also a kind of sequence, judgement should before sequence
+ return data
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, bool):
+ return torch.from_numpy(data)
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.integer):
+ return torch.from_numpy(data).long()
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating):
+ return torch.from_numpy(data).float()
+ elif isinstance(data, Mapping):
+ result = {sub_key: self(item) for sub_key, item in data.items()}
+ return result
+ elif isinstance(data, Sequence):
+ result = [self(item) for item in data]
+ return result
+ else:
+ raise TypeError(f"type {type(data)} cannot be converted to tensor.")
+
+
+@TRANSFORMS.register_module()
+class Add(object):
+ def __init__(self, keys_dict=None):
+ if keys_dict is None:
+ keys_dict = dict()
+ self.keys_dict = keys_dict
+
+ def __call__(self, data_dict):
+ for key, value in self.keys_dict.items():
+ data_dict[key] = value
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class NormalizeColor(object):
+ def __call__(self, data_dict):
+ if "color" in data_dict.keys():
+ data_dict["color"] = data_dict["color"] / 127.5 - 1
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class NormalizeCoord(object):
+ def __call__(self, data_dict):
+ if "coord" in data_dict.keys():
+ # modified from pointnet2
+ centroid = np.mean(data_dict["coord"], axis=0)
+ data_dict["coord"] -= centroid
+ m = np.max(np.sqrt(np.sum(data_dict["coord"] ** 2, axis=1)))
+ data_dict["coord"] = data_dict["coord"] / m
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class PositiveShift(object):
+ def __call__(self, data_dict):
+ if "coord" in data_dict.keys():
+ coord_min = np.min(data_dict["coord"], 0)
+ data_dict["coord"] -= coord_min
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class CenterShift(object):
+ def __init__(self, apply_z=True):
+ self.apply_z = apply_z
+
+ def __call__(self, data_dict):
+ if "coord" in data_dict.keys():
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
+ x_max, y_max, _ = data_dict["coord"].max(axis=0)
+ if self.apply_z:
+ shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, z_min]
+ else:
+ shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, 0]
+ data_dict["coord"] -= shift
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class RandomShift(object):
+ def __init__(self, shift=((-0.2, 0.2), (-0.2, 0.2), (0, 0))):
+ self.shift = shift
+
+ def __call__(self, data_dict):
+ if "coord" in data_dict.keys():
+ shift_x = np.random.uniform(self.shift[0][0], self.shift[0][1])
+ shift_y = np.random.uniform(self.shift[1][0], self.shift[1][1])
+ shift_z = np.random.uniform(self.shift[2][0], self.shift[2][1])
+ data_dict["coord"] += [shift_x, shift_y, shift_z]
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class PointClip(object):
+ def __init__(self, point_cloud_range=(-80, -80, -3, 80, 80, 1)):
+ self.point_cloud_range = point_cloud_range
+
+ def __call__(self, data_dict):
+ if "coord" in data_dict.keys():
+ data_dict["coord"] = np.clip(
+ data_dict["coord"],
+ a_min=self.point_cloud_range[:3],
+ a_max=self.point_cloud_range[3:],
+ )
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class RandomDropout(object):
+ def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5):
+ """
+ upright_axis: axis index among x,y,z, i.e. 2 for z
+ """
+ self.dropout_ratio = dropout_ratio
+ self.dropout_application_ratio = dropout_application_ratio
+
+ def __call__(self, data_dict):
+ if random.random() < self.dropout_application_ratio:
+ n = len(data_dict["coord"])
+ idx = np.random.choice(n, int(n * (1 - self.dropout_ratio)), replace=False)
+ if "sampled_index" in data_dict:
+ # for ScanNet data efficient, we need to make sure labeled point is sampled.
+ idx = np.unique(np.append(idx, data_dict["sampled_index"]))
+ mask = np.zeros_like(data_dict["segment"]).astype(bool)
+ mask[data_dict["sampled_index"]] = True
+ data_dict["sampled_index"] = np.where(mask[idx])[0]
+ if "coord" in data_dict.keys():
+ data_dict["coord"] = data_dict["coord"][idx]
+ if "color" in data_dict.keys():
+ data_dict["color"] = data_dict["color"][idx]
+ if "normal" in data_dict.keys():
+ data_dict["normal"] = data_dict["normal"][idx]
+ if "strength" in data_dict.keys():
+ data_dict["strength"] = data_dict["strength"][idx]
+ if "segment" in data_dict.keys():
+ data_dict["segment"] = data_dict["segment"][idx]
+ if "instance" in data_dict.keys():
+ data_dict["instance"] = data_dict["instance"][idx]
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class RandomRotate(object):
+ def __init__(self, angle=None, center=None, axis="z", always_apply=False, p=0.5):
+ self.angle = [-1, 1] if angle is None else angle
+ self.axis = axis
+ self.always_apply = always_apply
+ self.p = p if not self.always_apply else 1
+ self.center = center
+
+ def __call__(self, data_dict):
+ if random.random() > self.p:
+ return data_dict
+ angle = np.random.uniform(self.angle[0], self.angle[1]) * np.pi
+ rot_cos, rot_sin = np.cos(angle), np.sin(angle)
+ if self.axis == "x":
+ rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]])
+ elif self.axis == "y":
+ rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]])
+ elif self.axis == "z":
+ rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]])
+ else:
+ raise NotImplementedError
+ if "coord" in data_dict.keys():
+ if self.center is None:
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
+ x_max, y_max, z_max = data_dict["coord"].max(axis=0)
+ center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2]
+ else:
+ center = self.center
+ data_dict["coord"] -= center
+ data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t))
+ data_dict["coord"] += center
+ if "normal" in data_dict.keys():
+ data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t))
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class RandomRotateTargetAngle(object):
+ def __init__(
+ self, angle=(1 / 2, 1, 3 / 2), center=None, axis="z", always_apply=False, p=0.75
+ ):
+ self.angle = angle
+ self.axis = axis
+ self.always_apply = always_apply
+ self.p = p if not self.always_apply else 1
+ self.center = center
+
+ def __call__(self, data_dict):
+ if random.random() > self.p:
+ return data_dict
+ angle = np.random.choice(self.angle) * np.pi
+ rot_cos, rot_sin = np.cos(angle), np.sin(angle)
+ if self.axis == "x":
+ rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]])
+ elif self.axis == "y":
+ rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]])
+ elif self.axis == "z":
+ rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]])
+ else:
+ raise NotImplementedError
+ if "coord" in data_dict.keys():
+ if self.center is None:
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
+ x_max, y_max, z_max = data_dict["coord"].max(axis=0)
+ center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2]
+ else:
+ center = self.center
+ data_dict["coord"] -= center
+ data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t))
+ data_dict["coord"] += center
+ if "normal" in data_dict.keys():
+ data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t))
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class RandomScale(object):
+ def __init__(self, scale=None, anisotropic=False):
+ self.scale = scale if scale is not None else [0.95, 1.05]
+ self.anisotropic = anisotropic
+
+ def __call__(self, data_dict):
+ if "coord" in data_dict.keys():
+ scale = np.random.uniform(
+ self.scale[0], self.scale[1], 3 if self.anisotropic else 1
+ )
+ data_dict["coord"] *= scale
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class RandomFlip(object):
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, data_dict):
+ if np.random.rand() < self.p:
+ if "coord" in data_dict.keys():
+ data_dict["coord"][:, 0] = -data_dict["coord"][:, 0]
+ if "normal" in data_dict.keys():
+ data_dict["normal"][:, 0] = -data_dict["normal"][:, 0]
+ if np.random.rand() < self.p:
+ if "coord" in data_dict.keys():
+ data_dict["coord"][:, 1] = -data_dict["coord"][:, 1]
+ if "normal" in data_dict.keys():
+ data_dict["normal"][:, 1] = -data_dict["normal"][:, 1]
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class RandomJitter(object):
+ def __init__(self, sigma=0.01, clip=0.05):
+ assert clip > 0
+ self.sigma = sigma
+ self.clip = clip
+
+ def __call__(self, data_dict):
+ if "coord" in data_dict.keys():
+ jitter = np.clip(
+ self.sigma * np.random.randn(data_dict["coord"].shape[0], 3),
+ -self.clip,
+ self.clip,
+ )
+ data_dict["coord"] += jitter
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class ClipGaussianJitter(object):
+ def __init__(self, scalar=0.02, store_jitter=False):
+ self.scalar = scalar
+ self.mean = np.mean(3)
+ self.cov = np.identity(3)
+ self.quantile = 1.96
+ self.store_jitter = store_jitter
+
+ def __call__(self, data_dict):
+ if "coord" in data_dict.keys():
+ jitter = np.random.multivariate_normal(
+ self.mean, self.cov, data_dict["coord"].shape[0]
+ )
+ jitter = self.scalar * np.clip(jitter / 1.96, -1, 1)
+ data_dict["coord"] += jitter
+ if self.store_jitter:
+ data_dict["jitter"] = jitter
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class ChromaticAutoContrast(object):
+ def __init__(self, p=0.2, blend_factor=None):
+ self.p = p
+ self.blend_factor = blend_factor
+
+ def __call__(self, data_dict):
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
+ lo = np.min(data_dict["color"], 0, keepdims=True)
+ hi = np.max(data_dict["color"], 0, keepdims=True)
+ scale = 255 / (hi - lo)
+ contrast_feat = (data_dict["color"][:, :3] - lo) * scale
+ blend_factor = (
+ np.random.rand() if self.blend_factor is None else self.blend_factor
+ )
+ data_dict["color"][:, :3] = (1 - blend_factor) * data_dict["color"][
+ :, :3
+ ] + blend_factor * contrast_feat
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class ChromaticTranslation(object):
+ def __init__(self, p=0.95, ratio=0.05):
+ self.p = p
+ self.ratio = ratio
+
+ def __call__(self, data_dict):
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
+ tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.ratio
+ data_dict["color"][:, :3] = np.clip(tr + data_dict["color"][:, :3], 0, 255)
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class ChromaticJitter(object):
+ def __init__(self, p=0.95, std=0.005):
+ self.p = p
+ self.std = std
+
+ def __call__(self, data_dict):
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
+ noise = np.random.randn(data_dict["color"].shape[0], 3)
+ noise *= self.std * 255
+ data_dict["color"][:, :3] = np.clip(
+ noise + data_dict["color"][:, :3], 0, 255
+ )
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class RandomColorGrayScale(object):
+ def __init__(self, p):
+ self.p = p
+
+ @staticmethod
+ def rgb_to_grayscale(color, num_output_channels=1):
+ if color.shape[-1] < 3:
+ raise TypeError(
+ "Input color should have at least 3 dimensions, but found {}".format(
+ color.shape[-1]
+ )
+ )
+
+ if num_output_channels not in (1, 3):
+ raise ValueError("num_output_channels should be either 1 or 3")
+
+ r, g, b = color[..., 0], color[..., 1], color[..., 2]
+ gray = (0.2989 * r + 0.587 * g + 0.114 * b).astype(color.dtype)
+ gray = np.expand_dims(gray, axis=-1)
+
+ if num_output_channels == 3:
+ gray = np.broadcast_to(gray, color.shape)
+
+ return gray
+
+ def __call__(self, data_dict):
+ if np.random.rand() < self.p:
+ data_dict["color"] = self.rgb_to_grayscale(data_dict["color"], 3)
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class RandomColorJitter(object):
+ """
+ Random Color Jitter for 3D point cloud (refer torchvision)
+ """
+
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.95):
+ self.brightness = self._check_input(brightness, "brightness")
+ self.contrast = self._check_input(contrast, "contrast")
+ self.saturation = self._check_input(saturation, "saturation")
+ self.hue = self._check_input(
+ hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False
+ )
+ self.p = p
+
+ @staticmethod
+ def _check_input(
+ value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True
+ ):
+ if isinstance(value, numbers.Number):
+ if value < 0:
+ raise ValueError(
+ "If {} is a single number, it must be non negative.".format(name)
+ )
+ value = [center - float(value), center + float(value)]
+ if clip_first_on_zero:
+ value[0] = max(value[0], 0.0)
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
+ raise ValueError("{} values should be between {}".format(name, bound))
+ else:
+ raise TypeError(
+ "{} should be a single number or a list/tuple with length 2.".format(
+ name
+ )
+ )
+
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
+ # or (0., 0.) for hue, do nothing
+ if value[0] == value[1] == center:
+ value = None
+ return value
+
+ @staticmethod
+ def blend(color1, color2, ratio):
+ ratio = float(ratio)
+ bound = 255.0
+ return (
+ (ratio * color1 + (1.0 - ratio) * color2)
+ .clip(0, bound)
+ .astype(color1.dtype)
+ )
+
+ @staticmethod
+ def rgb2hsv(rgb):
+ r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
+ maxc = np.max(rgb, axis=-1)
+ minc = np.min(rgb, axis=-1)
+ eqc = maxc == minc
+ cr = maxc - minc
+ s = cr / (np.ones_like(maxc) * eqc + maxc * (1 - eqc))
+ cr_divisor = np.ones_like(maxc) * eqc + cr * (1 - eqc)
+ rc = (maxc - r) / cr_divisor
+ gc = (maxc - g) / cr_divisor
+ bc = (maxc - b) / cr_divisor
+
+ hr = (maxc == r) * (bc - gc)
+ hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
+ hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
+ h = hr + hg + hb
+ h = (h / 6.0 + 1.0) % 1.0
+ return np.stack((h, s, maxc), axis=-1)
+
+ @staticmethod
+ def hsv2rgb(hsv):
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
+ i = np.floor(h * 6.0)
+ f = (h * 6.0) - i
+ i = i.astype(np.int32)
+
+ p = np.clip((v * (1.0 - s)), 0.0, 1.0)
+ q = np.clip((v * (1.0 - s * f)), 0.0, 1.0)
+ t = np.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
+ i = i % 6
+ mask = np.expand_dims(i, axis=-1) == np.arange(6)
+
+ a1 = np.stack((v, q, p, p, t, v), axis=-1)
+ a2 = np.stack((t, v, v, q, p, p), axis=-1)
+ a3 = np.stack((p, p, t, v, v, q), axis=-1)
+ a4 = np.stack((a1, a2, a3), axis=-1)
+
+ return np.einsum("...na, ...nab -> ...nb", mask.astype(hsv.dtype), a4)
+
+ def adjust_brightness(self, color, brightness_factor):
+ if brightness_factor < 0:
+ raise ValueError(
+ "brightness_factor ({}) is not non-negative.".format(brightness_factor)
+ )
+
+ return self.blend(color, np.zeros_like(color), brightness_factor)
+
+ def adjust_contrast(self, color, contrast_factor):
+ if contrast_factor < 0:
+ raise ValueError(
+ "contrast_factor ({}) is not non-negative.".format(contrast_factor)
+ )
+ mean = np.mean(RandomColorGrayScale.rgb_to_grayscale(color))
+ return self.blend(color, mean, contrast_factor)
+
+ def adjust_saturation(self, color, saturation_factor):
+ if saturation_factor < 0:
+ raise ValueError(
+ "saturation_factor ({}) is not non-negative.".format(saturation_factor)
+ )
+ gray = RandomColorGrayScale.rgb_to_grayscale(color)
+ return self.blend(color, gray, saturation_factor)
+
+ def adjust_hue(self, color, hue_factor):
+ if not (-0.5 <= hue_factor <= 0.5):
+ raise ValueError(
+ "hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor)
+ )
+ orig_dtype = color.dtype
+ hsv = self.rgb2hsv(color / 255.0)
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
+ h = (h + hue_factor) % 1.0
+ hsv = np.stack((h, s, v), axis=-1)
+ color_hue_adj = (self.hsv2rgb(hsv) * 255.0).astype(orig_dtype)
+ return color_hue_adj
+
+ @staticmethod
+ def get_params(brightness, contrast, saturation, hue):
+ fn_idx = torch.randperm(4)
+ b = (
+ None
+ if brightness is None
+ else np.random.uniform(brightness[0], brightness[1])
+ )
+ c = None if contrast is None else np.random.uniform(contrast[0], contrast[1])
+ s = (
+ None
+ if saturation is None
+ else np.random.uniform(saturation[0], saturation[1])
+ )
+ h = None if hue is None else np.random.uniform(hue[0], hue[1])
+ return fn_idx, b, c, s, h
+
+ def __call__(self, data_dict):
+ (
+ fn_idx,
+ brightness_factor,
+ contrast_factor,
+ saturation_factor,
+ hue_factor,
+ ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
+
+ for fn_id in fn_idx:
+ if (
+ fn_id == 0
+ and brightness_factor is not None
+ and np.random.rand() < self.p
+ ):
+ data_dict["color"] = self.adjust_brightness(
+ data_dict["color"], brightness_factor
+ )
+ elif (
+ fn_id == 1 and contrast_factor is not None and np.random.rand() < self.p
+ ):
+ data_dict["color"] = self.adjust_contrast(
+ data_dict["color"], contrast_factor
+ )
+ elif (
+ fn_id == 2
+ and saturation_factor is not None
+ and np.random.rand() < self.p
+ ):
+ data_dict["color"] = self.adjust_saturation(
+ data_dict["color"], saturation_factor
+ )
+ elif fn_id == 3 and hue_factor is not None and np.random.rand() < self.p:
+ data_dict["color"] = self.adjust_hue(data_dict["color"], hue_factor)
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class HueSaturationTranslation(object):
+ @staticmethod
+ def rgb_to_hsv(rgb):
+ # Translated from source of colorsys.rgb_to_hsv
+ # r,g,b should be a numpy arrays with values between 0 and 255
+ # rgb_to_hsv returns an array of floats between 0.0 and 1.0.
+ rgb = rgb.astype("float")
+ hsv = np.zeros_like(rgb)
+ # in case an RGBA array was passed, just copy the A channel
+ hsv[..., 3:] = rgb[..., 3:]
+ r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
+ maxc = np.max(rgb[..., :3], axis=-1)
+ minc = np.min(rgb[..., :3], axis=-1)
+ hsv[..., 2] = maxc
+ mask = maxc != minc
+ hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask]
+ rc = np.zeros_like(r)
+ gc = np.zeros_like(g)
+ bc = np.zeros_like(b)
+ rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask]
+ gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask]
+ bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask]
+ hsv[..., 0] = np.select(
+ [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc
+ )
+ hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0
+ return hsv
+
+ @staticmethod
+ def hsv_to_rgb(hsv):
+ # Translated from source of colorsys.hsv_to_rgb
+ # h,s should be a numpy arrays with values between 0.0 and 1.0
+ # v should be a numpy array with values between 0.0 and 255.0
+ # hsv_to_rgb returns an array of uints between 0 and 255.
+ rgb = np.empty_like(hsv)
+ rgb[..., 3:] = hsv[..., 3:]
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
+ i = (h * 6.0).astype("uint8")
+ f = (h * 6.0) - i
+ p = v * (1.0 - s)
+ q = v * (1.0 - s * f)
+ t = v * (1.0 - s * (1.0 - f))
+ i = i % 6
+ conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5]
+ rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v)
+ rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t)
+ rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p)
+ return rgb.astype("uint8")
+
+ def __init__(self, hue_max=0.5, saturation_max=0.2):
+ self.hue_max = hue_max
+ self.saturation_max = saturation_max
+
+ def __call__(self, data_dict):
+ if "color" in data_dict.keys():
+ # Assume color[:, :3] is rgb
+ hsv = HueSaturationTranslation.rgb_to_hsv(data_dict["color"][:, :3])
+ hue_val = (np.random.rand() - 0.5) * 2 * self.hue_max
+ sat_ratio = 1 + (np.random.rand() - 0.5) * 2 * self.saturation_max
+ hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1)
+ hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1)
+ data_dict["color"][:, :3] = np.clip(
+ HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255
+ )
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class RandomColorDrop(object):
+ def __init__(self, p=0.2, color_augment=0.0):
+ self.p = p
+ self.color_augment = color_augment
+
+ def __call__(self, data_dict):
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
+ data_dict["color"] *= self.color_augment
+ return data_dict
+
+ def __repr__(self):
+ return "RandomColorDrop(color_augment: {}, p: {})".format(
+ self.color_augment, self.p
+ )
+
+
+@TRANSFORMS.register_module()
+class ElasticDistortion(object):
+ def __init__(self, distortion_params=None):
+ self.distortion_params = (
+ [[0.2, 0.4], [0.8, 1.6]] if distortion_params is None else distortion_params
+ )
+
+ @staticmethod
+ def elastic_distortion(coords, granularity, magnitude):
+ """
+ Apply elastic distortion on sparse coordinate space.
+ pointcloud: numpy array of (number of points, at least 3 spatial dims)
+ granularity: size of the noise grid (in same scale[m/cm] as the voxel grid)
+ magnitude: noise multiplier
+ """
+ blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3
+ blury = np.ones((1, 3, 1, 1)).astype("float32") / 3
+ blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3
+ coords_min = coords.min(0)
+
+ # Create Gaussian noise tensor of the size given by granularity.
+ noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3
+ noise = np.random.randn(*noise_dim, 3).astype(np.float32)
+
+ # Smoothing.
+ for _ in range(2):
+ noise = scipy.ndimage.filters.convolve(
+ noise, blurx, mode="constant", cval=0
+ )
+ noise = scipy.ndimage.filters.convolve(
+ noise, blury, mode="constant", cval=0
+ )
+ noise = scipy.ndimage.filters.convolve(
+ noise, blurz, mode="constant", cval=0
+ )
+
+ # Trilinear interpolate noise filters for each spatial dimensions.
+ ax = [
+ np.linspace(d_min, d_max, d)
+ for d_min, d_max, d in zip(
+ coords_min - granularity,
+ coords_min + granularity * (noise_dim - 2),
+ noise_dim,
+ )
+ ]
+ interp = scipy.interpolate.RegularGridInterpolator(
+ ax, noise, bounds_error=False, fill_value=0
+ )
+ coords += interp(coords) * magnitude
+ return coords
+
+ def __call__(self, data_dict):
+ if "coord" in data_dict.keys() and self.distortion_params is not None:
+ if random.random() < 0.95:
+ for granularity, magnitude in self.distortion_params:
+ data_dict["coord"] = self.elastic_distortion(
+ data_dict["coord"], granularity, magnitude
+ )
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class GridSample(object):
+ def __init__(
+ self,
+ grid_size=0.05,
+ hash_type="fnv",
+ mode="train",
+ keys=("coord", "color", "normal", "segment"),
+ return_inverse=False,
+ return_grid_coord=False,
+ return_min_coord=False,
+ return_displacement=False,
+ project_displacement=False,
+ return_idx=False,
+ ):
+ self.grid_size = grid_size
+ self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec
+ assert mode in ["train", "test"]
+ self.mode = mode
+ self.keys = keys
+ self.return_inverse = return_inverse
+ self.return_grid_coord = return_grid_coord
+ self.return_min_coord = return_min_coord
+ self.return_displacement = return_displacement
+ self.project_displacement = project_displacement
+ self.retrun_idx = return_idx
+
+ def __call__(self, data_dict):
+ assert "coord" in data_dict.keys()
+ scaled_coord = data_dict["coord"] / np.array(self.grid_size)
+ grid_coord = np.floor(scaled_coord).astype(int)
+ min_coord = grid_coord.min(0)
+ grid_coord -= min_coord
+ scaled_coord -= min_coord
+ min_coord = min_coord * np.array(self.grid_size)
+ key = self.hash(grid_coord)
+ idx_sort = np.argsort(key)
+ key_sort = key[idx_sort]
+ _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True)
+ if self.mode == "train": # train mode
+ idx_select = (
+ np.cumsum(np.insert(count, 0, 0)[0:-1])
+ + np.random.randint(0, count.max(), count.size) % count
+ )
+ idx_unique = idx_sort[idx_select]
+ if "sampled_index" in data_dict:
+ # for ScanNet data efficient, we need to make sure labeled point is sampled.
+ idx_unique = np.unique(
+ np.append(idx_unique, data_dict["sampled_index"])
+ )
+ mask = np.zeros_like(data_dict["segment"]).astype(bool)
+ mask[data_dict["sampled_index"]] = True
+ data_dict["sampled_index"] = np.where(mask[idx_unique])[0]
+ if self.return_inverse:
+ data_dict["inverse"] = np.zeros_like(inverse)
+ data_dict["inverse"][idx_sort] = inverse
+ if self.return_grid_coord:
+ data_dict["grid_coord"] = grid_coord[idx_unique]
+ if self.return_min_coord:
+ data_dict["min_coord"] = min_coord.reshape([1, 3])
+ if self.return_displacement:
+ displacement = (
+ scaled_coord - grid_coord - 0.5
+ ) # [0, 1] -> [-0.5, 0.5] displacement to center
+ if self.project_displacement:
+ displacement = np.sum(
+ displacement * data_dict["normal"], axis=-1, keepdims=True
+ )
+ data_dict["displacement"] = displacement[idx_unique]
+ for key in self.keys:
+ data_dict[key] = data_dict[key][idx_unique]
+ if self.retrun_idx:
+ data_dict["idx_unique"] = idx_unique
+ return data_dict
+
+ elif self.mode == "test": # test mode
+ data_part_list = []
+ for i in range(count.max()):
+ idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count
+ idx_part = idx_sort[idx_select]
+ data_part = dict(index=idx_part)
+ if self.return_inverse:
+ data_dict["inverse"] = np.zeros_like(inverse)
+ data_dict["inverse"][idx_sort] = inverse
+ if self.return_grid_coord:
+ data_part["grid_coord"] = grid_coord[idx_part]
+ if self.return_min_coord:
+ data_part["min_coord"] = min_coord.reshape([1, 3])
+ if self.return_displacement:
+ displacement = (
+ scaled_coord - grid_coord - 0.5
+ ) # [0, 1] -> [-0.5, 0.5] displacement to center
+ if self.project_displacement:
+ displacement = np.sum(
+ displacement * data_dict["normal"], axis=-1, keepdims=True
+ )
+ data_dict["displacement"] = displacement[idx_part]
+ for key in data_dict.keys():
+ if key in self.keys:
+ data_part[key] = data_dict[key][idx_part]
+ else:
+ data_part[key] = data_dict[key]
+ data_part_list.append(data_part)
+ return data_part_list
+ else:
+ raise NotImplementedError
+
+ @staticmethod
+ def ravel_hash_vec(arr):
+ """
+ Ravel the coordinates after subtracting the min coordinates.
+ """
+ assert arr.ndim == 2
+ arr = arr.copy()
+ arr -= arr.min(0)
+ arr = arr.astype(np.uint64, copy=False)
+ arr_max = arr.max(0).astype(np.uint64) + 1
+
+ keys = np.zeros(arr.shape[0], dtype=np.uint64)
+ # Fortran style indexing
+ for j in range(arr.shape[1] - 1):
+ keys += arr[:, j]
+ keys *= arr_max[j + 1]
+ keys += arr[:, -1]
+ return keys
+
+ @staticmethod
+ def fnv_hash_vec(arr):
+ """
+ FNV64-1A
+ """
+ assert arr.ndim == 2
+ # Floor first for negative coordinates
+ arr = arr.copy()
+ arr = arr.astype(np.uint64, copy=False)
+ hashed_arr = np.uint64(14695981039346656037) * np.ones(
+ arr.shape[0], dtype=np.uint64
+ )
+ for j in range(arr.shape[1]):
+ hashed_arr *= np.uint64(1099511628211)
+ hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
+ return hashed_arr
+
+
+@TRANSFORMS.register_module()
+class SphereCrop(object):
+ def __init__(self, point_max=80000, sample_rate=None, mode="random"):
+ self.point_max = point_max
+ self.sample_rate = sample_rate
+ assert mode in ["random", "center", "all"]
+ self.mode = mode
+
+ def __call__(self, data_dict):
+ point_max = (
+ int(self.sample_rate * data_dict["coord"].shape[0])
+ if self.sample_rate is not None
+ else self.point_max
+ )
+
+ assert "coord" in data_dict.keys()
+ if self.mode == "all":
+ # TODO: Optimize
+ if "index" not in data_dict.keys():
+ data_dict["index"] = np.arange(data_dict["coord"].shape[0])
+ data_part_list = []
+ # coord_list, color_list, dist2_list, idx_list, offset_list = [], [], [], [], []
+ if data_dict["coord"].shape[0] > point_max:
+ coord_p, idx_uni = np.random.rand(
+ data_dict["coord"].shape[0]
+ ) * 1e-3, np.array([])
+ while idx_uni.size != data_dict["index"].shape[0]:
+ init_idx = np.argmin(coord_p)
+ dist2 = np.sum(
+ np.power(data_dict["coord"] - data_dict["coord"][init_idx], 2),
+ 1,
+ )
+ idx_crop = np.argsort(dist2)[:point_max]
+
+ data_crop_dict = dict()
+ if "coord" in data_dict.keys():
+ data_crop_dict["coord"] = data_dict["coord"][idx_crop]
+ if "grid_coord" in data_dict.keys():
+ data_crop_dict["grid_coord"] = data_dict["grid_coord"][idx_crop]
+ if "normal" in data_dict.keys():
+ data_crop_dict["normal"] = data_dict["normal"][idx_crop]
+ if "color" in data_dict.keys():
+ data_crop_dict["color"] = data_dict["color"][idx_crop]
+ if "displacement" in data_dict.keys():
+ data_crop_dict["displacement"] = data_dict["displacement"][
+ idx_crop
+ ]
+ if "strength" in data_dict.keys():
+ data_crop_dict["strength"] = data_dict["strength"][idx_crop]
+ data_crop_dict["weight"] = dist2[idx_crop]
+ data_crop_dict["index"] = data_dict["index"][idx_crop]
+ data_part_list.append(data_crop_dict)
+
+ delta = np.square(
+ 1 - data_crop_dict["weight"] / np.max(data_crop_dict["weight"])
+ )
+ coord_p[idx_crop] += delta
+ idx_uni = np.unique(
+ np.concatenate((idx_uni, data_crop_dict["index"]))
+ )
+ else:
+ data_crop_dict = data_dict.copy()
+ data_crop_dict["weight"] = np.zeros(data_dict["coord"].shape[0])
+ data_crop_dict["index"] = data_dict["index"]
+ data_part_list.append(data_crop_dict)
+ return data_part_list
+ # mode is "random" or "center"
+ elif data_dict["coord"].shape[0] > point_max:
+ if self.mode == "random":
+ center = data_dict["coord"][
+ np.random.randint(data_dict["coord"].shape[0])
+ ]
+ elif self.mode == "center":
+ center = data_dict["coord"][data_dict["coord"].shape[0] // 2]
+ else:
+ raise NotImplementedError
+ idx_crop = np.argsort(np.sum(np.square(data_dict["coord"] - center), 1))[
+ :point_max
+ ]
+ if "coord" in data_dict.keys():
+ data_dict["coord"] = data_dict["coord"][idx_crop]
+ if "origin_coord" in data_dict.keys():
+ data_dict["origin_coord"] = data_dict["origin_coord"][idx_crop]
+ if "grid_coord" in data_dict.keys():
+ data_dict["grid_coord"] = data_dict["grid_coord"][idx_crop]
+ if "color" in data_dict.keys():
+ data_dict["color"] = data_dict["color"][idx_crop]
+ if "normal" in data_dict.keys():
+ data_dict["normal"] = data_dict["normal"][idx_crop]
+ if "segment" in data_dict.keys():
+ data_dict["segment"] = data_dict["segment"][idx_crop]
+ if "instance" in data_dict.keys():
+ data_dict["instance"] = data_dict["instance"][idx_crop]
+ if "displacement" in data_dict.keys():
+ data_dict["displacement"] = data_dict["displacement"][idx_crop]
+ if "strength" in data_dict.keys():
+ data_dict["strength"] = data_dict["strength"][idx_crop]
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class ShufflePoint(object):
+ def __call__(self, data_dict):
+ assert "coord" in data_dict.keys()
+ shuffle_index = np.arange(data_dict["coord"].shape[0])
+ np.random.shuffle(shuffle_index)
+ if "coord" in data_dict.keys():
+ data_dict["coord"] = data_dict["coord"][shuffle_index]
+ if "grid_coord" in data_dict.keys():
+ data_dict["grid_coord"] = data_dict["grid_coord"][shuffle_index]
+ if "displacement" in data_dict.keys():
+ data_dict["displacement"] = data_dict["displacement"][shuffle_index]
+ if "color" in data_dict.keys():
+ data_dict["color"] = data_dict["color"][shuffle_index]
+ if "normal" in data_dict.keys():
+ data_dict["normal"] = data_dict["normal"][shuffle_index]
+ if "segment" in data_dict.keys():
+ data_dict["segment"] = data_dict["segment"][shuffle_index]
+ if "instance" in data_dict.keys():
+ data_dict["instance"] = data_dict["instance"][shuffle_index]
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class CropBoundary(object):
+ def __call__(self, data_dict):
+ assert "segment" in data_dict
+ segment = data_dict["segment"].flatten()
+ mask = (segment != 0) * (segment != 1)
+ if "coord" in data_dict.keys():
+ data_dict["coord"] = data_dict["coord"][mask]
+ if "grid_coord" in data_dict.keys():
+ data_dict["grid_coord"] = data_dict["grid_coord"][mask]
+ if "color" in data_dict.keys():
+ data_dict["color"] = data_dict["color"][mask]
+ if "normal" in data_dict.keys():
+ data_dict["normal"] = data_dict["normal"][mask]
+ if "segment" in data_dict.keys():
+ data_dict["segment"] = data_dict["segment"][mask]
+ if "instance" in data_dict.keys():
+ data_dict["instance"] = data_dict["instance"][mask]
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class ContrastiveViewsGenerator(object):
+ def __init__(
+ self,
+ view_keys=("coord", "color", "normal", "origin_coord"),
+ view_trans_cfg=None,
+ ):
+ self.view_keys = view_keys
+ self.view_trans = Compose(view_trans_cfg)
+
+ def __call__(self, data_dict):
+ view1_dict = dict()
+ view2_dict = dict()
+ for key in self.view_keys:
+ view1_dict[key] = data_dict[key].copy()
+ view2_dict[key] = data_dict[key].copy()
+ view1_dict = self.view_trans(view1_dict)
+ view2_dict = self.view_trans(view2_dict)
+ for key, value in view1_dict.items():
+ data_dict["view1_" + key] = value
+ for key, value in view2_dict.items():
+ data_dict["view2_" + key] = value
+ return data_dict
+
+
+@TRANSFORMS.register_module()
+class InstanceParser(object):
+ def __init__(self, segment_ignore_index=(-1, 0, 1), instance_ignore_index=-1):
+ self.segment_ignore_index = segment_ignore_index
+ self.instance_ignore_index = instance_ignore_index
+
+ def __call__(self, data_dict):
+ coord = data_dict["coord"]
+ segment = data_dict["segment"]
+ instance = data_dict["instance"]
+ mask = ~np.in1d(segment, self.segment_ignore_index)
+ # mapping ignored instance to ignore index
+ instance[~mask] = self.instance_ignore_index
+ # reorder left instance
+ unique, inverse = np.unique(instance[mask], return_inverse=True)
+ instance_num = len(unique)
+ instance[mask] = inverse
+ # init instance information
+ centroid = np.ones((coord.shape[0], 3)) * self.instance_ignore_index
+ bbox = np.ones((instance_num, 8)) * self.instance_ignore_index
+ vacancy = [
+ index for index in self.segment_ignore_index if index >= 0
+ ] # vacate class index
+
+ for instance_id in range(instance_num):
+ mask_ = instance == instance_id
+ coord_ = coord[mask_]
+ bbox_min = coord_.min(0)
+ bbox_max = coord_.max(0)
+ bbox_centroid = coord_.mean(0)
+ bbox_center = (bbox_max + bbox_min) / 2
+ bbox_size = bbox_max - bbox_min
+ bbox_theta = np.zeros(1, dtype=coord_.dtype)
+ bbox_class = np.array([segment[mask_][0]], dtype=coord_.dtype)
+ # shift class index to fill vacate class index caused by segment ignore index
+ bbox_class -= np.greater(bbox_class, vacancy).sum()
+
+ centroid[mask_] = bbox_centroid
+ bbox[instance_id] = np.concatenate(
+ [bbox_center, bbox_size, bbox_theta, bbox_class]
+ ) # 3 + 3 + 1 + 1 = 8
+ data_dict["instance"] = instance
+ data_dict["instance_centroid"] = centroid
+ data_dict["bbox"] = bbox
+ return data_dict
+
+
+class Compose(object):
+ def __init__(self, cfg=None):
+ self.cfg = cfg if cfg is not None else []
+ self.transforms = []
+ for t_cfg in self.cfg:
+ self.transforms.append(TRANSFORMS.build(t_cfg))
+
+ def __call__(self, data_dict):
+ for t in self.transforms:
+ data_dict = t(data_dict)
+ return data_dict
diff --git a/UniRig/src/model/pointcept/datasets/utils.py b/UniRig/src/model/pointcept/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d24e17bf619306adf14f6746530829502b8e546
--- /dev/null
+++ b/UniRig/src/model/pointcept/datasets/utils.py
@@ -0,0 +1,58 @@
+import random
+from collections.abc import Mapping, Sequence
+import numpy as np
+import torch
+# import trimesh
+from torch.utils.data.dataloader import default_collate
+
+
+def collate_fn_dummy(batch):
+ return batch
+
+
+def collate_fn(batch):
+ """
+ collate function for point cloud which support dict and list,
+ 'coord' is necessary to determine 'offset'
+ """
+ if not isinstance(batch, Sequence):
+ raise TypeError(f"{batch.dtype} is not supported.")
+
+ if isinstance(batch[0], torch.Tensor):
+ return torch.cat(list(batch))
+ elif isinstance(batch[0], str):
+ # str is also a kind of Sequence, judgement should before Sequence
+ return list(batch)
+ elif isinstance(batch[0], Sequence):
+ for data in batch:
+ data.append(torch.tensor([data[0].shape[0]]))
+ batch = [collate_fn(samples) for samples in zip(*batch)]
+ batch[-1] = torch.cumsum(batch[-1], dim=0).int()
+ return batch
+ elif isinstance(batch[0], Mapping):
+ batch = {key: collate_fn([d[key] for d in batch]) for key in batch[0]}
+ for key in batch.keys():
+ if "offset" in key:
+ batch[key] = torch.cumsum(batch[key], dim=0)
+ return batch
+ else:
+ return default_collate(batch)
+
+
+def point_collate_fn(batch, mix_prob=0):
+ assert isinstance(
+ batch[0], Mapping
+ ) # currently, only support input_dict, rather than input_list
+ batch = collate_fn(batch)
+ if "offset" in batch.keys():
+ # Mix3d (https://arxiv.org/pdf/2110.02210.pdf)
+ if random.random() < mix_prob:
+ batch["offset"] = torch.cat(
+ [batch["offset"][1:-1:2], batch["offset"][-1].unsqueeze(0)], dim=0
+ )
+ return batch
+
+
+def gaussian_kernel(dist2: np.array, a: float = 1, c: float = 5):
+ return a * np.exp(-dist2 / (2 * c**2))
+
diff --git a/UniRig/src/model/pointcept/engines/__init__.py b/UniRig/src/model/pointcept/engines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/UniRig/src/model/pointcept/engines/defaults.py b/UniRig/src/model/pointcept/engines/defaults.py
new file mode 100644
index 0000000000000000000000000000000000000000..832bbcac47851c7ee033b63f3374022dea5e49dc
--- /dev/null
+++ b/UniRig/src/model/pointcept/engines/defaults.py
@@ -0,0 +1,143 @@
+import os
+import sys
+import argparse
+import multiprocessing as mp
+from torch.nn.parallel import DistributedDataParallel
+
+
+import pointcept.utils.comm as comm
+from pointcept.utils.env import get_random_seed, set_seed
+from pointcept.utils.config import Config, DictAction
+
+
+def create_ddp_model(model, *, fp16_compression=False, **kwargs):
+ """
+ Create a DistributedDataParallel model if there are >1 processes.
+ Args:
+ model: a torch.nn.Module
+ fp16_compression: add fp16 compression hooks to the ddp object.
+ See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
+ kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
+ """
+ if comm.get_world_size() == 1:
+ return model
+ # kwargs['find_unused_parameters'] = True
+ if "device_ids" not in kwargs:
+ kwargs["device_ids"] = [comm.get_local_rank()]
+ if "output_device" not in kwargs:
+ kwargs["output_device"] = [comm.get_local_rank()]
+ ddp = DistributedDataParallel(model, **kwargs)
+ if fp16_compression:
+ from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
+
+ ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
+ return ddp
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ """Worker init func for dataloader.
+
+ The seed of each worker equals to num_worker * rank + worker_id + user_seed
+
+ Args:
+ worker_id (int): Worker id.
+ num_workers (int): Number of workers.
+ rank (int): The rank of current process.
+ seed (int): The random seed to use.
+ """
+
+ worker_seed = num_workers * rank + worker_id + seed
+ set_seed(worker_seed)
+
+
+def default_argument_parser(epilog=None):
+ parser = argparse.ArgumentParser(
+ epilog=epilog
+ or f"""
+ Examples:
+ Run on single machine:
+ $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
+ Change some config options:
+ $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
+ Run on multiple machines:
+ (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags]
+ (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags]
+ """,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ parser.add_argument(
+ "--config-file", default="", metavar="FILE", help="path to config file"
+ )
+ parser.add_argument(
+ "--num-gpus", type=int, default=1, help="number of gpus *per machine*"
+ )
+ parser.add_argument(
+ "--num-machines", type=int, default=1, help="total number of machines"
+ )
+ parser.add_argument(
+ "--machine-rank",
+ type=int,
+ default=0,
+ help="the rank of this machine (unique per machine)",
+ )
+ # PyTorch still may leave orphan processes in multi-gpu training.
+ # Therefore we use a deterministic way to obtain port,
+ # so that users are aware of orphan processes by seeing the port occupied.
+ # port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
+ parser.add_argument(
+ "--dist-url",
+ # default="tcp://127.0.0.1:{}".format(port),
+ default="auto",
+ help="initialization URL for pytorch distributed backend. See "
+ "https://pytorch.org/docs/stable/distributed.html for details.",
+ )
+ parser.add_argument(
+ "--options", nargs="+", action=DictAction, help="custom options"
+ )
+ return parser
+
+
+def default_config_parser(file_path, options):
+ # config name protocol: dataset_name/model_name-exp_name
+ if os.path.isfile(file_path):
+ cfg = Config.fromfile(file_path)
+ else:
+ sep = file_path.find("-")
+ cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :]))
+
+ if options is not None:
+ cfg.merge_from_dict(options)
+
+ if cfg.seed is None:
+ cfg.seed = get_random_seed()
+
+ cfg.data.train.loop = cfg.epoch // cfg.eval_epoch
+
+ os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True)
+ if not cfg.resume:
+ cfg.dump(os.path.join(cfg.save_path, "config.py"))
+ return cfg
+
+
+def default_setup(cfg):
+ # scalar by world size
+ world_size = comm.get_world_size()
+ cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count()
+ cfg.num_worker_per_gpu = cfg.num_worker // world_size
+ assert cfg.batch_size % world_size == 0
+ assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0
+ assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0
+ cfg.batch_size_per_gpu = cfg.batch_size // world_size
+ cfg.batch_size_val_per_gpu = (
+ cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1
+ )
+ cfg.batch_size_test_per_gpu = (
+ cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1
+ )
+ # update data loop
+ assert cfg.epoch % cfg.eval_epoch == 0
+ # settle random seed
+ rank = comm.get_rank()
+ seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank
+ set_seed(seed)
+ return cfg
diff --git a/UniRig/src/model/pointcept/engines/eval.py b/UniRig/src/model/pointcept/engines/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5ebea23c91df213d90d0113987bd13597f08230
--- /dev/null
+++ b/UniRig/src/model/pointcept/engines/eval.py
@@ -0,0 +1,407 @@
+import os
+import sys
+import weakref
+import torch
+torch.multiprocessing.set_start_method('spawn')
+import torch.nn as nn
+import torch.utils.data
+from functools import partial
+
+if sys.version_info >= (3, 10):
+ from collections.abc import Iterator
+else:
+ from collections import Iterator
+from tensorboardX import SummaryWriter
+
+from .defaults import create_ddp_model, worker_init_fn
+from .hooks import HookBase, build_hooks
+import pointcept.utils.comm as comm
+from pointcept.datasets import build_dataset, point_collate_fn, collate_fn
+from pointcept.models import build_model
+from pointcept.utils.logger import get_root_logger
+from pointcept.utils.optimizer import build_optimizer
+from pointcept.utils.scheduler import build_scheduler
+from pointcept.utils.events import EventStorage
+from pointcept.utils.registry import Registry
+
+from sklearn.preprocessing import QuantileTransformer
+from pointcept.utils.timer import Timer
+
+TRAINERS = Registry("trainers")
+from cuml.cluster.hdbscan import HDBSCAN
+# from sklearn.cluster import HDBSCAN
+import open3d as o3d
+import matplotlib.colors as mcolors
+import numpy as np
+from collections import OrderedDict
+import trimesh
+import pointops
+
+class TrainerBase:
+ def __init__(self) -> None:
+ self.hooks = []
+ self.epoch = 0
+ self.start_epoch = 0
+ self.max_epoch = 0
+ self.max_iter = 0
+ self.comm_info = dict()
+ self.data_iterator: Iterator = enumerate([])
+ self.storage: EventStorage
+ self.writer: SummaryWriter
+ self._iter_timer = Timer()
+
+ def register_hooks(self, hooks) -> None:
+ hooks = build_hooks(hooks)
+ for h in hooks:
+ assert isinstance(h, HookBase)
+ # To avoid circular reference, hooks and trainer cannot own each other.
+ # This normally does not matter, but will cause memory leak if the
+ # involved objects contain __del__:
+ # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
+ h.trainer = weakref.proxy(self)
+ self.hooks.extend(hooks)
+
+ def train(self):
+ with EventStorage() as self.storage:
+ # => before train
+ self.before_train()
+ for self.epoch in range(self.start_epoch, self.max_epoch):
+ # => before epoch
+ self.before_epoch()
+ # => run_epoch
+ for (
+ self.comm_info["iter"],
+ self.comm_info["input_dict"],
+ ) in self.data_iterator:
+ # => before_step
+ self.before_step()
+ # => run_step
+ self.run_step()
+ # => after_step
+ self.after_step()
+ # => after epoch
+ self.after_epoch()
+ # => after train
+ self.after_train()
+
+ def before_train(self):
+ for h in self.hooks:
+ h.before_train()
+
+ def before_epoch(self):
+ for h in self.hooks:
+ h.before_epoch()
+
+ def before_step(self):
+ for h in self.hooks:
+ h.before_step()
+
+ def run_step(self):
+ raise NotImplementedError
+
+ def after_step(self):
+ for h in self.hooks:
+ h.after_step()
+
+ def after_epoch(self):
+ for h in self.hooks:
+ h.after_epoch()
+ self.storage.reset_histories()
+
+ def after_train(self):
+ # Sync GPU before running train hooks
+ comm.synchronize()
+ for h in self.hooks:
+ h.after_train()
+ if comm.is_main_process():
+ self.writer.close()
+
+
+@TRAINERS.register_module("DefaultTrainer")
+class Trainer(TrainerBase):
+ def __init__(self, cfg):
+ super(Trainer, self).__init__()
+ self.epoch = 0
+ self.start_epoch = 0
+ self.max_epoch = cfg.eval_epoch
+ self.best_metric_value = -torch.inf
+ self.logger = get_root_logger(
+ log_file=os.path.join(cfg.save_path, "train.log"),
+ # file_mode="a" if cfg.resume else "w",
+ file_mode="a",
+ )
+ self.logger.info("=> Loading config ...")
+ self.cfg = cfg
+ self.logger.info(f"Save path: {cfg.save_path}")
+ self.logger.info(f"Config:\n{cfg.pretty_text}")
+ self.logger.info("=> Building model ...")
+ self.model = self.build_model()
+ self.logger.info("=> Building val dataset & dataloader ...")
+ self.train_loader = self.build_train_loader()
+ self.logger.info("=> Building hooks ...")
+ self.register_hooks(self.cfg.hooks)
+
+ # !!!
+ self.val_scales_list = self.cfg.val_scales_list
+ self.mesh_voting = self.cfg.mesh_voting
+ self.backbone_weight_path = self.cfg.backbone_weight_path
+
+
+ def eval(self):
+ # val_data = build_dataset(self.cfg.data.val)
+ self.logger.info("=> Loading checkpoint & weight ...")
+ if self.backbone_weight_path != None:
+ self.logger.info("=> Loading checkpoint of pretrained backbone")
+ if os.path.isfile(self.backbone_weight_path):
+ checkpoint = torch.load(
+ self.backbone_weight_path,
+ map_location=lambda storage, loc: storage.cuda(),
+ )
+ weight = OrderedDict()
+ for key, value in checkpoint["state_dict"].items():
+ if not key.startswith("module."):
+ if comm.get_world_size() > 1:
+ key = "module." + key # xxx.xxx -> module.xxx.xxx
+ # Now all keys contain "module." no matter DDP or not.
+ # if self.keywords in key:
+ # key = key.replace(self.keywords, self.replacement)
+ if comm.get_world_size() == 1:
+ key = key[7:] # module.xxx.xxx -> xxx.xxx
+ # if key.startswith("backbone."):
+ # key = key[9:] # backbone.xxx.xxx -> xxx.xxx
+ key = "backbone." + key # xxx.xxx -> backbone.xxx.xxx
+ weight[key] = value
+ load_state_info = self.model.load_state_dict(weight, strict=False)
+ self.logger.info(f"Missing keys: {load_state_info[0]}")
+ else:
+ self.logger.info(f"No weight found at: {self.backbone_weight_path}")
+
+ if self.cfg.weight and os.path.isfile(self.cfg.weight):
+ checkpoint = torch.load(
+ self.cfg.weight,
+ map_location=lambda storage, loc: storage.cuda(),
+ )
+ load_state_info = self.model.load_state_dict(checkpoint["state_dict"], strict=False)
+ self.logger.info(f"Missing keys: {load_state_info[0]}")
+ scale_statistics = checkpoint["state_dict"]["scale_statistics"]
+ self.model.quantile_transformer = self._get_quantile_func(scale_statistics)
+ else:
+ self.logger.info(f"No weight found at: {self.cfg.weight}")
+ self.cfg.weight = "last"
+
+ self.model.eval()
+ save_root = os.path.join(self.cfg.save_path, "vis_pcd", os.path.splitext(os.path.basename(self.cfg.weight))[0])
+ os.makedirs(save_root, exist_ok=True)
+ group_save_root = os.path.join(self.cfg.save_path, "results", os.path.splitext(os.path.basename(self.cfg.weight))[0])
+ os.makedirs(group_save_root, exist_ok=True)
+
+ hex_colors = list(mcolors.CSS4_COLORS.values())
+ rgb_colors = np.array([mcolors.to_rgb(color) for color in hex_colors if color not in ['#000000', '#FFFFFF']])
+ def relative_luminance(color):
+ return 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2]
+ rgb_colors = [color for color in rgb_colors if (relative_luminance(color) > 0.4 and relative_luminance(color) < 0.8)]
+ np.random.shuffle(rgb_colors)
+ input_dict = self.train_loader.val_data()
+
+ pcd_inverse = self.train_loader.pcd_inverse
+ if self.mesh_voting:
+ mesh = trimesh.load(self.train_loader.mesh_path)
+ if isinstance(mesh, trimesh.Scene):
+ mesh = mesh.dump(concatenate=True)
+ mesh.visual = trimesh.visual.ColorVisuals(mesh=mesh)
+
+ for scale in self.val_scales_list:
+ input_dict["scale"] = scale
+ instance_feat = self.model(input_dict).cpu().detach().numpy()
+
+ clusterer = HDBSCAN(
+ cluster_selection_epsilon=0.1,
+ min_samples=30,
+ min_cluster_size=30,
+ allow_single_cluster=False,
+ ).fit(instance_feat)
+
+ labels = clusterer.labels_
+ invalid_label_mask = labels == -1
+ if invalid_label_mask.sum() > 0:
+ if invalid_label_mask.sum() == len(invalid_label_mask):
+ labels = np.zeros_like(labels)
+ else:
+ coord = input_dict["obj"]["coord"].cuda().contiguous().float()
+ valid_coord = coord[~invalid_label_mask]
+ valid_offset = torch.tensor(valid_coord.shape[0]).cuda()
+ invalid_coord = coord[invalid_label_mask]
+ invalid_offset = torch.tensor(invalid_coord.shape[0]).cuda()
+ indices, distances = pointops.knn_query(1, valid_coord, valid_offset, invalid_coord, invalid_offset)
+ indices = indices[:, 0].cpu().numpy()
+ labels[invalid_label_mask] = labels[~invalid_label_mask][indices]
+
+
+ # np.save(os.path.join(group_save_root, f"{str(scale)}.npy"), labels)
+ save_path = os.path.join(save_root, f"{str(scale)}.ply")
+ coord = input_dict["obj"]["coord"].cpu().numpy()
+ random_color = []
+ for i in range(max(labels) + 1):
+ random_color.append(rgb_colors[i % len(rgb_colors)])
+ random_color.append(np.array([0, 0, 0]))
+ color = [random_color[i] for i in labels]
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(coord)
+ pcd.colors = o3d.utility.Vector3dVector(color)
+ o3d.io.write_point_cloud(save_path, pcd)
+
+ labels = labels[pcd_inverse]
+
+ # print(len(clusterer.labels_))
+ self.logger.info(f"scale_{scale} has {max(labels)+1} groups")
+ if self.mesh_voting:
+ face_index = self.train_loader.face_index
+ face_index = face_index[pcd_inverse]
+
+ # Compute votes for each face using NumPy's bincount function
+ # labels = clusterer.labels_
+ num_faces = len(mesh.faces)
+ num_labels = max(labels) + 1
+ votes = np.zeros((num_faces, num_labels), dtype=np.int32)
+ np.add.at(votes, (face_index, labels), 1)
+
+ # Find the label with most votes for each face using NumPy's argmax function
+ max_votes_labels = np.argmax(votes, axis=1)
+ # Set the label to -1 for faces that have no corresponding points
+ max_votes_labels[np.all(votes == 0, axis=1)] = -1
+
+ valid_mask = max_votes_labels != -1
+ face_centroids = mesh.triangles_center
+ coord = torch.tensor(face_centroids).cuda().contiguous().float()
+ valid_coord = coord[valid_mask]
+ valid_offset = torch.tensor(valid_coord.shape[0]).cuda()
+ invalid_coord = coord[~valid_mask]
+ invalid_offset = torch.tensor(invalid_coord.shape[0]).cuda()
+ indices, distances = pointops.knn_query(1, valid_coord, valid_offset, invalid_coord, invalid_offset)
+ # # the first column is the point itself
+ # indices = indices[:, 1].cpu().numpy()
+ indices = indices[:, 0].cpu().numpy()
+ mesh_group = max_votes_labels.copy()
+ mesh_group[~valid_mask] = mesh_group[valid_mask][indices]
+
+ np.save(os.path.join(group_save_root, f"mesh_{str(scale)}.npy"), mesh_group)
+
+ # Assign color to each face based on the label with most votes
+ for face, label in enumerate(mesh_group):
+ color = (random_color[label] * 255).astype(np.uint8)
+ color_with_alpha = np.append(color, 255) # Add alpha value
+ mesh.visual.face_colors[face] = color_with_alpha
+
+ # Save the new mesh
+ mesh_save_path = os.path.join(save_root, f"mesh_{str(scale)}.ply")
+ mesh.export(mesh_save_path)
+
+
+ def _get_quantile_func(self, scales: torch.Tensor, distribution="normal"):
+ """
+ Use 3D scale statistics to normalize scales -- use quantile transformer.
+ """
+ scales = scales.flatten()
+ max_grouping_scale = 2
+ scales = scales[(scales > 0) & (scales < max_grouping_scale)]
+
+ scales = scales.detach().cpu().numpy()
+
+ # Calculate quantile transformer
+ quantile_transformer = QuantileTransformer(output_distribution=distribution)
+ quantile_transformer = quantile_transformer.fit(scales.reshape(-1, 1))
+
+ def quantile_transformer_func(scales):
+ # This function acts as a wrapper for QuantileTransformer.
+ # QuantileTransformer expects a numpy array, while we have a torch tensor.
+ return torch.Tensor(
+ quantile_transformer.transform(scales.cpu().numpy())
+ ).to(scales.device)
+
+ return quantile_transformer_func
+
+ def run_step(self):
+ input_dict = self.comm_info["input_dict"]
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp):
+ output_dict = self.model(input_dict)
+ loss = output_dict["loss"]
+ self.optimizer.zero_grad()
+ if self.cfg.enable_amp:
+ self.scaler.scale(loss).backward()
+ self.scaler.step(self.optimizer)
+
+ # When enable amp, optimizer.step call are skipped if the loss scaling factor is too large.
+ # Fix torch warning scheduler step before optimizer step.
+ scaler = self.scaler.get_scale()
+ self.scaler.update()
+ if scaler <= self.scaler.get_scale():
+ self.scheduler.step()
+ else:
+ loss.backward()
+ self.optimizer.step()
+ self.scheduler.step()
+ if self.cfg.empty_cache:
+ torch.cuda.empty_cache()
+ self.comm_info["model_output_dict"] = output_dict
+
+ def build_model(self):
+ model = build_model(self.cfg.model)
+ if self.cfg.sync_bn:
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ # logger.info(f"Model: \n{self.model}")
+ self.logger.info(f"Num params: {n_parameters}")
+ model = create_ddp_model(
+ model.cuda(),
+ broadcast_buffers=False,
+ find_unused_parameters=self.cfg.find_unused_parameters,
+ )
+ return model
+
+ def build_writer(self):
+ writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None
+ self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}")
+ return writer
+
+ def build_train_loader(self):
+ self.cfg.data.train.split = "val"
+ self.cfg.data.train.oid = self.cfg.oid
+ self.cfg.data.train.label = self.cfg.label
+ train_data = build_dataset(self.cfg.data.train)
+ return train_data
+
+ def build_val_loader(self):
+ val_loader = None
+ if self.cfg.evaluate:
+ val_data = build_dataset(self.cfg.data.val)
+ if comm.get_world_size() > 1:
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
+ else:
+ val_sampler = None
+ val_loader = torch.utils.data.DataLoader(
+ val_data,
+ batch_size=self.cfg.batch_size_val_per_gpu,
+ shuffle=False,
+ num_workers=self.cfg.num_worker_per_gpu,
+ pin_memory=True,
+ sampler=val_sampler,
+ collate_fn=collate_fn,
+ )
+ return val_loader
+
+ def build_optimizer(self):
+ return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts)
+
+ def build_scheduler(self):
+ assert hasattr(self, "optimizer")
+ assert hasattr(self, "train_loader")
+ # self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch
+ self.cfg.scheduler.total_steps = self.max_epoch
+ return build_scheduler(self.cfg.scheduler, self.optimizer)
+
+ def build_scaler(self):
+ scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None
+ return scaler
diff --git a/UniRig/src/model/pointcept/engines/hooks/__init__.py b/UniRig/src/model/pointcept/engines/hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8434169b30ac1563d5f60fa8fd89c61979c20568
--- /dev/null
+++ b/UniRig/src/model/pointcept/engines/hooks/__init__.py
@@ -0,0 +1,6 @@
+from .default import HookBase
+from .misc import *
+from .evaluator import *
+# from .partseg import *
+
+from .builder import build_hooks
diff --git a/UniRig/src/model/pointcept/engines/hooks/builder.py b/UniRig/src/model/pointcept/engines/hooks/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..834258b8f03c5451569a11cd3eaf299f9234fae8
--- /dev/null
+++ b/UniRig/src/model/pointcept/engines/hooks/builder.py
@@ -0,0 +1,11 @@
+from pointcept.utils.registry import Registry
+
+
+HOOKS = Registry("hooks")
+
+
+def build_hooks(cfg):
+ hooks = []
+ for hook_cfg in cfg:
+ hooks.append(HOOKS.build(hook_cfg))
+ return hooks
diff --git a/UniRig/src/model/pointcept/engines/hooks/default.py b/UniRig/src/model/pointcept/engines/hooks/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ccd1f0fd3bd34215274c65b92e6a0755166d313
--- /dev/null
+++ b/UniRig/src/model/pointcept/engines/hooks/default.py
@@ -0,0 +1,24 @@
+class HookBase:
+ """
+ Base class for hooks that can be registered with :class:`TrainerBase`.
+ """
+
+ trainer = None # A weak reference to the trainer object.
+
+ def before_train(self):
+ pass
+
+ def before_epoch(self):
+ pass
+
+ def before_step(self):
+ pass
+
+ def after_step(self):
+ pass
+
+ def after_epoch(self):
+ pass
+
+ def after_train(self):
+ pass
diff --git a/UniRig/src/model/pointcept/engines/hooks/evaluator.py b/UniRig/src/model/pointcept/engines/hooks/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b06ef6c553ae291e1a87c4a31ed56859a1a83c
--- /dev/null
+++ b/UniRig/src/model/pointcept/engines/hooks/evaluator.py
@@ -0,0 +1,574 @@
+import numpy as np
+import torch
+import torch.distributed as dist
+import pointops
+from uuid import uuid4
+
+import pointcept.utils.comm as comm
+from pointcept.utils.misc import intersection_and_union_gpu
+
+from .default import HookBase
+from .builder import HOOKS
+
+
+@HOOKS.register_module()
+class ClsEvaluator(HookBase):
+ def after_epoch(self):
+ if self.trainer.cfg.evaluate:
+ self.eval()
+
+ def eval(self):
+ self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
+ self.trainer.model.eval()
+ for i, input_dict in enumerate(self.trainer.val_loader):
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with torch.no_grad():
+ output_dict = self.trainer.model(input_dict)
+ output = output_dict["cls_logits"]
+ loss = output_dict["loss"]
+ pred = output.max(1)[1]
+ label = input_dict["category"]
+ intersection, union, target = intersection_and_union_gpu(
+ pred,
+ label,
+ self.trainer.cfg.data.num_classes,
+ self.trainer.cfg.data.ignore_index,
+ )
+ if comm.get_world_size() > 1:
+ dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
+ target
+ )
+ intersection, union, target = (
+ intersection.cpu().numpy(),
+ union.cpu().numpy(),
+ target.cpu().numpy(),
+ )
+ # Here there is no need to sync since sync happened in dist.all_reduce
+ self.trainer.storage.put_scalar("val_intersection", intersection)
+ self.trainer.storage.put_scalar("val_union", union)
+ self.trainer.storage.put_scalar("val_target", target)
+ self.trainer.storage.put_scalar("val_loss", loss.item())
+ self.trainer.logger.info(
+ "Test: [{iter}/{max_iter}] "
+ "Loss {loss:.4f} ".format(
+ iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
+ )
+ )
+ loss_avg = self.trainer.storage.history("val_loss").avg
+ intersection = self.trainer.storage.history("val_intersection").total
+ union = self.trainer.storage.history("val_union").total
+ target = self.trainer.storage.history("val_target").total
+ iou_class = intersection / (union + 1e-10)
+ acc_class = intersection / (target + 1e-10)
+ m_iou = np.mean(iou_class)
+ m_acc = np.mean(acc_class)
+ all_acc = sum(intersection) / (sum(target) + 1e-10)
+ self.trainer.logger.info(
+ "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
+ m_iou, m_acc, all_acc
+ )
+ )
+ for i in range(self.trainer.cfg.data.num_classes):
+ self.trainer.logger.info(
+ "Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
+ idx=i,
+ name=self.trainer.cfg.data.names[i],
+ iou=iou_class[i],
+ accuracy=acc_class[i],
+ )
+ )
+ current_epoch = self.trainer.epoch + 1
+ if self.trainer.writer is not None:
+ self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
+ self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
+ self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
+ self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
+ self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
+ self.trainer.comm_info["current_metric_value"] = all_acc # save for saver
+ self.trainer.comm_info["current_metric_name"] = "allAcc" # save for saver
+
+ def after_train(self):
+ self.trainer.logger.info(
+ "Best {}: {:.4f}".format("allAcc", self.trainer.best_metric_value)
+ )
+
+
+@HOOKS.register_module()
+class SemSegEvaluator(HookBase):
+ def after_epoch(self):
+ if self.trainer.cfg.evaluate:
+ self.eval()
+
+ def eval(self):
+ self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
+ self.trainer.model.eval()
+ for i, input_dict in enumerate(self.trainer.val_loader):
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with torch.no_grad():
+ output_dict = self.trainer.model(input_dict)
+ output = output_dict["seg_logits"]
+ loss = output_dict["loss"]
+ pred = output.max(1)[1]
+ segment = input_dict["segment"]
+ if "origin_coord" in input_dict.keys():
+ idx, _ = pointops.knn_query(
+ 1,
+ input_dict["coord"].float(),
+ input_dict["offset"].int(),
+ input_dict["origin_coord"].float(),
+ input_dict["origin_offset"].int(),
+ )
+ pred = pred[idx.flatten().long()]
+ segment = input_dict["origin_segment"]
+ intersection, union, target = intersection_and_union_gpu(
+ pred,
+ segment,
+ self.trainer.cfg.data.num_classes,
+ self.trainer.cfg.data.ignore_index,
+ )
+ if comm.get_world_size() > 1:
+ dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
+ target
+ )
+ intersection, union, target = (
+ intersection.cpu().numpy(),
+ union.cpu().numpy(),
+ target.cpu().numpy(),
+ )
+ # Here there is no need to sync since sync happened in dist.all_reduce
+ self.trainer.storage.put_scalar("val_intersection", intersection)
+ self.trainer.storage.put_scalar("val_union", union)
+ self.trainer.storage.put_scalar("val_target", target)
+ self.trainer.storage.put_scalar("val_loss", loss.item())
+ info = "Test: [{iter}/{max_iter}] ".format(
+ iter=i + 1, max_iter=len(self.trainer.val_loader)
+ )
+ if "origin_coord" in input_dict.keys():
+ info = "Interp. " + info
+ self.trainer.logger.info(
+ info
+ + "Loss {loss:.4f} ".format(
+ iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
+ )
+ )
+ loss_avg = self.trainer.storage.history("val_loss").avg
+ intersection = self.trainer.storage.history("val_intersection").total
+ union = self.trainer.storage.history("val_union").total
+ target = self.trainer.storage.history("val_target").total
+ iou_class = intersection / (union + 1e-10)
+ acc_class = intersection / (target + 1e-10)
+ m_iou = np.mean(iou_class)
+ m_acc = np.mean(acc_class)
+ all_acc = sum(intersection) / (sum(target) + 1e-10)
+ self.trainer.logger.info(
+ "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
+ m_iou, m_acc, all_acc
+ )
+ )
+ for i in range(self.trainer.cfg.data.num_classes):
+ self.trainer.logger.info(
+ "Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
+ idx=i,
+ name=self.trainer.cfg.data.names[i],
+ iou=iou_class[i],
+ accuracy=acc_class[i],
+ )
+ )
+ current_epoch = self.trainer.epoch + 1
+ if self.trainer.writer is not None:
+ self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
+ self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
+ self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
+ self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
+ self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
+ self.trainer.comm_info["current_metric_value"] = m_iou # save for saver
+ self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver
+
+ def after_train(self):
+ self.trainer.logger.info(
+ "Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value)
+ )
+
+
+@HOOKS.register_module()
+class InsSegEvaluator(HookBase):
+ def __init__(self, segment_ignore_index=(-1,), instance_ignore_index=-1):
+ self.segment_ignore_index = segment_ignore_index
+ self.instance_ignore_index = instance_ignore_index
+
+ self.valid_class_names = None # update in before train
+ self.overlaps = np.append(np.arange(0.5, 0.95, 0.05), 0.25)
+ self.min_region_sizes = 100
+ self.distance_threshes = float("inf")
+ self.distance_confs = -float("inf")
+
+ def before_train(self):
+ self.valid_class_names = [
+ self.trainer.cfg.data.names[i]
+ for i in range(self.trainer.cfg.data.num_classes)
+ if i not in self.segment_ignore_index
+ ]
+
+ def after_epoch(self):
+ if self.trainer.cfg.evaluate:
+ self.eval()
+
+ def associate_instances(self, pred, segment, instance):
+ segment = segment.cpu().numpy()
+ instance = instance.cpu().numpy()
+ void_mask = np.in1d(segment, self.segment_ignore_index)
+
+ assert (
+ pred["pred_classes"].shape[0]
+ == pred["pred_scores"].shape[0]
+ == pred["pred_masks"].shape[0]
+ )
+ assert pred["pred_masks"].shape[1] == segment.shape[0] == instance.shape[0]
+ # get gt instances
+ gt_instances = dict()
+ for i in range(self.trainer.cfg.data.num_classes):
+ if i not in self.segment_ignore_index:
+ gt_instances[self.trainer.cfg.data.names[i]] = []
+ instance_ids, idx, counts = np.unique(
+ instance, return_index=True, return_counts=True
+ )
+ segment_ids = segment[idx]
+ for i in range(len(instance_ids)):
+ if instance_ids[i] == self.instance_ignore_index:
+ continue
+ if segment_ids[i] in self.segment_ignore_index:
+ continue
+ gt_inst = dict()
+ gt_inst["instance_id"] = instance_ids[i]
+ gt_inst["segment_id"] = segment_ids[i]
+ gt_inst["dist_conf"] = 0.0
+ gt_inst["med_dist"] = -1.0
+ gt_inst["vert_count"] = counts[i]
+ gt_inst["matched_pred"] = []
+ gt_instances[self.trainer.cfg.data.names[segment_ids[i]]].append(gt_inst)
+
+ # get pred instances and associate with gt
+ pred_instances = dict()
+ for i in range(self.trainer.cfg.data.num_classes):
+ if i not in self.segment_ignore_index:
+ pred_instances[self.trainer.cfg.data.names[i]] = []
+ instance_id = 0
+ for i in range(len(pred["pred_classes"])):
+ if pred["pred_classes"][i] in self.segment_ignore_index:
+ continue
+ pred_inst = dict()
+ pred_inst["uuid"] = uuid4()
+ pred_inst["instance_id"] = instance_id
+ pred_inst["segment_id"] = pred["pred_classes"][i]
+ pred_inst["confidence"] = pred["pred_scores"][i]
+ pred_inst["mask"] = np.not_equal(pred["pred_masks"][i], 0)
+ pred_inst["vert_count"] = np.count_nonzero(pred_inst["mask"])
+ pred_inst["void_intersection"] = np.count_nonzero(
+ np.logical_and(void_mask, pred_inst["mask"])
+ )
+ if pred_inst["vert_count"] < self.min_region_sizes:
+ continue # skip if empty
+ segment_name = self.trainer.cfg.data.names[pred_inst["segment_id"]]
+ matched_gt = []
+ for gt_idx, gt_inst in enumerate(gt_instances[segment_name]):
+ intersection = np.count_nonzero(
+ np.logical_and(
+ instance == gt_inst["instance_id"], pred_inst["mask"]
+ )
+ )
+ if intersection > 0:
+ gt_inst_ = gt_inst.copy()
+ pred_inst_ = pred_inst.copy()
+ gt_inst_["intersection"] = intersection
+ pred_inst_["intersection"] = intersection
+ matched_gt.append(gt_inst_)
+ gt_inst["matched_pred"].append(pred_inst_)
+ pred_inst["matched_gt"] = matched_gt
+ pred_instances[segment_name].append(pred_inst)
+ instance_id += 1
+ return gt_instances, pred_instances
+
+ def evaluate_matches(self, scenes):
+ overlaps = self.overlaps
+ min_region_sizes = [self.min_region_sizes]
+ dist_threshes = [self.distance_threshes]
+ dist_confs = [self.distance_confs]
+
+ # results: class x overlap
+ ap_table = np.zeros(
+ (len(dist_threshes), len(self.valid_class_names), len(overlaps)), float
+ )
+ for di, (min_region_size, distance_thresh, distance_conf) in enumerate(
+ zip(min_region_sizes, dist_threshes, dist_confs)
+ ):
+ for oi, overlap_th in enumerate(overlaps):
+ pred_visited = {}
+ for scene in scenes:
+ for _ in scene["pred"]:
+ for label_name in self.valid_class_names:
+ for p in scene["pred"][label_name]:
+ if "uuid" in p:
+ pred_visited[p["uuid"]] = False
+ for li, label_name in enumerate(self.valid_class_names):
+ y_true = np.empty(0)
+ y_score = np.empty(0)
+ hard_false_negatives = 0
+ has_gt = False
+ has_pred = False
+ for scene in scenes:
+ pred_instances = scene["pred"][label_name]
+ gt_instances = scene["gt"][label_name]
+ # filter groups in ground truth
+ gt_instances = [
+ gt
+ for gt in gt_instances
+ if gt["vert_count"] >= min_region_size
+ and gt["med_dist"] <= distance_thresh
+ and gt["dist_conf"] >= distance_conf
+ ]
+ if gt_instances:
+ has_gt = True
+ if pred_instances:
+ has_pred = True
+
+ cur_true = np.ones(len(gt_instances))
+ cur_score = np.ones(len(gt_instances)) * (-float("inf"))
+ cur_match = np.zeros(len(gt_instances), dtype=bool)
+ # collect matches
+ for gti, gt in enumerate(gt_instances):
+ found_match = False
+ for pred in gt["matched_pred"]:
+ # greedy assignments
+ if pred_visited[pred["uuid"]]:
+ continue
+ overlap = float(pred["intersection"]) / (
+ gt["vert_count"]
+ + pred["vert_count"]
+ - pred["intersection"]
+ )
+ if overlap > overlap_th:
+ confidence = pred["confidence"]
+ # if already have a prediction for this gt,
+ # the prediction with the lower score is automatically a false positive
+ if cur_match[gti]:
+ max_score = max(cur_score[gti], confidence)
+ min_score = min(cur_score[gti], confidence)
+ cur_score[gti] = max_score
+ # append false positive
+ cur_true = np.append(cur_true, 0)
+ cur_score = np.append(cur_score, min_score)
+ cur_match = np.append(cur_match, True)
+ # otherwise set score
+ else:
+ found_match = True
+ cur_match[gti] = True
+ cur_score[gti] = confidence
+ pred_visited[pred["uuid"]] = True
+ if not found_match:
+ hard_false_negatives += 1
+ # remove non-matched ground truth instances
+ cur_true = cur_true[cur_match]
+ cur_score = cur_score[cur_match]
+
+ # collect non-matched predictions as false positive
+ for pred in pred_instances:
+ found_gt = False
+ for gt in pred["matched_gt"]:
+ overlap = float(gt["intersection"]) / (
+ gt["vert_count"]
+ + pred["vert_count"]
+ - gt["intersection"]
+ )
+ if overlap > overlap_th:
+ found_gt = True
+ break
+ if not found_gt:
+ num_ignore = pred["void_intersection"]
+ for gt in pred["matched_gt"]:
+ if gt["segment_id"] in self.segment_ignore_index:
+ num_ignore += gt["intersection"]
+ # small ground truth instances
+ if (
+ gt["vert_count"] < min_region_size
+ or gt["med_dist"] > distance_thresh
+ or gt["dist_conf"] < distance_conf
+ ):
+ num_ignore += gt["intersection"]
+ proportion_ignore = (
+ float(num_ignore) / pred["vert_count"]
+ )
+ # if not ignored append false positive
+ if proportion_ignore <= overlap_th:
+ cur_true = np.append(cur_true, 0)
+ confidence = pred["confidence"]
+ cur_score = np.append(cur_score, confidence)
+
+ # append to overall results
+ y_true = np.append(y_true, cur_true)
+ y_score = np.append(y_score, cur_score)
+
+ # compute average precision
+ if has_gt and has_pred:
+ # compute precision recall curve first
+
+ # sorting and cumsum
+ score_arg_sort = np.argsort(y_score)
+ y_score_sorted = y_score[score_arg_sort]
+ y_true_sorted = y_true[score_arg_sort]
+ y_true_sorted_cumsum = np.cumsum(y_true_sorted)
+
+ # unique thresholds
+ (thresholds, unique_indices) = np.unique(
+ y_score_sorted, return_index=True
+ )
+ num_prec_recall = len(unique_indices) + 1
+
+ # prepare precision recall
+ num_examples = len(y_score_sorted)
+ # https://github.com/ScanNet/ScanNet/pull/26
+ # all predictions are non-matched but also all of them are ignored and not counted as FP
+ # y_true_sorted_cumsum is empty
+ # num_true_examples = y_true_sorted_cumsum[-1]
+ num_true_examples = (
+ y_true_sorted_cumsum[-1]
+ if len(y_true_sorted_cumsum) > 0
+ else 0
+ )
+ precision = np.zeros(num_prec_recall)
+ recall = np.zeros(num_prec_recall)
+
+ # deal with the first point
+ y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
+ # deal with remaining
+ for idx_res, idx_scores in enumerate(unique_indices):
+ cumsum = y_true_sorted_cumsum[idx_scores - 1]
+ tp = num_true_examples - cumsum
+ fp = num_examples - idx_scores - tp
+ fn = cumsum + hard_false_negatives
+ p = float(tp) / (tp + fp)
+ r = float(tp) / (tp + fn)
+ precision[idx_res] = p
+ recall[idx_res] = r
+
+ # first point in curve is artificial
+ precision[-1] = 1.0
+ recall[-1] = 0.0
+
+ # compute average of precision-recall curve
+ recall_for_conv = np.copy(recall)
+ recall_for_conv = np.append(recall_for_conv[0], recall_for_conv)
+ recall_for_conv = np.append(recall_for_conv, 0.0)
+
+ stepWidths = np.convolve(
+ recall_for_conv, [-0.5, 0, 0.5], "valid"
+ )
+ # integrate is now simply a dot product
+ ap_current = np.dot(precision, stepWidths)
+
+ elif has_gt:
+ ap_current = 0.0
+ else:
+ ap_current = float("nan")
+ ap_table[di, li, oi] = ap_current
+ d_inf = 0
+ o50 = np.where(np.isclose(self.overlaps, 0.5))
+ o25 = np.where(np.isclose(self.overlaps, 0.25))
+ oAllBut25 = np.where(np.logical_not(np.isclose(self.overlaps, 0.25)))
+ ap_scores = dict()
+ ap_scores["all_ap"] = np.nanmean(ap_table[d_inf, :, oAllBut25])
+ ap_scores["all_ap_50%"] = np.nanmean(ap_table[d_inf, :, o50])
+ ap_scores["all_ap_25%"] = np.nanmean(ap_table[d_inf, :, o25])
+ ap_scores["classes"] = {}
+ for li, label_name in enumerate(self.valid_class_names):
+ ap_scores["classes"][label_name] = {}
+ ap_scores["classes"][label_name]["ap"] = np.average(
+ ap_table[d_inf, li, oAllBut25]
+ )
+ ap_scores["classes"][label_name]["ap50%"] = np.average(
+ ap_table[d_inf, li, o50]
+ )
+ ap_scores["classes"][label_name]["ap25%"] = np.average(
+ ap_table[d_inf, li, o25]
+ )
+ return ap_scores
+
+ def eval(self):
+ self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
+ self.trainer.model.eval()
+ scenes = []
+ for i, input_dict in enumerate(self.trainer.val_loader):
+ assert (
+ len(input_dict["offset"]) == 1
+ ) # currently only support bs 1 for each GPU
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with torch.no_grad():
+ output_dict = self.trainer.model(input_dict)
+
+ loss = output_dict["loss"]
+
+ segment = input_dict["segment"]
+ instance = input_dict["instance"]
+ # map to origin
+ if "origin_coord" in input_dict.keys():
+ idx, _ = pointops.knn_query(
+ 1,
+ input_dict["coord"].float(),
+ input_dict["offset"].int(),
+ input_dict["origin_coord"].float(),
+ input_dict["origin_offset"].int(),
+ )
+ idx = idx.cpu().flatten().long()
+ output_dict["pred_masks"] = output_dict["pred_masks"][:, idx]
+ segment = input_dict["origin_segment"]
+ instance = input_dict["origin_instance"]
+
+ gt_instances, pred_instance = self.associate_instances(
+ output_dict, segment, instance
+ )
+ scenes.append(dict(gt=gt_instances, pred=pred_instance))
+
+ self.trainer.storage.put_scalar("val_loss", loss.item())
+ self.trainer.logger.info(
+ "Test: [{iter}/{max_iter}] "
+ "Loss {loss:.4f} ".format(
+ iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
+ )
+ )
+
+ loss_avg = self.trainer.storage.history("val_loss").avg
+ comm.synchronize()
+ scenes_sync = comm.gather(scenes, dst=0)
+ scenes = [scene for scenes_ in scenes_sync for scene in scenes_]
+ ap_scores = self.evaluate_matches(scenes)
+ all_ap = ap_scores["all_ap"]
+ all_ap_50 = ap_scores["all_ap_50%"]
+ all_ap_25 = ap_scores["all_ap_25%"]
+ self.trainer.logger.info(
+ "Val result: mAP/AP50/AP25 {:.4f}/{:.4f}/{:.4f}.".format(
+ all_ap, all_ap_50, all_ap_25
+ )
+ )
+ for i, label_name in enumerate(self.valid_class_names):
+ ap = ap_scores["classes"][label_name]["ap"]
+ ap_50 = ap_scores["classes"][label_name]["ap50%"]
+ ap_25 = ap_scores["classes"][label_name]["ap25%"]
+ self.trainer.logger.info(
+ "Class_{idx}-{name} Result: AP/AP50/AP25 {AP:.4f}/{AP50:.4f}/{AP25:.4f}".format(
+ idx=i, name=label_name, AP=ap, AP50=ap_50, AP25=ap_25
+ )
+ )
+ current_epoch = self.trainer.epoch + 1
+ if self.trainer.writer is not None:
+ self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
+ self.trainer.writer.add_scalar("val/mAP", all_ap, current_epoch)
+ self.trainer.writer.add_scalar("val/AP50", all_ap_50, current_epoch)
+ self.trainer.writer.add_scalar("val/AP25", all_ap_25, current_epoch)
+ self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
+ self.trainer.comm_info["current_metric_value"] = all_ap_50 # save for saver
+ self.trainer.comm_info["current_metric_name"] = "AP50" # save for saver
diff --git a/UniRig/src/model/pointcept/engines/hooks/misc.py b/UniRig/src/model/pointcept/engines/hooks/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..995e3646bc24abbfd2af40678f239306d5da6181
--- /dev/null
+++ b/UniRig/src/model/pointcept/engines/hooks/misc.py
@@ -0,0 +1,429 @@
+import sys
+import glob
+import os
+import shutil
+import time
+import torch
+import torch.utils.data
+from collections import OrderedDict
+
+if sys.version_info >= (3, 10):
+ from collections.abc import Sequence
+else:
+ from collections import Sequence
+from pointcept.utils.timer import Timer
+from pointcept.utils.comm import is_main_process, synchronize, get_world_size
+from pointcept.utils.cache import shared_dict
+
+import pointcept.utils.comm as comm
+# from pointcept.engines.test import TESTERS
+
+from .default import HookBase
+from .builder import HOOKS
+
+
+@HOOKS.register_module()
+class IterationTimer(HookBase):
+ def __init__(self, warmup_iter=1):
+ self._warmup_iter = warmup_iter
+ self._start_time = time.perf_counter()
+ self._iter_timer = Timer()
+ self._remain_iter = 0
+
+ def before_train(self):
+ self._start_time = time.perf_counter()
+ self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader)
+
+ def before_epoch(self):
+ self._iter_timer.reset()
+
+ def before_step(self):
+ data_time = self._iter_timer.seconds()
+ self.trainer.storage.put_scalar("data_time", data_time)
+
+ def after_step(self):
+ batch_time = self._iter_timer.seconds()
+ self._iter_timer.reset()
+ self.trainer.storage.put_scalar("batch_time", batch_time)
+ self._remain_iter -= 1
+ remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg
+ t_m, t_s = divmod(remain_time, 60)
+ t_h, t_m = divmod(t_m, 60)
+ remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s))
+ if "iter_info" in self.trainer.comm_info.keys():
+ info = (
+ "Data {data_time_val:.3f} ({data_time_avg:.3f}) "
+ "Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) "
+ "Remain {remain_time} ".format(
+ data_time_val=self.trainer.storage.history("data_time").val,
+ data_time_avg=self.trainer.storage.history("data_time").avg,
+ batch_time_val=self.trainer.storage.history("batch_time").val,
+ batch_time_avg=self.trainer.storage.history("batch_time").avg,
+ remain_time=remain_time,
+ )
+ )
+ self.trainer.comm_info["iter_info"] += info
+ if self.trainer.comm_info["iter"] <= self._warmup_iter:
+ self.trainer.storage.history("data_time").reset()
+ self.trainer.storage.history("batch_time").reset()
+
+
+@HOOKS.register_module()
+class InformationWriter(HookBase):
+ def __init__(self):
+ self.curr_iter = 0
+ self.model_output_keys = []
+
+ def before_train(self):
+ self.trainer.comm_info["iter_info"] = ""
+ self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader)
+
+ def before_step(self):
+ self.curr_iter += 1
+ # MSC pretrain do not have offset information. Comment the code for support MSC
+ # info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \
+ # "Scan {batch_size} ({points_num}) ".format(
+ # epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch,
+ # iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader),
+ # batch_size=len(self.trainer.comm_info["input_dict"]["offset"]),
+ # points_num=self.trainer.comm_info["input_dict"]["offset"][-1]
+ # )
+ info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format(
+ epoch=self.trainer.epoch + 1,
+ max_epoch=self.trainer.max_epoch,
+ iter=self.trainer.comm_info["iter"] + 1,
+ max_iter=len(self.trainer.train_loader),
+ )
+ self.trainer.comm_info["iter_info"] += info
+
+ def after_step(self):
+ if "model_output_dict" in self.trainer.comm_info.keys():
+ model_output_dict = self.trainer.comm_info["model_output_dict"]
+ self.model_output_keys = model_output_dict.keys()
+ for key in self.model_output_keys:
+ self.trainer.storage.put_scalar(key, model_output_dict[key].item())
+
+ for key in self.model_output_keys:
+ self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format(
+ key=key, value=self.trainer.storage.history(key).val
+ )
+ lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"]
+ self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr)
+ self.trainer.logger.info(self.trainer.comm_info["iter_info"])
+ self.trainer.comm_info["iter_info"] = "" # reset iter info
+ if self.trainer.writer is not None:
+ self.trainer.writer.add_scalar("lr", lr, self.curr_iter)
+ for key in self.model_output_keys:
+ self.trainer.writer.add_scalar(
+ "train_batch/" + key,
+ self.trainer.storage.history(key).val,
+ self.curr_iter,
+ )
+
+ def after_epoch(self):
+ epoch_info = "Train result: "
+ for key in self.model_output_keys:
+ epoch_info += "{key}: {value:.4f} ".format(
+ key=key, value=self.trainer.storage.history(key).avg
+ )
+ self.trainer.logger.info(epoch_info)
+ if self.trainer.writer is not None:
+ for key in self.model_output_keys:
+ self.trainer.writer.add_scalar(
+ "train/" + key,
+ self.trainer.storage.history(key).avg,
+ self.trainer.epoch + 1,
+ )
+
+
+@HOOKS.register_module()
+class CheckpointSaver(HookBase):
+ def __init__(self, save_freq=None):
+ self.save_freq = save_freq # None or int, None indicate only save model last
+
+ def after_epoch(self):
+ if is_main_process():
+ is_best = False
+ if self.trainer.cfg.evaluate:
+ current_metric_value = self.trainer.comm_info["current_metric_value"]
+ current_metric_name = self.trainer.comm_info["current_metric_name"]
+ if current_metric_value > self.trainer.best_metric_value:
+ self.trainer.best_metric_value = current_metric_value
+ is_best = True
+ self.trainer.logger.info(
+ "Best validation {} updated to: {:.4f}".format(
+ current_metric_name, current_metric_value
+ )
+ )
+ self.trainer.logger.info(
+ "Currently Best {}: {:.4f}".format(
+ current_metric_name, self.trainer.best_metric_value
+ )
+ )
+
+ filename = os.path.join(
+ self.trainer.cfg.save_path, "model", "model_last.pth"
+ )
+ self.trainer.logger.info("Saving checkpoint to: " + filename)
+ torch.save(
+ {
+ "epoch": self.trainer.epoch + 1,
+ "state_dict": self.trainer.model.state_dict(),
+ "optimizer": self.trainer.optimizer.state_dict(),
+ "scheduler": self.trainer.scheduler.state_dict(),
+ "scaler": self.trainer.scaler.state_dict()
+ if self.trainer.cfg.enable_amp
+ else None,
+ "best_metric_value": self.trainer.best_metric_value,
+ },
+ filename + ".tmp",
+ )
+ os.replace(filename + ".tmp", filename)
+ if is_best:
+ shutil.copyfile(
+ filename,
+ os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"),
+ )
+ if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0:
+ shutil.copyfile(
+ filename,
+ os.path.join(
+ self.trainer.cfg.save_path,
+ "model",
+ f"epoch_{self.trainer.epoch + 1}.pth",
+ ),
+ )
+
+
+@HOOKS.register_module()
+class CheckpointLoader(HookBase):
+ def __init__(self, keywords="", replacement=None, strict=False):
+ self.keywords = keywords
+ self.replacement = replacement if replacement is not None else keywords
+ self.strict = strict
+
+ def before_train(self):
+ self.trainer.logger.info("=> Loading checkpoint & weight ...")
+ if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight):
+ self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}")
+ checkpoint = torch.load(
+ self.trainer.cfg.weight,
+ map_location=lambda storage, loc: storage.cuda(),
+ )
+ self.trainer.logger.info(
+ f"Loading layer weights with keyword: {self.keywords}, "
+ f"replace keyword with: {self.replacement}"
+ )
+ weight = OrderedDict()
+ for key, value in checkpoint["state_dict"].items():
+ if not key.startswith("module."):
+ if comm.get_world_size() > 1:
+ key = "module." + key # xxx.xxx -> module.xxx.xxx
+ # Now all keys contain "module." no matter DDP or not.
+ if self.keywords in key:
+ key = key.replace(self.keywords, self.replacement)
+ if comm.get_world_size() == 1:
+ key = key[7:] # module.xxx.xxx -> xxx.xxx
+ weight[key] = value
+ load_state_info = self.trainer.model.load_state_dict(
+ weight, strict=self.strict
+ )
+ self.trainer.logger.info(f"Missing keys: {load_state_info[0]}")
+ if self.trainer.cfg.resume:
+ self.trainer.logger.info(
+ f"Resuming train at eval epoch: {checkpoint['epoch']}"
+ )
+ self.trainer.start_epoch = checkpoint["epoch"]
+ self.trainer.best_metric_value = checkpoint["best_metric_value"]
+ self.trainer.optimizer.load_state_dict(checkpoint["optimizer"])
+ self.trainer.scheduler.load_state_dict(checkpoint["scheduler"])
+ if self.trainer.cfg.enable_amp:
+ self.trainer.scaler.load_state_dict(checkpoint["scaler"])
+ else:
+ self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}")
+
+
+@HOOKS.register_module()
+class DataCacheOperator(HookBase):
+ def __init__(self, data_root, split):
+ self.data_root = data_root
+ self.split = split
+ self.data_list = self.get_data_list()
+
+ def get_data_list(self):
+ if isinstance(self.split, str):
+ data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth"))
+ elif isinstance(self.split, Sequence):
+ data_list = []
+ for split in self.split:
+ data_list += glob.glob(os.path.join(self.data_root, split, "*.pth"))
+ else:
+ raise NotImplementedError
+ return data_list
+
+ def get_cache_name(self, data_path):
+ data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0]
+ return "pointcept" + data_name.replace(os.path.sep, "-")
+
+ def before_train(self):
+ self.trainer.logger.info(
+ f"=> Caching dataset: {self.data_root}, split: {self.split} ..."
+ )
+ if is_main_process():
+ for data_path in self.data_list:
+ cache_name = self.get_cache_name(data_path)
+ data = torch.load(data_path)
+ shared_dict(cache_name, data)
+ synchronize()
+
+
+@HOOKS.register_module()
+class RuntimeProfiler(HookBase):
+ def __init__(
+ self,
+ forward=True,
+ backward=True,
+ interrupt=False,
+ warm_up=2,
+ sort_by="cuda_time_total",
+ row_limit=30,
+ ):
+ self.forward = forward
+ self.backward = backward
+ self.interrupt = interrupt
+ self.warm_up = warm_up
+ self.sort_by = sort_by
+ self.row_limit = row_limit
+
+ def before_train(self):
+ self.trainer.logger.info("Profiling runtime ...")
+ from torch.profiler import profile, record_function, ProfilerActivity
+
+ for i, input_dict in enumerate(self.trainer.train_loader):
+ if i == self.warm_up + 1:
+ break
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ if self.forward:
+ with profile(
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ ) as forward_prof:
+ with record_function("model_inference"):
+ output_dict = self.trainer.model(input_dict)
+ else:
+ output_dict = self.trainer.model(input_dict)
+ loss = output_dict["loss"]
+ if self.backward:
+ with profile(
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ ) as backward_prof:
+ with record_function("model_inference"):
+ loss.backward()
+ self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]")
+ if self.forward:
+ self.trainer.logger.info(
+ "Forward profile: \n"
+ + str(
+ forward_prof.key_averages().table(
+ sort_by=self.sort_by, row_limit=self.row_limit
+ )
+ )
+ )
+ forward_prof.export_chrome_trace(
+ os.path.join(self.trainer.cfg.save_path, "forward_trace.json")
+ )
+
+ if self.backward:
+ self.trainer.logger.info(
+ "Backward profile: \n"
+ + str(
+ backward_prof.key_averages().table(
+ sort_by=self.sort_by, row_limit=self.row_limit
+ )
+ )
+ )
+ backward_prof.export_chrome_trace(
+ os.path.join(self.trainer.cfg.save_path, "backward_trace.json")
+ )
+ if self.interrupt:
+ sys.exit(0)
+
+
+@HOOKS.register_module()
+class RuntimeProfilerV2(HookBase):
+ def __init__(
+ self,
+ interrupt=False,
+ wait=1,
+ warmup=1,
+ active=10,
+ repeat=1,
+ sort_by="cuda_time_total",
+ row_limit=30,
+ ):
+ self.interrupt = interrupt
+ self.wait = wait
+ self.warmup = warmup
+ self.active = active
+ self.repeat = repeat
+ self.sort_by = sort_by
+ self.row_limit = row_limit
+
+ def before_train(self):
+ self.trainer.logger.info("Profiling runtime ...")
+ from torch.profiler import (
+ profile,
+ record_function,
+ ProfilerActivity,
+ schedule,
+ tensorboard_trace_handler,
+ )
+
+ prof = profile(
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ schedule=schedule(
+ wait=self.wait,
+ warmup=self.warmup,
+ active=self.active,
+ repeat=self.repeat,
+ ),
+ on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path),
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ )
+ prof.start()
+ for i, input_dict in enumerate(self.trainer.train_loader):
+ if i >= (self.wait + self.warmup + self.active) * self.repeat:
+ break
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with record_function("model_forward"):
+ output_dict = self.trainer.model(input_dict)
+ loss = output_dict["loss"]
+ with record_function("model_backward"):
+ loss.backward()
+ prof.step()
+ self.trainer.logger.info(
+ f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]"
+ )
+ self.trainer.logger.info(
+ "Profile: \n"
+ + str(
+ prof.key_averages().table(
+ sort_by=self.sort_by, row_limit=self.row_limit
+ )
+ )
+ )
+ prof.stop()
+
+ if self.interrupt:
+ sys.exit(0)
diff --git a/UniRig/src/model/pointcept/engines/launch.py b/UniRig/src/model/pointcept/engines/launch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2592dabbbc1a398e03ab73e3ace2de9f51fb3a9
--- /dev/null
+++ b/UniRig/src/model/pointcept/engines/launch.py
@@ -0,0 +1,128 @@
+import os
+import logging
+from datetime import timedelta
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+from pointcept.utils import comm
+
+__all__ = ["DEFAULT_TIMEOUT", "launch"]
+
+DEFAULT_TIMEOUT = timedelta(minutes=60)
+
+
+def _find_free_port():
+ import socket
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # Binding to port 0 will cause the OS to find an available port for us
+ sock.bind(("", 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ # NOTE: there is still a chance the port could be taken by other processes.
+ return port
+
+
+def launch(
+ main_func,
+ num_gpus_per_machine,
+ num_machines=1,
+ machine_rank=0,
+ dist_url=None,
+ cfg=(),
+ timeout=DEFAULT_TIMEOUT,
+):
+ """
+ Launch multi-gpu or distributed training.
+ This function must be called on all machines involved in the training.
+ It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
+ Args:
+ main_func: a function that will be called by `main_func(*args)`
+ num_gpus_per_machine (int): number of GPUs per machine
+ num_machines (int): the total number of machines
+ machine_rank (int): the rank of this machine
+ dist_url (str): url to connect to for distributed jobs, including protocol
+ e.g. "tcp://127.0.0.1:8686".
+ Can be set to "auto" to automatically select a free port on localhost
+ timeout (timedelta): timeout of the distributed workers
+ args (tuple): arguments passed to main_func
+ """
+ world_size = num_machines * num_gpus_per_machine
+ if world_size > 1:
+ if dist_url == "auto":
+ assert (
+ num_machines == 1
+ ), "dist_url=auto not supported in multi-machine jobs."
+ port = _find_free_port()
+ dist_url = f"tcp://127.0.0.1:{port}"
+ if num_machines > 1 and dist_url.startswith("file://"):
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
+ )
+
+ mp.spawn(
+ _distributed_worker,
+ nprocs=num_gpus_per_machine,
+ args=(
+ main_func,
+ world_size,
+ num_gpus_per_machine,
+ machine_rank,
+ dist_url,
+ cfg,
+ timeout,
+ ),
+ daemon=False,
+ )
+ else:
+ main_func(*cfg)
+
+
+def _distributed_worker(
+ local_rank,
+ main_func,
+ world_size,
+ num_gpus_per_machine,
+ machine_rank,
+ dist_url,
+ cfg,
+ timeout=DEFAULT_TIMEOUT,
+):
+ assert (
+ torch.cuda.is_available()
+ ), "cuda is not available. Please check your installation."
+ global_rank = machine_rank * num_gpus_per_machine + local_rank
+ try:
+ dist.init_process_group(
+ backend="NCCL",
+ init_method=dist_url,
+ world_size=world_size,
+ rank=global_rank,
+ timeout=timeout,
+ )
+ except Exception as e:
+ logger = logging.getLogger(__name__)
+ logger.error("Process group URL: {}".format(dist_url))
+ raise e
+
+ # Setup the local process group (which contains ranks within the same machine)
+ assert comm._LOCAL_PROCESS_GROUP is None
+ num_machines = world_size // num_gpus_per_machine
+ for i in range(num_machines):
+ ranks_on_i = list(
+ range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
+ )
+ pg = dist.new_group(ranks_on_i)
+ if i == machine_rank:
+ comm._LOCAL_PROCESS_GROUP = pg
+
+ assert num_gpus_per_machine <= torch.cuda.device_count()
+ torch.cuda.set_device(local_rank)
+
+ # synchronize is needed here to prevent a possible timeout after calling init_process_group
+ # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
+ comm.synchronize()
+
+ main_func(*cfg)
diff --git a/UniRig/src/model/pointcept/engines/train.py b/UniRig/src/model/pointcept/engines/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..66964b7774eafa64714806b3988abaa775de403f
--- /dev/null
+++ b/UniRig/src/model/pointcept/engines/train.py
@@ -0,0 +1,492 @@
+import os
+import sys
+import weakref
+import torch
+torch.multiprocessing.set_start_method('spawn')
+import torch.nn as nn
+import torch.utils.data
+from functools import partial
+
+if sys.version_info >= (3, 10):
+ from collections.abc import Iterator
+else:
+ from collections import Iterator
+from tensorboardX import SummaryWriter
+
+from .defaults import create_ddp_model, worker_init_fn
+from .hooks import HookBase, build_hooks
+import pointcept.utils.comm as comm
+from pointcept.datasets import build_dataset, point_collate_fn, collate_fn
+from pointcept.models import build_model
+from pointcept.utils.logger import get_root_logger
+from pointcept.utils.optimizer import build_optimizer
+from pointcept.utils.scheduler import build_scheduler
+from pointcept.utils.events import EventStorage
+from pointcept.utils.registry import Registry
+
+from sklearn.preprocessing import QuantileTransformer
+from pointcept.utils.timer import Timer
+
+TRAINERS = Registry("trainers")
+from cuml.cluster.hdbscan import HDBSCAN
+# from sklearn.cluster import HDBSCAN
+import open3d as o3d
+import matplotlib.colors as mcolors
+import numpy as np
+from collections import OrderedDict
+import trimesh
+import pointops
+
+class TrainerBase:
+ def __init__(self) -> None:
+ self.hooks = []
+ self.epoch = 0
+ self.start_epoch = 0
+ self.max_epoch = 0
+ self.max_iter = 0
+ self.comm_info = dict()
+ self.data_iterator: Iterator = enumerate([])
+ self.storage: EventStorage
+ self.writer: SummaryWriter
+ self._iter_timer = Timer()
+
+ def register_hooks(self, hooks) -> None:
+ hooks = build_hooks(hooks)
+ for h in hooks:
+ assert isinstance(h, HookBase)
+ # To avoid circular reference, hooks and trainer cannot own each other.
+ # This normally does not matter, but will cause memory leak if the
+ # involved objects contain __del__:
+ # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
+ h.trainer = weakref.proxy(self)
+ self.hooks.extend(hooks)
+
+ def train(self):
+ with EventStorage() as self.storage:
+ # => before train
+ self.before_train()
+ for self.epoch in range(self.start_epoch, self.max_epoch):
+ # => before epoch
+ self.before_epoch()
+ # => run_epoch
+ for (
+ self.comm_info["iter"],
+ self.comm_info["input_dict"],
+ ) in self.data_iterator:
+ # => before_step
+ self.before_step()
+ # => run_step
+ self.run_step()
+ # => after_step
+ self.after_step()
+ # => after epoch
+ self.after_epoch()
+ # => after train
+ self.after_train()
+
+ def before_train(self):
+ for h in self.hooks:
+ h.before_train()
+
+ def before_epoch(self):
+ for h in self.hooks:
+ h.before_epoch()
+
+ def before_step(self):
+ for h in self.hooks:
+ h.before_step()
+
+ def run_step(self):
+ raise NotImplementedError
+
+ def after_step(self):
+ for h in self.hooks:
+ h.after_step()
+
+ def after_epoch(self):
+ for h in self.hooks:
+ h.after_epoch()
+ self.storage.reset_histories()
+
+ def after_train(self):
+ # Sync GPU before running train hooks
+ comm.synchronize()
+ for h in self.hooks:
+ h.after_train()
+ if comm.is_main_process():
+ self.writer.close()
+
+
+@TRAINERS.register_module("DefaultTrainer")
+class Trainer(TrainerBase):
+ def __init__(self, cfg):
+ super(Trainer, self).__init__()
+ self.epoch = 0
+ self.start_epoch = 0
+ self.max_epoch = cfg.eval_epoch
+ self.best_metric_value = -torch.inf
+ self.logger = get_root_logger(
+ log_file=os.path.join(cfg.save_path, "train.log"),
+ # file_mode="a" if cfg.resume else "w",
+ file_mode="a",
+ )
+ self.logger.info("=> Loading config ...")
+ self.cfg = cfg
+ self.logger.info(f"Save path: {cfg.save_path}")
+ self.logger.info(f"Config:\n{cfg.pretty_text}")
+ self.logger.info("=> Building model ...")
+ self.model = self.build_model()
+ self.logger.info("=> Building writer ...")
+ self.writer = self.build_writer()
+ self.logger.info("=> Building train dataset & dataloader ...")
+ self.train_loader = self.build_train_loader()
+ # self.logger.info("=> Building val dataset & dataloader ...")
+ # self.val_loader = self.build_val_loader()
+ self.logger.info("=> Building optimize, scheduler, scaler(amp) ...")
+ self.optimizer = self.build_optimizer()
+ self.scheduler = self.build_scheduler()
+ self.scaler = self.build_scaler()
+ self.logger.info("=> Building hooks ...")
+ self.register_hooks(self.cfg.hooks)
+
+ # !!!
+ self.model.scale_statistics = nn.Parameter(self.train_loader.scale_3d_statistics)
+ self.model.scale_statistics.requires_grad = False
+ self.model.quantile_transformer = self._get_quantile_func(self.train_loader.scale_3d_statistics)
+ # print(id(self.model))
+ # self.val_scales_list = [0.0, 0.5, 1.0, 1.5, 2.0]
+ self.val_scales_list = self.cfg.val_scales_list
+ self.mesh_voting = self.cfg.mesh_voting
+ self.backbone_weight_path = self.cfg.backbone_weight_path
+
+ def train(self):
+ with EventStorage() as self.storage:
+ # => before train
+ # self.before_train()
+ self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
+ # data_time = self._iter_timer.seconds()
+ # self.trainer.storage.put_scalar("data_time", data_time)
+ # !!! load checkpoint
+ if self.backbone_weight_path != None:
+ self.logger.info("=> Loading checkpoint & weight ...")
+ if os.path.isfile(self.backbone_weight_path):
+ checkpoint = torch.load(
+ self.backbone_weight_path,
+ map_location=lambda storage, loc: storage.cuda(),
+ )
+ weight = OrderedDict()
+ for key, value in checkpoint["state_dict"].items():
+ if not key.startswith("module."):
+ if comm.get_world_size() > 1:
+ key = "module." + key # xxx.xxx -> module.xxx.xxx
+ # Now all keys contain "module." no matter DDP or not.
+ # if self.keywords in key:
+ # key = key.replace(self.keywords, self.replacement)
+ if comm.get_world_size() == 1:
+ key = key[7:] # module.xxx.xxx -> xxx.xxx
+ # if key.startswith("backbone."):
+ # key = key[9:] # backbone.xxx.xxx -> xxx.xxx
+ key = "backbone." + key # xxx.xxx -> backbone.xxx.xxx
+ weight[key] = value
+ load_state_info = self.model.load_state_dict(weight, strict=False)
+ self.logger.info(f"Missing keys: {load_state_info[0]}")
+ else:
+ self.logger.info(f"No weight found at: {self.backbone_weight_path}")
+
+ for self.epoch in range(self.start_epoch, self.max_epoch):
+ self.model.train()
+ loss_dict = self.model(self.train_loader.get_data(0))
+ loss = loss_dict["instance_loss"]
+
+ # !!! writer
+ lr = self.optimizer.state_dict()["param_groups"][0]["lr"]
+ self.writer.add_scalar("lr", lr, self.epoch)
+ for key in loss_dict.keys():
+ self.writer.add_scalar(
+ "train/" + key,
+ loss_dict[key].item(),
+ self.epoch,
+ )
+ if self.epoch % 10 == 0:
+ self.logger.info(
+ f"iter: {self.epoch}, total_loss: {loss.item()}, loss_1: {loss_dict['instance_loss_1'].item()}, loss_2: {loss_dict['instance_loss_2'].item()}, loss_3: {loss_dict['instance_loss_3'].item()}"
+ )
+
+ # !!! optimizer
+ self.optimizer.zero_grad()
+ if self.cfg.enable_amp:
+ self.scaler.scale(loss).backward()
+ self.scaler.step(self.optimizer)
+
+ # When enable amp, optimizer.step call are skipped if the loss scaling factor is too large.
+ # Fix torch warning scheduler step before optimizer step.
+ scaler = self.scaler.get_scale()
+ self.scaler.update()
+ if scaler <= self.scaler.get_scale():
+ self.scheduler.step()
+ else:
+ loss.backward()
+ self.optimizer.step()
+ self.scheduler.step()
+
+ # !!! save checkpoint
+ if (self.epoch + 1) % 5000 == 0:
+ filename = os.path.join(self.cfg.save_path, "model", f"{str(self.epoch + 1)}.pth")
+ self.logger.info("Saving checkpoint to: " + filename)
+ torch.save(
+ {
+ "epoch": self.epoch + 1,
+ "state_dict": self.model.state_dict(),
+ },
+ filename + ".tmp",
+ )
+ os.replace(filename + ".tmp", filename)
+ self.eval()
+
+ def eval(self):
+ # val_data = build_dataset(self.cfg.data.val)
+ self.logger.info("=> Loading checkpoint & weight ...")
+ if self.cfg.weight and os.path.isfile(self.cfg.weight):
+ checkpoint = torch.load(
+ self.cfg.weight,
+ map_location=lambda storage, loc: storage.cuda(),
+ )
+ load_state_info = self.model.load_state_dict(checkpoint["state_dict"])
+ self.logger.info(f"Missing keys: {load_state_info[0]}")
+ else:
+ self.logger.info(f"No weight found at: {self.cfg.weight}")
+ self.cfg.weight = "last"
+
+ self.model.eval()
+ save_root = os.path.join(self.cfg.save_path, "vis_pcd", os.path.splitext(os.path.basename(self.cfg.weight))[0])
+ os.makedirs(save_root, exist_ok=True)
+ group_save_root = os.path.join(self.cfg.save_path, "results", os.path.splitext(os.path.basename(self.cfg.weight))[0])
+ os.makedirs(group_save_root, exist_ok=True)
+
+ hex_colors = list(mcolors.CSS4_COLORS.values())
+ rgb_colors = np.array([mcolors.to_rgb(color) for color in hex_colors if color not in ['#000000', '#FFFFFF']])
+ def relative_luminance(color):
+ return 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2]
+ rgb_colors = [color for color in rgb_colors if (relative_luminance(color) > 0.4 and relative_luminance(color) < 0.8)]
+ np.random.shuffle(rgb_colors)
+ input_dict = self.train_loader.val_data()
+
+ pcd_inverse = self.train_loader.pcd_inverse
+ if self.mesh_voting:
+ mesh = trimesh.load(self.train_loader.mesh_path)
+ if isinstance(mesh, trimesh.Scene):
+ mesh = mesh.dump(concatenate=True)
+ mesh.visual = trimesh.visual.ColorVisuals(mesh=mesh)
+
+ for scale in self.val_scales_list:
+ input_dict["scale"] = scale
+ instance_feat = self.model(input_dict).cpu().detach().numpy()
+
+ clusterer = HDBSCAN(
+ cluster_selection_epsilon=0.1,
+ min_samples=30,
+ min_cluster_size=30,
+ allow_single_cluster=False,
+ ).fit(instance_feat)
+
+ labels = clusterer.labels_
+ invalid_label_mask = labels == -1
+ if invalid_label_mask.sum() > 0:
+ if invalid_label_mask.sum() == len(invalid_label_mask):
+ labels = np.zeros_like(labels)
+ else:
+ coord = input_dict["obj"]["coord"].cuda().contiguous().float()
+ valid_coord = coord[~invalid_label_mask]
+ valid_offset = torch.tensor(valid_coord.shape[0]).cuda()
+ invalid_coord = coord[invalid_label_mask]
+ invalid_offset = torch.tensor(invalid_coord.shape[0]).cuda()
+ indices, distances = pointops.knn_query(1, valid_coord, valid_offset, invalid_coord, invalid_offset)
+ indices = indices[:, 0].cpu().numpy()
+ labels[invalid_label_mask] = labels[~invalid_label_mask][indices]
+
+
+ # np.save(os.path.join(group_save_root, f"{str(scale)}.npy"), labels)
+ save_path = os.path.join(save_root, f"{str(scale)}.ply")
+ coord = input_dict["obj"]["coord"].cpu().numpy()
+ random_color = []
+ for i in range(max(labels) + 1):
+ random_color.append(rgb_colors[i % len(rgb_colors)])
+ random_color.append(np.array([0, 0, 0]))
+ color = [random_color[i] for i in labels]
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(coord)
+ pcd.colors = o3d.utility.Vector3dVector(color)
+ o3d.io.write_point_cloud(save_path, pcd)
+
+ labels = labels[pcd_inverse]
+
+ # print(len(clusterer.labels_))
+ self.logger.info(f"scale_{scale} has {max(labels)+1} groups")
+ # print(min(clusterer.labels_))
+ if self.mesh_voting:
+ face_index = self.train_loader.face_index
+ face_index = face_index[pcd_inverse]
+
+ # Compute votes for each face using NumPy's bincount function
+ # labels = clusterer.labels_
+ num_faces = len(mesh.faces)
+ num_labels = max(labels) + 1
+ votes = np.zeros((num_faces, num_labels), dtype=np.int32)
+ np.add.at(votes, (face_index, labels), 1)
+
+ # Find the label with most votes for each face using NumPy's argmax function
+ max_votes_labels = np.argmax(votes, axis=1)
+ # Set the label to -1 for faces that have no corresponding points
+ max_votes_labels[np.all(votes == 0, axis=1)] = -1
+
+ valid_mask = max_votes_labels != -1
+ face_centroids = mesh.triangles_center
+ coord = torch.tensor(face_centroids).cuda().contiguous().float()
+ valid_coord = coord[valid_mask]
+ valid_offset = torch.tensor(valid_coord.shape[0]).cuda()
+ invalid_coord = coord[~valid_mask]
+ invalid_offset = torch.tensor(invalid_coord.shape[0]).cuda()
+ indices, distances = pointops.knn_query(1, valid_coord, valid_offset, invalid_coord, invalid_offset)
+ # # the first column is the point itself
+ # indices = indices[:, 1].cpu().numpy()
+ indices = indices[:, 0].cpu().numpy()
+ mesh_group = max_votes_labels.copy()
+ mesh_group[~valid_mask] = mesh_group[valid_mask][indices]
+
+ np.save(os.path.join(group_save_root, f"mesh_{str(scale)}.npy"), mesh_group)
+
+ # Assign color to each face based on the label with most votes
+ for face, label in enumerate(mesh_group):
+ color = (random_color[label] * 255).astype(np.uint8)
+ color_with_alpha = np.append(color, 255) # Add alpha value
+ mesh.visual.face_colors[face] = color_with_alpha
+
+ # Save the new mesh
+ mesh_save_path = os.path.join(save_root, f"mesh_{str(scale)}.ply")
+ mesh.export(mesh_save_path)
+
+
+ def _get_quantile_func(self, scales: torch.Tensor, distribution="normal"):
+ """
+ Use 3D scale statistics to normalize scales -- use quantile transformer.
+ """
+ scales = scales.flatten()
+ max_grouping_scale = 2
+ scales = scales[(scales > 0) & (scales < max_grouping_scale)]
+
+ scales = scales.detach().cpu().numpy()
+
+ # Calculate quantile transformer
+ quantile_transformer = QuantileTransformer(output_distribution=distribution)
+ quantile_transformer = quantile_transformer.fit(scales.reshape(-1, 1))
+
+ def quantile_transformer_func(scales):
+ # This function acts as a wrapper for QuantileTransformer.
+ # QuantileTransformer expects a numpy array, while we have a torch tensor.
+ return torch.Tensor(
+ quantile_transformer.transform(scales.cpu().numpy())
+ ).to(scales.device)
+
+ return quantile_transformer_func
+
+ def run_step(self):
+ input_dict = self.comm_info["input_dict"]
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp):
+ output_dict = self.model(input_dict)
+ loss = output_dict["loss"]
+ self.optimizer.zero_grad()
+ if self.cfg.enable_amp:
+ self.scaler.scale(loss).backward()
+ self.scaler.step(self.optimizer)
+
+ # When enable amp, optimizer.step call are skipped if the loss scaling factor is too large.
+ # Fix torch warning scheduler step before optimizer step.
+ scaler = self.scaler.get_scale()
+ self.scaler.update()
+ if scaler <= self.scaler.get_scale():
+ self.scheduler.step()
+ else:
+ loss.backward()
+ self.optimizer.step()
+ self.scheduler.step()
+ if self.cfg.empty_cache:
+ torch.cuda.empty_cache()
+ self.comm_info["model_output_dict"] = output_dict
+
+ def build_model(self):
+ model = build_model(self.cfg.model)
+ if self.cfg.sync_bn:
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ # logger.info(f"Model: \n{self.model}")
+ self.logger.info(f"Num params: {n_parameters}")
+ model = create_ddp_model(
+ model.cuda(),
+ broadcast_buffers=False,
+ find_unused_parameters=self.cfg.find_unused_parameters,
+ )
+ return model
+
+ def build_writer(self):
+ writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None
+ self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}")
+ return writer
+
+ def build_train_loader(self):
+ self.cfg.data.train.oid = self.cfg.oid
+ self.cfg.data.train.label = self.cfg.label
+ train_data = build_dataset(self.cfg.data.train)
+ return train_data
+
+ def build_val_loader(self):
+ val_loader = None
+ if self.cfg.evaluate:
+ val_data = build_dataset(self.cfg.data.val)
+ if comm.get_world_size() > 1:
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
+ else:
+ val_sampler = None
+ val_loader = torch.utils.data.DataLoader(
+ val_data,
+ batch_size=self.cfg.batch_size_val_per_gpu,
+ shuffle=False,
+ num_workers=self.cfg.num_worker_per_gpu,
+ pin_memory=True,
+ sampler=val_sampler,
+ collate_fn=collate_fn,
+ )
+ return val_loader
+
+ def build_optimizer(self):
+ return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts)
+
+ def build_scheduler(self):
+ assert hasattr(self, "optimizer")
+ assert hasattr(self, "train_loader")
+ # self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch
+ self.cfg.scheduler.total_steps = self.max_epoch
+ return build_scheduler(self.cfg.scheduler, self.optimizer)
+
+ def build_scaler(self):
+ scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None
+ return scaler
+
+
+@TRAINERS.register_module("MultiDatasetTrainer")
+class MultiDatasetTrainer(Trainer):
+ def build_train_loader(self):
+ from pointcept.datasets import MultiDatasetDataloader
+
+ train_data = build_dataset(self.cfg.data.train)
+ train_loader = MultiDatasetDataloader(
+ train_data,
+ self.cfg.batch_size_per_gpu,
+ self.cfg.num_worker_per_gpu,
+ self.cfg.mix_prob,
+ self.cfg.seed,
+ )
+ self.comm_info["iter_per_epoch"] = len(train_loader)
+ return train_loader
diff --git a/UniRig/src/model/pointcept/models/PTv3Object.py b/UniRig/src/model/pointcept/models/PTv3Object.py
new file mode 100644
index 0000000000000000000000000000000000000000..5006353c34bb95b1c266e420d960a9c288b2ade4
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/PTv3Object.py
@@ -0,0 +1,664 @@
+from functools import partial
+from addict import Dict
+import math
+import torch
+import torch.nn as nn
+import spconv.pytorch as spconv
+import torch_scatter
+from timm.models.layers import DropPath
+from typing import Union
+from einops import rearrange
+
+try:
+ import flash_attn
+except ImportError:
+ flash_attn = None
+
+from .utils.misc import offset2bincount
+from .utils.structure import Point
+from .modules import PointModule, PointSequential
+
+
+class RPE(torch.nn.Module):
+ def __init__(self, patch_size, num_heads):
+ super().__init__()
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+ self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
+ self.rpe_num = 2 * self.pos_bnd + 1
+ self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
+ torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
+
+ def forward(self, coord):
+ idx = (
+ coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
+ + self.pos_bnd # relative position to positive index
+ + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
+ )
+ out = self.rpe_table.index_select(0, idx.reshape(-1))
+ out = out.view(idx.shape + (-1,)).sum(3)
+ out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
+ return out
+
+class QueryKeyNorm(nn.Module):
+ def __init__(self, channels, num_heads):
+ super(QueryKeyNorm, self).__init__()
+ self.num_heads = num_heads
+ self.norm = nn.LayerNorm(channels // num_heads, elementwise_affine=False)
+
+ def forward(self, qkv):
+ H = self.num_heads
+ #qkv = qkv.reshape(-1, 3, H, qkv.shape[1] // H).permute(1, 0, 2, 3)
+ qkv = rearrange(qkv, 'N (S H Ch) -> S N H Ch', H=H, S=3)
+ q, k, v = qkv.unbind(dim=0)
+ # q, k, v: [N, H, C // H]
+ q_norm = self.norm(q)
+ k_norm = self.norm(k)
+
+ # qkv_norm: [3, N, H, C // H]
+ qkv_norm = torch.stack([q_norm, k_norm, v])
+ qkv_norm = rearrange(qkv_norm, 'S N H Ch -> N (S H Ch)')
+ return qkv_norm
+
+class SerializedAttention(PointModule):
+ def __init__(
+ self,
+ channels,
+ num_heads,
+ patch_size,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ order_index=0,
+ enable_rpe=False,
+ enable_flash=True,
+ upcast_attention=True,
+ upcast_softmax=True,
+ enable_qknorm=False,
+ ):
+ super().__init__()
+ assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
+ self.channels = channels
+ self.num_heads = num_heads
+ self.scale = qk_scale or (channels // num_heads) ** -0.5
+ self.order_index = order_index
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.enable_rpe = enable_rpe
+ self.enable_flash = enable_flash
+ self.enable_qknorm = enable_qknorm
+ if enable_qknorm:
+ self.qknorm = QueryKeyNorm(channels, num_heads)
+ else:
+ print("WARNING: enable_qknorm is False in PTv3Object and training may be fragile")
+ if enable_flash:
+ assert (
+ enable_rpe is False
+ ), "Set enable_rpe to False when enable Flash Attention"
+ assert (
+ upcast_attention is False
+ ), "Set upcast_attention to False when enable Flash Attention"
+ assert (
+ upcast_softmax is False
+ ), "Set upcast_softmax to False when enable Flash Attention"
+ assert flash_attn is not None, "Make sure flash_attn is installed."
+ self.patch_size = patch_size
+ self.attn_drop = attn_drop
+ else:
+ # when disable flash attention, we still don't want to use mask
+ # consequently, patch size will auto set to the
+ # min number of patch_size_max and number of points
+ self.patch_size_max = patch_size
+ self.patch_size = 0
+ self.attn_drop = torch.nn.Dropout(attn_drop)
+
+ self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
+ self.proj = torch.nn.Linear(channels, channels)
+ self.proj_drop = torch.nn.Dropout(proj_drop)
+ self.softmax = torch.nn.Softmax(dim=-1)
+ self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
+
+ @torch.no_grad()
+ def get_rel_pos(self, point, order):
+ K = self.patch_size
+ rel_pos_key = f"rel_pos_{self.order_index}"
+ if rel_pos_key not in point.keys():
+ grid_coord = point.grid_coord[order]
+ grid_coord = grid_coord.reshape(-1, K, 3)
+ point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
+ return point[rel_pos_key]
+
+ @torch.no_grad()
+ def get_padding_and_inverse(self, point):
+ pad_key = "pad"
+ unpad_key = "unpad"
+ cu_seqlens_key = "cu_seqlens_key"
+ if (
+ pad_key not in point.keys()
+ or unpad_key not in point.keys()
+ or cu_seqlens_key not in point.keys()
+ ):
+ offset = point.offset
+ bincount = offset2bincount(offset)
+ bincount_pad = (
+ torch.div(
+ bincount + self.patch_size - 1,
+ self.patch_size,
+ rounding_mode="trunc",
+ )
+ * self.patch_size
+ )
+ # only pad point when num of points larger than patch_size
+ mask_pad = bincount > self.patch_size
+ bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
+ _offset = nn.functional.pad(offset, (1, 0))
+ _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0))
+ pad = torch.arange(_offset_pad[-1], device=offset.device)
+ unpad = torch.arange(_offset[-1], device=offset.device)
+ cu_seqlens = []
+ for i in range(len(offset)):
+ unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
+ if bincount[i] != bincount_pad[i]:
+ pad[
+ _offset_pad[i + 1]
+ - self.patch_size
+ + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
+ ] = pad[
+ _offset_pad[i + 1]
+ - 2 * self.patch_size
+ + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
+ - self.patch_size
+ ]
+ pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
+ cu_seqlens.append(
+ torch.arange(
+ _offset_pad[i],
+ _offset_pad[i + 1],
+ step=self.patch_size,
+ dtype=torch.int32,
+ device=offset.device,
+ )
+ )
+ point[pad_key] = pad
+ point[unpad_key] = unpad
+ point[cu_seqlens_key] = nn.functional.pad(
+ torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]
+ )
+ return point[pad_key], point[unpad_key], point[cu_seqlens_key]
+
+ def forward(self, point):
+ if not self.enable_flash:
+ self.patch_size = min(
+ offset2bincount(point.offset).min().tolist(), self.patch_size_max
+ )
+
+ H = self.num_heads
+ K = self.patch_size
+ C = self.channels
+
+ pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)
+
+ order = point.serialized_order[self.order_index][pad]
+ inverse = unpad[point.serialized_inverse[self.order_index]]
+
+ # padding and reshape feat and batch for serialized point patch
+ qkv = self.qkv(point.feat)[order]
+ if self.enable_qknorm:
+ qkv = self.qknorm(qkv)
+
+ if not self.enable_flash:
+ # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
+ q, k, v = (
+ qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
+ )
+ # attn
+ if self.upcast_attention:
+ q = q.float()
+ k = k.float()
+ attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K)
+ if self.enable_rpe:
+ attn = attn + self.rpe(self.get_rel_pos(point, order))
+ if self.upcast_softmax:
+ attn = attn.float()
+ attn = self.softmax(attn)
+ attn = self.attn_drop(attn).to(qkv.dtype)
+ feat = (attn @ v).transpose(1, 2).reshape(-1, C)
+ else:
+ feat = flash_attn.flash_attn_varlen_qkvpacked_func(
+ qkv.half().reshape(-1, 3, H, C // H),
+ cu_seqlens,
+ max_seqlen=self.patch_size,
+ dropout_p=self.attn_drop if self.training else 0,
+ softmax_scale=self.scale,
+ ).reshape(-1, C)
+ feat = feat.to(qkv.dtype)
+ feat = feat[inverse]
+
+ # ffn
+ feat = self.proj(feat)
+ feat = self.proj_drop(feat)
+ point.feat = feat
+ return point
+
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ hidden_channels=None,
+ out_channels=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_channels = out_channels or in_channels
+ hidden_channels = hidden_channels or in_channels
+ self.fc1 = nn.Linear(in_channels, hidden_channels)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_channels, out_channels)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Block(PointModule):
+ def __init__(
+ self,
+ channels,
+ num_heads,
+ patch_size=48,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ act_layer=nn.GELU,
+ pre_norm=True,
+ order_index=0,
+ cpe_indice_key=None,
+ enable_rpe=False,
+ enable_flash=True,
+ upcast_attention=True,
+ upcast_softmax=True,
+ enable_qknorm=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.pre_norm = pre_norm
+
+ self.cpe = PointSequential(
+ spconv.SubMConv3d(
+ channels,
+ channels,
+ kernel_size=3,
+ bias=True,
+ indice_key=cpe_indice_key,
+ ),
+ nn.Linear(channels, channels),
+ norm_layer(channels),
+ )
+
+ self.norm1 = PointSequential(norm_layer(channels))
+ self.attn = SerializedAttention(
+ channels=channels,
+ patch_size=patch_size,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ order_index=order_index,
+ enable_rpe=enable_rpe,
+ enable_flash=enable_flash,
+ upcast_attention=upcast_attention,
+ upcast_softmax=upcast_softmax,
+ enable_qknorm=enable_qknorm,
+ )
+ self.norm2 = PointSequential(norm_layer(channels))
+ self.mlp = PointSequential(
+ MLP(
+ in_channels=channels,
+ hidden_channels=int(channels * mlp_ratio),
+ out_channels=channels,
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+ )
+ self.drop_path = PointSequential(
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ )
+
+ def forward(self, point: Point):
+ shortcut = point.feat
+ point = self.cpe(point)
+ point.feat = shortcut + point.feat
+ shortcut = point.feat
+ if self.pre_norm:
+ point = self.norm1(point)
+ point = self.drop_path(self.attn(point))
+ point.feat = shortcut + point.feat
+ if not self.pre_norm:
+ point = self.norm1(point)
+
+ shortcut = point.feat
+ if self.pre_norm:
+ point = self.norm2(point)
+ point = self.drop_path(self.mlp(point))
+ point.feat = shortcut + point.feat
+ if not self.pre_norm:
+ point = self.norm2(point)
+ # point.sparse_conv_feat.replace_feature(point.feat)
+ point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
+ return point
+
+
+class SerializedPooling(PointModule):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride=2,
+ norm_layer=None,
+ act_layer=None,
+ reduce="max",
+ shuffle_orders=True,
+ traceable=True, # record parent and cluster
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8
+ # TODO: add support to grid pool (any stride)
+ self.stride = stride
+ assert reduce in ["sum", "mean", "min", "max"]
+ self.reduce = reduce
+ self.shuffle_orders = shuffle_orders
+ self.traceable = traceable
+
+ self.proj = nn.Linear(in_channels, out_channels)
+ if norm_layer is not None:
+ self.norm = PointSequential(norm_layer(out_channels))
+ if act_layer is not None:
+ self.act = PointSequential(act_layer())
+
+ def forward(self, point: Point):
+ pooling_depth = (math.ceil(self.stride) - 1).bit_length()
+ if pooling_depth > point.serialized_depth:
+ pooling_depth = 0
+ assert {
+ "serialized_code",
+ "serialized_order",
+ "serialized_inverse",
+ "serialized_depth",
+ }.issubset(
+ point.keys()
+ ), "Run point.serialization() point cloud before SerializedPooling"
+
+ code = point.serialized_code >> pooling_depth * 3
+ code_, cluster, counts = torch.unique(
+ code[0],
+ sorted=True,
+ return_inverse=True,
+ return_counts=True,
+ )
+ # indices of point sorted by cluster, for torch_scatter.segment_csr
+ _, indices = torch.sort(cluster)
+ # index pointer for sorted point, for torch_scatter.segment_csr
+ idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
+ # head_indices of each cluster, for reduce attr e.g. code, batch
+ head_indices = indices[idx_ptr[:-1]]
+ # generate down code, order, inverse
+ code = code[:, head_indices]
+ order = torch.argsort(code)
+ inverse = torch.zeros_like(order).scatter_(
+ dim=1,
+ index=order,
+ src=torch.arange(0, code.shape[1], device=order.device).repeat(
+ code.shape[0], 1
+ ),
+ )
+
+ if self.shuffle_orders:
+ perm = torch.randperm(code.shape[0])
+ code = code[perm]
+ order = order[perm]
+ inverse = inverse[perm]
+
+ # collect information
+ point_dict = Dict(
+ feat=torch_scatter.segment_csr(
+ self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
+ ),
+ coord=torch_scatter.segment_csr(
+ point.coord[indices], idx_ptr, reduce="mean"
+ ),
+ grid_coord=point.grid_coord[head_indices] >> pooling_depth,
+ serialized_code=code,
+ serialized_order=order,
+ serialized_inverse=inverse,
+ serialized_depth=point.serialized_depth - pooling_depth,
+ batch=point.batch[head_indices],
+ )
+
+ if "condition" in point.keys():
+ point_dict["condition"] = point.condition
+ if "context" in point.keys():
+ point_dict["context"] = point.context
+
+ if self.traceable:
+ point_dict["pooling_inverse"] = cluster
+ point_dict["pooling_parent"] = point
+ point = Point(point_dict)
+ if self.norm is not None:
+ point = self.norm(point)
+ if self.act is not None:
+ point = self.act(point)
+ point.sparsify()
+ return point
+
+
+class SerializedUnpooling(PointModule):
+ def __init__(
+ self,
+ in_channels,
+ skip_channels,
+ out_channels,
+ norm_layer=None,
+ act_layer=None,
+ traceable=False, # record parent and cluster
+ ):
+ super().__init__()
+ self.proj = PointSequential(nn.Linear(in_channels, out_channels))
+ self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))
+
+ if norm_layer is not None:
+ self.proj.add(norm_layer(out_channels))
+ self.proj_skip.add(norm_layer(out_channels))
+
+ if act_layer is not None:
+ self.proj.add(act_layer())
+ self.proj_skip.add(act_layer())
+
+ self.traceable = traceable
+
+ def forward(self, point):
+ assert "pooling_parent" in point.keys()
+ assert "pooling_inverse" in point.keys()
+ parent = point.pop("pooling_parent")
+ inverse = point.pop("pooling_inverse")
+ point = self.proj(point)
+ parent = self.proj_skip(parent)
+ parent.feat = parent.feat + point.feat[inverse]
+
+ if self.traceable:
+ parent["unpooling_parent"] = point
+ return parent
+
+
+class Embedding(PointModule):
+ def __init__(
+ self,
+ in_channels,
+ embed_channels,
+ norm_layer=None,
+ act_layer=None,
+ res_linear=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.embed_channels = embed_channels
+
+ # TODO: check remove spconv
+ self.stem = PointSequential(
+ conv=spconv.SubMConv3d(
+ in_channels,
+ embed_channels,
+ kernel_size=5,
+ padding=1,
+ bias=False,
+ indice_key="stem",
+ )
+ )
+ if norm_layer is not None:
+ self.stem.add(norm_layer(embed_channels), name="norm")
+ if act_layer is not None:
+ self.stem.add(act_layer(), name="act")
+
+ if res_linear:
+ self.res_linear = nn.Linear(in_channels, embed_channels)
+ else:
+ self.res_linear = None
+
+ def forward(self, point: Point):
+ if self.res_linear:
+ res_feature = self.res_linear(point.feat)
+ point = self.stem(point)
+ if self.res_linear:
+ point.feat = point.feat + res_feature
+ point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
+ return point
+
+
+class PointTransformerV3Object(PointModule):
+ def __init__(
+ self,
+ in_channels=9,
+ order=("z", "z-trans", "hilbert", "hilbert-trans"),
+ stride=(),
+ enc_depths=(3, 3, 3, 6, 16),
+ enc_channels=(32, 64, 128, 256, 384),
+ enc_num_head=(2, 4, 8, 16, 24),
+ enc_patch_size=(1024, 1024, 1024, 1024, 1024),
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ drop_path=0.0,
+ pre_norm=True,
+ shuffle_orders=True,
+ enable_rpe=False,
+ enable_flash=True,
+ upcast_attention=False,
+ upcast_softmax=False,
+ cls_mode=False,
+ enable_qknorm=False,
+ layer_norm=False,
+ res_linear=True,
+ ):
+ super().__init__()
+ self.num_stages = len(enc_depths)
+ self.order = [order] if isinstance(order, str) else order
+ self.cls_mode = cls_mode
+ self.shuffle_orders = shuffle_orders
+
+ # norm layers
+ if layer_norm:
+ bn_layer = partial(nn.LayerNorm)
+ else:
+ print("WARNING: use BatchNorm in ptv3obj !!!")
+ bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
+ ln_layer = nn.LayerNorm
+ # activation layers
+ act_layer = nn.GELU
+
+ self.embedding = Embedding(
+ in_channels=in_channels,
+ embed_channels=enc_channels[0],
+ norm_layer=bn_layer,
+ act_layer=act_layer,
+ res_linear=res_linear,
+ )
+
+ # encoder
+ enc_drop_path = [
+ x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
+ ]
+ self.enc = PointSequential()
+ for s in range(self.num_stages):
+ enc_drop_path_ = enc_drop_path[
+ sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
+ ]
+ enc = PointSequential()
+ if s > 0:
+ enc.add(nn.Linear(enc_channels[s - 1], enc_channels[s]))
+
+ for i in range(enc_depths[s]):
+ enc.add(
+ Block(
+ channels=enc_channels[s],
+ num_heads=enc_num_head[s],
+ patch_size=enc_patch_size[s],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ drop_path=enc_drop_path_[i],
+ norm_layer=ln_layer,
+ act_layer=act_layer,
+ pre_norm=pre_norm,
+ order_index=i % len(self.order),
+ cpe_indice_key=f"stage{s}",
+ enable_rpe=enable_rpe,
+ enable_flash=enable_flash,
+ upcast_attention=upcast_attention,
+ upcast_softmax=upcast_softmax,
+ enable_qknorm=enable_qknorm,
+ ),
+ name=f"block{i}",
+ )
+ if len(enc) != 0:
+ self.enc.add(module=enc, name=f"enc{s}")
+
+
+ def forward(self, data_dict, min_coord=None):
+ point = Point(data_dict)
+ point.serialization(order=self.order, shuffle_orders=self.shuffle_orders, min_coord=min_coord)
+ point.sparsify()
+ point = self.embedding(point)
+ point = self.enc(point)
+ return point
+
+def get_encoder(pretrained_path: Union[str, None]=None, freeze_encoder: bool=False, **kwargs) -> PointTransformerV3Object:
+ point_encoder = PointTransformerV3Object(**kwargs)
+ if pretrained_path is not None:
+ checkpoint = torch.load(pretrained_path)
+ state_dict = checkpoint["state_dict"]
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
+ point_encoder.load_state_dict(state_dict, strict=False)
+ if freeze_encoder is True:
+ for name, param in point_encoder.named_parameters():
+ if 'res_linear' not in name and 'qknorm' not in name:
+ param.requires_grad = False
+ return point_encoder
\ No newline at end of file
diff --git a/UniRig/src/model/pointcept/models/__init__.py b/UniRig/src/model/pointcept/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7027e601f513c17e2e5c36a6491e928dbfe4b1a
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/__init__.py
@@ -0,0 +1 @@
+from .PTv3Object import PointTransformerV3Object
diff --git a/UniRig/src/model/pointcept/models/modules.py b/UniRig/src/model/pointcept/models/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fb77634cf5f54759ee301d6cffb966ca0076396
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/modules.py
@@ -0,0 +1,83 @@
+import sys
+import torch.nn as nn
+import spconv.pytorch as spconv
+from collections import OrderedDict
+from .utils.structure import Point
+
+
+class PointModule(nn.Module):
+ r"""PointModule
+ placeholder, all module subclass from this will take Point in PointSequential.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+
+class PointSequential(PointModule):
+ r"""A sequential container.
+ Modules will be added to it in the order they are passed in the constructor.
+ Alternatively, an ordered dict of modules can also be passed in.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
+ for key, module in args[0].items():
+ self.add_module(key, module)
+ else:
+ for idx, module in enumerate(args):
+ self.add_module(str(idx), module)
+ for name, module in kwargs.items():
+ if sys.version_info < (3, 6):
+ raise ValueError("kwargs only supported in py36+")
+ if name in self._modules:
+ raise ValueError("name exists.")
+ self.add_module(name, module)
+
+ def __getitem__(self, idx):
+ if not (-len(self) <= idx < len(self)):
+ raise IndexError("index {} is out of range".format(idx))
+ if idx < 0:
+ idx += len(self)
+ it = iter(self._modules.values())
+ for i in range(idx):
+ next(it)
+ return next(it)
+
+ def __len__(self):
+ return len(self._modules)
+
+ def add(self, module, name=None):
+ if name is None:
+ name = str(len(self._modules))
+ if name in self._modules:
+ raise KeyError("name exists")
+ self.add_module(name, module)
+
+ def forward(self, input):
+ for k, module in self._modules.items():
+ # Point module
+ if isinstance(module, PointModule):
+ input = module(input)
+ # Spconv module
+ elif spconv.modules.is_spconv_module(module):
+ if isinstance(input, Point):
+ input.sparse_conv_feat = module(input.sparse_conv_feat)
+ input.feat = input.sparse_conv_feat.features
+ else:
+ input = module(input)
+ # PyTorch module
+ else:
+ if isinstance(input, Point):
+ input.feat = module(input.feat)
+ if "sparse_conv_feat" in input.keys():
+ input.sparse_conv_feat = input.sparse_conv_feat.replace_feature(
+ input.feat
+ )
+ elif isinstance(input, spconv.SparseConvTensor):
+ if input.indices.shape[0] != 0:
+ input = input.replace_feature(module(input.features))
+ else:
+ input = module(input)
+ return input
diff --git a/UniRig/src/model/pointcept/models/utils/__init__.py b/UniRig/src/model/pointcept/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..66e6bc0f62993abb3625a9598f54e7775aeb0008
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/utils/__init__.py
@@ -0,0 +1,4 @@
+from .misc import offset2batch, offset2bincount, batch2offset, off_diagonal
+from .checkpoint import checkpoint
+from .serialization import encode, decode
+from .structure import Point
diff --git a/UniRig/src/model/pointcept/models/utils/checkpoint.py b/UniRig/src/model/pointcept/models/utils/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..09b6d307ccf5b6d79135eec9274c9c8a8c29432e
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/utils/checkpoint.py
@@ -0,0 +1,50 @@
+import torch
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
diff --git a/UniRig/src/model/pointcept/models/utils/misc.py b/UniRig/src/model/pointcept/models/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb1b282f480f23cba30644edf35477a9358d0e11
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/utils/misc.py
@@ -0,0 +1,28 @@
+import torch
+
+
+@torch.inference_mode()
+def offset2bincount(offset):
+ return torch.diff(
+ offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)
+ )
+
+
+@torch.inference_mode()
+def offset2batch(offset):
+ bincount = offset2bincount(offset)
+ return torch.arange(
+ len(bincount), device=offset.device, dtype=torch.long
+ ).repeat_interleave(bincount)
+
+
+@torch.inference_mode()
+def batch2offset(batch):
+ return torch.cumsum(batch.bincount(), dim=0).long()
+
+
+def off_diagonal(x):
+ # return a flattened view of the off-diagonal elements of a square matrix
+ n, m = x.shape
+ assert n == m
+ return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
diff --git a/UniRig/src/model/pointcept/models/utils/serialization/__init__.py b/UniRig/src/model/pointcept/models/utils/serialization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..058c5e1001c76d9c7014bf0bbb824eec4f54f476
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/utils/serialization/__init__.py
@@ -0,0 +1,8 @@
+from .default import (
+ encode,
+ decode,
+ z_order_encode,
+ z_order_decode,
+ hilbert_encode,
+ hilbert_decode,
+)
diff --git a/UniRig/src/model/pointcept/models/utils/serialization/default.py b/UniRig/src/model/pointcept/models/utils/serialization/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..15898b55625fc0e1125db9b713e900892f04176c
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/utils/serialization/default.py
@@ -0,0 +1,59 @@
+import torch
+from .z_order import xyz2key as z_order_encode_
+from .z_order import key2xyz as z_order_decode_
+from .hilbert import encode as hilbert_encode_
+from .hilbert import decode as hilbert_decode_
+
+
+@torch.inference_mode()
+def encode(grid_coord, batch=None, depth=16, order="z"):
+ assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
+ if order == "z":
+ code = z_order_encode(grid_coord, depth=depth)
+ elif order == "z-trans":
+ code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
+ elif order == "hilbert":
+ code = hilbert_encode(grid_coord, depth=depth)
+ elif order == "hilbert-trans":
+ code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
+ else:
+ raise NotImplementedError
+ if batch is not None:
+ batch = batch.long()
+ code = batch << depth * 3 | code
+ return code
+
+
+@torch.inference_mode()
+def decode(code, depth=16, order="z"):
+ assert order in {"z", "hilbert"}
+ batch = code >> depth * 3
+ code = code & ((1 << depth * 3) - 1)
+ if order == "z":
+ grid_coord = z_order_decode(code, depth=depth)
+ elif order == "hilbert":
+ grid_coord = hilbert_decode(code, depth=depth)
+ else:
+ raise NotImplementedError
+ return grid_coord, batch
+
+
+def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
+ x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
+ # we block the support to batch, maintain batched code in Point class
+ code = z_order_encode_(x, y, z, b=None, depth=depth)
+ return code
+
+
+def z_order_decode(code: torch.Tensor, depth):
+ x, y, z = z_order_decode_(code, depth=depth)
+ grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
+ return grid_coord
+
+
+def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
+ return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
+
+
+def hilbert_decode(code: torch.Tensor, depth: int = 16):
+ return hilbert_decode_(code, num_dims=3, num_bits=depth)
diff --git a/UniRig/src/model/pointcept/models/utils/serialization/hilbert.py b/UniRig/src/model/pointcept/models/utils/serialization/hilbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..79356577e12179a35341af9b430a1ba8e238b832
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/utils/serialization/hilbert.py
@@ -0,0 +1,295 @@
+import torch
+
+
+def right_shift(binary, k=1, axis=-1):
+ """Right shift an array of binary values.
+
+ Parameters:
+ -----------
+ binary: An ndarray of binary values.
+
+ k: The number of bits to shift. Default 1.
+
+ axis: The axis along which to shift. Default -1.
+
+ Returns:
+ --------
+ Returns an ndarray with zero prepended and the ends truncated, along
+ whatever axis was specified."""
+
+ # If we're shifting the whole thing, just return zeros.
+ if binary.shape[axis] <= k:
+ return torch.zeros_like(binary)
+
+ # Determine the padding pattern.
+ # padding = [(0,0)] * len(binary.shape)
+ # padding[axis] = (k,0)
+
+ # Determine the slicing pattern to eliminate just the last one.
+ slicing = [slice(None)] * len(binary.shape)
+ slicing[axis] = slice(None, -k)
+ shifted = torch.nn.functional.pad(
+ binary[tuple(slicing)], (k, 0), mode="constant", value=0
+ )
+
+ return shifted
+
+
+def binary2gray(binary, axis=-1):
+ """Convert an array of binary values into Gray codes.
+
+ This uses the classic X ^ (X >> 1) trick to compute the Gray code.
+
+ Parameters:
+ -----------
+ binary: An ndarray of binary values.
+
+ axis: The axis along which to compute the gray code. Default=-1.
+
+ Returns:
+ --------
+ Returns an ndarray of Gray codes.
+ """
+ shifted = right_shift(binary, axis=axis)
+
+ # Do the X ^ (X >> 1) trick.
+ gray = torch.logical_xor(binary, shifted)
+
+ return gray
+
+
+def gray2binary(gray, axis=-1):
+ """Convert an array of Gray codes back into binary values.
+
+ Parameters:
+ -----------
+ gray: An ndarray of gray codes.
+
+ axis: The axis along which to perform Gray decoding. Default=-1.
+
+ Returns:
+ --------
+ Returns an ndarray of binary values.
+ """
+
+ # Loop the log2(bits) number of times necessary, with shift and xor.
+ shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
+ while shift > 0:
+ gray = torch.logical_xor(gray, right_shift(gray, shift))
+ shift = torch.div(shift, 2, rounding_mode="floor")
+ return gray
+
+
+def encode(locs, num_dims, num_bits):
+ """Decode an array of locations in a hypercube into a Hilbert integer.
+
+ This is a vectorized-ish version of the Hilbert curve implementation by John
+ Skilling as described in:
+
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
+
+ Params:
+ -------
+ locs - An ndarray of locations in a hypercube of num_dims dimensions, in
+ which each dimension runs from 0 to 2**num_bits-1. The shape can
+ be arbitrary, as long as the last dimension of the same has size
+ num_dims.
+
+ num_dims - The dimensionality of the hypercube. Integer.
+
+ num_bits - The number of bits for each dimension. Integer.
+
+ Returns:
+ --------
+ The output is an ndarray of uint64 integers with the same shape as the
+ input, excluding the last dimension, which needs to be num_dims.
+ """
+
+ # Keep around the original shape for later.
+ orig_shape = locs.shape
+ bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
+ bitpack_mask_rev = bitpack_mask.flip(-1)
+
+ if orig_shape[-1] != num_dims:
+ raise ValueError(
+ """
+ The shape of locs was surprising in that the last dimension was of size
+ %d, but num_dims=%d. These need to be equal.
+ """
+ % (orig_shape[-1], num_dims)
+ )
+
+ if num_dims * num_bits > 63:
+ raise ValueError(
+ """
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
+ into a int64. Are you sure you need that many points on your Hilbert
+ curve?
+ """
+ % (num_dims, num_bits, num_dims * num_bits)
+ )
+
+ # Treat the location integers as 64-bit unsigned and then split them up into
+ # a sequence of uint8s. Preserve the association by dimension.
+ locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
+
+ # Now turn these into bits and truncate to num_bits.
+ gray = (
+ locs_uint8.unsqueeze(-1)
+ .bitwise_and(bitpack_mask_rev)
+ .ne(0)
+ .byte()
+ .flatten(-2, -1)[..., -num_bits:]
+ )
+
+ # Run the decoding process the other way.
+ # Iterate forwards through the bits.
+ for bit in range(0, num_bits):
+ # Iterate forwards through the dimensions.
+ for dim in range(0, num_dims):
+ # Identify which ones have this bit active.
+ mask = gray[:, dim, bit]
+
+ # Where this bit is on, invert the 0 dimension for lower bits.
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
+ gray[:, 0, bit + 1 :], mask[:, None]
+ )
+
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
+ to_flip = torch.logical_and(
+ torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
+ )
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
+ gray[:, dim, bit + 1 :], to_flip
+ )
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
+
+ # Now flatten out.
+ gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
+
+ # Convert Gray back to binary.
+ hh_bin = gray2binary(gray)
+
+ # Pad back out to 64 bits.
+ extra_dims = 64 - num_bits * num_dims
+ padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
+
+ # Convert binary values into uint8s.
+ hh_uint8 = (
+ (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
+ .sum(2)
+ .squeeze()
+ .type(torch.uint8)
+ )
+
+ # Convert uint8s into uint64s.
+ hh_uint64 = hh_uint8.view(torch.int64).squeeze()
+
+ return hh_uint64
+
+
+def decode(hilberts, num_dims, num_bits):
+ """Decode an array of Hilbert integers into locations in a hypercube.
+
+ This is a vectorized-ish version of the Hilbert curve implementation by John
+ Skilling as described in:
+
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
+
+ Params:
+ -------
+ hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
+ cannot have fewer bits than num_dims * num_bits.
+
+ num_dims - The dimensionality of the hypercube. Integer.
+
+ num_bits - The number of bits for each dimension. Integer.
+
+ Returns:
+ --------
+ The output is an ndarray of unsigned integers with the same shape as hilberts
+ but with an additional dimension of size num_dims.
+ """
+
+ if num_dims * num_bits > 64:
+ raise ValueError(
+ """
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
+ into a uint64. Are you sure you need that many points on your Hilbert
+ curve?
+ """
+ % (num_dims, num_bits)
+ )
+
+ # Handle the case where we got handed a naked integer.
+ hilberts = torch.atleast_1d(hilberts)
+
+ # Keep around the shape for later.
+ orig_shape = hilberts.shape
+ bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
+ bitpack_mask_rev = bitpack_mask.flip(-1)
+
+ # Treat each of the hilberts as a s equence of eight uint8.
+ # This treats all of the inputs as uint64 and makes things uniform.
+ hh_uint8 = (
+ hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
+ )
+
+ # Turn these lists of uints into lists of bits and then truncate to the size
+ # we actually need for using Skilling's procedure.
+ hh_bits = (
+ hh_uint8.unsqueeze(-1)
+ .bitwise_and(bitpack_mask_rev)
+ .ne(0)
+ .byte()
+ .flatten(-2, -1)[:, -num_dims * num_bits :]
+ )
+
+ # Take the sequence of bits and Gray-code it.
+ gray = binary2gray(hh_bits)
+
+ # There has got to be a better way to do this.
+ # I could index them differently, but the eventual packbits likes it this way.
+ gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
+
+ # Iterate backwards through the bits.
+ for bit in range(num_bits - 1, -1, -1):
+ # Iterate backwards through the dimensions.
+ for dim in range(num_dims - 1, -1, -1):
+ # Identify which ones have this bit active.
+ mask = gray[:, dim, bit]
+
+ # Where this bit is on, invert the 0 dimension for lower bits.
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
+ gray[:, 0, bit + 1 :], mask[:, None]
+ )
+
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
+ to_flip = torch.logical_and(
+ torch.logical_not(mask[:, None]),
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
+ )
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
+ gray[:, dim, bit + 1 :], to_flip
+ )
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
+
+ # Pad back out to 64 bits.
+ extra_dims = 64 - num_bits
+ padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
+
+ # Now chop these up into blocks of 8.
+ locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
+
+ # Take those blocks and turn them unto uint8s.
+ # from IPython import embed; embed()
+ locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
+
+ # Finally, treat these as uint64s.
+ flat_locs = locs_uint8.view(torch.int64)
+
+ # Return them in the expected shape.
+ return flat_locs.reshape((*orig_shape, num_dims))
diff --git a/UniRig/src/model/pointcept/models/utils/serialization/z_order.py b/UniRig/src/model/pointcept/models/utils/serialization/z_order.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0237c87c48c32d07fe96d9ecc72d91e1c464099
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/utils/serialization/z_order.py
@@ -0,0 +1,119 @@
+import torch
+from typing import Optional, Union
+
+
+class KeyLUT:
+ def __init__(self):
+ r256 = torch.arange(256, dtype=torch.int64)
+ r512 = torch.arange(512, dtype=torch.int64)
+ zero = torch.zeros(256, dtype=torch.int64)
+ device = torch.device("cpu")
+
+ self._encode = {
+ device: (
+ self.xyz2key(r256, zero, zero, 8),
+ self.xyz2key(zero, r256, zero, 8),
+ self.xyz2key(zero, zero, r256, 8),
+ )
+ }
+ self._decode = {device: self.key2xyz(r512, 9)}
+
+ def encode_lut(self, device=torch.device("cpu")):
+ if device not in self._encode:
+ cpu = torch.device("cpu")
+ self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
+ return self._encode[device]
+
+ def decode_lut(self, device=torch.device("cpu")):
+ if device not in self._decode:
+ cpu = torch.device("cpu")
+ self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
+ return self._decode[device]
+
+ def xyz2key(self, x, y, z, depth):
+ key = torch.zeros_like(x)
+ for i in range(depth):
+ mask = 1 << i
+ key = (
+ key
+ | ((x & mask) << (2 * i + 2))
+ | ((y & mask) << (2 * i + 1))
+ | ((z & mask) << (2 * i + 0))
+ )
+ return key
+
+ def key2xyz(self, key, depth):
+ x = torch.zeros_like(key)
+ y = torch.zeros_like(key)
+ z = torch.zeros_like(key)
+ for i in range(depth):
+ x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
+ y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
+ z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
+ return x, y, z
+
+
+_key_lut = KeyLUT()
+
+
+def xyz2key(
+ x: torch.Tensor,
+ y: torch.Tensor,
+ z: torch.Tensor,
+ b: Optional[Union[torch.Tensor, int]] = None,
+ depth: int = 16,
+):
+ r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
+ based on pre-computed look up tables. The speed of this function is much
+ faster than the method based on for-loop.
+
+ Args:
+ x (torch.Tensor): The x coordinate.
+ y (torch.Tensor): The y coordinate.
+ z (torch.Tensor): The z coordinate.
+ b (torch.Tensor or int): The batch index of the coordinates, and should be
+ smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
+ :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
+ """
+
+ EX, EY, EZ = _key_lut.encode_lut(x.device)
+ x, y, z = x.long(), y.long(), z.long()
+
+ mask = 255 if depth > 8 else (1 << depth) - 1
+ key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
+ if depth > 8:
+ mask = (1 << (depth - 8)) - 1
+ key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
+ key = key16 << 24 | key
+
+ if b is not None:
+ b = b.long()
+ key = b << 48 | key
+
+ return key
+
+
+def key2xyz(key: torch.Tensor, depth: int = 16):
+ r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
+ and the batch index based on pre-computed look up tables.
+
+ Args:
+ key (torch.Tensor): The shuffled key.
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
+ """
+
+ DX, DY, DZ = _key_lut.decode_lut(key.device)
+ x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
+
+ b = key >> 48
+ key = key & ((1 << 48) - 1)
+
+ n = (depth + 2) // 3
+ for i in range(n):
+ k = key >> (i * 9) & 511
+ x = x | (DX[k] << (i * 3))
+ y = y | (DY[k] << (i * 3))
+ z = z | (DZ[k] << (i * 3))
+
+ return x, y, z, b
diff --git a/UniRig/src/model/pointcept/models/utils/structure.py b/UniRig/src/model/pointcept/models/utils/structure.py
new file mode 100644
index 0000000000000000000000000000000000000000..afef3d7dbe2c3118e67d8eaa4a7e2e5198cb6a5e
--- /dev/null
+++ b/UniRig/src/model/pointcept/models/utils/structure.py
@@ -0,0 +1,210 @@
+import torch
+import spconv.pytorch as spconv
+
+try:
+ import ocnn
+except ImportError:
+ ocnn = None
+from addict import Dict
+
+from .serialization import encode, decode
+from ..utils import offset2batch, batch2offset
+import torch_scatter
+
+class Point(Dict):
+ """
+ Point Structure of Pointcept
+
+ A Point (point cloud) in Pointcept is a dictionary that contains various properties of
+ a batched point cloud. The property with the following names have a specific definition
+ as follows:
+
+ - "coord": original coordinate of point cloud;
+ - "grid_coord": grid coordinate for specific grid size (related to GridSampling);
+ Point also support the following optional attributes:
+ - "offset": if not exist, initialized as batch size is 1;
+ - "batch": if not exist, initialized as batch size is 1;
+ - "feat": feature of point cloud, default input of model;
+ - "grid_size": Grid size of point cloud (related to GridSampling);
+ (related to Serialization)
+ - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;
+ - "serialized_code": a list of serialization codes;
+ - "serialized_order": a list of serialization order determined by code;
+ - "serialized_inverse": a list of inverse mapping determined by code;
+ (related to Sparsify: SpConv)
+ - "sparse_shape": Sparse shape for Sparse Conv Tensor;
+ - "sparse_conv_feat": SparseConvTensor init with information provide by Point;
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # If one of "offset" or "batch" do not exist, generate by the existing one
+ if "batch" not in self.keys() and "offset" in self.keys():
+ self["batch"] = offset2batch(self.offset)
+ elif "offset" not in self.keys() and "batch" in self.keys():
+ self["offset"] = batch2offset(self.batch)
+
+ def serialization(self, order="z", depth=None, shuffle_orders=False, min_coord=None):
+ """
+ Point Cloud Serialization
+
+ relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
+ """
+ assert "batch" in self.keys()
+ # if "grid_coord" not in self.keys():
+ # # if you don't want to operate GridSampling in data augmentation,
+ # # please add the following augmentation into your pipline:
+ # # dict(type="Copy", keys_dict={"grid_size": 0.01}),
+ # # (adjust `grid_size` to what your want)
+ # assert {"grid_size", "coord"}.issubset(self.keys())
+ # self["grid_coord"] = torch.div(
+ # self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
+ # ).int()
+ if "grid_coord" not in self.keys():
+ # if you don't want to operate GridSampling in data augmentation,
+ # please add the following augmentation into your pipline:
+ # dict(type="Copy", keys_dict={"grid_size": 0.01}),
+ # (adjust `grid_size` to what your want)
+ assert {"grid_size", "coord"}.issubset(self.keys())
+ idx_ptr = torch.nn.functional.pad(self.offset, (1, 0), value=0)
+ if min_coord is None:
+ min_coord = torch_scatter.segment_csr(self.coord, idx_ptr, reduce="min")
+ self["grid_coord"] = torch.div(
+ self.coord - min_coord[self.batch],
+ self.grid_size,
+ rounding_mode="trunc",
+ ).int()
+
+ # print(self.grid_coord.max())
+ # print(int(self.grid_coord.max()).bit_length())
+
+ if depth is None:
+ # Adaptive measure the depth of serialization cube (length = 2 ^ depth)
+ depth = int(self.grid_coord.max()).bit_length()
+ self["serialized_depth"] = depth
+ # Maximum bit length for serialization code is 63 (int64)
+ assert depth * 3 + len(self.offset).bit_length() <= 63
+ # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.
+ # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3
+ # cube with a grid size of 0.01 meter. We consider it is enough for the current stage.
+ # We can unlock the limitation by optimizing the z-order encoding function if necessary.
+ assert depth <= 16
+
+ # The serialization codes are arranged as following structures:
+ # [Order1 ([n]),
+ # Order2 ([n]),
+ # ...
+ # OrderN ([n])] (k, n)
+ code = [
+ encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order
+ ]
+ code = torch.stack(code)
+ order = torch.argsort(code)
+ inverse = torch.zeros_like(order).scatter_(
+ dim=1,
+ index=order,
+ src=torch.arange(0, code.shape[1], device=order.device).repeat(
+ code.shape[0], 1
+ ),
+ )
+
+ if shuffle_orders:
+ perm = torch.randperm(code.shape[0])
+ code = code[perm]
+ order = order[perm]
+ inverse = inverse[perm]
+
+ self["serialized_code"] = code
+ self["serialized_order"] = order
+ self["serialized_inverse"] = inverse
+
+ def sparsify(self, pad=96):
+ """
+ Point Cloud Serialization
+
+ Point cloud is sparse, here we use "sparsify" to specifically refer to
+ preparing "spconv.SparseConvTensor" for SpConv.
+
+ relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
+
+ pad: padding sparse for sparse shape.
+ """
+ assert {"feat", "batch"}.issubset(self.keys())
+ # if "grid_coord" not in self.keys():
+ # # if you don't want to operate GridSampling in data augmentation,
+ # # please add the following augmentation into your pipline:
+ # # dict(type="Copy", keys_dict={"grid_size": 0.01}),
+ # # (adjust `grid_size` to what your want)
+ # assert {"grid_size", "coord"}.issubset(self.keys())
+ # self["grid_coord"] = torch.div(
+ # self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
+ # ).int()
+ if "grid_coord" not in self.keys():
+ # if you don't want to operate GridSampling in data augmentation,
+ # please add the following augmentation into your pipline:
+ # dict(type="Copy", keys_dict={"grid_size": 0.01}),
+ # (adjust `grid_size` to what your want)
+ assert {"grid_size", "coord"}.issubset(self.keys())
+ idx_ptr = torch.nn.functional.pad(self.offset, (1, 0), value=0)
+ min_coord = torch_scatter.segment_csr(self.coord, idx_ptr, reduce="min")
+ self["grid_coord"] = torch.div(
+ self.coord - min_coord[self.batch],
+ self.grid_size,
+ rounding_mode="trunc",
+ ).int()
+ if "sparse_shape" in self.keys():
+ sparse_shape = self.sparse_shape
+ else:
+ sparse_shape = torch.add(
+ torch.max(self.grid_coord, dim=0).values, pad
+ ).tolist()
+ sparse_conv_feat = spconv.SparseConvTensor(
+ features=self.feat,
+ indices=torch.cat(
+ [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1
+ ).contiguous(),
+ spatial_shape=sparse_shape,
+ batch_size=self.batch[-1].tolist() + 1,
+ )
+ self["sparse_shape"] = sparse_shape
+ self["sparse_conv_feat"] = sparse_conv_feat
+
+ def octreetization(self, depth=None, full_depth=None):
+ """
+ Point Cloud Octreelization
+
+ Generate octree with OCNN
+ relay on ["grid_coord", "batch", "feat"]
+ """
+ assert (
+ ocnn is not None
+ ), "Please follow https://github.com/octree-nn/ocnn-pytorch install ocnn."
+ assert {"grid_coord", "feat", "batch"}.issubset(self.keys())
+ # add 1 to make grid space support shift order
+ if depth is None:
+ if "depth" in self.keys():
+ depth = self.depth
+ else:
+ depth = int(self.grid_coord.max() + 1).bit_length()
+ if full_depth is None:
+ full_depth = 2
+ self["depth"] = depth
+ assert depth <= 16 # maximum in ocnn
+
+ # [0, 2**depth] -> [0, 2] -> [-1, 1]
+ coord = self.grid_coord / 2 ** (self.depth - 1) - 1.0
+ point = ocnn.octree.Points(
+ points=coord,
+ features=self.feat,
+ batch_id=self.batch.unsqueeze(-1),
+ batch_size=self.batch[-1] + 1,
+ )
+ octree = ocnn.octree.Octree(
+ depth=depth,
+ full_depth=full_depth,
+ batch_size=self.batch[-1] + 1,
+ device=coord.device,
+ )
+ octree.build_octree(point)
+ octree.construct_all_neigh()
+ self["octree"] = octree
diff --git a/UniRig/src/model/pointcept/utils/__init__.py b/UniRig/src/model/pointcept/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/UniRig/src/model/pointcept/utils/cache.py b/UniRig/src/model/pointcept/utils/cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..a32f06c644b835925d165bca7060f8b185cf9c54
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/cache.py
@@ -0,0 +1,56 @@
+"""
+Data Cache Utils
+
+Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
+Please cite our work if the code is helpful to you.
+"""
+
+import os
+# import SharedArray
+
+try:
+ from multiprocessing.shared_memory import ShareableList
+except ImportError:
+ import warnings
+
+ warnings.warn("Please update python version >= 3.8 to enable shared_memory")
+import numpy as np
+
+
+def shared_array(name, var=None):
+ if var is not None:
+ # check exist
+ if os.path.exists(f"/dev/shm/{name}"):
+ return SharedArray.attach(f"shm://{name}")
+ # create shared_array
+ data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype)
+ data[...] = var[...]
+ data.flags.writeable = False
+ else:
+ data = SharedArray.attach(f"shm://{name}").copy()
+ return data
+
+
+def shared_dict(name, var=None):
+ name = str(name)
+ assert "." not in name # '.' is used as sep flag
+ data = {}
+ if var is not None:
+ assert isinstance(var, dict)
+ keys = var.keys()
+ # current version only cache np.array
+ keys_valid = []
+ for key in keys:
+ if isinstance(var[key], np.ndarray):
+ keys_valid.append(key)
+ keys = keys_valid
+
+ ShareableList(sequence=keys, name=name + ".keys")
+ for key in keys:
+ if isinstance(var[key], np.ndarray):
+ data[key] = shared_array(name=f"{name}.{key}", var=var[key])
+ else:
+ keys = list(ShareableList(name=name + ".keys"))
+ for key in keys:
+ data[key] = shared_array(name=f"{name}.{key}")
+ return data
diff --git a/UniRig/src/model/pointcept/utils/comm.py b/UniRig/src/model/pointcept/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..69e29e7c690fe0500d3d9a84b6a8749e2f4f4655
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/comm.py
@@ -0,0 +1,198 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+"""
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+Modified from detectron2(https://github.com/facebookresearch/detectron2)
+
+Copyright (c) Xiaoyang Wu (xiaoyang.wu@connect.hku.hk). All Rights Reserved.
+Please cite our work if you use any part of the code.
+"""
+
+import functools
+import numpy as np
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+"""
+A torch process group which only includes processes that on the same machine as the current process.
+This variable is set when processes are spawned by `launch()` in "engine/launch.py".
+"""
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ assert (
+ _LOCAL_PROCESS_GROUP is not None
+ ), "Local process group is not created! Please use launch() to spawn processes!"
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ if dist.get_backend() == dist.Backend.NCCL:
+ # This argument is needed to avoid warnings.
+ # It's valid only for NCCL backend.
+ dist.barrier(device_ids=[torch.cuda.current_device()])
+ else:
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = (
+ _get_global_gloo_group()
+ ) # use CPU group by default, to reduce GPU RAM usage.
+ world_size = dist.get_world_size(group)
+ if world_size == 1:
+ return [data]
+
+ output = [None for _ in range(world_size)]
+ dist.all_gather_object(output, data, group=group)
+ return output
+
+
+def gather(data, dst=0, group=None):
+ """
+ Run gather on arbitrary picklable data (not necessarily tensors).
+ Args:
+ data: any picklable object
+ dst (int): destination rank
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+ Returns:
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
+ an empty list.
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ world_size = dist.get_world_size(group=group)
+ if world_size == 1:
+ return [data]
+ rank = dist.get_rank(group=group)
+
+ if rank == dst:
+ output = [None for _ in range(world_size)]
+ dist.gather_object(data, output, dst=dst, group=group)
+ return output
+ else:
+ dist.gather_object(data, None, dst=dst, group=group)
+ return []
+
+
+def shared_random_seed():
+ """
+ Returns:
+ int: a random number that is the same across all workers.
+ If workers need a shared RNG, they can use this shared seed to
+ create one.
+ All workers must call this function, otherwise it will deadlock.
+ """
+ ints = np.random.randint(2**31)
+ all_ints = all_gather(ints)
+ return all_ints[0]
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+ Args:
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ average (bool): whether to do average or sum
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/UniRig/src/model/pointcept/utils/config.py b/UniRig/src/model/pointcept/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0388caf896a88189ff8497cca5ac71b1bd37034f
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/config.py
@@ -0,0 +1,695 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import ast
+import copy
+import os
+import os.path as osp
+import platform
+import shutil
+import sys
+import tempfile
+import uuid
+import warnings
+from argparse import Action, ArgumentParser
+from collections import abc
+from importlib import import_module
+
+from addict import Dict
+from yapf.yapflib.yapf_api import FormatCode
+
+from .misc import import_modules_from_strings
+from .path import check_file_exist
+
+if platform.system() == "Windows":
+ import regex as re
+else:
+ import re
+
+BASE_KEY = "_base_"
+DELETE_KEY = "_delete_"
+DEPRECATION_KEY = "_deprecation_"
+RESERVED_KEYS = ["filename", "text", "pretty_text"]
+
+
+class ConfigDict(Dict):
+ def __missing__(self, name):
+ raise KeyError(name)
+
+ def __getattr__(self, name):
+ try:
+ value = super(ConfigDict, self).__getattr__(name)
+ except KeyError:
+ ex = AttributeError(
+ f"'{self.__class__.__name__}' object has no " f"attribute '{name}'"
+ )
+ except Exception as e:
+ ex = e
+ else:
+ return value
+ raise ex
+
+
+def add_args(parser, cfg, prefix=""):
+ for k, v in cfg.items():
+ if isinstance(v, str):
+ parser.add_argument("--" + prefix + k)
+ elif isinstance(v, int):
+ parser.add_argument("--" + prefix + k, type=int)
+ elif isinstance(v, float):
+ parser.add_argument("--" + prefix + k, type=float)
+ elif isinstance(v, bool):
+ parser.add_argument("--" + prefix + k, action="store_true")
+ elif isinstance(v, dict):
+ add_args(parser, v, prefix + k + ".")
+ elif isinstance(v, abc.Iterable):
+ parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+")
+ else:
+ print(f"cannot parse key {prefix + k} of type {type(v)}")
+ return parser
+
+
+class Config:
+ """A facility for config and config files.
+
+ It supports common file formats as configs: python/json/yaml. The interface
+ is the same as a dict object and also allows access config values as
+ attributes.
+
+ Example:
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+ >>> cfg.a
+ 1
+ >>> cfg.b
+ {'b1': [0, 1]}
+ >>> cfg.b.b1
+ [0, 1]
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
+ >>> cfg.filename
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
+ >>> cfg.item4
+ 'test'
+ >>> cfg
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+ """
+
+ @staticmethod
+ def _validate_py_syntax(filename):
+ with open(filename, "r", encoding="utf-8") as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ content = f.read()
+ try:
+ ast.parse(content)
+ except SyntaxError as e:
+ raise SyntaxError(
+ "There are syntax errors in config " f"file {filename}: {e}"
+ )
+
+ @staticmethod
+ def _substitute_predefined_vars(filename, temp_config_name):
+ file_dirname = osp.dirname(filename)
+ file_basename = osp.basename(filename)
+ file_basename_no_extension = osp.splitext(file_basename)[0]
+ file_extname = osp.splitext(filename)[1]
+ support_templates = dict(
+ fileDirname=file_dirname,
+ fileBasename=file_basename,
+ fileBasenameNoExtension=file_basename_no_extension,
+ fileExtname=file_extname,
+ )
+ with open(filename, "r", encoding="utf-8") as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ for key, value in support_templates.items():
+ regexp = r"\{\{\s*" + str(key) + r"\s*\}\}"
+ value = value.replace("\\", "/")
+ config_file = re.sub(regexp, value, config_file)
+ with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
+ tmp_config_file.write(config_file)
+
+ @staticmethod
+ def _pre_substitute_base_vars(filename, temp_config_name):
+ """Substitute base variable placehoders to string, so that parsing
+ would work."""
+ with open(filename, "r", encoding="utf-8") as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ base_var_dict = {}
+ regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}"
+ base_vars = set(re.findall(regexp, config_file))
+ for base_var in base_vars:
+ randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}"
+ base_var_dict[randstr] = base_var
+ regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}"
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
+ with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
+ tmp_config_file.write(config_file)
+ return base_var_dict
+
+ @staticmethod
+ def _substitute_base_vars(cfg, base_var_dict, base_cfg):
+ """Substitute variable strings to their actual values."""
+ cfg = copy.deepcopy(cfg)
+
+ if isinstance(cfg, dict):
+ for k, v in cfg.items():
+ if isinstance(v, str) and v in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[v].split("."):
+ new_v = new_v[new_k]
+ cfg[k] = new_v
+ elif isinstance(v, (list, tuple, dict)):
+ cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg)
+ elif isinstance(cfg, tuple):
+ cfg = tuple(
+ Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
+ )
+ elif isinstance(cfg, list):
+ cfg = [
+ Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
+ ]
+ elif isinstance(cfg, str) and cfg in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[cfg].split("."):
+ new_v = new_v[new_k]
+ cfg = new_v
+
+ return cfg
+
+ @staticmethod
+ def _file2dict(filename, use_predefined_variables=True):
+ filename = osp.abspath(osp.expanduser(filename))
+ check_file_exist(filename)
+ fileExtname = osp.splitext(filename)[1]
+ if fileExtname not in [".py", ".json", ".yaml", ".yml"]:
+ raise IOError("Only py/yml/yaml/json type are supported now!")
+
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+ temp_config_file = tempfile.NamedTemporaryFile(
+ dir=temp_config_dir, suffix=fileExtname
+ )
+ if platform.system() == "Windows":
+ temp_config_file.close()
+ temp_config_name = osp.basename(temp_config_file.name)
+ # Substitute predefined variables
+ if use_predefined_variables:
+ Config._substitute_predefined_vars(filename, temp_config_file.name)
+ else:
+ shutil.copyfile(filename, temp_config_file.name)
+ # Substitute base variables from placeholders to strings
+ base_var_dict = Config._pre_substitute_base_vars(
+ temp_config_file.name, temp_config_file.name
+ )
+
+ if filename.endswith(".py"):
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ sys.path.insert(0, temp_config_dir)
+ Config._validate_py_syntax(filename)
+ mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith("__")
+ }
+ # delete imported module
+ del sys.modules[temp_module_name]
+ elif filename.endswith((".yml", ".yaml", ".json")):
+ raise NotImplementedError
+ # close temp file
+ temp_config_file.close()
+
+ # check deprecation information
+ if DEPRECATION_KEY in cfg_dict:
+ deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
+ warning_msg = (
+ f"The config file {filename} will be deprecated " "in the future."
+ )
+ if "expected" in deprecation_info:
+ warning_msg += f' Please use {deprecation_info["expected"]} ' "instead."
+ if "reference" in deprecation_info:
+ warning_msg += (
+ " More information can be found at "
+ f'{deprecation_info["reference"]}'
+ )
+ warnings.warn(warning_msg)
+
+ cfg_text = filename + "\n"
+ with open(filename, "r", encoding="utf-8") as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ cfg_text += f.read()
+
+ if BASE_KEY in cfg_dict:
+ cfg_dir = osp.dirname(filename)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = (
+ base_filename if isinstance(base_filename, list) else [base_filename]
+ )
+
+ cfg_dict_list = list()
+ cfg_text_list = list()
+ for f in base_filename:
+ _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+ cfg_text_list.append(_cfg_text)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ duplicate_keys = base_cfg_dict.keys() & c.keys()
+ if len(duplicate_keys) > 0:
+ raise KeyError(
+ "Duplicate key is not allowed among bases. "
+ f"Duplicate keys: {duplicate_keys}"
+ )
+ base_cfg_dict.update(c)
+
+ # Substitute base variables from strings to their actual values
+ cfg_dict = Config._substitute_base_vars(
+ cfg_dict, base_var_dict, base_cfg_dict
+ )
+
+ base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
+ cfg_dict = base_cfg_dict
+
+ # merge cfg_text
+ cfg_text_list.append(cfg_text)
+ cfg_text = "\n".join(cfg_text_list)
+
+ return cfg_dict, cfg_text
+
+ @staticmethod
+ def _merge_a_into_b(a, b, allow_list_keys=False):
+ """merge dict ``a`` into dict ``b`` (non-inplace).
+
+ Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
+ in-place modifications.
+
+ Args:
+ a (dict): The source dict to be merged into ``b``.
+ b (dict): The origin dict to be fetch keys from ``a``.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in source ``a`` and will replace the element of the
+ corresponding index in b if b is a list. Default: False.
+
+ Returns:
+ dict: The modified dict of ``b`` using ``a``.
+
+ Examples:
+ # Normally merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+
+ # Delete b first and merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+
+ # b is a list
+ >>> Config._merge_a_into_b(
+ ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
+ [{'a': 2}, {'b': 2}]
+ """
+ b = b.copy()
+ for k, v in a.items():
+ if allow_list_keys and k.isdigit() and isinstance(b, list):
+ k = int(k)
+ if len(b) <= k:
+ raise KeyError(f"Index {k} exceeds the length of list {b}")
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
+ allowed_types = (dict, list) if allow_list_keys else dict
+ if not isinstance(b[k], allowed_types):
+ raise TypeError(
+ f"{k}={v} in child config cannot inherit from base "
+ f"because {k} is a dict in the child config but is of "
+ f"type {type(b[k])} in base config. You may set "
+ f"`{DELETE_KEY}=True` to ignore the base config"
+ )
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ else:
+ b[k] = v
+ return b
+
+ @staticmethod
+ def fromfile(filename, use_predefined_variables=True, import_custom_modules=True):
+ cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables)
+ if import_custom_modules and cfg_dict.get("custom_imports", None):
+ import_modules_from_strings(**cfg_dict["custom_imports"])
+ return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
+
+ @staticmethod
+ def fromstring(cfg_str, file_format):
+ """Generate config from config str.
+
+ Args:
+ cfg_str (str): Config str.
+ file_format (str): Config file format corresponding to the
+ config str. Only py/yml/yaml/json type are supported now!
+
+ Returns:
+ obj:`Config`: Config obj.
+ """
+ if file_format not in [".py", ".json", ".yaml", ".yml"]:
+ raise IOError("Only py/yml/yaml/json type are supported now!")
+ if file_format != ".py" and "dict(" in cfg_str:
+ # check if users specify a wrong suffix for python
+ warnings.warn('Please check "file_format", the file format may be .py')
+ with tempfile.NamedTemporaryFile(
+ "w", encoding="utf-8", suffix=file_format, delete=False
+ ) as temp_file:
+ temp_file.write(cfg_str)
+ # on windows, previous implementation cause error
+ # see PR 1077 for details
+ cfg = Config.fromfile(temp_file.name)
+ os.remove(temp_file.name)
+ return cfg
+
+ @staticmethod
+ def auto_argparser(description=None):
+ """Generate argparser from config file automatically (experimental)"""
+ partial_parser = ArgumentParser(description=description)
+ partial_parser.add_argument("config", help="config file path")
+ cfg_file = partial_parser.parse_known_args()[0].config
+ cfg = Config.fromfile(cfg_file)
+ parser = ArgumentParser(description=description)
+ parser.add_argument("config", help="config file path")
+ add_args(parser, cfg)
+ return parser, cfg
+
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+ if cfg_dict is None:
+ cfg_dict = dict()
+ elif not isinstance(cfg_dict, dict):
+ raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
+ for key in cfg_dict:
+ if key in RESERVED_KEYS:
+ raise KeyError(f"{key} is reserved for config file")
+
+ super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
+ super(Config, self).__setattr__("_filename", filename)
+ if cfg_text:
+ text = cfg_text
+ elif filename:
+ with open(filename, "r") as f:
+ text = f.read()
+ else:
+ text = ""
+ super(Config, self).__setattr__("_text", text)
+
+ @property
+ def filename(self):
+ return self._filename
+
+ @property
+ def text(self):
+ return self._text
+
+ @property
+ def pretty_text(self):
+ indent = 4
+
+ def _indent(s_, num_spaces):
+ s = s_.split("\n")
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(num_spaces * " ") + line for line in s]
+ s = "\n".join(s)
+ s = first + "\n" + s
+ return s
+
+ def _format_basic_types(k, v, use_mapping=False):
+ if isinstance(v, str):
+ v_str = f"'{v}'"
+ else:
+ v_str = str(v)
+
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f"{k_str}: {v_str}"
+ else:
+ attr_str = f"{str(k)}={v_str}"
+ attr_str = _indent(attr_str, indent)
+
+ return attr_str
+
+ def _format_list(k, v, use_mapping=False):
+ # check if all items in the list are dict
+ if all(isinstance(_, dict) for _ in v):
+ v_str = "[\n"
+ v_str += "\n".join(
+ f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
+ ).rstrip(",")
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f"{k_str}: {v_str}"
+ else:
+ attr_str = f"{str(k)}={v_str}"
+ attr_str = _indent(attr_str, indent) + "]"
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping)
+ return attr_str
+
+ def _contain_invalid_identifier(dict_str):
+ contain_invalid_identifier = False
+ for key_name in dict_str:
+ contain_invalid_identifier |= not str(key_name).isidentifier()
+ return contain_invalid_identifier
+
+ def _format_dict(input_dict, outest_level=False):
+ r = ""
+ s = []
+
+ use_mapping = _contain_invalid_identifier(input_dict)
+ if use_mapping:
+ r += "{"
+ for idx, (k, v) in enumerate(input_dict.items()):
+ is_last = idx >= len(input_dict) - 1
+ end = "" if outest_level or is_last else ","
+ if isinstance(v, dict):
+ v_str = "\n" + _format_dict(v)
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f"{k_str}: dict({v_str}"
+ else:
+ attr_str = f"{str(k)}=dict({v_str}"
+ attr_str = _indent(attr_str, indent) + ")" + end
+ elif isinstance(v, list):
+ attr_str = _format_list(k, v, use_mapping) + end
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping) + end
+
+ s.append(attr_str)
+ r += "\n".join(s)
+ if use_mapping:
+ r += "}"
+ return r
+
+ cfg_dict = self._cfg_dict.to_dict()
+ text = _format_dict(cfg_dict, outest_level=True)
+ # copied from setup.cfg
+ yapf_style = dict(
+ based_on_style="pep8",
+ blank_line_before_nested_class_or_def=True,
+ split_before_expression_after_opening_paren=True,
+ )
+ # text, _ = FormatCode(text, style_config=yapf_style, verify=True)
+ text, _ = FormatCode(text, style_config=yapf_style)
+
+ return text
+
+ def __repr__(self):
+ return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
+
+ def __len__(self):
+ return len(self._cfg_dict)
+
+ def __getattr__(self, name):
+ return getattr(self._cfg_dict, name)
+
+ def __getitem__(self, name):
+ return self._cfg_dict.__getitem__(name)
+
+ def __setattr__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setitem__(name, value)
+
+ def __iter__(self):
+ return iter(self._cfg_dict)
+
+ def __getstate__(self):
+ return (self._cfg_dict, self._filename, self._text)
+
+ def __setstate__(self, state):
+ _cfg_dict, _filename, _text = state
+ super(Config, self).__setattr__("_cfg_dict", _cfg_dict)
+ super(Config, self).__setattr__("_filename", _filename)
+ super(Config, self).__setattr__("_text", _text)
+
+ def dump(self, file=None):
+ cfg_dict = super(Config, self).__getattribute__("_cfg_dict").to_dict()
+ if self.filename.endswith(".py"):
+ if file is None:
+ return self.pretty_text
+ else:
+ with open(file, "w", encoding="utf-8") as f:
+ f.write(self.pretty_text)
+ else:
+ import mmcv
+
+ if file is None:
+ file_format = self.filename.split(".")[-1]
+ return mmcv.dump(cfg_dict, file_format=file_format)
+ else:
+ mmcv.dump(cfg_dict, file)
+
+ def merge_from_dict(self, options, allow_list_keys=True):
+ """Merge list into cfg_dict.
+
+ Merge the dict parsed by MultipleKVAction into this cfg.
+
+ Examples:
+ >>> options = {'models.backbone.depth': 50,
+ ... 'models.backbone.with_cp':True}
+ >>> cfg = Config(dict(models=dict(backbone=dict(type='ResNet'))))
+ >>> cfg.merge_from_dict(options)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(
+ ... models=dict(backbone=dict(depth=50, with_cp=True)))
+
+ # Merge list element
+ >>> cfg = Config(dict(pipeline=[
+ ... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
+ >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
+ >>> cfg.merge_from_dict(options, allow_list_keys=True)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(pipeline=[
+ ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
+
+ Args:
+ options (dict): dict of configs to merge from.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in ``options`` and will replace the element of the
+ corresponding index in the config if the config is a list.
+ Default: True.
+ """
+ option_cfg_dict = {}
+ for full_key, v in options.items():
+ d = option_cfg_dict
+ key_list = full_key.split(".")
+ for subkey in key_list[:-1]:
+ d.setdefault(subkey, ConfigDict())
+ d = d[subkey]
+ subkey = key_list[-1]
+ d[subkey] = v
+
+ cfg_dict = super(Config, self).__getattribute__("_cfg_dict")
+ super(Config, self).__setattr__(
+ "_cfg_dict",
+ Config._merge_a_into_b(
+ option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys
+ ),
+ )
+
+
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options can
+ be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
+ brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
+ list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
+ """
+
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ["true", "false"]:
+ return True if val.lower() == "true" else False
+ return val
+
+ @staticmethod
+ def _parse_iterable(val):
+ """Parse iterable values in the string.
+
+ All elements inside '()' or '[]' are treated as iterable values.
+
+ Args:
+ val (str): Value string.
+
+ Returns:
+ list | tuple: The expanded list or tuple from the string.
+
+ Examples:
+ >>> DictAction._parse_iterable('1,2,3')
+ [1, 2, 3]
+ >>> DictAction._parse_iterable('[a, b, c]')
+ ['a', 'b', 'c']
+ >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
+ [(1, 2, 3), ['a', 'b'], 'c']
+ """
+
+ def find_next_comma(string):
+ """Find the position of next comma in the string.
+
+ If no ',' is found in the string, return the string length. All
+ chars inside '()' and '[]' are treated as one element and thus ','
+ inside these brackets are ignored.
+ """
+ assert (string.count("(") == string.count(")")) and (
+ string.count("[") == string.count("]")
+ ), f"Imbalanced brackets exist in {string}"
+ end = len(string)
+ for idx, char in enumerate(string):
+ pre = string[:idx]
+ # The string before this ',' is balanced
+ if (
+ (char == ",")
+ and (pre.count("(") == pre.count(")"))
+ and (pre.count("[") == pre.count("]"))
+ ):
+ end = idx
+ break
+ return end
+
+ # Strip ' and " characters and replace whitespace.
+ val = val.strip("'\"").replace(" ", "")
+ is_tuple = False
+ if val.startswith("(") and val.endswith(")"):
+ is_tuple = True
+ val = val[1:-1]
+ elif val.startswith("[") and val.endswith("]"):
+ val = val[1:-1]
+ elif "," not in val:
+ # val is a single value
+ return DictAction._parse_int_float_bool(val)
+
+ values = []
+ while len(val) > 0:
+ comma_idx = find_next_comma(val)
+ element = DictAction._parse_iterable(val[:comma_idx])
+ values.append(element)
+ val = val[comma_idx + 1 :]
+ if is_tuple:
+ values = tuple(values)
+ return values
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split("=", maxsplit=1)
+ options[key] = self._parse_iterable(val)
+ setattr(namespace, self.dest, options)
diff --git a/UniRig/src/model/pointcept/utils/env.py b/UniRig/src/model/pointcept/utils/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..653f007dde5c4a7564e732da88dd47e7d37adf97
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/env.py
@@ -0,0 +1,36 @@
+"""
+Environment Utils
+
+Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
+Please cite our work if the code is helpful to you.
+"""
+
+import os
+import random
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+
+from datetime import datetime
+
+
+def get_random_seed():
+ seed = (
+ os.getpid()
+ + int(datetime.now().strftime("%S%f"))
+ + int.from_bytes(os.urandom(2), "big")
+ )
+ return seed
+
+
+def set_seed(seed=None):
+ if seed is None:
+ seed = get_random_seed()
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+ os.environ["PYTHONHASHSEED"] = str(seed)
diff --git a/UniRig/src/model/pointcept/utils/events.py b/UniRig/src/model/pointcept/utils/events.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2db40895bce751e2d69482578d460f187cb8d0e
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/events.py
@@ -0,0 +1,590 @@
+"""
+Events Utils
+
+Modified from Detectron2
+
+Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
+Please cite our work if the code is helpful to you.
+"""
+
+
+import datetime
+import json
+import logging
+import os
+import time
+import torch
+import numpy as np
+
+from typing import List, Optional, Tuple
+from collections import defaultdict
+from contextlib import contextmanager
+
+__all__ = [
+ "get_event_storage",
+ "JSONWriter",
+ "TensorboardXWriter",
+ "CommonMetricPrinter",
+ "EventStorage",
+]
+
+_CURRENT_STORAGE_STACK = []
+
+
+def get_event_storage():
+ """
+ Returns:
+ The :class:`EventStorage` object that's currently being used.
+ Throws an error if no :class:`EventStorage` is currently enabled.
+ """
+ assert len(
+ _CURRENT_STORAGE_STACK
+ ), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
+ return _CURRENT_STORAGE_STACK[-1]
+
+
+class EventWriter:
+ """
+ Base class for writers that obtain events from :class:`EventStorage` and process them.
+ """
+
+ def write(self):
+ raise NotImplementedError
+
+ def close(self):
+ pass
+
+
+class JSONWriter(EventWriter):
+ """
+ Write scalars to a json file.
+ It saves scalars as one json per line (instead of a big json) for easy parsing.
+ Examples parsing such a json file:
+ ::
+ $ cat metrics.json | jq -s '.[0:2]'
+ [
+ {
+ "data_time": 0.008433341979980469,
+ "iteration": 19,
+ "loss": 1.9228371381759644,
+ "loss_box_reg": 0.050025828182697296,
+ "loss_classifier": 0.5316952466964722,
+ "loss_mask": 0.7236229181289673,
+ "loss_rpn_box": 0.0856662318110466,
+ "loss_rpn_cls": 0.48198649287223816,
+ "lr": 0.007173333333333333,
+ "time": 0.25401854515075684
+ },
+ {
+ "data_time": 0.007216215133666992,
+ "iteration": 39,
+ "loss": 1.282649278640747,
+ "loss_box_reg": 0.06222952902317047,
+ "loss_classifier": 0.30682939291000366,
+ "loss_mask": 0.6970193982124329,
+ "loss_rpn_box": 0.038663312792778015,
+ "loss_rpn_cls": 0.1471673548221588,
+ "lr": 0.007706666666666667,
+ "time": 0.2490077018737793
+ }
+ ]
+ $ cat metrics.json | jq '.loss_mask'
+ 0.7126231789588928
+ 0.689423680305481
+ 0.6776131987571716
+ ...
+ """
+
+ def __init__(self, json_file, window_size=20):
+ """
+ Args:
+ json_file (str): path to the json file. New data will be appended if the file exists.
+ window_size (int): the window size of median smoothing for the scalars whose
+ `smoothing_hint` are True.
+ """
+ self._file_handle = open(json_file, "a")
+ self._window_size = window_size
+ self._last_write = -1
+
+ def write(self):
+ storage = get_event_storage()
+ to_save = defaultdict(dict)
+
+ for k, (v, iter) in storage.latest_with_smoothing_hint(
+ self._window_size
+ ).items():
+ # keep scalars that have not been written
+ if iter <= self._last_write:
+ continue
+ to_save[iter][k] = v
+ if len(to_save):
+ all_iters = sorted(to_save.keys())
+ self._last_write = max(all_iters)
+
+ for itr, scalars_per_iter in to_save.items():
+ scalars_per_iter["iteration"] = itr
+ self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
+ self._file_handle.flush()
+ try:
+ os.fsync(self._file_handle.fileno())
+ except AttributeError:
+ pass
+
+ def close(self):
+ self._file_handle.close()
+
+
+class TensorboardXWriter(EventWriter):
+ """
+ Write all scalars to a tensorboard file.
+ """
+
+ def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
+ """
+ Args:
+ log_dir (str): the directory to save the output events
+ window_size (int): the scalars will be median-smoothed by this window size
+ kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
+ """
+ self._window_size = window_size
+ from torch.utils.tensorboard import SummaryWriter
+
+ self._writer = SummaryWriter(log_dir, **kwargs)
+ self._last_write = -1
+
+ def write(self):
+ storage = get_event_storage()
+ new_last_write = self._last_write
+ for k, (v, iter) in storage.latest_with_smoothing_hint(
+ self._window_size
+ ).items():
+ if iter > self._last_write:
+ self._writer.add_scalar(k, v, iter)
+ new_last_write = max(new_last_write, iter)
+ self._last_write = new_last_write
+
+ # storage.put_{image,histogram} is only meant to be used by
+ # tensorboard writer. So we access its internal fields directly from here.
+ if len(storage._vis_data) >= 1:
+ for img_name, img, step_num in storage._vis_data:
+ self._writer.add_image(img_name, img, step_num)
+ # Storage stores all image data and rely on this writer to clear them.
+ # As a result it assumes only one writer will use its image data.
+ # An alternative design is to let storage store limited recent
+ # data (e.g. only the most recent image) that all writers can access.
+ # In that case a writer may not see all image data if its period is long.
+ storage.clear_images()
+
+ if len(storage._histograms) >= 1:
+ for params in storage._histograms:
+ self._writer.add_histogram_raw(**params)
+ storage.clear_histograms()
+
+ def close(self):
+ if hasattr(self, "_writer"): # doesn't exist when the code fails at import
+ self._writer.close()
+
+
+class CommonMetricPrinter(EventWriter):
+ """
+ Print **common** metrics to the terminal, including
+ iteration time, ETA, memory, all losses, and the learning rate.
+ It also applies smoothing using a window of 20 elements.
+ It's meant to print common metrics in common ways.
+ To print something in more customized ways, please implement a similar printer by yourself.
+ """
+
+ def __init__(self, max_iter: Optional[int] = None, window_size: int = 20):
+ """
+ Args:
+ max_iter: the maximum number of iterations to train.
+ Used to compute ETA. If not given, ETA will not be printed.
+ window_size (int): the losses will be median-smoothed by this window size
+ """
+ self.logger = logging.getLogger(__name__)
+ self._max_iter = max_iter
+ self._window_size = window_size
+ self._last_write = (
+ None # (step, time) of last call to write(). Used to compute ETA
+ )
+
+ def _get_eta(self, storage) -> Optional[str]:
+ if self._max_iter is None:
+ return ""
+ iteration = storage.iter
+ try:
+ eta_seconds = storage.history("time").median(1000) * (
+ self._max_iter - iteration - 1
+ )
+ storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
+ return str(datetime.timedelta(seconds=int(eta_seconds)))
+ except KeyError:
+ # estimate eta on our own - more noisy
+ eta_string = None
+ if self._last_write is not None:
+ estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
+ iteration - self._last_write[0]
+ )
+ eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ self._last_write = (iteration, time.perf_counter())
+ return eta_string
+
+ def write(self):
+ storage = get_event_storage()
+ iteration = storage.iter
+ if iteration == self._max_iter:
+ # This hook only reports training progress (loss, ETA, etc) but not other data,
+ # therefore do not write anything after training succeeds, even if this method
+ # is called.
+ return
+
+ try:
+ data_time = storage.history("data_time").avg(20)
+ except KeyError:
+ # they may not exist in the first few iterations (due to warmup)
+ # or when SimpleTrainer is not used
+ data_time = None
+ try:
+ iter_time = storage.history("time").global_avg()
+ except KeyError:
+ iter_time = None
+ try:
+ lr = "{:.5g}".format(storage.history("lr").latest())
+ except KeyError:
+ lr = "N/A"
+
+ eta_string = self._get_eta(storage)
+
+ if torch.cuda.is_available():
+ max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
+ else:
+ max_mem_mb = None
+
+ # NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
+ self.logger.info(
+ " {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
+ eta=f"eta: {eta_string} " if eta_string else "",
+ iter=iteration,
+ losses=" ".join(
+ [
+ "{}: {:.4g}".format(k, v.median(self._window_size))
+ for k, v in storage.histories().items()
+ if "loss" in k
+ ]
+ ),
+ time="time: {:.4f} ".format(iter_time)
+ if iter_time is not None
+ else "",
+ data_time="data_time: {:.4f} ".format(data_time)
+ if data_time is not None
+ else "",
+ lr=lr,
+ memory="max_mem: {:.0f}M".format(max_mem_mb)
+ if max_mem_mb is not None
+ else "",
+ )
+ )
+
+
+class EventStorage:
+ """
+ The user-facing class that provides metric storage functionalities.
+ In the future we may add support for storing / logging other types of data if needed.
+ """
+
+ def __init__(self, start_iter=0):
+ """
+ Args:
+ start_iter (int): the iteration number to start with
+ """
+ self._history = defaultdict(AverageMeter)
+ self._smoothing_hints = {}
+ self._latest_scalars = {}
+ self._iter = start_iter
+ self._current_prefix = ""
+ self._vis_data = []
+ self._histograms = []
+
+ # def put_image(self, img_name, img_tensor):
+ # """
+ # Add an `img_tensor` associated with `img_name`, to be shown on
+ # tensorboard.
+ # Args:
+ # img_name (str): The name of the image to put into tensorboard.
+ # img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
+ # Tensor of shape `[channel, height, width]` where `channel` is
+ # 3. The image format should be RGB. The elements in img_tensor
+ # can either have values in [0, 1] (float32) or [0, 255] (uint8).
+ # The `img_tensor` will be visualized in tensorboard.
+ # """
+ # self._vis_data.append((img_name, img_tensor, self._iter))
+
+ def put_scalar(self, name, value, n=1, smoothing_hint=False):
+ """
+ Add a scalar `value` to the `HistoryBuffer` associated with `name`.
+ Args:
+ smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
+ smoothed when logged. The hint will be accessible through
+ :meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
+ and apply custom smoothing rule.
+ It defaults to True because most scalars we save need to be smoothed to
+ provide any useful signal.
+ """
+ name = self._current_prefix + name
+ history = self._history[name]
+ history.update(value, n)
+ self._latest_scalars[name] = (value, self._iter)
+
+ existing_hint = self._smoothing_hints.get(name)
+ if existing_hint is not None:
+ assert (
+ existing_hint == smoothing_hint
+ ), "Scalar {} was put with a different smoothing_hint!".format(name)
+ else:
+ self._smoothing_hints[name] = smoothing_hint
+
+ # def put_scalars(self, *, smoothing_hint=True, **kwargs):
+ # """
+ # Put multiple scalars from keyword arguments.
+ # Examples:
+ # storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
+ # """
+ # for k, v in kwargs.items():
+ # self.put_scalar(k, v, smoothing_hint=smoothing_hint)
+ #
+ # def put_histogram(self, hist_name, hist_tensor, bins=1000):
+ # """
+ # Create a histogram from a tensor.
+ # Args:
+ # hist_name (str): The name of the histogram to put into tensorboard.
+ # hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
+ # into a histogram.
+ # bins (int): Number of histogram bins.
+ # """
+ # ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
+ #
+ # # Create a histogram with PyTorch
+ # hist_counts = torch.histc(hist_tensor, bins=bins)
+ # hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
+ #
+ # # Parameter for the add_histogram_raw function of SummaryWriter
+ # hist_params = dict(
+ # tag=hist_name,
+ # min=ht_min,
+ # max=ht_max,
+ # num=len(hist_tensor),
+ # sum=float(hist_tensor.sum()),
+ # sum_squares=float(torch.sum(hist_tensor**2)),
+ # bucket_limits=hist_edges[1:].tolist(),
+ # bucket_counts=hist_counts.tolist(),
+ # global_step=self._iter,
+ # )
+ # self._histograms.append(hist_params)
+
+ def history(self, name):
+ """
+ Returns:
+ AverageMeter: the history for name
+ """
+ ret = self._history.get(name, None)
+ if ret is None:
+ raise KeyError("No history metric available for {}!".format(name))
+ return ret
+
+ def histories(self):
+ """
+ Returns:
+ dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
+ """
+ return self._history
+
+ def latest(self):
+ """
+ Returns:
+ dict[str -> (float, int)]: mapping from the name of each scalar to the most
+ recent value and the iteration number its added.
+ """
+ return self._latest_scalars
+
+ def latest_with_smoothing_hint(self, window_size=20):
+ """
+ Similar to :meth:`latest`, but the returned values
+ are either the un-smoothed original latest value,
+ or a median of the given window_size,
+ depend on whether the smoothing_hint is True.
+ This provides a default behavior that other writers can use.
+ """
+ result = {}
+ for k, (v, itr) in self._latest_scalars.items():
+ result[k] = (
+ self._history[k].median(window_size) if self._smoothing_hints[k] else v,
+ itr,
+ )
+ return result
+
+ def smoothing_hints(self):
+ """
+ Returns:
+ dict[name -> bool]: the user-provided hint on whether the scalar
+ is noisy and needs smoothing.
+ """
+ return self._smoothing_hints
+
+ def step(self):
+ """
+ User should either: (1) Call this function to increment storage.iter when needed. Or
+ (2) Set `storage.iter` to the correct iteration number before each iteration.
+ The storage will then be able to associate the new data with an iteration number.
+ """
+ self._iter += 1
+
+ @property
+ def iter(self):
+ """
+ Returns:
+ int: The current iteration number. When used together with a trainer,
+ this is ensured to be the same as trainer.iter.
+ """
+ return self._iter
+
+ @iter.setter
+ def iter(self, val):
+ self._iter = int(val)
+
+ @property
+ def iteration(self):
+ # for backward compatibility
+ return self._iter
+
+ def __enter__(self):
+ _CURRENT_STORAGE_STACK.append(self)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ assert _CURRENT_STORAGE_STACK[-1] == self
+ _CURRENT_STORAGE_STACK.pop()
+
+ @contextmanager
+ def name_scope(self, name):
+ """
+ Yields:
+ A context within which all the events added to this storage
+ will be prefixed by the name scope.
+ """
+ old_prefix = self._current_prefix
+ self._current_prefix = name.rstrip("/") + "/"
+ yield
+ self._current_prefix = old_prefix
+
+ def clear_images(self):
+ """
+ Delete all the stored images for visualization. This should be called
+ after images are written to tensorboard.
+ """
+ self._vis_data = []
+
+ def clear_histograms(self):
+ """
+ Delete all the stored histograms for visualization.
+ This should be called after histograms are written to tensorboard.
+ """
+ self._histograms = []
+
+ def reset_history(self, name):
+ ret = self._history.get(name, None)
+ if ret is None:
+ raise KeyError("No history metric available for {}!".format(name))
+ ret.reset()
+
+ def reset_histories(self):
+ for name in self._history.keys():
+ self._history[name].reset()
+
+
+class AverageMeter:
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.val = 0
+ self.avg = 0
+ self.total = 0
+ self.count = 0
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.total = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.total += val * n
+ self.count += n
+ self.avg = self.total / self.count
+
+
+class HistoryBuffer:
+ """
+ Track a series of scalar values and provide access to smoothed values over a
+ window or the global average of the series.
+ """
+
+ def __init__(self, max_length: int = 1000000) -> None:
+ """
+ Args:
+ max_length: maximal number of values that can be stored in the
+ buffer. When the capacity of the buffer is exhausted, old
+ values will be removed.
+ """
+ self._max_length: int = max_length
+ self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
+ self._count: int = 0
+ self._global_avg: float = 0
+
+ def update(self, value: float, iteration: Optional[float] = None) -> None:
+ """
+ Add a new scalar value produced at certain iteration. If the length
+ of the buffer exceeds self._max_length, the oldest element will be
+ removed from the buffer.
+ """
+ if iteration is None:
+ iteration = self._count
+ if len(self._data) == self._max_length:
+ self._data.pop(0)
+ self._data.append((value, iteration))
+
+ self._count += 1
+ self._global_avg += (value - self._global_avg) / self._count
+
+ def latest(self) -> float:
+ """
+ Return the latest scalar value added to the buffer.
+ """
+ return self._data[-1][0]
+
+ def median(self, window_size: int) -> float:
+ """
+ Return the median of the latest `window_size` values in the buffer.
+ """
+ return np.median([x[0] for x in self._data[-window_size:]])
+
+ def avg(self, window_size: int) -> float:
+ """
+ Return the mean of the latest `window_size` values in the buffer.
+ """
+ return np.mean([x[0] for x in self._data[-window_size:]])
+
+ def global_avg(self) -> float:
+ """
+ Return the mean of all the elements in the buffer. Note that this
+ includes those getting removed due to limited buffer storage.
+ """
+ return self._global_avg
+
+ def values(self) -> List[Tuple[float, float]]:
+ """
+ Returns:
+ list[(number, iteration)]: content of the current buffer.
+ """
+ return self._data
diff --git a/UniRig/src/model/pointcept/utils/logger.py b/UniRig/src/model/pointcept/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddaf2c5a765c9f1325737c3cbc73e1169f13cdd4
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/logger.py
@@ -0,0 +1,172 @@
+"""
+Logger Utils
+
+Modified from mmcv
+
+Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
+Please cite our work if the code is helpful to you.
+"""
+
+import logging
+import torch
+import torch.distributed as dist
+
+from termcolor import colored
+
+logger_initialized = {}
+root_status = 0
+
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="a", color=False):
+ """Initialize and get a logger by name.
+
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
+ will also be added.
+
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ file_mode (str): The file mode used in opening log file.
+ Defaults to 'a'.
+ color (bool): Colorful log output. Defaults to True
+
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+
+ if name in logger_initialized:
+ return logger
+ # handle hierarchical names
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
+ # initialization since it is a child of "a".
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ logger.propagate = False
+
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+
+ # only rank 0 will add a FileHandler
+ if rank == 0 and log_file is not None:
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
+ # provide an interface to change the file mode to the default
+ # behaviour.
+ file_handler = logging.FileHandler(log_file, file_mode)
+ handlers.append(file_handler)
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
+ )
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ )
+ else:
+ formatter = plain_formatter
+ for handler in handlers:
+ handler.setFormatter(formatter)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+
+ if rank == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+
+ logger_initialized[name] = True
+
+ return logger
+
+
+def print_log(msg, logger=None, level=logging.INFO):
+ """Print a log message.
+
+ Args:
+ msg (str): The message to be logged.
+ logger (logging.Logger | str | None): The logger to be used.
+ Some special loggers are:
+ - "silent": no message will be printed.
+ - other str: the logger obtained with `get_root_logger(logger)`.
+ - None: The `print()` method will be used to print log messages.
+ level (int): Logging level. Only available when `logger` is a Logger
+ object or "root".
+ """
+ if logger is None:
+ print(msg)
+ elif isinstance(logger, logging.Logger):
+ logger.log(level, msg)
+ elif logger == "silent":
+ pass
+ elif isinstance(logger, str):
+ _logger = get_logger(logger)
+ _logger.log(level, msg)
+ else:
+ raise TypeError(
+ "logger should be either a logging.Logger object, str, "
+ f'"silent" or None, but got {type(logger)}'
+ )
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO, file_mode="a"):
+ """Get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added. The name of the root logger is the top-level package name.
+
+ Args:
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+ file_mode (str): File Mode of logger. (w or a)
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = get_logger(
+ name="pointcept", log_file=log_file, log_level=log_level, file_mode=file_mode
+ )
+ return logger
+
+
+def _log_api_usage(identifier: str):
+ """
+ Internal function used to log the usage of different detectron2 components
+ inside facebook's infra.
+ """
+ torch._C._log_api_usage_once("pointcept." + identifier)
diff --git a/UniRig/src/model/pointcept/utils/misc.py b/UniRig/src/model/pointcept/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d890275ac26e1ac23c88e33d00065050a2230bac
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/misc.py
@@ -0,0 +1,159 @@
+"""
+Misc
+
+Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
+Please cite our work if the code is helpful to you.
+"""
+
+import os
+import warnings
+from collections import abc
+import numpy as np
+import torch
+from importlib import import_module
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def intersection_and_union(output, target, K, ignore_index=-1):
+ # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
+ assert output.ndim in [1, 2, 3]
+ assert output.shape == target.shape
+ output = output.reshape(output.size).copy()
+ target = target.reshape(target.size)
+ output[np.where(target == ignore_index)[0]] = ignore_index
+ intersection = output[np.where(output == target)[0]]
+ area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1))
+ area_output, _ = np.histogram(output, bins=np.arange(K + 1))
+ area_target, _ = np.histogram(target, bins=np.arange(K + 1))
+ area_union = area_output + area_target - area_intersection
+ return area_intersection, area_union, area_target
+
+
+def intersection_and_union_gpu(output, target, k, ignore_index=-1):
+ # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
+ assert output.dim() in [1, 2, 3]
+ assert output.shape == target.shape
+ output = output.view(-1)
+ target = target.view(-1)
+ output[target == ignore_index] = ignore_index
+ intersection = output[output == target]
+ area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1)
+ area_output = torch.histc(output, bins=k, min=0, max=k - 1)
+ area_target = torch.histc(target, bins=k, min=0, max=k - 1)
+ area_union = area_output + area_target - area_intersection
+ return area_intersection, area_union, area_target
+
+
+def make_dirs(dir_name):
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name, exist_ok=True)
+
+
+def find_free_port():
+ import socket
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # Binding to port 0 will cause the OS to find an available port for us
+ sock.bind(("", 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ # NOTE: there is still a chance the port could be taken by other processes.
+ return port
+
+
+def is_seq_of(seq, expected_type, seq_type=None):
+ """Check whether it is a sequence of some type.
+
+ Args:
+ seq (Sequence): The sequence to be checked.
+ expected_type (type): Expected type of sequence items.
+ seq_type (type, optional): Expected sequence type.
+
+ Returns:
+ bool: Whether the sequence is valid.
+ """
+ if seq_type is None:
+ exp_seq_type = abc.Sequence
+ else:
+ assert isinstance(seq_type, type)
+ exp_seq_type = seq_type
+ if not isinstance(seq, exp_seq_type):
+ return False
+ for item in seq:
+ if not isinstance(item, expected_type):
+ return False
+ return True
+
+
+def is_str(x):
+ """Whether the input is an string instance.
+
+ Note: This method is deprecated since python 2 is no longer supported.
+ """
+ return isinstance(x, str)
+
+
+def import_modules_from_strings(imports, allow_failed_imports=False):
+ """Import modules from the given list of strings.
+
+ Args:
+ imports (list | str | None): The given module names to be imported.
+ allow_failed_imports (bool): If True, the failed imports will return
+ None. Otherwise, an ImportError is raise. Default: False.
+
+ Returns:
+ list[module] | module | None: The imported modules.
+
+ Examples:
+ >>> osp, sys = import_modules_from_strings(
+ ... ['os.path', 'sys'])
+ >>> import os.path as osp_
+ >>> import sys as sys_
+ >>> assert osp == osp_
+ >>> assert sys == sys_
+ """
+ if not imports:
+ return
+ single_import = False
+ if isinstance(imports, str):
+ single_import = True
+ imports = [imports]
+ if not isinstance(imports, list):
+ raise TypeError(f"custom_imports must be a list but got type {type(imports)}")
+ imported = []
+ for imp in imports:
+ if not isinstance(imp, str):
+ raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.")
+ try:
+ imported_tmp = import_module(imp)
+ except ImportError:
+ if allow_failed_imports:
+ warnings.warn(f"{imp} failed to import and is ignored.", UserWarning)
+ imported_tmp = None
+ else:
+ raise ImportError
+ imported.append(imported_tmp)
+ if single_import:
+ imported = imported[0]
+ return imported
diff --git a/UniRig/src/model/pointcept/utils/optimizer.py b/UniRig/src/model/pointcept/utils/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc2a3fe00a7e7ac3751926568d6112a964e13930
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/optimizer.py
@@ -0,0 +1,67 @@
+"""
+Optimizer
+
+Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
+Please cite our work if the code is helpful to you.
+"""
+
+import torch
+from pointcept.utils.logger import get_root_logger
+from pointcept.utils.registry import Registry
+
+OPTIMIZERS = Registry("optimizers")
+
+
+OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD")
+OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam")
+OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW")
+
+
+def build_optimizer(cfg, model, param_dicts=None):
+ if param_dicts is None:
+ # cfg.params = model.parameters()
+ cfg.params = [dict(names=[], params=[], lr=cfg.lr)]
+ for n, p in model.named_parameters():
+ if not p.requires_grad:
+ continue
+ cfg.params[0]["names"].append(n)
+ cfg.params[0]["params"].append(p)
+ else:
+ cfg.params = [dict(names=[], params=[], lr=cfg.lr)]
+ for i in range(len(param_dicts)):
+ param_group = dict(names=[], params=[])
+ if "lr" in param_dicts[i].keys():
+ param_group["lr"] = param_dicts[i].lr
+ if "momentum" in param_dicts[i].keys():
+ param_group["momentum"] = param_dicts[i].momentum
+ if "weight_decay" in param_dicts[i].keys():
+ param_group["weight_decay"] = param_dicts[i].weight_decay
+ cfg.params.append(param_group)
+
+ for n, p in model.named_parameters():
+ # !!! requires_grad is a must
+ if not p.requires_grad:
+ continue
+ flag = False
+ for i in range(len(param_dicts)):
+ if param_dicts[i].keyword in n:
+ cfg.params[i + 1]["names"].append(n)
+ cfg.params[i + 1]["params"].append(p)
+ flag = True
+ break
+ if not flag:
+ cfg.params[0]["names"].append(n)
+ cfg.params[0]["params"].append(p)
+
+ logger = get_root_logger()
+
+ for i in range(len(cfg.params)):
+ param_names = cfg.params[i].pop("names")
+ message = ""
+ for key in cfg.params[i].keys():
+ if key != "params":
+ message += f" {key}: {cfg.params[i][key]};"
+ logger.info(f"Params Group {i+1} -{message} Params: {param_names}.")
+ # print(111)
+ # exit(0)
+ return OPTIMIZERS.build(cfg=cfg)
diff --git a/UniRig/src/model/pointcept/utils/path.py b/UniRig/src/model/pointcept/utils/path.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce98fa5fd0dfbf6e1d61e833ecc35fea4ab2782b
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/path.py
@@ -0,0 +1,103 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+from pathlib import Path
+
+from .misc import is_str
+
+
+def is_filepath(x):
+ return is_str(x) or isinstance(x, Path)
+
+
+def fopen(filepath, *args, **kwargs):
+ if is_str(filepath):
+ return open(filepath, *args, **kwargs)
+ elif isinstance(filepath, Path):
+ return filepath.open(*args, **kwargs)
+ raise ValueError("`filepath` should be a string or a Path")
+
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+
+
+def mkdir_or_exist(dir_name, mode=0o777):
+ if dir_name == "":
+ return
+ dir_name = osp.expanduser(dir_name)
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
+
+
+def symlink(src, dst, overwrite=True, **kwargs):
+ if os.path.lexists(dst) and overwrite:
+ os.remove(dst)
+ os.symlink(src, dst, **kwargs)
+
+
+def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str | obj:`Path`): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ case_sensitive (bool, optional) : If set to False, ignore the case of
+ suffix. Default: True.
+
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+ if isinstance(dir_path, (str, Path)):
+ dir_path = str(dir_path)
+ else:
+ raise TypeError('"dir_path" must be a string or Path object')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ if suffix is not None and not case_sensitive:
+ suffix = (
+ suffix.lower()
+ if isinstance(suffix, str)
+ else tuple(item.lower() for item in suffix)
+ )
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith(".") and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
+ if suffix is None or _rel_path.endswith(suffix):
+ yield rel_path
+ elif recursive and os.path.isdir(entry.path):
+ # scan recursively if entry.path is a directory
+ yield from _scandir(entry.path, suffix, recursive, case_sensitive)
+
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
+
+
+def find_vcs_root(path, markers=(".git",)):
+ """Finds the root directory (including itself) of specified markers.
+
+ Args:
+ path (str): Path of directory or file.
+ markers (list[str], optional): List of file or directory names.
+
+ Returns:
+ The directory contained one of the markers or None if not found.
+ """
+ if osp.isfile(path):
+ path = osp.dirname(path)
+
+ prev, cur = None, osp.abspath(osp.expanduser(path))
+ while cur != prev:
+ if any(osp.exists(osp.join(cur, marker)) for marker in markers):
+ return cur
+ prev, cur = cur, osp.split(cur)[0]
+ return None
diff --git a/UniRig/src/model/pointcept/utils/registry.py b/UniRig/src/model/pointcept/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ac308a87d38ff61da14d6b4d5c73b4c68c15a58
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/registry.py
@@ -0,0 +1,316 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import warnings
+from functools import partial
+
+from .misc import is_seq_of
+
+
+def build_from_cfg(cfg, registry, default_args=None):
+ """Build a module from configs dict.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ registry (:obj:`Registry`): The registry to search the type from.
+ default_args (dict, optional): Default initialization arguments.
+
+ Returns:
+ object: The constructed object.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f"cfg must be a dict, but got {type(cfg)}")
+ if "type" not in cfg:
+ if default_args is None or "type" not in default_args:
+ raise KeyError(
+ '`cfg` or `default_args` must contain the key "type", '
+ f"but got {cfg}\n{default_args}"
+ )
+ if not isinstance(registry, Registry):
+ raise TypeError(
+ "registry must be an mmcv.Registry object, " f"but got {type(registry)}"
+ )
+ if not (isinstance(default_args, dict) or default_args is None):
+ raise TypeError(
+ "default_args must be a dict or None, " f"but got {type(default_args)}"
+ )
+
+ args = cfg.copy()
+
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+
+ obj_type = args.pop("type")
+ if isinstance(obj_type, str):
+ obj_cls = registry.get(obj_type)
+ if obj_cls is None:
+ raise KeyError(f"{obj_type} is not in the {registry.name} registry")
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
+ try:
+ return obj_cls(**args)
+ except Exception as e:
+ # Normal TypeError does not print class name.
+ raise type(e)(f"{obj_cls.__name__}: {e}")
+
+
+class Registry:
+ """A registry to map strings to classes.
+
+ Registered object could be built from registry.
+ Example:
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = MODELS.build(dict(type='ResNet'))
+
+ Please refer to
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
+ advanced usage.
+
+ Args:
+ name (str): Registry name.
+ build_func(func, optional): Build function to construct instance from
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
+ ``build_func`` is specified. If ``parent`` is specified and
+ ``build_func`` is not given, ``build_func`` will be inherited
+ from ``parent``. Default: None.
+ parent (Registry, optional): Parent registry. The class registered in
+ children registry could be built from parent. Default: None.
+ scope (str, optional): The scope of registry. It is the key to search
+ for children registry. If not specified, scope will be the name of
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
+ Default: None.
+ """
+
+ def __init__(self, name, build_func=None, parent=None, scope=None):
+ self._name = name
+ self._module_dict = dict()
+ self._children = dict()
+ self._scope = self.infer_scope() if scope is None else scope
+
+ # self.build_func will be set with the following priority:
+ # 1. build_func
+ # 2. parent.build_func
+ # 3. build_from_cfg
+ if build_func is None:
+ if parent is not None:
+ self.build_func = parent.build_func
+ else:
+ self.build_func = build_from_cfg
+ else:
+ self.build_func = build_func
+ if parent is not None:
+ assert isinstance(parent, Registry)
+ parent._add_children(self)
+ self.parent = parent
+ else:
+ self.parent = None
+
+ def __len__(self):
+ return len(self._module_dict)
+
+ def __contains__(self, key):
+ return self.get(key) is not None
+
+ def __repr__(self):
+ format_str = (
+ self.__class__.__name__ + f"(name={self._name}, "
+ f"items={self._module_dict})"
+ )
+ return format_str
+
+ @staticmethod
+ def infer_scope():
+ """Infer the scope of registry.
+
+ The name of the package where registry is defined will be returned.
+
+ Example:
+ # in mmdet/models/backbone/resnet.py
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ The scope of ``ResNet`` will be ``mmdet``.
+
+
+ Returns:
+ scope (str): The inferred scope name.
+ """
+ # inspect.stack() trace where this function is called, the index-2
+ # indicates the frame where `infer_scope()` is called
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
+ split_filename = filename.split(".")
+ return split_filename[0]
+
+ @staticmethod
+ def split_scope_key(key):
+ """Split scope and key.
+
+ The first scope will be split from key.
+
+ Examples:
+ >>> Registry.split_scope_key('mmdet.ResNet')
+ 'mmdet', 'ResNet'
+ >>> Registry.split_scope_key('ResNet')
+ None, 'ResNet'
+
+ Return:
+ scope (str, None): The first scope.
+ key (str): The remaining key.
+ """
+ split_index = key.find(".")
+ if split_index != -1:
+ return key[:split_index], key[split_index + 1 :]
+ else:
+ return None, key
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def scope(self):
+ return self._scope
+
+ @property
+ def module_dict(self):
+ return self._module_dict
+
+ @property
+ def children(self):
+ return self._children
+
+ def get(self, key):
+ """Get the registry record.
+
+ Args:
+ key (str): The class name in string format.
+
+ Returns:
+ class: The corresponding class.
+ """
+ scope, real_key = self.split_scope_key(key)
+ if scope is None or scope == self._scope:
+ # get from self
+ if real_key in self._module_dict:
+ return self._module_dict[real_key]
+ else:
+ # get from self._children
+ if scope in self._children:
+ return self._children[scope].get(real_key)
+ else:
+ # goto root
+ parent = self.parent
+ while parent.parent is not None:
+ parent = parent.parent
+ return parent.get(key)
+
+ def build(self, *args, **kwargs):
+ return self.build_func(*args, **kwargs, registry=self)
+
+ def _add_children(self, registry):
+ """Add children for a registry.
+
+ The ``registry`` will be added as children based on its scope.
+ The parent registry could build objects from children registry.
+
+ Example:
+ >>> models = Registry('models')
+ >>> mmdet_models = Registry('models', parent=models)
+ >>> @mmdet_models.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
+ """
+
+ assert isinstance(registry, Registry)
+ assert registry.scope is not None
+ assert (
+ registry.scope not in self.children
+ ), f"scope {registry.scope} exists in {self.name} registry"
+ self.children[registry.scope] = registry
+
+ def _register_module(self, module_class, module_name=None, force=False):
+ if not inspect.isclass(module_class):
+ raise TypeError("module must be a class, " f"but got {type(module_class)}")
+
+ if module_name is None:
+ module_name = module_class.__name__
+ if isinstance(module_name, str):
+ module_name = [module_name]
+ for name in module_name:
+ if not force and name in self._module_dict:
+ raise KeyError(f"{name} is already registered " f"in {self.name}")
+ self._module_dict[name] = module_class
+
+ def deprecated_register_module(self, cls=None, force=False):
+ warnings.warn(
+ "The old API of register_module(module, force=False) "
+ "is deprecated and will be removed, please use the new API "
+ "register_module(name=None, force=False, module=None) instead."
+ )
+ if cls is None:
+ return partial(self.deprecated_register_module, force=force)
+ self._register_module(cls, force=force)
+ return cls
+
+ def register_module(self, name=None, force=False, module=None):
+ """Register a module.
+
+ A record will be added to `self._module_dict`, whose key is the class
+ name or the specified name, and value is the class itself.
+ It can be used as a decorator or a normal function.
+
+ Example:
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module()
+ >>> class ResNet:
+ >>> pass
+
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module(name='mnet')
+ >>> class MobileNet:
+ >>> pass
+
+ >>> backbones = Registry('backbone')
+ >>> class ResNet:
+ >>> pass
+ >>> backbones.register_module(ResNet)
+
+ Args:
+ name (str | None): The module name to be registered. If not
+ specified, the class name will be used.
+ force (bool, optional): Whether to override an existing class with
+ the same name. Default: False.
+ module (type): Module class to be registered.
+ """
+ if not isinstance(force, bool):
+ raise TypeError(f"force must be a boolean, but got {type(force)}")
+ # NOTE: This is a walkaround to be compatible with the old api,
+ # while it may introduce unexpected bugs.
+ if isinstance(name, type):
+ return self.deprecated_register_module(name, force=force)
+
+ # raise the error ahead of time
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
+ raise TypeError(
+ "name must be either of None, an instance of str or a sequence"
+ f" of str, but got {type(name)}"
+ )
+
+ # use it as a normal method: x.register_module(module=SomeClass)
+ if module is not None:
+ self._register_module(module_class=module, module_name=name, force=force)
+ return module
+
+ # use it as a decorator: @x.register_module()
+ def _register(cls):
+ self._register_module(module_class=cls, module_name=name, force=force)
+ return cls
+
+ return _register
diff --git a/UniRig/src/model/pointcept/utils/scheduler.py b/UniRig/src/model/pointcept/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e2e29fdde2e2668c023af36afdb89e73fb9ce53
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/scheduler.py
@@ -0,0 +1,147 @@
+"""
+Scheduler
+
+Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
+Please cite our work if the code is helpful to you.
+"""
+
+import torch.optim.lr_scheduler as lr_scheduler
+from .registry import Registry
+
+SCHEDULERS = Registry("schedulers")
+
+
+@SCHEDULERS.register_module()
+class MultiStepLR(lr_scheduler.MultiStepLR):
+ def __init__(
+ self,
+ optimizer,
+ milestones,
+ total_steps,
+ gamma=0.1,
+ last_epoch=-1,
+ verbose=False,
+ ):
+ super().__init__(
+ optimizer=optimizer,
+ milestones=[rate * total_steps for rate in milestones],
+ gamma=gamma,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+@SCHEDULERS.register_module()
+class MultiStepWithWarmupLR(lr_scheduler.LambdaLR):
+ def __init__(
+ self,
+ optimizer,
+ milestones,
+ total_steps,
+ gamma=0.1,
+ warmup_rate=0.05,
+ warmup_scale=1e-6,
+ last_epoch=-1,
+ verbose=False,
+ ):
+ milestones = [rate * total_steps for rate in milestones]
+
+ def multi_step_with_warmup(s):
+ factor = 1.0
+ for i in range(len(milestones)):
+ if s < milestones[i]:
+ break
+ factor *= gamma
+
+ if s <= warmup_rate * total_steps:
+ warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (
+ 1 - warmup_scale
+ )
+ else:
+ warmup_coefficient = 1.0
+ return warmup_coefficient * factor
+
+ super().__init__(
+ optimizer=optimizer,
+ lr_lambda=multi_step_with_warmup,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+@SCHEDULERS.register_module()
+class PolyLR(lr_scheduler.LambdaLR):
+ def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False):
+ super().__init__(
+ optimizer=optimizer,
+ lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+@SCHEDULERS.register_module()
+class ExpLR(lr_scheduler.LambdaLR):
+ def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False):
+ super().__init__(
+ optimizer=optimizer,
+ lr_lambda=lambda s: gamma ** (s / total_steps),
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+@SCHEDULERS.register_module()
+class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR):
+ def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False):
+ super().__init__(
+ optimizer=optimizer,
+ T_max=total_steps,
+ eta_min=eta_min,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+@SCHEDULERS.register_module()
+class OneCycleLR(lr_scheduler.OneCycleLR):
+ r"""
+ torch.optim.lr_scheduler.OneCycleLR, Block total_steps
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ max_lr,
+ total_steps=None,
+ pct_start=0.3,
+ anneal_strategy="cos",
+ cycle_momentum=True,
+ base_momentum=0.85,
+ max_momentum=0.95,
+ div_factor=25.0,
+ final_div_factor=1e4,
+ three_phase=False,
+ last_epoch=-1,
+ verbose=False,
+ ):
+ super().__init__(
+ optimizer=optimizer,
+ max_lr=max_lr,
+ total_steps=total_steps,
+ pct_start=pct_start,
+ anneal_strategy=anneal_strategy,
+ cycle_momentum=cycle_momentum,
+ base_momentum=base_momentum,
+ max_momentum=max_momentum,
+ div_factor=div_factor,
+ final_div_factor=final_div_factor,
+ three_phase=three_phase,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+def build_scheduler(cfg, optimizer):
+ cfg.optimizer = optimizer
+ return SCHEDULERS.build(cfg=cfg)
diff --git a/UniRig/src/model/pointcept/utils/timer.py b/UniRig/src/model/pointcept/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3de4a16e33c43fe61ea3088f82216fd62eb6e959
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/timer.py
@@ -0,0 +1,70 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# -*- coding: utf-8 -*-
+
+from time import perf_counter
+from typing import Optional
+
+
+class Timer:
+ """
+ A timer which computes the time elapsed since the start/reset of the timer.
+ """
+
+ def __init__(self) -> None:
+ self.reset()
+
+ def reset(self) -> None:
+ """
+ Reset the timer.
+ """
+ self._start = perf_counter()
+ self._paused: Optional[float] = None
+ self._total_paused = 0
+ self._count_start = 1
+
+ def pause(self) -> None:
+ """
+ Pause the timer.
+ """
+ if self._paused is not None:
+ raise ValueError("Trying to pause a Timer that is already paused!")
+ self._paused = perf_counter()
+
+ def is_paused(self) -> bool:
+ """
+ Returns:
+ bool: whether the timer is currently paused
+ """
+ return self._paused is not None
+
+ def resume(self) -> None:
+ """
+ Resume the timer.
+ """
+ if self._paused is None:
+ raise ValueError("Trying to resume a Timer that is not paused!")
+ # pyre-fixme[58]: `-` is not supported for operand types `float` and
+ # `Optional[float]`.
+ self._total_paused += perf_counter() - self._paused
+ self._paused = None
+ self._count_start += 1
+
+ def seconds(self) -> float:
+ """
+ Returns:
+ (float): the total number of seconds since the start/reset of the
+ timer, excluding the time when the timer is paused.
+ """
+ if self._paused is not None:
+ end_time: float = self._paused # type: ignore
+ else:
+ end_time = perf_counter()
+ return end_time - self._start - self._total_paused
+
+ def avg_seconds(self) -> float:
+ """
+ Returns:
+ (float): the average number of seconds between every start/reset and
+ pause.
+ """
+ return self.seconds() / self._count_start
diff --git a/UniRig/src/model/pointcept/utils/visualization.py b/UniRig/src/model/pointcept/utils/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a010dd8289f60119d1bfbccdff65edb908e24f6
--- /dev/null
+++ b/UniRig/src/model/pointcept/utils/visualization.py
@@ -0,0 +1,89 @@
+"""
+Visualization Utils
+
+Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
+Please cite our work if the code is helpful to you.
+"""
+
+import os
+import open3d as o3d
+import numpy as np
+import torch
+
+
+def to_numpy(x):
+ if isinstance(x, torch.Tensor):
+ x = x.clone().detach().cpu().numpy()
+ assert isinstance(x, np.ndarray)
+ return x
+
+
+def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None):
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ coord = to_numpy(coord)
+ if color is not None:
+ color = to_numpy(color)
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(coord)
+ pcd.colors = o3d.utility.Vector3dVector(
+ np.ones_like(coord) if color is None else color
+ )
+ o3d.io.write_point_cloud(file_path, pcd)
+ if logger is not None:
+ logger.info(f"Save Point Cloud to: {file_path}")
+
+
+def save_bounding_boxes(
+ bboxes_corners, color=(1.0, 0.0, 0.0), file_path="bbox.ply", logger=None
+):
+ bboxes_corners = to_numpy(bboxes_corners)
+ # point list
+ points = bboxes_corners.reshape(-1, 3)
+ # line list
+ box_lines = np.array(
+ [
+ [0, 1],
+ [1, 2],
+ [2, 3],
+ [3, 0],
+ [4, 5],
+ [5, 6],
+ [6, 7],
+ [7, 0],
+ [0, 4],
+ [1, 5],
+ [2, 6],
+ [3, 7],
+ ]
+ )
+ lines = []
+ for i, _ in enumerate(bboxes_corners):
+ lines.append(box_lines + i * 8)
+ lines = np.concatenate(lines)
+ # color list
+ color = np.array([color for _ in range(len(lines))])
+ # generate line set
+ line_set = o3d.geometry.LineSet()
+ line_set.points = o3d.utility.Vector3dVector(points)
+ line_set.lines = o3d.utility.Vector2iVector(lines)
+ line_set.colors = o3d.utility.Vector3dVector(color)
+ o3d.io.write_line_set(file_path, line_set)
+
+ if logger is not None:
+ logger.info(f"Save Boxes to: {file_path}")
+
+
+def save_lines(
+ points, lines, color=(1.0, 0.0, 0.0), file_path="lines.ply", logger=None
+):
+ points = to_numpy(points)
+ lines = to_numpy(lines)
+ colors = np.array([color for _ in range(len(lines))])
+ line_set = o3d.geometry.LineSet()
+ line_set.points = o3d.utility.Vector3dVector(points)
+ line_set.lines = o3d.utility.Vector2iVector(lines)
+ line_set.colors = o3d.utility.Vector3dVector(colors)
+ o3d.io.write_line_set(file_path, line_set)
+
+ if logger is not None:
+ logger.info(f"Save Lines to: {file_path}")
diff --git a/UniRig/src/model/spec.py b/UniRig/src/model/spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0201e9f779c802dfbe9ccb957f3e3925aca15cc
--- /dev/null
+++ b/UniRig/src/model/spec.py
@@ -0,0 +1,134 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+import numpy as np
+from numpy import ndarray
+from typing import Dict, Union, List, final
+import lightning.pytorch as pl
+
+from ..data.asset import Asset
+from ..data.augment import Augment
+
+@dataclass
+class ModelInput():
+ # tokens for ar input
+ tokens: Union[ndarray, None]=None
+
+ # pad token
+ pad: Union[int, None]=None
+
+ # vertices(usually sampled), (N, 3)
+ vertices: Union[ndarray, None]=None
+
+ # normals(usually sampled), (N, 3)
+ normals: Union[ndarray, None]=None
+
+ # joints
+ joints: Union[ndarray, None]=None
+
+ # tails
+ tails: Union[ndarray, None]=None
+
+ # assets for debug usage
+ asset: Union[Asset, None]=None
+
+ # augments asset used
+ augments: Union[Augment, None]=None
+
+class ModelSpec(pl.LightningModule, ABC):
+
+ @abstractmethod
+ def __init__(self):
+ super().__init__()
+
+ @final
+ def _process_fn(self, batch: List[ModelInput]) -> List[Dict]:
+ '''
+ Returns
+ cls: List[str]
+
+ path: List[str]
+
+ data_name: List[str]
+
+ joints: shape (B, J, 3), J==max_bones
+
+ tails: shape (B, J, 3)
+
+ parents: shape (B, J), -1 represents no parent(should always appear at 0-th position)
+
+ num_bones: shape (B), the true number of bones
+
+ skin: shape (B, J), padding value==0.
+
+ vertices: (B, N, 3)
+
+ normals: (B, N, 3)
+
+ matrix_local: (B, J, 4, 4), current matrix_local
+
+ pose_matrix: (B, J, 4, 4), for motion loss calculation
+ '''
+ n_batch = self.process_fn(batch)
+ BAN = ['cls', 'path', 'data_name', 'joints', 'tails', 'parents', 'num_bones', 'vertices',
+ 'normals', 'matrix_local', 'pose_matrix', 'num_points', 'origin_vertices',
+ 'origin_vertex_normals', 'origin_face_normals', 'num_faces', 'faces']
+ # skin should be in vertex group
+ max_bones = 0
+ max_points = 0
+ max_faces = 0
+ for b in batch:
+ if b.joints is not None:
+ max_bones = max(max_bones, b.asset.J)
+ max_faces = max(max_faces, b.asset.F)
+ max_points = max(max_points, b.asset.N)
+ self._augments = []
+ self._assets = []
+ for (id, b) in enumerate(batch):
+ for ban in BAN:
+ assert ban not in n_batch[id], f"cannot override `{ban}` in process_fn"
+ n_batch[id]['cls'] = b.asset.cls
+ n_batch[id]['path'] = b.asset.path
+ n_batch[id]['data_name'] = b.asset.data_name
+ if b.asset.joints is not None:
+ n_batch[id]['joints'] = np.pad(b.asset.joints, ((0, max_bones-b.asset.J), (0, 0)), mode='constant', constant_values=0.)
+ n_batch[id]['num_bones'] = b.asset.J
+ if b.asset.tails is not None:
+ n_batch[id]['tails'] = np.pad(b.asset.tails, ((0, max_bones-b.asset.J), (0, 0)), mode='constant', constant_values=0.)
+ if b.asset.parents is not None:
+ parents = b.asset.parents.copy() # cannot put None into dict
+ parents[0] = -1
+ parents = np.pad(parents, (0, max_bones-b.asset.J), 'constant', constant_values=-1)
+ n_batch[id]['parents'] = parents
+ if b.asset.matrix_local is not None:
+ J = b.asset.J
+ matrix_local = np.pad(b.asset.matrix_local, ((0, max_bones-J), (0, 0), (0, 0)), 'constant', constant_values=0.)
+ # set identity to prevent singular matrix in lbs
+ matrix_local[J:, 0, 0] = 1.
+ matrix_local[J:, 1, 1] = 1.
+ matrix_local[J:, 2, 2] = 1.
+ matrix_local[J:, 3, 3] = 1.
+ n_batch[id]['matrix_local'] = matrix_local
+ if b.asset.pose_matrix is not None:
+ J = b.asset.J
+ pose_matrix = np.pad(b.asset.pose_matrix, ((0, max_bones-J), (0, 0), (0, 0)), 'constant', constant_values=0.)
+ pose_matrix[J:, 0, 0] = 1.
+ pose_matrix[J:, 1, 1] = 1.
+ pose_matrix[J:, 2, 2] = 1.
+ pose_matrix[J:, 3, 3] = 1.
+ n_batch[id]['pose_matrix'] = pose_matrix
+ n_batch[id]['vertices'] = b.vertices
+ n_batch[id]['normals'] = b.normals
+ n_batch[id]['num_points'] = b.asset.N
+ n_batch[id]['origin_vertices'] = np.pad(b.asset.vertices, ((0, max_points-b.asset.N), (0, 0)))
+ n_batch[id]['origin_vertex_normals'] = np.pad(b.asset.vertex_normals, ((0, max_points-b.asset.N), (0, 0)))
+ n_batch[id]['num_faces'] = b.asset.F
+ n_batch[id]['origin_faces'] = np.pad(b.asset.faces, ((0, max_faces-b.asset.F), (0, 0)))
+ n_batch[id]['origin_face_normals'] = np.pad(b.asset.face_normals, ((0, max_faces-b.asset.F), (0, 0)))
+ return n_batch
+
+ @abstractmethod
+ def process_fn(self, batch: List[ModelInput]) -> Dict:
+ '''
+ Fetch data from dataloader and turn it into Tensor objects.
+ '''
+ pass
\ No newline at end of file
diff --git a/UniRig/src/model/unirig_ar.py b/UniRig/src/model/unirig_ar.py
new file mode 100644
index 0000000000000000000000000000000000000000..260f25923c3c10f310ea0fc061b938d24f9da233
--- /dev/null
+++ b/UniRig/src/model/unirig_ar.py
@@ -0,0 +1,179 @@
+import torch
+from torch import nn, FloatTensor, LongTensor
+import numpy as np
+from torch.nn.functional import pad
+from typing import Dict, List, Union
+from transformers import AutoModelForCausalLM, AutoConfig
+
+from .spec import ModelSpec, ModelInput
+from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder
+
+from ..tokenizer.spec import TokenizerSpec, DetokenzeOutput
+from copy import deepcopy
+
+class UniRigAR(ModelSpec):
+
+ def process_fn(self, batch: List[ModelInput]) -> List[Dict]:
+ if batch[0].joints is None: # predict
+ return [{} for _ in range(len(batch))]
+ max_length = 0
+ for b in batch:
+ max_length = max(max_length, b.tokens.shape[0])
+ res = [{
+ 'input_ids': np.pad(b.tokens, ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=b.pad),
+ 'attention_mask': np.pad(torch.ones(b.tokens.shape[0]), ((0, max_length - b.tokens.shape[0])), 'constant', constant_values=0.),
+ } for b in batch]
+ return res
+
+ def __init__(self, llm, mesh_encoder, **kwargs):
+ super().__init__()
+ self.tokenizer: TokenizerSpec = kwargs.get('tokenizer')
+ self.vocab_size = self.tokenizer.vocab_size
+
+ _d = llm.copy()
+ _d['vocab_size'] = self.tokenizer.vocab_size
+ llm_config = AutoConfig.from_pretrained(**_d)
+ # Force float32 precision for the model
+ llm_config.torch_dtype = torch.float32
+ # Force enable pre_norm
+ llm_config.pre_norm = True
+ self.transformer = AutoModelForCausalLM.from_config(config=llm_config)
+
+ self.hidden_size = llm.hidden_size
+
+ self.mesh_encoder = get_mesh_encoder(**mesh_encoder)
+
+ if (
+ isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo) or
+ isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo_encoder)
+ ):
+ self.output_proj = nn.Linear(self.mesh_encoder.width, self.hidden_size)
+ else:
+ raise NotImplementedError()
+
+ def encode_mesh_cond(self, vertices: FloatTensor, normals: FloatTensor) -> FloatTensor:
+ assert not torch.isnan(vertices).any()
+ assert not torch.isnan(normals).any()
+ if (
+ isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo) or
+ isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo_encoder)
+ ):
+ if (len(vertices.shape) == 3):
+ shape_embed, latents, token_num, pre_pc = self.mesh_encoder.encode_latents(pc=vertices, feats=normals)
+ else:
+ shape_embed, latents, token_num, pre_pc = self.mesh_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0))
+ latents = self.output_proj(latents)
+ return latents
+ else:
+ raise NotImplementedError()
+
+ def training_step(self, batch: Dict) -> Dict[str, FloatTensor]:
+ cond = self.encode_mesh_cond(vertices=batch['vertices'], normals=batch['normals']).to(dtype=self.transformer.dtype)
+ B = cond.shape[0]
+ input_ids: LongTensor = batch['input_ids']
+ inputs_embeds = self.transformer.get_input_embeddings()(input_ids).to(dtype=self.transformer.dtype)
+
+ inputs_embeds = torch.concat([cond, inputs_embeds], dim=1)
+
+ attention_mask = batch['attention_mask']
+ # add attention to condition
+ attention_mask = pad(attention_mask, (cond.shape[1], 0, 0, 0), value=1.)
+ output = self.transformer(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ )
+
+ # (B, L, vocab_size)
+ logit = output.logits[:, cond.shape[1]:].reshape(B, -1, self.vocab_size)
+ # compute loss with shift one-token right
+ device = logit.device # (B, n, num_discrete)
+ logit = logit[:, :-1] # (B, n)
+ num_discrete = self.tokenizer.num_discrete
+ s = torch.nn.functional.softmax(logit, dim=-1)
+
+ label = input_ids[:, 1:].clone() # (B, n)
+ mask = label < num_discrete
+ dis = torch.arange(num_discrete, device=device).view(1, 1, -1) # (B, n, num_discrete)
+ dis = (dis - label.unsqueeze(2).repeat(1, 1, num_discrete)).type(torch.float32) / num_discrete
+ dis_loss = (s[:, :, :num_discrete] * torch.abs(dis))[mask].sum() / 50 # ignore padding loss
+
+ label[attention_mask[:, cond.shape[1] + 1:]==0] = -100
+
+ assert not torch.isnan(logit).any(), logit
+ ce_loss = nn.functional.cross_entropy(logit.permute(0, 2, 1), label)
+ return {
+ 'ce_loss': ce_loss,
+ 'dis_loss': dis_loss,
+ }
+
+ def forward(self, data: Dict):
+ return self.training_step(data=data)
+
+ @torch.no_grad()
+ def generate(
+ self,
+ vertices: FloatTensor,
+ normals: FloatTensor,
+ cls: Union[str, None]=None,
+ **kwargs,
+ ) -> DetokenzeOutput:
+ '''
+ Do not support batch!
+ '''
+ cond = self.encode_mesh_cond(vertices=vertices, normals=normals).to(dtype=self.transformer.dtype)
+
+ start_tokens = [self.tokenizer.bos]
+
+ if cls is not None:
+ start_tokens.append(self.tokenizer.cls_name_to_token(cls=cls))
+ start_embed = self.transformer.get_input_embeddings()(
+ torch.tensor(start_tokens, dtype=torch.long, device=cond.device).unsqueeze(0)
+ ).to(dtype=self.transformer.dtype)
+ cond = torch.cat([cond, start_embed], dim=1)
+
+ results = self.transformer.generate(
+ inputs_embeds=cond,
+ bos_token_id=self.tokenizer.bos,
+ eos_token_id=self.tokenizer.eos,
+ pad_token_id=self.tokenizer.pad,
+ **kwargs,
+ )
+ output_ids = results[0, :]
+ for token in reversed(start_tokens):
+ output_ids = pad(output_ids, (1, 0), value=token)
+ output_ids = output_ids.detach().cpu().numpy()
+
+ res = self.tokenizer.detokenize(ids=output_ids)
+ return res
+
+ def predict_step(self, batch: Dict, no_cls: bool=False):
+ vertices: FloatTensor = batch['vertices']
+ normals : FloatTensor = batch['normals']
+ paths : List[str] = batch['path']
+ cls = batch['cls']
+ generate_kwargs = deepcopy(batch['generate_kwargs'])
+
+ no_cls = generate_kwargs.get('no_cls', False)
+ use_dir_cls = generate_kwargs.get('use_dir_cls', False)
+ assign_cls = generate_kwargs.get('assign_cls', None)
+
+ generate_kwargs.pop('no_cls', None)
+ generate_kwargs.pop('use_dir_cls', None)
+ generate_kwargs.pop('assign_cls', None)
+
+ if vertices.dim() == 2:
+ vertices = vertices.unsqueeze(0)
+ normals = normals.unsqueeze(0)
+ outputs = []
+ for i in range(vertices.shape[0]):
+ if no_cls:
+ _cls = None
+ elif assign_cls is not None:
+ _cls = assign_cls
+ elif use_dir_cls:
+ _cls = paths[i].removeprefix('./').split('/')[0]
+ else:
+ _cls = cls[i]
+ res = self.generate(vertices=vertices[i], normals=normals[i], cls=_cls, **generate_kwargs)
+ outputs.append(res)
+ return outputs
\ No newline at end of file
diff --git a/UniRig/src/model/unirig_skin.py b/UniRig/src/model/unirig_skin.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e4212e1974c6036255bb638995bf2e5bebf676c
--- /dev/null
+++ b/UniRig/src/model/unirig_skin.py
@@ -0,0 +1,440 @@
+import torch
+from torch import nn, FloatTensor, LongTensor, Tensor
+import torch.nn.functional as F
+import numpy as np
+from torch.nn.functional import pad
+from typing import Dict, List
+from transformers import AutoModelForCausalLM, AutoConfig
+import math
+import torch_scatter
+from flash_attn.modules.mha import MHA
+
+from .spec import ModelSpec, ModelInput
+from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder
+
+from ..data.utils import linear_blend_skinning
+
+class FrequencyPositionalEmbedding(nn.Module):
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
+ each feature dimension of `x[..., i]` into:
+ [
+ sin(x[..., i]),
+ sin(f_1*x[..., i]),
+ sin(f_2*x[..., i]),
+ ...
+ sin(f_N * x[..., i]),
+ cos(x[..., i]),
+ cos(f_1*x[..., i]),
+ cos(f_2*x[..., i]),
+ ...
+ cos(f_N * x[..., i]),
+ x[..., i] # only present if include_input is True.
+ ], here f_i is the frequency.
+
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
+
+ Args:
+ num_freqs (int): the number of frequencies, default is 6;
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
+ input_dim (int): the input dimension, default is 3;
+ include_input (bool): include the input tensor or not, default is True.
+
+ Attributes:
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
+
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
+ otherwise, it is input_dim * num_freqs * 2.
+
+ """
+
+ def __init__(
+ self,
+ num_freqs: int = 6,
+ logspace: bool = True,
+ input_dim: int = 3,
+ include_input: bool = True,
+ include_pi: bool = True,
+ ) -> None:
+ """The initialization"""
+
+ super().__init__()
+
+ if logspace:
+ frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
+ else:
+ frequencies = torch.linspace(
+ 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
+ )
+
+ if include_pi:
+ frequencies *= torch.pi
+
+ self.register_buffer("frequencies", frequencies, persistent=False)
+ self.include_input = include_input
+ self.num_freqs = num_freqs
+
+ self.out_dim = self._get_dims(input_dim)
+
+ def _get_dims(self, input_dim):
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
+
+ return out_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward process.
+
+ Args:
+ x: tensor of shape [..., dim]
+
+ Returns:
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
+ where temp is 1 if include_input is True and 0 otherwise.
+ """
+
+ if self.num_freqs > 0:
+ embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device)).view(
+ *x.shape[:-1], -1
+ )
+ if self.include_input:
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
+ else:
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
+ else:
+ return x
+
+class ResidualCrossAttn(nn.Module):
+ def __init__(self, feat_dim: int, num_heads: int):
+ super().__init__()
+ assert feat_dim % num_heads == 0, "feat_dim must be divisible by num_heads"
+
+ self.norm1 = nn.LayerNorm(feat_dim)
+ self.norm2 = nn.LayerNorm(feat_dim)
+ # self.attention = nn.MultiheadAttention(embed_dim=feat_dim, num_heads=num_heads, batch_first=True)
+ self.attention = MHA(embed_dim=feat_dim, num_heads=num_heads, cross_attn=True)
+ self.ffn = nn.Sequential(
+ nn.Linear(feat_dim, feat_dim * 4),
+ nn.GELU(),
+ nn.Linear(feat_dim * 4, feat_dim),
+ )
+
+ def forward(self, q, kv):
+ residual = q
+ attn_output = self.attention(q, x_kv=kv)
+ x = self.norm1(residual + attn_output)
+ x = self.norm2(x + self.ffn(x))
+ return x
+
+class BoneEncoder(nn.Module):
+ def __init__(
+ self,
+ feat_bone_dim: int,
+ feat_dim: int,
+ embed_dim: int,
+ num_heads: int,
+ num_attn: int,
+ ):
+ super().__init__()
+ self.feat_bone_dim = feat_bone_dim
+ self.feat_dim = feat_dim
+ self.num_heads = num_heads
+ self.num_attn = num_attn
+
+ self.position_embed = FrequencyPositionalEmbedding(input_dim=self.feat_bone_dim)
+
+ self.bone_encoder = nn.Sequential(
+ self.position_embed,
+ nn.Linear(self.position_embed.out_dim, embed_dim),
+ nn.LayerNorm(embed_dim),
+ nn.GELU(),
+ nn.Linear(embed_dim, embed_dim * 4),
+ nn.LayerNorm(embed_dim * 4),
+ nn.GELU(),
+ nn.Linear(embed_dim * 4, feat_dim),
+ nn.LayerNorm(feat_dim),
+ nn.GELU(),
+ )
+ self.attn = nn.ModuleList()
+ for _ in range(self.num_attn):
+ self.attn.append(ResidualCrossAttn(feat_dim, self.num_heads))
+
+ def forward(
+ self,
+ base_bone: FloatTensor,
+ num_bones: LongTensor,
+ parents: LongTensor,
+ min_coord: FloatTensor,
+ global_latents: FloatTensor,
+ ):
+ # base_bone: (B, J, C)
+ B = base_bone.shape[0]
+ J = base_bone.shape[1]
+ x = self.bone_encoder((base_bone-min_coord[:, None, :]).reshape(-1, base_bone.shape[-1])).reshape(B, J, -1)
+
+ latents = torch.cat([x, global_latents], dim=1)
+
+ for (i, attn) in enumerate(self.attn):
+ x = attn(x, latents)
+ return x
+
+class SkinweightPred(nn.Module):
+ def __init__(self, in_dim, mlp_dim):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(in_dim, mlp_dim),
+ nn.LayerNorm(mlp_dim),
+ nn.GELU(),
+ nn.Linear(mlp_dim, mlp_dim),
+ nn.LayerNorm(mlp_dim),
+ nn.GELU(),
+ nn.Linear(mlp_dim, mlp_dim),
+ nn.LayerNorm(mlp_dim),
+ nn.GELU(),
+ nn.Linear(mlp_dim, mlp_dim),
+ nn.LayerNorm(mlp_dim),
+ nn.GELU(),
+ nn.Linear(mlp_dim, 1),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+class UniRigSkin(ModelSpec):
+
+ def process_fn(self, batch: List[ModelInput]) -> List[Dict]:
+ max_bones = 0
+ for b in batch:
+ max_bones = max(max_bones, b.asset.J)
+ res = []
+ current_offset = 0
+ for b in batch:
+ vertex_groups = b.asset.sampled_vertex_groups
+ current_offset += b.vertices.shape[0]
+ # (N, J)
+ voxel_skin = vertex_groups['voxel_skin']
+
+ voxel_skin = np.pad(voxel_skin, ((0, 0), (0, max_bones-b.asset.J)), 'constant', constant_values=0.0)
+
+ # (J, 4, 4)
+ res.append({
+ 'voxel_skin': voxel_skin,
+ 'offset': current_offset,
+ })
+ return res
+
+ def __init__(self, mesh_encoder, global_encoder, **kwargs):
+ super().__init__()
+
+ self.num_train_vertex = kwargs['num_train_vertex']
+ self.feat_dim = kwargs['feat_dim']
+ self.num_heads = kwargs['num_heads']
+ self.grid_size = kwargs['grid_size']
+ self.mlp_dim = kwargs['mlp_dim']
+ self.num_bone_attn = kwargs['num_bone_attn']
+ self.num_mesh_bone_attn = kwargs['num_mesh_bone_attn']
+ self.bone_embed_dim = kwargs['bone_embed_dim']
+ self.voxel_mask = kwargs.get('voxel_mask', 2)
+
+ self.mesh_encoder = get_mesh_encoder(**mesh_encoder)
+ self.global_encoder = get_mesh_encoder(**global_encoder)
+ if isinstance(self.mesh_encoder, MAP_MESH_ENCODER.ptv3obj):
+ self.feat_map = nn.Sequential(
+ nn.Linear(mesh_encoder['enc_channels'][-1], self.feat_dim),
+ nn.LayerNorm(self.feat_dim),
+ nn.GELU(),
+ )
+ else:
+ raise NotImplementedError()
+ if isinstance(self.global_encoder, MAP_MESH_ENCODER.michelangelo_encoder):
+ self.out_proj = nn.Sequential(
+ nn.Linear(self.global_encoder.width, self.feat_dim),
+ nn.LayerNorm(self.feat_dim),
+ nn.GELU(),
+ )
+ else:
+ raise NotImplementedError()
+
+ self.bone_encoder = BoneEncoder(
+ feat_bone_dim=3,
+ feat_dim=self.feat_dim,
+ embed_dim=self.bone_embed_dim,
+ num_heads=self.num_heads,
+ num_attn=self.num_bone_attn,
+ )
+
+ self.downscale = nn.Sequential(
+ nn.Linear(2 * self.num_heads, self.num_heads),
+ nn.LayerNorm(self.num_heads),
+ nn.GELU(),
+ )
+ self.skinweight_pred = SkinweightPred(
+ self.num_heads,
+ self.mlp_dim,
+ )
+
+ self.mesh_bone_attn = nn.ModuleList()
+ self.mesh_bone_attn.extend([
+ ResidualCrossAttn(self.feat_dim, self.num_heads) for _ in range(self.num_mesh_bone_attn)
+ ])
+
+ self.qmesh = nn.Linear(self.feat_dim, self.feat_dim * self.num_heads)
+ self.kmesh = nn.Linear(self.feat_dim, self.feat_dim * self.num_heads)
+
+ self.voxel_skin_embed = nn.Linear(1, self.num_heads)
+ self.voxel_skin_norm = nn.LayerNorm(self.num_heads)
+ self.attn_skin_norm = nn.LayerNorm(self.num_heads)
+
+ def encode_mesh_cond(self, vertices: FloatTensor, normals: FloatTensor) -> FloatTensor:
+ assert not torch.isnan(vertices).any()
+ assert not torch.isnan(normals).any()
+ if isinstance(self.global_encoder, MAP_MESH_ENCODER.michelangelo_encoder):
+ if (len(vertices.shape) == 3):
+ shape_embed, latents, token_num, pre_pc = self.global_encoder.encode_latents(pc=vertices, feats=normals)
+ else:
+ shape_embed, latents, token_num, pre_pc = self.global_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0))
+ latents = self.out_proj(latents)
+ return latents
+ else:
+ raise NotImplementedError()
+
+ def _get_predict(self, batch: Dict) -> FloatTensor:
+ '''
+ Return predicted skin.
+ '''
+
+ num_bones: Tensor = batch['num_bones']
+ vertices: FloatTensor = batch['vertices'] # (B, N, 3)
+ normals: FloatTensor = batch['normals']
+ joints: FloatTensor = batch['joints']
+ tails: FloatTensor = batch['tails']
+ voxel_skin: FloatTensor = batch['voxel_skin']
+ parents: LongTensor = batch['parents']
+
+ # turn inputs' dtype into model's dtype
+ dtype = next(self.parameters()).dtype
+ vertices = vertices.type(dtype)
+ normals = normals.type(dtype)
+ joints = joints.type(dtype)
+ tails = tails.type(dtype)
+ voxel_skin = voxel_skin.type(dtype)
+
+ B = vertices.shape[0]
+ N = vertices.shape[1]
+ J = joints.shape[1]
+
+ assert vertices.dim() == 3
+ assert normals.dim() == 3
+
+ part_offset = torch.tensor([(i+1)*N for i in range(B)], dtype=torch.int64, device=vertices.device)
+ idx_ptr = torch.nn.functional.pad(part_offset, (1, 0), value=0)
+ min_coord = torch_scatter.segment_csr(vertices.reshape(-1, 3), idx_ptr, reduce="min")
+
+ pack = []
+ if self.training:
+ train_indices = torch.randperm(N)[:self.num_train_vertex]
+ pack.append(train_indices)
+ else:
+ for i in range((N + self.num_train_vertex - 1) // self.num_train_vertex):
+ pack.append(torch.arange(i*self.num_train_vertex, min((i+1)*self.num_train_vertex, N)))
+
+ # (B, seq_len, feat_dim)
+ global_latents = self.encode_mesh_cond(vertices, normals)
+ bone_feat = self.bone_encoder(
+ base_bone=joints,
+ num_bones=num_bones,
+ parents=parents,
+ min_coord=min_coord,
+ global_latents=global_latents,
+ )
+
+ if isinstance(self.mesh_encoder, MAP_MESH_ENCODER.ptv3obj):
+ feat = torch.cat([vertices, normals, torch.zeros_like(vertices)], dim=-1)
+ ptv3_input = {
+ 'coord': vertices.reshape(-1, 3),
+ 'feat': feat.reshape(-1, 9),
+ 'offset': torch.tensor(batch['offset']),
+ 'grid_size': self.grid_size,
+ }
+ if not self.training:
+ # must cast to float32 to avoid sparse-conv precision bugs
+ with torch.autocast(device_type='cuda', dtype=torch.float32):
+ mesh_feat = self.mesh_encoder(ptv3_input).feat
+ mesh_feat = self.feat_map(mesh_feat).view(B, N, self.feat_dim)
+ else:
+ mesh_feat = self.mesh_encoder(ptv3_input).feat
+ mesh_feat = self.feat_map(mesh_feat).view(B, N, self.feat_dim)
+ mesh_feat = mesh_feat.type(dtype)
+ else:
+ raise NotImplementedError()
+
+ # (B, J + seq_len, feat_dim)
+ latents = torch.cat([bone_feat, global_latents], dim=1)
+ # (B, N, feat_dim)
+ for block in self.mesh_bone_attn:
+ mesh_feat = block(
+ q=mesh_feat,
+ kv=latents,
+ )
+
+ # trans to (B, num_heads, J, feat_dim)
+ bone_feat = self.kmesh(bone_feat).view(B, J, self.num_heads, self.feat_dim).transpose(1, 2)
+
+ skin_pred_list = []
+ if not self.training:
+ skin_mask = voxel_skin.clone()
+ for b in range(B):
+ num = num_bones[b]
+ for i in range(num):
+ p = parents[b, i]
+ if p < 0:
+ continue
+ skin_mask[b, :, p] += skin_mask[b, :, i]
+ for indices in pack:
+ cur_N = len(indices)
+ # trans to (B, num_heads, N, feat_dim)
+ cur_mesh_feat = self.qmesh(mesh_feat[:, indices]).view(B, cur_N, self.num_heads, self.feat_dim).transpose(1, 2)
+
+ # attn_weight shape : (B, num_heads, N, J)
+ attn_weight = F.softmax(torch.bmm(
+ cur_mesh_feat.reshape(B * self.num_heads, cur_N, -1),
+ bone_feat.transpose(-2, -1).reshape(B * self.num_heads, -1, J)
+ ) / math.sqrt(self.feat_dim), dim=-1, dtype=dtype)
+ # (B, num_heads, N, J) -> (B, N, J, num_heads)
+ attn_weight = attn_weight.reshape(B, self.num_heads, cur_N, J).permute(0, 2, 3, 1)
+ attn_weight = self.attn_skin_norm(attn_weight)
+
+ embed_voxel_skin = self.voxel_skin_embed(voxel_skin[:, indices].reshape(B, cur_N, J, 1))
+ embed_voxel_skin = self.voxel_skin_norm(embed_voxel_skin)
+
+ attn_weight = torch.cat([attn_weight, embed_voxel_skin], dim=-1)
+ attn_weight = self.downscale(attn_weight)
+
+ # (B, N, J, num_heads * (1+c)) -> (B, N, J)
+ skin_pred = torch.zeros(B, cur_N, J).to(attn_weight.device, dtype)
+ for i in range(B):
+ # (N*J, C)
+ input_features = attn_weight[i, :, :num_bones[i], :].reshape(-1, attn_weight.shape[-1])
+
+ pred = self.skinweight_pred(input_features).reshape(cur_N, num_bones[i])
+ skin_pred[i, :, :num_bones[i]] = F.softmax(pred)
+ skin_pred_list.append(skin_pred)
+ skin_pred_list = torch.cat(skin_pred_list, dim=1)
+ for i in range(B):
+ n = num_bones[i]
+ skin_pred_list[i, :, :n] = skin_pred_list[i, :, :n] * torch.pow(skin_mask[i, :, :n], self.voxel_mask)
+ skin_pred_list[i, :, :n] = skin_pred_list[i, :, :n] / skin_pred_list[i, :, :n].sum(dim=-1, keepdim=True)
+ return skin_pred_list, torch.cat(pack, dim=0)
+
+ def predict_step(self, batch: Dict):
+ with torch.no_grad():
+ num_bones: Tensor = batch['num_bones']
+
+ skin_pred, _ = self._get_predict(batch=batch)
+ outputs = []
+ for i in range(skin_pred.shape[0]):
+ outputs.append(skin_pred[i, :, :num_bones[i]])
+ return outputs
\ No newline at end of file
diff --git a/UniRig/src/system/__init__.py b/UniRig/src/system/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/UniRig/src/system/ar.py b/UniRig/src/system/ar.py
new file mode 100644
index 0000000000000000000000000000000000000000..a07d506034b8e0a6251eb0becd865cfb3f2f163a
--- /dev/null
+++ b/UniRig/src/system/ar.py
@@ -0,0 +1,165 @@
+from collections import defaultdict
+import lightning as L
+import os
+import torch
+import numpy as np
+from torch import Tensor
+from typing import Dict, Union, List
+from lightning.pytorch.callbacks import BasePredictionWriter
+
+from numpy import ndarray
+
+from ..data.raw_data import RawData
+from ..data.order import OrderConfig, get_order
+from ..model.spec import ModelSpec
+from ..tokenizer.spec import DetokenzeOutput
+
+class ARSystem(L.LightningModule):
+
+ def __init__(
+ self,
+ steps_per_epoch: int,
+ model: ModelSpec,
+ generate_kwargs: Dict={},
+ output_path: Union[str, None]=None,
+ record_res: Union[bool]=False,
+ validate_cast: str='bfloat16',
+ val_interval: Union[int, None]=None,
+ val_start_from: Union[int, None]=None,
+ ):
+ super().__init__()
+ self.save_hyperparameters(ignore="model")
+ self.steps_per_epoch = steps_per_epoch
+ self.model = model
+ self.generate_kwargs = generate_kwargs
+ self.output_path = output_path
+ self.record_res = record_res
+ self.validate_cast = validate_cast
+ self.val_interval = val_interval
+ self.val_start_from = val_start_from
+
+ if self.record_res:
+ assert self.output_path is not None, "record_res is True, but output_path in ar is None"
+
+ def _predict_step(self, batch, batch_idx, dataloader_idx=None):
+ batch['generate_kwargs'] = self.generate_kwargs
+ res = self.model.predict_step(batch)
+ assert isinstance(res, list), f"expect type of prediction from {self.model.__class__} to be a list, found: {type(res)}"
+ return res
+
+ def predict_step(self, batch, batch_idx, dataloader_idx=None):
+ try:
+ prediction: List[DetokenzeOutput] = self._predict_step(batch=batch, batch_idx=batch_idx, dataloader_idx=dataloader_idx)
+ return prediction
+ except Exception as e:
+ print(str(e))
+ return []
+
+class ARWriter(BasePredictionWriter):
+ def __init__(
+ self,
+ output_dir: Union[str, None],
+ order_config: Union[OrderConfig, None]=None,
+ **kwargs
+ ):
+ super().__init__('batch')
+ self.output_dir = output_dir
+ self.npz_dir = kwargs.get('npz_dir', None)
+ self.user_mode = kwargs.get('user_mode', False)
+ self.output_name = kwargs.get('output_name', None) # for a single name
+ self.repeat = kwargs.get('repeat', 1)
+ self.add_num = kwargs.get('add_num', False)
+ self.export_npz = kwargs.get('export_npz', None)
+ self.export_obj = kwargs.get('export_obj', None)
+ self.export_fbx = kwargs.get('export_fbx', None)
+ self.export_pc = kwargs.get('export_pc', None)
+ if order_config is not None:
+ self.order = get_order(config=order_config)
+ else:
+ self.order = None
+
+ self._epoch = 0
+
+ def on_predict_end(self, trainer, pl_module):
+ if self._epoch < self.repeat - 1:
+ print(f"Finished prediction run {self._epoch + 1}/{self.repeat}, starting next run...")
+ self._epoch += 1
+ trainer.predict_dataloader = trainer.datamodule.predict_dataloader()
+ trainer.predict_loop.run()
+
+ def write_on_batch_end(self, trainer, pl_module: ARSystem, prediction: List[Dict], batch_indices, batch, batch_idx, dataloader_idx):
+ assert 'path' in batch
+ paths = batch['path']
+ detokenize_output_list: List[DetokenzeOutput] = prediction
+ vertices = batch['vertices']
+
+ origin_vertices = batch['origin_vertices']
+ origin_vertex_normals = batch['origin_vertex_normals']
+ origin_faces = batch['origin_faces']
+ origin_face_normals = batch['origin_face_normals']
+ num_points = batch['num_points']
+ num_faces = batch['num_faces']
+
+ if isinstance(origin_vertices, torch.Tensor):
+ origin_vertices = origin_vertices.detach().cpu().numpy()
+ if isinstance(origin_vertex_normals, torch.Tensor):
+ origin_vertex_normals = origin_vertex_normals.detach().cpu().numpy()
+ if isinstance(origin_faces, torch.Tensor):
+ origin_faces = origin_faces.detach().cpu().numpy()
+ if isinstance(origin_face_normals, torch.Tensor):
+ origin_face_normals = origin_face_normals.detach().cpu().numpy()
+ if isinstance(num_points, torch.Tensor):
+ num_points = num_points.detach().cpu().numpy()
+ if isinstance(num_faces, torch.Tensor):
+ num_faces = num_faces.detach().cpu().numpy()
+
+ for (id, detokenize_output) in enumerate(detokenize_output_list):
+ assert isinstance(detokenize_output, DetokenzeOutput), f"expect item of the list to be DetokenzeOutput, found: {type(detokenize_output)}"
+ def make_path(save_name: str, suffix: str, trim: bool=False):
+ if trim:
+ path = os.path.relpath(paths[id], self.npz_dir)
+ else:
+ path = paths[id]
+
+ if self.output_dir is not None:
+ path = os.path.join(self.output_dir, path)
+
+ if self.add_num:
+ path = os.path.join(path, f"{save_name}_{self._epoch}.{suffix}")
+ else:
+ path = os.path.join(path, f"{save_name}.{suffix}")
+ return path
+
+ num_p = num_points[id]
+ num_f = num_faces[id]
+
+ raw_data = RawData(
+ vertices=origin_vertices[id, :num_p],
+ vertex_normals=origin_vertex_normals[id, :num_p],
+ faces=origin_faces[id, :num_f],
+ face_normals=origin_face_normals[id, :num_f],
+ joints=detokenize_output.joints,
+ tails=detokenize_output.tails,
+ parents=detokenize_output.parents,
+ skin=None,
+ no_skin=detokenize_output.no_skin,
+ names=detokenize_output.names,
+ matrix_local=None,
+ path=None,
+ cls=detokenize_output.cls,
+ )
+ if not self.user_mode and self.export_npz is not None:
+ print(make_path(self.export_npz, 'npz'))
+ raw_data.save(path=make_path(self.export_npz, 'npz'))
+ if not self.user_mode and self.export_obj is not None:
+ raw_data.export_skeleton(path=make_path(self.export_obj, 'obj'))
+ if not self.user_mode and self.export_pc is not None:
+ raw_data.export_pc(path=make_path(self.export_pc, 'obj'))
+ if self.export_fbx is not None:
+ if not self.user_mode:
+ raw_data.export_fbx(path=make_path(self.export_fbx, 'fbx'))
+ else:
+ if self.output_name is not None:
+ raw_data.export_fbx(path=self.output_name)
+ else:
+ raw_data.export_fbx(path=make_path(self.export_fbx, 'fbx', trim=True))
\ No newline at end of file
diff --git a/UniRig/src/system/parse.py b/UniRig/src/system/parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..114cc2d16b60420ad63ce0a36fcb2ec3f574a51f
--- /dev/null
+++ b/UniRig/src/system/parse.py
@@ -0,0 +1,27 @@
+import torch
+from torch.optim import Optimizer
+from lightning.pytorch import LightningModule
+from lightning.pytorch.callbacks import BasePredictionWriter
+
+from .ar import ARSystem, ARWriter
+from .skin import SkinSystem, SkinWriter
+
+def get_system(**kwargs) -> LightningModule:
+ MAP = {
+ 'ar': ARSystem,
+ 'skin': SkinSystem,
+ }
+ __target__ = kwargs['__target__']
+ assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
+ del kwargs['__target__']
+ return MAP[__target__](**kwargs)
+
+def get_writer(**kwargs) -> BasePredictionWriter:
+ MAP = {
+ 'ar': ARWriter,
+ 'skin': SkinWriter,
+ }
+ __target__ = kwargs['__target__']
+ assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
+ del kwargs['__target__']
+ return MAP[__target__](**kwargs)
\ No newline at end of file
diff --git a/UniRig/src/system/skin.py b/UniRig/src/system/skin.py
new file mode 100644
index 0000000000000000000000000000000000000000..4548c3ea69a225ce6e1e5594d63997225aa6482a
--- /dev/null
+++ b/UniRig/src/system/skin.py
@@ -0,0 +1,262 @@
+from collections import defaultdict
+
+import torch.distributed
+import lightning as L
+import os
+import torch
+import numpy as np
+from torch import Tensor, FloatTensor, LongTensor
+from typing import Dict, Union, List, Literal
+from lightning.pytorch.callbacks import BasePredictionWriter
+
+from numpy import ndarray
+from scipy.sparse import csr_matrix
+from scipy.spatial import cKDTree
+
+from ..data.order import OrderConfig, get_order
+from ..data.raw_data import RawSkin, RawData
+from ..data.exporter import Exporter
+from ..model.spec import ModelSpec
+
+class SkinSystem(L.LightningModule):
+
+ def __init__(
+ self,
+ steps_per_epoch: int,
+ model: ModelSpec,
+ output_path: Union[str, None]=None,
+ record_res: Union[bool]=False,
+ val_interval: Union[int, None]=None,
+ val_start_from: Union[int, None]=None,
+ ):
+ super().__init__()
+ self.save_hyperparameters(ignore="model")
+ self.steps_per_epoch = steps_per_epoch
+ self.model = model
+ self.output_path = output_path
+ self.record_res = record_res
+ self.val_interval = val_interval
+ self.val_start_from = val_start_from
+
+ if self.record_res:
+ assert self.output_path is not None, "record_res is True, but output_path in skin is None"
+
+ def predict_step(self, batch, batch_idx, dataloader_idx=None):
+ res = self.model.predict_step(batch)
+
+ if isinstance(res, list):
+ return {
+ 'skin_pred': res,
+ }
+ elif isinstance(res, dict):
+ assert 'skin_pred' in res, f"expect key 'skin_pred' in prediction from {self.model.__class__}, found: {res.keys()}"
+ return res
+ else:
+ assert 0, f"expect type of prediction from {self.model.__class__} to be a list or dict, found: {type(res)}"
+
+class SkinWriter(BasePredictionWriter):
+ def __init__(
+ self,
+ output_dir: Union[str, None],
+ save_name: str,
+ order_config: Union[OrderConfig, None]=None,
+ **kwargs
+ ):
+ super().__init__('batch')
+ self.output_dir = output_dir
+ self.npz_dir = kwargs.get('npz_dir', None)
+ self.user_mode = kwargs.get('user_mode', False)
+ self.output_name = kwargs.get('output_name', None) # for a single name
+ self.save_name = save_name
+ self.add_num = kwargs.get('add_num', False)
+ self.export_npz = kwargs.get('export_npz', True)
+ self.export_fbx = kwargs.get('export_fbx', False)
+ if order_config is not None:
+ self.order = get_order(config=order_config)
+ else:
+ self.order = None
+
+ self._epoch = 0
+
+ def write_on_batch_end(self, trainer, pl_module: SkinSystem, prediction: List[Dict], batch_indices, batch, batch_idx, dataloader_idx):
+ assert 'path' in batch
+ paths: List[str] = batch['path']
+ data_names: List[str] = batch['data_name']
+ joints: FloatTensor = batch['joints']
+ num_bones: LongTensor = batch['num_bones']
+ num_faces: LongTensor = batch['num_faces']
+ num_points: LongTensor = batch['num_points']
+ tails: FloatTensor = batch['tails']
+ parents_list: LongTensor = batch['parents'] # -1 represents root
+ vertices: FloatTensor = batch['origin_vertices']
+ sampled_vertices: FloatTensor = batch['vertices']
+ faces: LongTensor = batch['origin_faces']
+
+ joints = joints.detach().cpu().numpy()
+ tails = tails.detach().cpu().numpy()
+ parents_list = parents_list.detach().cpu().numpy()
+ num_bones = num_bones.detach().cpu().numpy()
+ num_faces = num_faces.detach().cpu().numpy()
+ vertices = vertices.detach().cpu().numpy()
+ faces = faces.detach().cpu().numpy()
+
+ skin_pred_list: List = prediction['skin_pred']
+ ret_sampled_vertices = prediction.get('sampled_vertices', None)
+ if ret_sampled_vertices is not None:
+ assert isinstance(ret_sampled_vertices, Tensor)
+ sampled_vertices = ret_sampled_vertices
+ if isinstance(sampled_vertices, Tensor):
+ sampled_vertices = sampled_vertices.type(torch.float32).detach().cpu().numpy()
+ for (id, skin_pred) in enumerate(skin_pred_list):
+ if isinstance(skin_pred, Tensor):
+ skin_pred = skin_pred.type(torch.float32).detach().cpu().numpy()
+
+ # TODO: add custom post-processing here
+
+ # resample
+ N = num_points[id]
+ J = num_bones[id]
+ F = num_faces[id]
+ o_vertices = vertices[id, :N]
+
+ _parents = parents_list[id]
+ parents = []
+ for i in range(J):
+ if _parents[i] == -1:
+ parents.append(None)
+ else:
+ parents.append(_parents[i])
+
+ skin_resampled = reskin(
+ sampled_vertices=sampled_vertices[id],
+ vertices=o_vertices,
+ parents=parents,
+ faces=faces[id, :F],
+ sampled_skin=skin_pred,
+ sample_method='median',
+ alpha=2.0,
+ threshold=0.03,
+ )
+
+ def make_path(save_name: str, suffix: str, trim: bool=False):
+ if trim:
+ path = os.path.relpath(paths[id], self.npz_dir)
+ else:
+ path = paths[id]
+
+ if self.output_dir is not None:
+ path = os.path.join(self.output_dir, path)
+
+ if self.add_num:
+ path = os.path.join(path, f"{save_name}_{self._epoch}.{suffix}")
+ else:
+ path = os.path.join(path, f"{save_name}.{suffix}")
+ return path
+
+ raw_data = RawSkin(skin=skin_pred, vertices=sampled_vertices[id], joints=joints[id, :J])
+ if self.export_npz is not None:
+ raw_data.save(path=make_path(self.export_npz, 'npz'))
+ if self.export_fbx is not None:
+ try:
+ exporter = Exporter()
+ names = RawData.load(path=os.path.join(paths[id], data_names[id])).names
+ if names is None:
+ names = [f"bone_{i}" for i in range(J)]
+ if self.user_mode:
+ if self.output_name is not None:
+ path = self.output_name
+ else:
+ path = make_path(self.save_name, 'fbx', trim=True)
+ else:
+ path = make_path(self.export_fbx, 'fbx')
+ exporter._export_fbx(
+ path=path,
+ vertices=o_vertices,
+ joints=joints[id, :J],
+ skin=skin_resampled,
+ parents=parents,
+ names=names,
+ faces=faces[id, :F],
+ group_per_vertex=4,
+ tails=tails[id, :J],
+ use_extrude_bone=False,
+ use_connect_unique_child=False,
+ # do_not_normalize=True,
+ )
+ except Exception as e:
+ print(str(e))
+
+ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
+ self._epoch += 1
+
+def reskin(
+ sampled_vertices: ndarray,
+ vertices: ndarray,
+ parents: List[Union[None, int]],
+ faces: ndarray,
+ sampled_skin: ndarray,
+ sample_method: Literal['mean', 'median']='mean',
+ **kwargs,
+) -> ndarray:
+ nearest_samples = kwargs.get('nearest_samples', 7)
+ iter_steps = kwargs.get('iter_steps', 1)
+ threshold = kwargs.get('threshold', 0.01)
+ alpha = kwargs.get('alpha', 2)
+
+ assert sample_method in ['mean', 'median']
+
+ N = vertices.shape[0]
+ J = sampled_skin.shape[1]
+ if sample_method == 'mean':
+ tree = cKDTree(sampled_vertices)
+ dis, nearest = tree.query(vertices, k=nearest_samples, p=2)
+ # weighted sum
+ weights = np.exp(-alpha * dis) # (N, nearest_samples)
+ weight_sum = weights.sum(axis=1, keepdims=True)
+ sampled_skin_nearest = sampled_skin[nearest]
+ skin = (sampled_skin_nearest * weights[..., np.newaxis]).sum(axis=1) / weight_sum
+ elif sample_method == 'median':
+ tree = cKDTree(sampled_vertices)
+ dis, nearest = tree.query(vertices, k=nearest_samples, p=2)
+ skin = np.median(sampled_skin[nearest], axis=1)
+ else:
+ assert 0
+
+ # (from, to)
+ edges = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0)
+ edges = np.concatenate([edges, edges[:, [1, 0]]], axis=0) # (2*F*3, 2)
+
+ # diffusion in neighbours
+ for _ in range(iter_steps):
+ sum_skin = skin.copy()
+ for i in reversed(range(J)):
+ p = parents[i]
+ if p is None:
+ continue
+ sum_skin[:, p] += sum_skin[:, i]
+ # (2*F*3, J)
+ # only transfer from hotter to cooler
+ mask = sum_skin[edges[:, 1]] < sum_skin[edges[:, 0]]
+ neighbor_skin = np.zeros_like(sum_skin) # (N, J)
+ neighbor_co = np.zeros((N, J), dtype=np.float32)
+
+ dis = np.sqrt(((vertices[edges[:, 1]] - vertices[edges[:, 0]])**2).sum(axis=1, keepdims=True))
+ co = np.exp(-dis * alpha)
+
+ neighbor_skin[edges[:, 1]] += sum_skin[edges[:, 0]] * co * mask
+ neighbor_co[edges[:, 1]] += co * mask
+
+ sum_skin = (sum_skin + neighbor_skin) / (1. + neighbor_co)
+ for i in range(J):
+ p = parents[i]
+ if p is None:
+ continue
+ sum_skin[:, p] -= sum_skin[:, i]
+ skin = sum_skin / sum_skin.sum(axis=-1, keepdims=True)
+
+ # avoid 0-skin
+ mask = (skin>=threshold).any(axis=-1, keepdims=True)
+ skin[(skin TokenizerSpec:
+ MAP = {
+ 'tokenizer_part': TokenizerPart,
+ }
+ assert config.method in MAP, f"expect: [{','.join(MAP.keys())}], found: {config.method}"
+ return MAP[config.method](config=config)
\ No newline at end of file
diff --git a/UniRig/src/tokenizer/spec.py b/UniRig/src/tokenizer/spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..8665cbf7a8908e7ab4f9c28b91bb528f474fd751
--- /dev/null
+++ b/UniRig/src/tokenizer/spec.py
@@ -0,0 +1,314 @@
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from typing import Dict
+
+import numpy as np
+from numpy import ndarray
+
+from typing import Union, List, Tuple
+from dataclasses import dataclass
+
+from ..data.exporter import Exporter
+from ..data.order import OrderConfig, Order, get_order
+
+@dataclass(frozen=True)
+class TokenizerConfig():
+ # which tokenizer to use
+ method: str
+
+ # coord discrete
+ num_discrete: int
+
+ # normalization range
+ continuous_range: Tuple[float, float]
+
+ # cls token id
+ cls_token_id: Dict[str, int]
+
+ # parts token id
+ parts_token_id: Dict[str, int]
+
+ order_config: Union[OrderConfig, None]
+
+ @staticmethod
+ def parse(config) -> 'TokenizerConfig':
+ order_config = config.get('order_config', None)
+
+ return TokenizerConfig(
+ method=config.method,
+ num_discrete=config.num_discrete,
+ continuous_range=config.continuous_range,cls_token_id=config.cls_token_id,
+ parts_token_id=config.get('parts_token_id', {}),
+ order_config=OrderConfig.parse(order_config) if order_config is not None else None,
+ )
+
+@dataclass(frozen=True)
+class TokenizeInput():
+ # (J, 6), (parent position, position)
+ bones: ndarray
+
+ # (J, 3), tails of bones(this is an attribute to indicate direction, not bones[i, 3:6]). Should NOT be used for non-leaf joints.
+ tails: Union[ndarray, None]
+
+ # (B, J), bool, whether there is a branch, always False for root
+ branch: ndarray
+
+ # (J), bool, whether the bone is a leaf node (has no child)
+ is_leaf: ndarray
+
+ # (B, J), bool, whether the bone has skin
+ no_skin: Union[ndarray, None]
+
+ # string of class in tokenizer
+ cls: Union[str, None]
+
+ # Part token added before the i-th bone. If parts_bias[i] is None, a spring token will be added.
+ parts_bias: Dict[int, Union[str, None]]
+
+ @property
+ def num_bones(self):
+ return self.bones.shape[0]
+
+@dataclass(frozen=True)
+class DetokenzeOutput(Exporter):
+ # original tokens
+ tokens: ndarray
+
+ # (J, 6), (parent position, position)
+ bones: ndarray
+
+ # (J), parent of each bone
+ parents: List[Union[int, None]]
+
+ # (J, 3), tails of bones(this is an attribute to indicate direction, not bones[i, 3:6])
+ tails: Union[ndarray, None]
+
+ # (B, J), bool, whether the bone has skin
+ no_skin: Union[ndarray, None]
+
+ # string of class in tokenizer
+ cls: Union[str, None]
+
+ # part names in order
+ parts: List[str]
+
+ # names of joints
+ names: Union[None, List[str]]
+
+ # normalization cube
+ continuous_range: Tuple[float, float]
+
+ @property
+ def joints(self):
+ return self.bones[:, 3:]
+
+ @property
+ def p_joints(self):
+ return self.bones[:, :3]
+
+ @property
+ def num_bones(self):
+ return self.bones.shape[0]
+
+ def _get_parents(self) -> List[Union[int, None]]:
+ parents = []
+ for (i, bone) in enumerate(self.bones):
+ p_joint = bone[:3]
+ dis = 999999
+ pid = None
+ for j in reversed(range(i)):
+ n_dis = ((self.bones[j][3:] - p_joint)**2).sum()
+ if n_dis < dis:
+ pid = j
+ dis = n_dis
+ parents.append(pid)
+ return parents
+
+ def export_skeleton(self, path: str):
+ parents = self._get_parents()
+ self._export_skeleton(joints=self.bones[:, 3:], parents=parents, path=path)
+
+ def export_bones(self, path: str):
+ assert self.tails is not None, 'tails is None, cannot exporrt bones'
+ self._export_bones(bones=np.concatenate([self.bones[:, 3:], self.tails], axis=-1), path=path)
+
+ def export_skeleton_sequence(self, path: str):
+ parents = self._get_parents()
+ self._export_skeleton_sequence(joints=self.bones[:, 3:], parents=parents, path=path)
+
+class TokenizerSpec(ABC):
+ """
+ Abstract class for tokenizer
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__()
+ pass
+
+ @abstractmethod
+ def tokenize(self, input: TokenizeInput) -> ndarray:
+ pass
+
+ def detokenize(self, ids: ndarray, **kwargs) -> DetokenzeOutput:
+ raise NotImplementedError("{} has no method 'detokenize'".format(type(self).__name__))
+
+ @abstractmethod
+ def get_require_parts(self) -> List[str]:
+ """All parts token names"""
+ pass
+
+ @abstractmethod
+ def cls_name_to_token(self, cls: str) -> int:
+ """Cls name to token"""
+ pass
+
+ @abstractmethod
+ def part_name_to_token(self, part: str) -> int:
+ """Part name to token"""
+ pass
+
+ @property
+ @abstractmethod
+ def vocab_size(self):
+ """The vocabulary size"""
+ pass
+
+ @property
+ def pad(self):
+ raise NotImplementedError("{} has no attribute 'pad'".format(type(self).__name__))
+
+ @property
+ def bos(self):
+ raise NotImplementedError("{} has no attribute 'bos'".format(type(self).__name__))
+
+ @property
+ def eos(self):
+ raise NotImplementedError("{} has no attribute 'eos'".format(type(self).__name__))
+
+ @property
+ def num_discrete(self):
+ raise NotImplementedError("{} has no attribute 'num_discrete'".format(type(self).__name__))
+
+ @property
+ @abstractmethod
+ def continuous_range(self) -> Tuple[float, float]:
+ pass
+
+def make_skeleton(
+ joints: ndarray,
+ p_joints: ndarray,
+ tails_dict: Dict[int, ndarray],
+ convert_leaf_bones_to_tails: bool,
+ extrude_tail_for_leaf: bool,
+ extrude_tail_for_branch: bool,
+ extrude_scale: float=0.5,
+ strict: bool=False,
+) -> Tuple[ndarray, ndarray, List[int], List[Union[None, int]]]:
+ '''
+ Args:
+ joints: heads of bones
+
+ p_joints: parent position of joints
+
+ tails_dict: tail position of the i-th joint
+
+ convert_leaf_bones_to_tails: remove leaf bones and make them tails of their parents
+
+ extrude_tail_for_leaf: add a tail for leaf bone
+
+ extrude_tail_for_branch: add a tail for joint with multiple children
+
+ extrude_scale: length scale of tail offset
+
+ strict: if true, raise error when there are joints in the same location
+
+ Returns:
+ bones, tails, available_bones_id, parents
+ '''
+ assert (convert_leaf_bones_to_tails & extrude_tail_for_leaf)==False, 'cannot extrude tail for leaf when convert_leaf_bones_to_tails is True'
+ assert joints.shape[0] == p_joints.shape[0]
+ # build parents
+ bones = [] # (parent_position, position)
+ parents = []
+ for (i, joint) in enumerate(joints):
+ if len(bones) == 0:
+ bones.append(np.concatenate([joint, joint])) # root
+ parents.append(None)
+ continue
+ p_joint = p_joints[i]
+ dis = 999999
+ pid = None
+ for j in reversed(range(i)):
+ n_dis = ((bones[j][3:] - p_joint)**2).sum()
+ if n_dis < dis:
+ pid = j
+ dis = n_dis
+ bones.append(np.concatenate([joints[pid], joint]))
+ parents.append(pid)
+ bones = np.stack(bones)
+
+ children = defaultdict(list)
+ for (i, pid) in enumerate(parents):
+ if pid is None:
+ continue
+ children[pid].append(i)
+
+ available_bones_id = []
+ if convert_leaf_bones_to_tails:
+ for (i, pid) in enumerate(parents):
+ if len(children[i]) != 0:
+ available_bones_id.append(i)
+ continue
+ tails_dict[pid] = bones[i, 3:]
+ else:
+ available_bones_id = [i for i in range(bones.shape[0])]
+
+ # tail for leaf
+ for (i, pid) in enumerate(parents):
+ if len(children[i]) != 0:
+ continue
+ if extrude_tail_for_leaf:
+ d = bones[i, 3:] - bones[pid, 3:]
+ length = np.linalg.norm(d)
+ if strict:
+ assert length > 1e-9, 'two joints in the same point found'
+ elif length <= 1e-9:
+ d = np.array([0., 0., 1.])
+ tails_dict[i] = bones[i, 3:] + d * extrude_scale
+ else:
+ tails_dict[i] = bones[i, 3:]
+
+ # tail for branch
+ for (i, pid) in enumerate(parents):
+ if len(children[i]) <= 1:
+ continue
+ if extrude_tail_for_branch:
+ if pid is None: # root
+ av_len = 0
+ for child in children[i]:
+ av_len += np.linalg.norm(bones[i, 3:] - bones[child, 3:])
+ av_len /= len(children[i])
+ d = bones[i, 3:] + np.array([0., 0., extrude_scale * av_len])
+ else:
+ d = bones[i, 3:] - bones[pid, 3:]
+ length = np.linalg.norm(d)
+ if strict:
+ assert length > 1e-9, 'two joints in the same point found'
+ elif length <= 1e-9:
+ d = np.array([0., 0., 1.])
+ tails_dict[i] = bones[i, 3:] + d * extrude_scale
+ else:
+ tails_dict[i] = bones[i, 3:]
+
+ # assign new tail
+ for (i, pid) in enumerate(parents):
+ if len(children[i]) != 1:
+ continue
+ child = children[i][0]
+ tails_dict[i] = bones[child, 3:]
+
+ tails = []
+ for i in range(bones.shape[0]):
+ tails.append(tails_dict[i])
+ tails = np.stack(tails)
+ return bones, tails, available_bones_id, parents
\ No newline at end of file
diff --git a/UniRig/src/tokenizer/tokenizer_part.py b/UniRig/src/tokenizer/tokenizer_part.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1bdd6cd8db6e690355584afd90ba97826042856
--- /dev/null
+++ b/UniRig/src/tokenizer/tokenizer_part.py
@@ -0,0 +1,241 @@
+import numpy as np
+from numpy import ndarray
+
+from typing import Dict, Tuple, Union, List
+
+from .spec import TokenizerSpec, TokenizeInput, DetokenzeOutput, TokenizerConfig
+from .spec import make_skeleton
+from ..data.order import get_order
+
+class TokenizerPart(TokenizerSpec):
+ def __init__(
+ self,
+ config: TokenizerConfig,
+ ):
+ super().__init__()
+
+ self._num_discrete = config.num_discrete
+ self._continuous_range = config.continuous_range
+ self.cls_token_id = config.cls_token_id.copy()
+ self.parts_token_id = config.parts_token_id.copy()
+ self.order = get_order(config.order_config)
+ _offset = config.num_discrete
+
+ self.token_id_branch = _offset + 0
+ self.token_id_bos = _offset + 1
+ self.token_id_eos = _offset + 2
+ self.token_id_pad = _offset + 3
+ _offset += 4
+
+ self.token_id_spring = _offset + 0
+ _offset += 1
+
+ assert None not in self.parts_token_id
+ for i in self.parts_token_id:
+ self.parts_token_id[i] += _offset
+ _offset += len(self.parts_token_id)
+
+ self.token_id_cls_none = _offset + 0
+ _offset += 1
+
+ for i in self.cls_token_id:
+ self.cls_token_id[i] += _offset
+ _offset += len(self.cls_token_id)
+
+ self._vocab_size = _offset
+
+ self.parts_token_id_name = [x for x in self.parts_token_id]
+
+ self.part_token_to_name = {v: k for k, v in self.parts_token_id.items()}
+ assert len(self.part_token_to_name) == len(self.parts_token_id), 'names with same token found in parts_token_id'
+ self.part_token_to_name[self.token_id_spring] = None
+
+ self.cls_token_to_name = {v: k for k, v in self.cls_token_id.items()}
+ assert len(self.cls_token_to_name) == len(self.cls_token_id), 'names with same token found in cls_token_id'
+
+ def cls_name_to_token(self, cls: str) -> int:
+ if cls not in self.cls_token_id:
+ return self.token_id_cls_none
+ return self.cls_token_id[cls]
+
+ def part_name_to_token(self, part: str) -> int:
+ assert part in self.parts_token_id, f"do not find part name `{part}` in tokenizer"
+ return self.parts_token_id[part]
+
+ def tokenize(self, input: TokenizeInput) -> ndarray:
+ num_bones = input.num_bones
+ bones = discretize(t=input.bones, continuous_range=self.continuous_range, num_discrete=self.num_discrete)
+ tails = discretize(t=input.tails, continuous_range=self.continuous_range, num_discrete=self.num_discrete)
+
+ branch = input.branch
+ is_leaf = input.is_leaf
+
+ tokens = [self.token_id_bos]
+ if input.cls is None or input.cls not in self.cls_token_id:
+ tokens.append(self.token_id_cls_none)
+ else:
+ tokens.append(self.cls_token_id[input.cls])
+ use_leaf = False
+ for i in range(num_bones):
+ # add parts token id
+ if i in input.parts_bias:
+ part = input.parts_bias[i]
+ if part is None:
+ tokens.append(self.token_id_spring)
+ else:
+ assert part in self.parts_token_id, f"do not find part name {part} in tokenizer {self.__class__}"
+ tokens.append(self.parts_token_id[part])
+ if branch[i]:
+ tokens.append(self.token_id_branch)
+ tokens.append(bones[i, 0])
+ tokens.append(bones[i, 1])
+ tokens.append(bones[i, 2])
+ tokens.append(bones[i, 3])
+ tokens.append(bones[i, 4])
+ tokens.append(bones[i, 5])
+ else:
+ tokens.append(bones[i, 3])
+ tokens.append(bones[i, 4])
+ tokens.append(bones[i, 5])
+ tokens.append(self.token_id_eos)
+ return np.array(tokens, dtype=np.int64)
+
+
+ def detokenize(self, ids: ndarray, **kwargs) -> DetokenzeOutput:
+ assert isinstance(ids, ndarray), 'expect ids to be ndarray'
+ if ids[0] != self.token_id_bos:
+ raise ValueError(f"first token is not bos")
+ trailing_pad = 0
+ while trailing_pad < ids.shape[0] and ids[-trailing_pad-1] == self.token_id_pad:
+ trailing_pad += 1
+ if ids[-1-trailing_pad] != self.token_id_eos:
+ raise ValueError(f"last token is not eos")
+ ids = ids[1:-1-trailing_pad]
+ joints = []
+ p_joints = []
+ tails_dict = {}
+ parts = []
+ i = 0
+ is_branch = False
+ last_joint = None
+ num_bones = 0
+ while i < len(ids):
+ if ids[i] < self.num_discrete:
+ if is_branch:
+ p_joint = undiscretize(t=ids[i:i+3], continuous_range=self.continuous_range, num_discrete=self.num_discrete)
+ current_joint = undiscretize(t=ids[i+3:i+6], continuous_range=self.continuous_range, num_discrete=self.num_discrete)
+ joints.append(current_joint)
+ p_joints.append(p_joint)
+ i += 6
+ else:
+ current_joint = undiscretize(t=ids[i:i+3], continuous_range=self.continuous_range, num_discrete=self.num_discrete)
+ joints.append(current_joint)
+ if len(p_joints) == 0: # root
+ p_joints.append(current_joint)
+ p_joint = current_joint
+ else:
+ assert last_joint is not None
+ p_joints.append(last_joint)
+ p_joint = last_joint
+ i += 3
+ if last_joint is not None:
+ tails_dict[num_bones-1] = current_joint
+ last_joint = current_joint
+ num_bones += 1
+ is_branch = False
+ elif ids[i]==self.token_id_branch:
+ is_branch = True
+ last_joint = None
+ i += 1
+ elif ids[i]==self.token_id_spring or ids[i] in self.parts_token_id.values():
+ parts.append(self.part_token_to_name[ids[i]])
+ i += 1
+ elif ids[i] in self.cls_token_id.values():
+ cls = ids[i]
+ i += 1
+ elif ids[i] == self.token_id_cls_none:
+ cls = None
+ i += 1
+ else:
+ raise ValueError(f"unexpected token found: {ids[i]}")
+ joints = np.stack(joints)
+ p_joints = np.stack(p_joints)
+ # leaf is ignored in this tokenizer so need to extrude tails for leaf and branch
+ bones, tails, available_bones_id, parents = make_skeleton(
+ joints=joints,
+ p_joints=p_joints,
+ tails_dict=tails_dict,
+ convert_leaf_bones_to_tails=False,
+ extrude_tail_for_leaf=True,
+ extrude_tail_for_branch=True,
+ )
+ bones = bones[available_bones_id]
+ tails = tails[available_bones_id]
+ if cls in self.cls_token_to_name:
+ cls = self.cls_token_to_name[cls]
+ else:
+ cls = None
+ if self.order is not None:
+ names = self.order.make_names(cls=cls, parts=parts, num_bones=num_bones)
+ else:
+ names = [f"bone_{i}" for i in range(num_bones)]
+ return DetokenzeOutput(
+ tokens=ids,
+ parents=parents,
+ bones=bones,
+ tails=tails,
+ no_skin=None,
+ cls=cls,
+ parts=parts,
+ names=names,
+ continuous_range=self.continuous_range,
+ )
+
+ def get_require_parts(self) -> List[str]:
+ return self.parts_token_id_name
+
+ @property
+ def vocab_size(self):
+ return self._vocab_size
+
+ @property
+ def pad(self):
+ return self.token_id_pad
+
+ @property
+ def bos(self):
+ return self.token_id_bos
+
+ @property
+ def eos(self):
+ return self.token_id_eos
+
+ @property
+ def num_discrete(self):
+ return self._num_discrete
+
+ @property
+ def continuous_range(self) -> Tuple[float, float]:
+ return self._continuous_range
+
+def discretize(
+ t: ndarray,
+ continuous_range: Tuple[float, float],
+ num_discrete: int,
+) -> ndarray:
+ lo, hi = continuous_range
+ assert hi >= lo
+ t = (t - lo) / (hi - lo)
+ t *= num_discrete
+ return np.clip(t.round(), 0, num_discrete - 1).astype(np.int64)
+
+def undiscretize(
+ t: ndarray,
+ continuous_range: Tuple[float, float],
+ num_discrete: int,
+) -> ndarray:
+ lo, hi = continuous_range
+ assert hi >= lo
+ t = t.astype(np.float32) + 0.5
+ t /= num_discrete
+ return t * (hi - lo) + lo