diff --git a/UniRig/.gitattributes b/UniRig/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..520ec4e3434c3d3ba2d43125fe86933ec680c3ba --- /dev/null +++ b/UniRig/.gitattributes @@ -0,0 +1,5 @@ +examples/*.glb filter=lfs diff=lfs merge=lfs -text +examples/*.fbx filter=lfs diff=lfs merge=lfs -text +examples/*.vrm filter=lfs diff=lfs merge=lfs -text +examples/*.FBX filter=lfs diff=lfs merge=lfs -text +examples/*.obj filter=lfs diff=lfs merge=lfs -text diff --git a/UniRig/.gitignore b/UniRig/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..474041aa2dbf654480297bd340b6e5fb3a57d24b --- /dev/null +++ b/UniRig/.gitignore @@ -0,0 +1,59 @@ +# igonore all pychace +**/__pycache__/ +*.py[cod] +*$py.class + +# ignore tmp & output files +_data/ +tmp/ +*.npz +*.blend +*.blend1 +*.blend2 + +# ignore logs +wandb/ +lightning_logs/ +*.log + +# ignore experiments +experiments/ +results/ +dataset_clean/ +logs/ +datalist/ +dataset_inference/ +dataset_inference_clean/ +feature_viz/ + +# Distribution / packaging +dist/ +build/ +*.egg-info/ +*.egg +*.whl + +# Virtual environments +venv/ +env/ +.env/ +.venv/ + +# IDE specific files +.idea/ +.vscode/ +*.swp +*.swo +.DS_Store + +# Jupyter Notebook +.ipynb_checkpoints +*.ipynb + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +coverage.xml +*.cover \ No newline at end of file diff --git a/UniRig/LICENSE b/UniRig/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a20d482774136ee1858d92964e896031598902b4 --- /dev/null +++ b/UniRig/LICENSE @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2025 VAST-AI-Research and contributors. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE diff --git a/UniRig/README.md b/UniRig/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b222e6161e9bf181d27cad9041c9d27c640224a6 --- /dev/null +++ b/UniRig/README.md @@ -0,0 +1,164 @@ +# UniRig: One Model to Rig Them All + +
+ +[![Project Page](https://img.shields.io/badge/🏠-Project%20Page-blue.svg)](https://zjp-shadow.github.io/works/UniRig/) +[![Paper](https://img.shields.io/badge/📑-Paper-green.svg)](https://arxiv.org/abs/2504.12451) +[![Model](https://img.shields.io/badge/🤗-Model-yellow.svg)](https://huggingface.co/VAST-AI/UniRig) + +
+ +![teaser](assets/doc/unirig_teaser.png) + +This repository contains the official implementation for the **SIGGRAPH'25 (TOG) UniRig** framework, a unified solution for automatic 3D model rigging, developed by Tsinghua University and [Tripo](https://www.tripo3d.ai). + +**Paper:** [One Model to Rig Them All: Diverse Skeleton Rigging with UniRig](https://arxiv.org/abs/2504.12451) + +## Overview + +Rigging 3D models – creating a skeleton and assigning skinning weights – is a crucial but often complex and time-consuming step in 3D animation. UniRig tackles this challenge by introducing a novel, unified framework leveraging large autoregressive models to automate the process for a diverse range of 3D assets. + +Combining UniRig with keyframe animation produces these following results: + +| ![devil](assets/doc/devil.gif) | ![dragon](assets/doc/dragon.gif) | ![rabbit](assets/doc/rabbit.gif) | +|:-----------------------------:|:-------------------------------:|:-------------------------------:| + +The full UniRig system consists of two main stages: +1. **Skeleton Prediction:** An GPT-like transformer autoregressively predicts a topologically valid skeleton hierarchy using a novel **Skeleton Tree Tokenization** scheme. +2. **Skinning Weight & Attribute Prediction:** A **Bone-Point Cross Attention** mechanism predicts per-vertex skinning weights and relevant bone attributes (e.g., for physics simulation) based on the predicted skeleton and input mesh geometry. + +This repository provides the code implementation for the entire framework vision, with components being released progressively. + +## Key Features (Full UniRig Framework) + +* **Unified Model:** Aims to handle diverse model categories (humans, animals, objects) with a single framework. +* **Automated Skeleton Generation:** Predicts topologically valid skeleton structures. **(✅ Available in current release)** +* **Automated Skinning Prediction:** Predicts per-vertex skinning weights. **(✅ Available in current release)** +* **Bone Attribute Prediction:** Predicts attributes like stiffness for physics-based secondary motion. **(⏳ Coming Soon)** +* **High Accuracy & Robustness:** Achieves state-of-the-art results on challenging datasets (as shown in the paper with Rig-XL/VRoid training). +* **Efficient Tokenization:** Uses Skeleton Tree Tokenization for compact representation and efficient processing. +* **Human-in-the-Loop Ready:** Designed to potentially support iterative refinement workflows. + +## 🚨 Current Release Status & Roadmap 🚨 + +We are open-sourcing UniRig progressively. Please note the current status: + +**Available Now (Initial Release):** +* ✅ **Code:** Implementation for skeleton and skinning prediction. +* ✅ **Model:** Skeleton & Skinning Prediction checkpoint trained on [**Articulation-XL2.0**](https://huggingface.co/datasets/Seed3D/Articulation-XL2.0). Available on [Hugging Face](https://huggingface.co/VAST-AI/UniRig). + +**Planned Future Releases:** +* ⏳ Release of the **Rig-XL** and **VRoid** datasets used in the paper. +* ⏳ Full UniRig model checkpoints (Skeleton + Skinning) trained on Rig-XL/VRoid, replicating the paper's main results. + +We appreciate your patience as we prepare these components for release. Follow [VAST-AI-Research](https://github.com/orgs/VAST-AI-Research) announcements for updates! + +## Installation + +1. **Prerequisites:** + * Python 3.11 + * PyTorch (tested with version >=2.3.1) + +2. **Clone the repository:** + ```bash + git clone https://github.com/VAST-AI-Research/UniRig + cd UniRig + ``` + +3. **Set up a virtual environment (recommended):** + ```bash + conda create -n UniRig python=3.11 + conda activate UniRig + ``` + +4. **Install dependencies:** + ```bash + python -m pip install torch torchvision + python -m pip install -r requirements.txt + python -m pip install spconv-{you-cuda-version} + python -m pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{your-torch-version}+{your-cuda-version}.html --no-cache-dir + python -m pip install numpy==1.26.4 + ``` + +5. **Download Model Checkpoint:** + The currently available skeleton prediction model checkpoint is hosted on Hugging Face and will typically be downloaded automatically by the provided scripts/functions. + +6. **(Optional, for importing/exporting .vrm) Install the blender addon:** + The blender addon is modifed from [VRM-Addon-for-Blender](https://github.com/saturday06/VRM-Addon-for-Blender). + + Make sure you are in the root directory of the project, then: + ```bash + python -c "import bpy, os; bpy.ops.preferences.addon_install(filepath=os.path.abspath('blender/add-on-vrm-v2.20.77_modified.zip'))" + ``` + +## Usage + +### Skeleton Prediction (Available Now) + +Generate a skeleton for your 3D model using our pre-trained model. The process automatically analyzes the geometry and predicts an appropriate skeletal structure. + +```bash +# Process a single file +bash launch/inference/generate_skeleton.sh --input examples/giraffe.glb --output results/giraffe_skeleton.fbx + +# Process multiple files in a directory +bash launch/inference/generate_skeleton.sh --input_dir --output_dir + +# Try different skeleton variations by changing the random seed +bash launch/inference/generate_skeleton.sh --input examples/giraffe.glb --output results/giraffe_skeleton.fbx --seed 42 +``` + +Supported input formats: `.obj`, `.fbx`, `.glb`, and `.vrm` + +### Skinning Weight Prediction (Available Now) +```bash +# Skin a single file +bash launch/inference/generate_skin.sh --input examples/skeleton/giraffe.fbx --output results/giraffe_skin.fbx + +# Process multiple files in a directory +bash launch/inference/generate_skin.sh --input_dir --output_dir +``` + +Note that the command above uses an **edited-version** from skeleton phase. The results may degrade significantly if the skeleton is inaccurate — for example, if tail bones or wing bones are missing. Therefore, it is recommended to refine the skeleton before performing skinning in order to achieve better results. + +### Merge the Predicted Results + +Combine the predicted skeleton with your original 3D model to create a fully rigged asset: + +```bash +# Merge skeleton from skeleton prediction +bash launch/inference/merge.sh --source results/giraffe_skeleton.fbx --target examples/giraffe.glb --output results/giraffe_rigged.glb + +# Or merge skin from skin prediction +bash launch/inference/merge.sh --source results/giraffe_skin.fbx --target examples/giraffe.glb --output results/giraffe_rigged.glb +``` + +## Models + +Available models are hosted on the: https://huggingface.co/VAST-AI/UniRig + +## System Requirements + +- CUDA-enabled GPU with at least 8GB VRAM + +## Citation + +``` +@article{zhang2025unirig, + title={One Model to Rig Them All: Diverse Skeleton Rigging with UniRig}, + author={Zhang, Jia-Peng and Pu, Cheng-Feng and Guo, Meng-Hao and Cao, Yan-Pei and Hu, Shi-Min}, + journal={arXiv preprint arXiv:2504.12451}, + year={2025} +} +``` + +## Acknowledgements + +We would like to thank the following open-source projects and research works: + +- [OPT](https://huggingface.co/facebook/opt-350m) for model architecture +- [3DShape2VecSet](https://github.com/1zb/3DShape2VecSet) for 3D shape representation +- [SAMPart3D](https://github.com/Pointcept/SAMPart3D) and [Michelangelo](https://github.com/NeuralCarver/Michelangelo/) for shape encoder implementation +- [Articulation-XL2.0](https://huggingface.co/datasets/Seed3D/Articulation-XL2.0) for a curated dataset + +We are grateful to the broader research community for their open exploration and contributions to the field of 3D generation. diff --git a/UniRig/assets/doc/devil.gif b/UniRig/assets/doc/devil.gif new file mode 100644 index 0000000000000000000000000000000000000000..181a003e5b4a3b61e88bdb9dd672d8e332e20a79 --- /dev/null +++ b/UniRig/assets/doc/devil.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a6d00b42bd25d6708df9ca5b69900865eb50cf2bfe3cc72d64d87873cd99749 +size 1831961 diff --git a/UniRig/assets/doc/dragon.gif b/UniRig/assets/doc/dragon.gif new file mode 100644 index 0000000000000000000000000000000000000000..4dc0ac2fe070a3e8d9b3e139dffcbb52233300c7 --- /dev/null +++ b/UniRig/assets/doc/dragon.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1459fb925b79b710d9496be5105eed48539942a47763029ae0cc855684c9a0e9 +size 1922869 diff --git a/UniRig/assets/doc/rabbit.gif b/UniRig/assets/doc/rabbit.gif new file mode 100644 index 0000000000000000000000000000000000000000..1322b653d3a1ddbaea58d3bea9621eef486512f2 --- /dev/null +++ b/UniRig/assets/doc/rabbit.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:124cad359767a48a4f5cf84103c5400ceb2815728c7f4c7bab7d57d0944d93a3 +size 732336 diff --git a/UniRig/assets/doc/unirig_teaser.png b/UniRig/assets/doc/unirig_teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..7cb8af8d8b071643c625f54f0a799d6834a4d0cf --- /dev/null +++ b/UniRig/assets/doc/unirig_teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44b056c9355386b872de584b734c57c823c7e73be7d20257e01e15861204247f +size 12417252 diff --git a/UniRig/blender/add-on-vrm-v2.20.77_modified.zip b/UniRig/blender/add-on-vrm-v2.20.77_modified.zip new file mode 100644 index 0000000000000000000000000000000000000000..76de53479b744097e69c3a0aa2070821f6c43ee5 --- /dev/null +++ b/UniRig/blender/add-on-vrm-v2.20.77_modified.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fe1e8b7e31ec602d9b32db33c533b03d68df747837bfad3479e1057bc9937c5 +size 1331571 diff --git a/UniRig/configs/data/quick_inference.yaml b/UniRig/configs/data/quick_inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25f988646a99d0cdf97a72cb1ae0a7e66d98caa6 --- /dev/null +++ b/UniRig/configs/data/quick_inference.yaml @@ -0,0 +1,16 @@ +input_dataset_dir: &input_dataset_dir ./dataset_inference +output_dataset_dir: &output_dataset_dir ./dataset_inference_clean + +predict_dataset_config: + shuffle: False + batch_size: 1 + num_workers: 1 + pin_memory: False + persistent_workers: False + datapath_config: + input_dataset_dir: *output_dataset_dir + use_prob: False + data_path: + inference: [ + [./dataset_inference_clean/inference_datalist.txt, 1.0], + ] \ No newline at end of file diff --git a/UniRig/configs/model/unirig_ar_350m_1024_81920_float32.yaml b/UniRig/configs/model/unirig_ar_350m_1024_81920_float32.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b5a83e17433bd6c9b655cd72d3c05f935a3c48a --- /dev/null +++ b/UniRig/configs/model/unirig_ar_350m_1024_81920_float32.yaml @@ -0,0 +1,32 @@ +__target__: unirig_ar +llm: + pretrained_model_name_or_path: facebook/opt-350m + n_positions: 3076 + max_position_embeddings: 3076 + hidden_size: 1024 + word_embed_proj_dim: 1024 + do_layer_norm_before: True + _attn_implementation: flash_attention_2 + +mesh_encoder: + __target__: michelangelo_encoder + pretrained_path: ~ + freeze_encoder: False + device: cpu + dtype: float32 + num_latents: 512 + embed_dim: 64 + point_feats: 3 + num_freqs: 8 + include_pi: False + heads: 8 + width: 512 + num_encoder_layers: 16 + use_ln_post: True + init_scale: 0.25 + qkv_bias: False + use_checkpoint: False + flash: True + supervision_type: sdf + query_method: False + token_num: 1024 \ No newline at end of file diff --git a/UniRig/configs/model/unirig_skin.yaml b/UniRig/configs/model/unirig_skin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e13e1bb6e7388538af0e9b0e975ee4ad72f3ca3e --- /dev/null +++ b/UniRig/configs/model/unirig_skin.yaml @@ -0,0 +1,52 @@ +__target__: unirig_skin + +num_train_vertex: 512 # increase this for faster speed at the cost of memory +num_heads: 16 +feat_dim: 768 +grid_size: 0.005 +mlp_dim: 512 +num_bone_attn: 8 +num_mesh_bone_attn: 16 +bone_embed_dim: 1024 +voxel_mask: 3.0 + +mesh_encoder: + # vertex groups are handled in model + __target__: ptv3obj + pretrained_path: ~ + freeze_encoder: False + in_channels: 9 + cls_mode: False + shuffle_orders: True + drop_path: 0.0 + upcast_attention: False + upcast_softmax: False + enc_depths: [3, 3, 3, 6, 16] + enc_channels: [32, 64, 128, 256, 384] + enc_num_head: [2, 4, 8, 16, 24] + enable_qknorm: True + layer_norm: False + res_linear: True + +global_encoder: + __target__: michelangelo_encoder + pretrained_path: ~ + freeze_encoder: False + device: cpu + dtype: float32 + num_latents: 512 + embed_dim: 64 + point_feats: 3 + num_freqs: 8 + include_pi: False + heads: 8 + width: 512 + num_encoder_layers: 16 + use_ln_post: True + init_scale: 0.25 + qkv_bias: False + use_checkpoint: False + flash: True + supervision_type: sdf + query_method: False + token_num: 1024 \ No newline at end of file diff --git a/UniRig/configs/skeleton/mixamo.yaml b/UniRig/configs/skeleton/mixamo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6271e7b9f59eec59ae08ce77138d1a054be382c8 --- /dev/null +++ b/UniRig/configs/skeleton/mixamo.yaml @@ -0,0 +1,59 @@ +parts_order: [body, hand] + +parts: + body: [ + mixamorig:Hips, + mixamorig:Spine, + mixamorig:Spine1, + mixamorig:Spine2, + mixamorig:Neck, + mixamorig:Head, + mixamorig:LeftShoulder, + mixamorig:LeftArm, + mixamorig:LeftForeArm, + mixamorig:LeftHand, + mixamorig:RightShoulder, + mixamorig:RightArm, + mixamorig:RightForeArm, + mixamorig:RightHand, + mixamorig:LeftUpLeg, + mixamorig:LeftLeg, + mixamorig:LeftFoot, + mixamorig:LeftToeBase, + mixamorig:RightUpLeg, + mixamorig:RightLeg, + mixamorig:RightFoot, + mixamorig:RightToeBase, + ] + hand: [ + mixamorig:LeftHandThumb1, + mixamorig:LeftHandThumb2, + mixamorig:LeftHandThumb3, + mixamorig:LeftHandIndex1, + mixamorig:LeftHandIndex2, + mixamorig:LeftHandIndex3, + mixamorig:LeftHandMiddle1, + mixamorig:LeftHandMiddle2, + mixamorig:LeftHandMiddle3, + mixamorig:LeftHandRing1, + mixamorig:LeftHandRing2, + mixamorig:LeftHandRing3, + mixamorig:LeftHandPinky1, + mixamorig:LeftHandPinky2, + mixamorig:LeftHandPinky3, + mixamorig:RightHandIndex1, + mixamorig:RightHandIndex2, + mixamorig:RightHandIndex3, + mixamorig:RightHandThumb1, + mixamorig:RightHandThumb2, + mixamorig:RightHandThumb3, + mixamorig:RightHandMiddle1, + mixamorig:RightHandMiddle2, + mixamorig:RightHandMiddle3, + mixamorig:RightHandRing1, + mixamorig:RightHandRing2, + mixamorig:RightHandRing3, + mixamorig:RightHandPinky1, + mixamorig:RightHandPinky2, + mixamorig:RightHandPinky3, + ] \ No newline at end of file diff --git a/UniRig/configs/skeleton/vroid.yaml b/UniRig/configs/skeleton/vroid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d6f066e78686d4f3e7d9f50dae8b9ab73922aa9 --- /dev/null +++ b/UniRig/configs/skeleton/vroid.yaml @@ -0,0 +1,59 @@ +parts_order: [body, hand] + +parts: + body: [ + J_Bip_C_Hips, + J_Bip_C_Spine, + J_Bip_C_Chest, + J_Bip_C_UpperChest, + J_Bip_C_Neck, + J_Bip_C_Head, + J_Bip_L_Shoulder, + J_Bip_L_UpperArm, + J_Bip_L_LowerArm, + J_Bip_L_Hand, + J_Bip_R_Shoulder, + J_Bip_R_UpperArm, + J_Bip_R_LowerArm, + J_Bip_R_Hand, + J_Bip_L_UpperLeg, + J_Bip_L_LowerLeg, + J_Bip_L_Foot, + J_Bip_L_ToeBase, + J_Bip_R_UpperLeg, + J_Bip_R_LowerLeg, + J_Bip_R_Foot, + J_Bip_R_ToeBase, + ] + hand: [ + J_Bip_L_Thumb1, + J_Bip_L_Thumb2, + J_Bip_L_Thumb3, + J_Bip_L_Index1, + J_Bip_L_Index2, + J_Bip_L_Index3, + J_Bip_L_Middle1, + J_Bip_L_Middle2, + J_Bip_L_Middle3, + J_Bip_L_Ring1, + J_Bip_L_Ring2, + J_Bip_L_Ring3, + J_Bip_L_Little1, + J_Bip_L_Little2, + J_Bip_L_Little3, + J_Bip_R_Index1, + J_Bip_R_Index2, + J_Bip_R_Index3, + J_Bip_R_Thumb1, + J_Bip_R_Thumb2, + J_Bip_R_Thumb3, + J_Bip_R_Middle1, + J_Bip_R_Middle2, + J_Bip_R_Middle3, + J_Bip_R_Ring1, + J_Bip_R_Ring2, + J_Bip_R_Ring3, + J_Bip_R_Little1, + J_Bip_R_Little2, + J_Bip_R_Little3, + ] \ No newline at end of file diff --git a/UniRig/configs/system/ar_inference_articulationxl.yaml b/UniRig/configs/system/ar_inference_articulationxl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..87fab782cf9dd5a6970e579d1f170c0bcc67b9ba --- /dev/null +++ b/UniRig/configs/system/ar_inference_articulationxl.yaml @@ -0,0 +1,14 @@ +__target__: ar +val_interval: 1 +generate_kwargs: + max_new_tokens: 2048 + num_return_sequences: 1 + num_beams: 15 + do_sample: True + top_k: 5 + top_p: 0.95 + repetition_penalty: 3.0 + temperature: 1.5 # must be a float + no_cls: False + assign_cls: articulationxl + use_dir_cls: False \ No newline at end of file diff --git a/UniRig/configs/system/skin.yaml b/UniRig/configs/system/skin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..782de0e970ece5cfb4fbdaf3078eb2af60ab70cb --- /dev/null +++ b/UniRig/configs/system/skin.yaml @@ -0,0 +1,5 @@ +__target__: skin +val_interval: 1 +val_start_from: 1 +output_path: tmp_skin +record_res: True \ No newline at end of file diff --git a/UniRig/configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml b/UniRig/configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7477b9c107cb95070510745a8544e894809b57e7 --- /dev/null +++ b/UniRig/configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml @@ -0,0 +1,30 @@ +mode: predict +debug: False +experiment_name: quick_inference_skeleton_articulationxl_ar_256 +resume_from_checkpoint: experiments/skeleton/articulation-xl_quantization_256/model.ckpt + +components: + data: quick_inference + tokenizer: tokenizer_parts_articulationxl_256 + transform: inference_ar_transform + model: unirig_ar_350m_1024_81920_float32 + system: ar_inference_articulationxl + data_name: raw_data.npz + +writer: + __target__: ar + output_dir: ~ # export results into the same input folder + add_num: False + repeat: 1 + export_npz: predict_skeleton + export_obj: skeleton + export_fbx: skeleton + # export_pc: pc + +trainer: + max_epochs: 1 + num_nodes: 1 + devices: 1 + precision: bf16-mixed + accelerator: gpu + strategy: auto \ No newline at end of file diff --git a/UniRig/configs/task/quick_inference_unirig_skin.yaml b/UniRig/configs/task/quick_inference_unirig_skin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c0797ef545edf5995610f2b8698340222b709f10 --- /dev/null +++ b/UniRig/configs/task/quick_inference_unirig_skin.yaml @@ -0,0 +1,28 @@ +mode: predict +debug: False +experiment_name: quick_inference_skin +resume_from_checkpoint: experiments/skin/articulation-xl/model.ckpt + +components: + data: quick_inference + transform: inference_skin_transform + model: unirig_skin + system: skin + data_name: predict_skeleton.npz # capture data from ar phase + +writer: + __target__: skin + output_dir: results + add_num: False + repeat: 1 + save_name: predict + export_npz: predict_skin # this must be specified if textured results are required + export_fbx: result_fbx + +trainer: + num_nodes: 1 + devices: 1 + precision: bf16-mixed + accelerator: gpu + strategy: auto + inference_mode: True diff --git a/UniRig/configs/tokenizer/tokenizer_parts_articulationxl_256.yaml b/UniRig/configs/tokenizer/tokenizer_parts_articulationxl_256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..570b810058053fdb9212977518310e1fae4719e0 --- /dev/null +++ b/UniRig/configs/tokenizer/tokenizer_parts_articulationxl_256.yaml @@ -0,0 +1,14 @@ +method: tokenizer_part +num_discrete: 256 +continuous_range: [-1, 1] +cls_token_id: + vroid: 0 + mixamo: 1 # this is currently untrained, do not use it + articulationxl: 2 +parts_token_id: + body: 0 + hand: 1 +order_config: + skeleton_path: + vroid: ./configs/skeleton/vroid.yaml + mixamo: ./configs/skeleton/mixamo.yaml \ No newline at end of file diff --git a/UniRig/configs/transform/inference_ar_transform.yaml b/UniRig/configs/transform/inference_ar_transform.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2ca90f588c5325c5e8a2c42868d7980910131be --- /dev/null +++ b/UniRig/configs/transform/inference_ar_transform.yaml @@ -0,0 +1,30 @@ +sampler_config: &sampler_config + method: mix + num_samples: 65536 + vertex_samples: 8192 + +tail_config: &tail_config + copy_joint_to_tail: False # Be careful ! If tail is important, keep it False !!! + connect_tail_to_unique_son: True + +order_config: &order_config + skeleton_path: + vroid: ./configs/skeleton/vroid.yaml + mixamo: ./configs/skeleton/mixamo.yaml + +vertex_group_config: &vertex_group_config + +validate_transform_config: &validate_transform_config + augment_config: + augment_affine_config: + normalize_into: [-1.0, 1.0] + random_scale_p: 0.0 + random_scale: [1.0, 1.0] + random_shift_p: 0.0 + random_shift: [0.0, 0.0] + tail_config: *tail_config + order_config: *order_config + vertex_group_config: *vertex_group_config + sampler_config: *sampler_config + +predict_transform_config: *validate_transform_config \ No newline at end of file diff --git a/UniRig/configs/transform/inference_skin_transform.yaml b/UniRig/configs/transform/inference_skin_transform.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04284fd6d7fe312a76a0a2f55e31cc5b52e6a899 --- /dev/null +++ b/UniRig/configs/transform/inference_skin_transform.yaml @@ -0,0 +1,32 @@ +sampler_config: &sampler_config + method: mix + num_samples: 32768 + vertex_samples: 8192 + +tail_config: &tail_config + copy_joint_to_tail: False # Be careful ! If tail is important, keep it False !!! + connect_tail_to_unique_son: True + +order_config: &order_config + skeleton_path: + vroid: ./configs/skeleton/vroid.yaml + mixamo: ./configs/skeleton/mixamo.yaml + +predict_transform_config: + augment_config: + augment_affine_config: + normalize_into: [-1.0, 1.0] + tail_config: *tail_config + order_config: *order_config + vertex_group_config: + names: ['voxel_skin'] + kwargs: + voxel_skin: + grid: 196 # increase this for better results + alpha: 0.5 + link_dis: 0.00001 + grid_query: 7 + vertex_query: 1 + grid_weight: 3.0 + # mode: exp + sampler_config: *sampler_config \ No newline at end of file diff --git a/UniRig/examples/bird.glb b/UniRig/examples/bird.glb new file mode 100644 index 0000000000000000000000000000000000000000..c82417b3c66e411ac5b7b26926b5a2329167df60 --- /dev/null +++ b/UniRig/examples/bird.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb59726f598ab4a4e4c431b9789317f5b2f4252a6fb57364d929f5e1ddd7b5bb +size 8032388 diff --git a/UniRig/examples/giraffe.glb b/UniRig/examples/giraffe.glb new file mode 100644 index 0000000000000000000000000000000000000000..9cec09a7ea93692b61033b9de5f3f7bcab4c790c --- /dev/null +++ b/UniRig/examples/giraffe.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a947cae00b169345802c08885c1e54313c6ce93c885bacf8e37e8a1f18f9e3b +size 6310044 diff --git a/UniRig/examples/skeleton/bird.fbx b/UniRig/examples/skeleton/bird.fbx new file mode 100644 index 0000000000000000000000000000000000000000..7d075b3cd1093150cdf13a0b8871390e36d6f203 --- /dev/null +++ b/UniRig/examples/skeleton/bird.fbx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:885d432850506ab673d509ae2481067cc623452d5910090e1e15f323b9f83fa2 +size 401084 diff --git a/UniRig/examples/skeleton/giraffe.fbx b/UniRig/examples/skeleton/giraffe.fbx new file mode 100644 index 0000000000000000000000000000000000000000..805b4bc0267f94425f2d8c472ce74f09cf44e048 --- /dev/null +++ b/UniRig/examples/skeleton/giraffe.fbx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73673a15c8103fbf9a6b39762768ae48e7c9404eab4850a60e4863ee400336fd +size 759180 diff --git a/UniRig/examples/skeleton/tira.fbx b/UniRig/examples/skeleton/tira.fbx new file mode 100644 index 0000000000000000000000000000000000000000..61ad187c7ed805ee6ac521bdf2ec6c53cbe94235 --- /dev/null +++ b/UniRig/examples/skeleton/tira.fbx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57ed6969266d642da2d8a579059462c7cd4a36c9bd7a8415236dcbea36607fee +size 1694668 diff --git a/UniRig/examples/skeleton/tripo_carrot.fbx b/UniRig/examples/skeleton/tripo_carrot.fbx new file mode 100644 index 0000000000000000000000000000000000000000..66e22239154887c8a59dfa07d5ef7708b467f5cb --- /dev/null +++ b/UniRig/examples/skeleton/tripo_carrot.fbx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50e08d70c7bdfa96eb842aed18564d8486a4622a474dbc0e0ef9304af1d4c6d3 +size 1879420 diff --git a/UniRig/examples/tira.glb b/UniRig/examples/tira.glb new file mode 100644 index 0000000000000000000000000000000000000000..961966ed74868fcaa26aa397cf31dfc82efa7af0 --- /dev/null +++ b/UniRig/examples/tira.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e5a282d0a99c61a8d2439496057b15b0c6ea02643c16da2dcbec85571157799 +size 32346060 diff --git a/UniRig/examples/tripo_carrot.glb b/UniRig/examples/tripo_carrot.glb new file mode 100644 index 0000000000000000000000000000000000000000..9e334e2549a3576d2996fc468d7a8a337a1eca03 --- /dev/null +++ b/UniRig/examples/tripo_carrot.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c00b7e6ff1e71a019a02128a5b37e8d2db259a402313ad9b4eaab4f183ae40b +size 9511824 diff --git a/UniRig/launch/inference/extract.sh b/UniRig/launch/inference/extract.sh new file mode 100644 index 0000000000000000000000000000000000000000..42bd86d77504943b42acc2d4db5be30dd6979149 --- /dev/null +++ b/UniRig/launch/inference/extract.sh @@ -0,0 +1,60 @@ +# extract mesh +config="configs/data/quick_inference.yaml" +require_suffix="obj,fbx,FBX,dae,glb,gltf,vrm" +num_runs=1 +force_override="false" +faces_target_count=50000 + +while [[ "$#" -gt 0 ]]; do + case $1 in + --config) config="$2"; shift ;; + --require_suffix) require_suffix="$2"; shift ;; + --num_runs) num_runs="$2"; shift ;; + --force_override) force_override="$2"; shift ;; + --faces_target_count) faces_target_count="$2"; shift ;; + --time) time="$2"; shift ;; + --input) input="$2"; shift ;; + --input_dir) input_dir="$2"; shift ;; + --output_dir) output_dir="$2"; shift ;; + *) echo "Unknown parameter: $1"; exit 1 ;; + esac + shift +done + +# ensure psutil is installed for memory management +pip install psutil --quiet +if [ $? -ne 0 ]; then + echo "Warning: Failed to install psutil. Memory management may not work properly." +fi + +# set the time for all processes to use +time=$(date "+%Y_%m_%d_%H_%M_%S") + +for (( i=0; i Box: + if path.endswith('.yaml'): + path = path.removesuffix('.yaml') + path += '.yaml' + print(f"\033[92mload {task} config: {path}\033[0m") + return Box(yaml.safe_load(open(path, 'r'))) + +def nullable_string(val): + if not val: + return None + return val + +if __name__ == "__main__": + torch.set_float32_matmul_precision('high') + + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, required=True) + parser.add_argument("--seed", type=int, required=False, default=123, + help="random seed") + parser.add_argument("--input", type=nullable_string, required=False, default=None, + help="a single input file or files splited by comma") + parser.add_argument("--input_dir", type=nullable_string, required=False, default=None, + help="input directory") + parser.add_argument("--output", type=nullable_string, required=False, default=None, + help="filename for a single output") + parser.add_argument("--output_dir", type=nullable_string, required=False, default=None, + help="output directory") + parser.add_argument("--npz_dir", type=nullable_string, required=False, default='tmp', + help="intermediate npz directory") + parser.add_argument("--cls", type=nullable_string, required=False, default=None, + help="class name") + parser.add_argument("--data_name", type=nullable_string, required=False, default=None, + help="npz filename from skeleton phase") + args = parser.parse_args() + + L.seed_everything(args.seed, workers=True) + + task = load('task', args.task) + mode = task.mode + assert mode in ['predict'] + + if args.input is not None or args.input_dir is not None: + assert args.output_dir is not None or args.output is not None, 'output or output_dir must be specified' + assert args.npz_dir is not None, 'npz_dir must be specified' + files = get_files( + data_name=task.components.data_name, + inputs=args.input, + input_dataset_dir=args.input_dir, + output_dataset_dir=args.npz_dir, + force_override=True, + warning=False, + ) + files = [f[1] for f in files] + if len(files) > 1 and args.output is not None: + print("\033[92mwarning: output is specified, but multiple files are detected. Output will be written.\033[0m") + datapath = Datapath(files=files, cls=args.cls) + else: + datapath = None + + data_config = load('data', os.path.join('configs/data', task.components.data)) + transform_config = load('transform', os.path.join('configs/transform', task.components.transform)) + + # get tokenizer + tokenizer_config = task.components.get('tokenizer', None) + if tokenizer_config is not None: + tokenizer_config = load('tokenizer', os.path.join('configs/tokenizer', task.components.tokenizer)) + tokenizer_config = TokenizerConfig.parse(config=tokenizer_config) + + # get data name + data_name = task.components.get('data_name', 'raw_data.npz') + if args.data_name is not None: + data_name = args.data_name + + # get predict dataset + predict_dataset_config = data_config.get('predict_dataset_config', None) + if predict_dataset_config is not None: + predict_dataset_config = DatasetConfig.parse(config=predict_dataset_config).split_by_cls() + + # get predict transform + predict_transform_config = transform_config.get('predict_transform_config', None) + if predict_transform_config is not None: + predict_transform_config = TransformConfig.parse(config=predict_transform_config) + + # get model + model_config = task.components.get('model', None) + if model_config is not None: + model_config = load('model', os.path.join('configs/model', model_config)) + if tokenizer_config is not None: + tokenizer = get_tokenizer(config=tokenizer_config) + else: + tokenizer = None + model = get_model(tokenizer=tokenizer, **model_config) + else: + model = None + + # set data + data = UniRigDatasetModule( + process_fn=None if model is None else model._process_fn, + predict_dataset_config=predict_dataset_config, + predict_transform_config=predict_transform_config, + tokenizer_config=tokenizer_config, + debug=False, + data_name=data_name, + datapath=datapath, + cls=args.cls, + ) + + # add call backs + callbacks = [] + + ## get checkpoint callback + checkpoint_config = task.get('checkpoint', None) + if checkpoint_config is not None: + checkpoint_config['dirpath'] = os.path.join('experiments', task.experiment_name) + callbacks.append(ModelCheckpoint(**checkpoint_config)) + + ## get writer callback + writer_config = task.get('writer', None) + if writer_config is not None: + assert predict_transform_config is not None, 'missing predict_transform_config in transform' + if args.output_dir is not None or args.output is not None: + if args.output is not None: + assert args.output.endswith('.fbx'), 'output must be .fbx' + writer_config['npz_dir'] = args.npz_dir + writer_config['output_dir'] = args.output_dir + writer_config['output_name'] = args.output + writer_config['user_mode'] = True + callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config)) + + # get trainer + trainer_config = task.get('trainer', {}) + + # get system + system_config = task.components.get('system', None) + if system_config is not None: + system_config = load('system', os.path.join('configs/system', system_config)) + system = get_system( + **system_config, + model=model, + steps_per_epoch=1, + ) + else: + system = None + + logger = None + + # set ckpt path + resume_from_checkpoint = task.get('resume_from_checkpoint', None) + resume_from_checkpoint = download(resume_from_checkpoint) + trainer = L.Trainer( + callbacks=callbacks, + logger=logger, + **trainer_config, + ) + + if mode == 'predict': + assert resume_from_checkpoint is not None, 'expect resume_from_checkpoint in task' + trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False) + else: + assert 0 \ No newline at end of file diff --git a/UniRig/src/data/__init__.py b/UniRig/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/UniRig/src/data/asset.py b/UniRig/src/data/asset.py new file mode 100644 index 0000000000000000000000000000000000000000..24f56e470d6f7b58e6d42f4a55dd0b93a5cdf63f --- /dev/null +++ b/UniRig/src/data/asset.py @@ -0,0 +1,433 @@ +from collections import defaultdict +from dataclasses import dataclass +import numpy as np +from numpy import ndarray + +from typing import Dict, Union, List, Tuple + +from .order import Order +from .raw_data import RawData +from .exporter import Exporter + +from ..tokenizer.spec import TokenizeInput +from .utils import linear_blend_skinning + +import trimesh + + +@dataclass +class Asset(Exporter): + ''' + Dataclass to handle data parsed from raw data. + ''' + + # data class + cls: str + + # where is this asset from + path: str + + # data file name + data_name: str + + # vertices of the mesh, shape (N, 3), float32 + vertices: ndarray + + # normals of vertices, shape (N, 3), float32 + vertex_normals: ndarray + + # faces of mesh, shape (F, 3), face id starts from 0 to F-1, int64 + faces: ndarray + + # face normal of mesh, shape (F, 3), float32 + face_normals: ndarray + + # joints of bones, shape (J, 3), float32 + joints: Union[ndarray, None]=None + + # tails of joints, shape (J, 3), float32 + tails: Union[ndarray, None]=None + + # skinning of joints, shape (N, J), float32 + skin: Union[ndarray, None]=None + + # whether the joint has skin, bool + no_skin: Union[ndarray, None]=None + + # vertex groups + vertex_groups: Union[Dict[str, ndarray], None]=None + + # parents of joints, None represents no parent(a root joint) + # make sure parent[k] < k + parents: Union[List[Union[int, None]], None]=None + + # names of joints + names: Union[List[str], None]=None + + # sampled vertices, shape (N, 3) + sampled_vertices: Union[ndarray, None]=None + + # sampled normals, shape (N, 3) + sampled_normals: Union[ndarray, None]=None + + # sampled vertex groups, every vertex group should be (N, J) + sampled_vertex_groups: Union[Dict[str, ndarray], None]=None + + # {id: part}, part==None -> a spring token + parts_bias: Union[Dict[int, Union[str, None]], None]=None + + # local coordinate, shape (J, 4, 4) + matrix_local: Union[ndarray, None]=None + + # pose matrix for skinning loss calculation, shape (J, 4, 4) + pose_matrix: Union[ndarray, None]=None + + meta: Union[Dict[str, ...], None]=None + + @property + def N(self): + ''' + number of vertices + ''' + return self.vertices.shape[0] + + @property + def F(self): + ''' + number of faces + ''' + return self.faces.shape[0] + + @property + def J(self): + ''' + number of joints + ''' + return self.joints.shape[0] + + def get_matrix(self, matrix_basis: ndarray, matrix_local: Union[ndarray, None]=None): + ''' + get matrix + + matrix_basis: (J, 4, 4) + ''' + if matrix_local is None: + assert self.joints is not None + matrix_local = self.matrix_local + if matrix_local is None: + matrix_local = np.zeros((self.J, 4, 4)) + matrix_local[:, 0, 0] = 1. + matrix_local[:, 1, 1] = 1. + matrix_local[:, 2, 2] = 1. + matrix_local[:, 3, 3] = 1. + for i in range(self.J): + matrix_local[i, :3, 3] = self.joints[i] + + matrix = np.zeros((self.J, 4, 4)) + for i in range(self.J): + if i==0: + matrix[i] = matrix_local[i] @ matrix_basis[i] + else: + pid = self.parents[i] + matrix_parent = matrix[pid] + matrix_local_parent = matrix_local[pid] + + matrix[i] = ( + matrix_parent @ + (np.linalg.inv(matrix_local_parent) @ matrix_local[i]) @ + matrix_basis[i] + ) + return matrix + + def apply_matrix_basis(self, matrix_basis: ndarray): + ''' + apply a pose to armature + + matrix_basis: (J, 4, 4) + ''' + matrix_local = self.matrix_local + if matrix_local is None: + matrix_local = np.zeros((self.J, 4, 4)) + matrix_local[:, 0, 0] = 1. + matrix_local[:, 1, 1] = 1. + matrix_local[:, 2, 2] = 1. + matrix_local[:, 3, 3] = 1. + for i in range(self.J): + matrix_local[i, :3, 3] = self.joints[i].copy() + + matrix = self.get_matrix(matrix_basis=matrix_basis, matrix_local=matrix_local) + self.joints = matrix[:, :3, 3].copy() + vertices = linear_blend_skinning(self.vertices, matrix_local, matrix, self.skin, pad=1, value=1.) + # update matrix_local + self.matrix_local = matrix.copy() + + # change tails + if self.tails is not None: + t_skin = np.eye(self.J) + self.tails = linear_blend_skinning(self.tails, matrix_local, matrix, t_skin, pad=1, value=1.) + # in accordance with trimesh's normals + mesh = trimesh.Trimesh(vertices=vertices, faces=self.faces, process=False) + self.vertices = vertices + self.vertex_normals = mesh.vertex_normals.copy() + self.face_normals = mesh.face_normals.copy() + + def set_order_by_names(self, new_names: List[str]): + assert len(new_names) == len(self.names) + name_to_id = {name: id for (id, name) in enumerate(self.names)} + new_name_to_id = {name: id for (id, name) in enumerate(new_names)} + perm = [] + new_parents = [] + for (new_id, name) in enumerate(new_names): + perm.append(name_to_id[name]) + pid = self.parents[name_to_id[name]] + if new_id == 0: + assert pid is None, 'first bone is not root bone' + else: + pname = self.names[pid] + pid = new_name_to_id[pname] + assert pid < new_id, 'new order does not form a tree' + new_parents.append(pid) + + if self.joints is not None: + self.joints = self.joints[perm] + self.parents = new_parents + if self.tails is not None: + self.tails = self.tails[perm] + if self.skin is not None: + self.skin = self.skin[:, perm] + if self.no_skin is not None: + self.no_skin = self.no_skin[perm] + if self.matrix_local is not None: + self.matrix_local = self.matrix_local[perm] + self.names = new_names + + def set_order(self, order: Order): + if self.names is None or self.parents is None: + return + new_names, self.parts_bias = order.arrange_names(cls=self.cls, names=self.names, parents=self.parents) + self.set_order_by_names(new_names=new_names) + + def collapse(self, keep: List[str]): + dsu = [i for i in range(self.J)] + + def find(x: int) -> int: + if dsu[x] == x: + return x + y = find(dsu[x]) + dsu[x] = y + return y + + def merge(x: int, y: int): + dsu[find(x)] = find(y) + + if self.tails is not None: + new_tails = self.tails.copy() + else: + new_tails = None + if self.skin is not None: + new_skin = self.skin.copy() + else: + new_skin = None + + if self.no_skin is not None: + new_no_skin = self.no_skin.copy() + else: + new_no_skin = None + + if self.matrix_local is not None: + matrix_local = self.matrix_local.copy() + else: + matrix_local = None + new_names = [] + new_parents = [] + perm = [] + new_name_to_id = {} + tot = 0 + for (i, name) in enumerate(self.names): + if name in keep: + new_names.append(name) + new_name_to_id[name] = tot + tot += 1 + perm.append(i) + pid = self.parents[i] + if pid is None: + new_parents.append(None) + else: + pid = find(pid) + new_parents.append(new_name_to_id[self.names[pid]]) + continue + assert i != 0, 'cannot remove root' + id = find(i) + pid = find(self.parents[id]) + # be careful ! + # do not copy tail here because you dont know which child to inherit from + if new_skin is not None: + new_skin[:, pid] += new_skin[:, id] + if new_no_skin is not None: + new_no_skin[pid] &= new_no_skin[id] + merge(id, pid) + + if new_tails is not None: + new_tails = new_tails[perm] + if new_skin is not None: + new_skin = new_skin[:, perm] + if new_no_skin is not None: + new_no_skin = new_no_skin[perm] + if matrix_local is not None: + matrix_local = matrix_local[perm] + + if self.joints is not None: + self.joints = self.joints[perm] + self.parents = new_parents + self.tails = new_tails + self.skin = new_skin + self.no_skin = new_no_skin + self.names = new_names + self.matrix_local = matrix_local + + @staticmethod + def from_raw_data( + raw_data: RawData, + cls: str, + path: str, + data_name: str, + ) -> 'Asset': + ''' + Return an asset initialized from raw data and do transform. + ''' + return Asset( + cls=cls, + path=path, + data_name=data_name, + vertices=raw_data.vertices, + vertex_normals=raw_data.vertex_normals, + faces=raw_data.faces, + face_normals=raw_data.face_normals, + joints=raw_data.joints, + tails=raw_data.tails, + skin=raw_data.skin, + no_skin=raw_data.no_skin, + parents=raw_data.parents, + names=raw_data.names, + matrix_local=raw_data.matrix_local, + meta={}, + ) + + def get_tokenize_input(self) -> TokenizeInput: + children = defaultdict(list) + + for (id, p) in enumerate(self.parents): + if p is not None: + children[p].append(id) + bones = [] + branch = [] + is_leaf = [] + last = None + for i in range(self.J): + is_leaf.append(len(children[i])==0) + if i == 0: + bones.append(np.concatenate([self.joints[i], self.joints[i]])) + branch.append(False) + else: + pid = self.parents[i] + bones.append(np.concatenate([self.joints[pid], self.joints[i]])) + branch.append(pid!=last) + last = i + bones = np.stack(bones) + branch = np.array(branch, dtype=bool) + is_leaf = np.array(is_leaf, dtype=bool) + return TokenizeInput( + bones=bones, + tails=self.tails, + branch=branch, + is_leaf=is_leaf, + no_skin=self.no_skin, + cls=self.cls, + parts_bias=self.parts_bias, + ) + + def export_pc(self, path: str, with_normal: bool=True, normal_size=0.01): + ''' + export point cloud + ''' + vertices = self.vertices + normals = self.vertex_normals + if self.sampled_vertices is not None: + vertices = self.sampled_vertices + normals = self.sampled_normals + if with_normal == False: + normals = None + self._export_pc(vertices=vertices, path=path, vertex_normals=normals, normal_size=normal_size) + + def export_mesh(self, path: str): + ''' + export mesh + ''' + self._export_mesh(vertices=self.vertices, faces=self.faces, path=path) + + def export_skeleton(self, path: str): + ''' + export spring + ''' + self._export_skeleton(joints=self.joints, parents=self.parents, path=path) + + def export_skeleton_sequence(self, path: str): + ''' + export spring + ''' + self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path) + + def export_fbx( + self, + path: str, + vertex_group_name: str, + extrude_size: float=0.03, + group_per_vertex: int=-1, + add_root: bool=False, + do_not_normalize: bool=False, + use_extrude_bone: bool=True, + use_connect_unique_child: bool=True, + extrude_from_parent: bool=True, + use_tail: bool=False, + use_origin: bool=False, + ): + ''' + export the whole model with skining + ''' + self._export_fbx( + path=path, + vertices=self.vertices if use_origin else self.sampled_vertices, + joints=self.joints, + skin=self.sampled_vertex_groups[vertex_group_name], + parents=self.parents, + names=self.names, + faces=self.faces if use_origin else None, + extrude_size=extrude_size, + group_per_vertex=group_per_vertex, + add_root=add_root, + do_not_normalize=do_not_normalize, + use_extrude_bone=use_extrude_bone, + use_connect_unique_child=use_connect_unique_child, + extrude_from_parent=extrude_from_parent, + tails=self.tails if use_tail else None, + ) + + def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256], use_tail: bool=False): + if use_tail: + assert self.tails is not None + self._export_render( + path=path, + vertices=self.vertices, + faces=self.faces, + bones=np.concatenate([self.joints, self.tails], axis=-1), + resolution=resolution, + ) + else: + pjoints = self.joints[self.parents[1:]] + self._export_render( + path=path, + vertices=self.vertices, + faces=self.faces, + bones=np.concatenate([pjoints, self.joints[1:]], axis=-1), + resolution=resolution, + ) \ No newline at end of file diff --git a/UniRig/src/data/augment.py b/UniRig/src/data/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..80e2e9810384b898ae0762d506974d9c4be0eced --- /dev/null +++ b/UniRig/src/data/augment.py @@ -0,0 +1,152 @@ +from dataclasses import dataclass +from typing import Tuple, Union, List, Dict +from numpy import ndarray +import numpy as np +from abc import ABC, abstractmethod +from scipy.spatial.transform import Rotation as R + +from .spec import ConfigSpec +from .asset import Asset +from .utils import axis_angle_to_matrix + +@dataclass(frozen=True) +class AugmentAffineConfig(ConfigSpec): + # final normalization cube + normalize_into: Tuple[float, float] + + # randomly scale coordinates with probability p + random_scale_p: float + + # scale range (lower, upper) + random_scale: Tuple[float, float] + + # randomly shift coordinates with probability p + random_shift_p: float + + # shift range (lower, upper) + random_shift: Tuple[float, float] + + @classmethod + def parse(cls, config) -> Union['AugmentAffineConfig', None]: + if config is None: + return None + cls.check_keys(config) + return AugmentAffineConfig( + normalize_into=config.normalize_into, + random_scale_p=config.get('random_scale_p', 0.), + random_scale=config.get('random_scale', [1., 1.]), + random_shift_p=config.get('random_shift_p', 0.), + random_shift=config.get('random_shift', [0., 0.]), + ) + +@dataclass(frozen=True) +class AugmentConfig(ConfigSpec): + ''' + Config to handle final easy augmentation of vertices, normals and bones before sampling. + ''' + augment_affine_config: Union[AugmentAffineConfig, None] + + @classmethod + def parse(cls, config) -> 'AugmentConfig': + cls.check_keys(config) + return AugmentConfig( + augment_affine_config=AugmentAffineConfig.parse(config.get('augment_affine_config', None)), + ) + +class Augment(ABC): + ''' + Abstract class for augmentation + ''' + def __init__(self): + pass + + @abstractmethod + def transform(self, asset: Asset, **kwargs): + pass + + @abstractmethod + def inverse(self, asset: Asset): + pass + +class AugmentAffine(Augment): + + def __init__(self, config: AugmentAffineConfig): + super().__init__() + self.config = config + + def _apply(self, v: ndarray, trans: ndarray) -> ndarray: + return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3] + + def transform(self, asset: Asset, **kwargs): + bound_min = asset.vertices.min(axis=0) + bound_max = asset.vertices.max(axis=0) + if asset.joints is not None: + joints_bound_min = asset.joints.min(axis=0) + joints_bound_max = asset.joints.max(axis=0) + bound_min = np.minimum(bound_min, joints_bound_min) + bound_max = np.maximum(bound_max, joints_bound_max) + + trans_vertex = np.eye(4, dtype=np.float32) + + trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex + + # scale into the cube + normalize_into = self.config.normalize_into + scale = np.max((bound_max - bound_min) / (normalize_into[1] - normalize_into[0])) + trans_vertex = _scale_to_m(1. / scale) @ trans_vertex + + bias = (normalize_into[0] + normalize_into[1]) / 2 + trans_vertex = _trans_to_m(np.array([bias, bias, bias], dtype=np.float32)) @ trans_vertex + + if np.random.rand() < self.config.random_scale_p: + scale = _scale_to_m(np.random.uniform(self.config.random_scale[0], self.config.random_scale[1])) + trans_vertex = scale @ trans_vertex + + if np.random.rand() < self.config.random_shift_p: + l, r = self.config.random_shift + shift = _trans_to_m(np.array([np.random.uniform(l, r), np.random.uniform(l, r), np.random.uniform(l, r)]), dtype=np.float32) + trans_vertex = shift @ trans_vertex + + asset.vertices = self._apply(asset.vertices, trans_vertex) + # do not affect scale in matrix + if asset.matrix_local is not None: + asset.matrix_local[:, :, 3:4] = trans_vertex @ asset.matrix_local[:, :, 3:4] + if asset.pose_matrix is not None: + asset.pose_matrix[:, :, 3:4] = trans_vertex @ asset.pose_matrix[:, :, 3:4] + # do not affect normal here + if asset.joints is not None: + asset.joints = self._apply(asset.joints, trans_vertex) + if asset.tails is not None: + asset.tails = self._apply(asset.tails, trans_vertex) + + self.trans_vertex = trans_vertex + + def inverse(self, asset: Asset): + m = np.linalg.inv(self.trans_vertex) + asset.vertices = self._apply(asset.vertices, m) + if asset.joints is not None: + asset.joints = self._apply(asset.joints, m) + if asset.tails is not None: + asset.tails = self._apply(asset.tails, m) + +def _trans_to_m(v: ndarray): + m = np.eye(4, dtype=np.float32) + m[0:3, 3] = v + return m + +def _scale_to_m(r: ndarray): + m = np.zeros((4, 4), dtype=np.float32) + m[0, 0] = r + m[1, 1] = r + m[2, 2] = r + m[3, 3] = 1. + return m + +def get_augments(config: AugmentConfig) -> Tuple[List[Augment], List[Augment]]: + first_augments = [] # augments before sample + second_augments = [] # augments after sample + augment_affine_config = config.augment_affine_config + + if augment_affine_config is not None: + second_augments.append(AugmentAffine(config=augment_affine_config)) + return first_augments, second_augments \ No newline at end of file diff --git a/UniRig/src/data/datapath.py b/UniRig/src/data/datapath.py new file mode 100644 index 0000000000000000000000000000000000000000..18c9e6ebc1a23569a6ded4462b8b3d2d812b7a88 --- /dev/null +++ b/UniRig/src/data/datapath.py @@ -0,0 +1,149 @@ +from copy import deepcopy +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, Union, Tuple, List +import numpy as np +from numpy import ndarray +import os +from random import shuffle +from box import Box +from torch.onnx.symbolic_opset11 import index_copy + +from .spec import ConfigSpec + +@dataclass +class DatapathConfig(ConfigSpec): + ''' + Config to handle input data paths. + ''' + # root + input_dataset_dir: str + + # use proportion data sampling + use_prob: bool + + # cls: [(path_1, p_1), ...] + data_path: Dict[str, List[Tuple[str, float]]] + + # how many files to return when using data sampling + num_files: Union[int, None] + + @classmethod + def from_args(cls, **kwargs) -> 'DatapathConfig': + ''' + Make a temporary datapath from user inputs. + ''' + input = kwargs.get('input', None) + output = kwargs.get('output', None) + recursive = kwargs.get('recursive', False) + + + @classmethod + def parse(cls, config) -> 'DatapathConfig': + cls.check_keys(config) + return DatapathConfig( + input_dataset_dir=config.input_dataset_dir, + use_prob=config.get('use_prob', True), + data_path=config.data_path, + num_files=config.get('num_files', None), + ) + + def split_by_cls(self) -> Dict[str, 'DatapathConfig']: + res: Dict[str, DatapathConfig] = {} + for cls in self.data_path: + res[cls] = deepcopy(self) + res[cls].data_path = {cls: self.data_path[cls]} + return res + +class Datapath(): + def __init__( + self, + config: Union[DatapathConfig, None]=None, + files: Union[List[str], None]=None, + cls: Union[str, None]=None, + ): + if config is not None: + self.config = config + self.file_list = [] + cls_probs_first = [] + cls_first = [] + + self.files_by_class: Dict[str, List[Dict]] = defaultdict(list) + self.class_positions: Dict[str, List[int]] = defaultdict(list) + self.cls_probs_second: Dict[str, ndarray] = defaultdict(List) + + for cls in self.config.data_path: + prob = 0. + probs_second = [] + for (path, p) in self.config.data_path[cls]: + prob += p + probs_second.append(p) + with open(path, 'r') as f: + file_items = [] + missing = 0 + for l in f.readlines(): + raw_data_path = os.path.join(self.config.input_dataset_dir, l.strip(), 'raw_data.npz') + if not os.path.exists(raw_data_path): + missing += 1 + continue + file_items.append({ + 'cls': cls, + 'path': os.path.join(self.config.input_dataset_dir, l.strip()), + 'prob': p + }) + assert len(file_items) > 0, f"files in {path} are all missing! root: {self.config.input_dataset_dir}" + if missing > 0: + print(f"\033[31m{cls}: {missing} missing files\033[0m") + self.files_by_class[cls].append(file_items) + self.class_positions[cls].append(0) + self.file_list.extend(file_items) + probs_second = np.array(probs_second) + self.cls_probs_second[cls] = probs_second / probs_second.sum() + cls_first.append(cls) + cls_probs_first.append(prob) + cls_probs_first = np.array(cls_probs_first) + self.cls_first: List[str] = cls_first + self.cls_probs_first: Dict[str, List[float]] = cls_probs_first / cls_probs_first.sum() + elif files is not None: + if cls is None: + cls = 'inference' + self.file_list = [{'cls': cls, 'path': file} for file in files] + cls_probs_first = np.array([1.]) + cls_first = [] + + self.files_by_class: Dict[str, List[Dict]] = {cls: self.file_list.copy()} + self.class_positions: Dict[str, List[int]] = {cls: [0]} + self.cls_probs_second: Dict[str, ndarray] = {cls: np.array([1.])} + self.config = Box({'use_prob': False}) + else: + assert(0) + + def __len__(self): + if self.config.use_prob: + assert self.config.num_files is not None, 'num_files is not specified' + return self.config.num_files + return len(self.file_list) + + def __getitem__(self, index) -> Tuple[str, str]: + if self.config.use_prob: + # first sample a class + cls = np.random.choice(self.cls_first, p=self.cls_probs_first) + + # second sample in this class + idx = np.random.choice(len(self.files_by_class[cls]), p=self.cls_probs_second[cls]) + + # get the current position + pos = self.class_positions[cls][idx] + files = self.files_by_class[cls][idx] + + # get the item andd update position + item = files[pos] + self.class_positions[cls][idx] = (pos + 1) % len(files) + if (pos + 1) % len(files) == 0: + shuffle(self.files_by_class[cls][idx]) + else: + item = self.file_list[index] + return (item['cls'], item['path']) + + def get_data(self) -> List[Tuple[str, str]]: + return [self[i] for i in range(len(self))] \ No newline at end of file diff --git a/UniRig/src/data/dataset.py b/UniRig/src/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6dce4c7a8b8b8239300a4b757f291df8bd0dd0fa --- /dev/null +++ b/UniRig/src/data/dataset.py @@ -0,0 +1,231 @@ +from copy import deepcopy +from dataclasses import dataclass +import lightning.pytorch as pl +# from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +import torch +from torch import LongTensor +from torch.utils import data +from torch.utils.data import DataLoader, Dataset +from typing import Dict, List, Tuple, Union, Callable +import os +import numpy as np + +from .raw_data import RawData +from .asset import Asset +from .transform import TransformConfig, transform_asset +from .datapath import DatapathConfig, Datapath +from .spec import ConfigSpec + +from ..tokenizer.spec import TokenizerSpec, TokenizerConfig +from ..tokenizer.parse import get_tokenizer +from ..model.spec import ModelInput + +@dataclass +class DatasetConfig(ConfigSpec): + ''' + Config to handle dataset format. + ''' + # shuffle dataset + shuffle: bool + + # batch size + batch_size: int + + # number of workers + num_workers: int + + # datapath + datapath_config: DatapathConfig + + # use pin memory + pin_memory: bool = True + + # use persistent workers + persistent_workers: bool = True + + @classmethod + def parse(cls, config) -> 'DatapathConfig': + cls.check_keys(config) + return DatasetConfig( + shuffle=config.shuffle, + batch_size=config.batch_size, + num_workers=config.num_workers, + pin_memory=config.pin_memory, + persistent_workers=config.persistent_workers, + datapath_config=DatapathConfig.parse(config.datapath_config), + ) + + def split_by_cls(self) -> Dict[str, 'DatasetConfig']: + res: Dict[str, DatasetConfig] = {} + datapath_config_dict = self.datapath_config.split_by_cls() + for cls in self.datapath_config.data_path: + res[cls] = deepcopy(self) + res[cls].datapath_config = datapath_config_dict[cls] + return res + +class UniRigDatasetModule(pl.LightningDataModule): + def __init__( + self, + process_fn: Union[Callable[[List[ModelInput]], Dict]]=None, + predict_dataset_config: Union[Dict[str, DatasetConfig], None]=None, + predict_transform_config: Union[TransformConfig, None]=None, + tokenizer_config: Union[TokenizerConfig, None]=None, + debug: bool=False, + data_name: str='raw_data.npz', + datapath: Union[Datapath, None]=None, + cls: Union[str, None]=None, + ): + super().__init__() + self.process_fn = process_fn + self.predict_dataset_config = predict_dataset_config + self.predict_transform_config = predict_transform_config + self.tokenizer_config = tokenizer_config + self.debug = debug + self.data_name = data_name + + if debug: + print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m") + + if datapath is not None: + self.train_datapath = None + self.validate_datapath = None + self.predict_datapath = { + cls: deepcopy(datapath), + } + self.predict_dataset_config = { + cls: DatasetConfig( + shuffle=False, + batch_size=1, + num_workers=0, + datapath_config=deepcopy(datapath), + pin_memory=False, + persistent_workers=False, + ) + } + else: + # build predict datapath + if self.predict_dataset_config is not None: + self.predict_datapath = { + cls: Datapath(self.predict_dataset_config[cls].datapath_config) + for cls in self.predict_dataset_config + } + else: + self.predict_datapath = None + + # get tokenizer + if tokenizer_config is None: + self.tokenizer = None + else: + self.tokenizer = get_tokenizer(config=tokenizer_config) + + def prepare_data(self): + pass + + def setup(self, stage=None): + if self.predict_datapath is not None: + self._predict_ds = {} + for cls in self.predict_datapath: + self._predict_ds[cls] = UniRigDataset( + process_fn=self.process_fn, + data=self.predict_datapath[cls].get_data(), + name=f"predict-{cls}", + tokenizer=self.tokenizer, + transform_config=self.predict_transform_config, + debug=self.debug, + data_name=self.data_name, + ) + + def predict_dataloader(self): + if not hasattr(self, "_predict_ds"): + self.setup() + return self._create_dataloader( + dataset=self._predict_ds, + config=self.predict_dataset_config, + is_train=False, + drop_last=False, + ) + + def _create_dataloader( + self, + dataset: Union[Dataset, Dict[str, Dataset]], + config: DatasetConfig, + is_train: bool, + **kwargs, + ) -> Union[DataLoader, Dict[str, DataLoader]]: + def create_single_dataloader(dataset, config: Union[DatasetConfig, Dict[str, DatasetConfig]], **kwargs): + return DataLoader( + dataset, + batch_size=config.batch_size, + shuffle=config.shuffle, + num_workers=config.num_workers, + pin_memory=config.pin_memory, + persistent_workers=config.persistent_workers, + collate_fn=dataset.collate_fn, + **kwargs, + ) + if isinstance(dataset, Dict): + return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()} + else: + return create_single_dataloader(dataset, config, **kwargs) + +class UniRigDataset(Dataset): + def __init__( + self, + data: List[Tuple[str, str]], # (cls, part) + name: str, + process_fn: Union[Callable[[List[ModelInput]], Dict]]=None, + tokenizer: Union[TokenizerSpec, None]=None, + transform_config: Union[TransformConfig, None]=None, + debug: bool=False, + data_name: str='raw_data.npz', + ) -> None: + super().__init__() + + self.data = data + self.name = name + self.process_fn = process_fn + self.tokenizer = tokenizer + self.transform_config = transform_config + self.debug = debug + self.data_name = data_name + + if not debug: + assert self.process_fn is not None, 'missing data processing function' + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx) -> ModelInput: + cls, dir_path = self.data[idx] + raw_data = RawData.load(path=os.path.join(dir_path, self.data_name)) + asset = Asset.from_raw_data(raw_data=raw_data, cls=cls, path=dir_path, data_name=self.data_name) + + first_augments, second_augments = transform_asset( + asset=asset, + transform_config=self.transform_config, + ) + if self.tokenizer is not None and asset.parents is not None: + tokens = self.tokenizer.tokenize(input=asset.get_tokenize_input()) + else: + tokens = None + return ModelInput( + tokens=tokens, + pad=None if self.tokenizer is None else self.tokenizer.pad, + vertices=asset.sampled_vertices.astype(np.float32), + normals=asset.sampled_normals.astype(np.float32), + joints=None if asset.joints is None else asset.joints.astype(np.float32), + tails=None if asset.tails is None else asset.tails.astype(np.float32), + asset=asset, + augments=None, + ) + + def _collate_fn_debug(self, batch): + return batch + + def _collate_fn(self, batch): + return data.dataloader.default_collate(self.process_fn(batch)) + + def collate_fn(self, batch): + if self.debug: + return self._collate_fn_debug(batch) + return self._collate_fn(batch) \ No newline at end of file diff --git a/UniRig/src/data/exporter.py b/UniRig/src/data/exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..c84ff5f2ce1e76fcceab813552a523b41f6207a6 --- /dev/null +++ b/UniRig/src/data/exporter.py @@ -0,0 +1,486 @@ +import numpy as np +from numpy import ndarray +from typing import List, Union, Tuple +from collections import defaultdict +import os + +try: + import open3d as o3d + OPEN3D_EQUIPPED = True +except: + print("do not have open3d") + OPEN3D_EQUIPPED = False + +class Exporter(): + + def _safe_make_dir(self, path): + if os.path.dirname(path) == '': + return + os.makedirs(os.path.dirname(path), exist_ok=True) + + def _export_skeleton(self, joints: ndarray, parents: List[Union[int, None]], path: str): + format = path.split('.')[-1] + assert format in ['obj'] + name = path.removesuffix('.obj') + path = name + ".obj" + self._safe_make_dir(path) + J = joints.shape[0] + with open(path, 'w') as file: + file.write("o spring_joint\n") + _joints = [] + for id in range(J): + pid = parents[id] + if pid is None or pid == -1: + continue + bx, by, bz = joints[id] + ex, ey, ez = joints[pid] + _joints.extend([ + f"v {bx} {bz} {-by}\n", + f"v {ex} {ez} {-ey}\n", + f"v {ex} {ez} {-ey + 0.00001}\n" + ]) + file.writelines(_joints) + + _faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)] + file.writelines(_faces) + + def _export_bones(self, bones: ndarray, path: str): + format = path.split('.')[-1] + assert format in ['obj'] + name = path.removesuffix('.obj') + path = name + ".obj" + self._safe_make_dir(path) + J = bones.shape[0] + with open(path, 'w') as file: + file.write("o bones\n") + _joints = [] + for bone in bones: + bx, by, bz = bone[:3] + ex, ey, ez = bone[3:] + _joints.extend([ + f"v {bx} {bz} {-by}\n", + f"v {ex} {ez} {-ey}\n", + f"v {ex} {ez} {-ey + 0.00001}\n" + ]) + file.writelines(_joints) + + _faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)] + file.writelines(_faces) + + def _export_skeleton_sequence(self, joints: ndarray, parents: List[Union[int, None]], path: str): + format = path.split('.')[-1] + assert format in ['obj'] + name = path.removesuffix('.obj') + path = name + ".obj" + self._safe_make_dir(path) + J = joints.shape[0] + for i in range(J): + file = open(name + f"_{i}.obj", 'w') + file.write("o spring_joint\n") + _joints = [] + for id in range(i + 1): + pid = parents[id] + if pid is None: + continue + bx, by, bz = joints[id] + ex, ey, ez = joints[pid] + _joints.extend([ + f"v {bx} {bz} {-by}\n", + f"v {ex} {ez} {-ey}\n", + f"v {ex} {ez} {-ey + 0.00001}\n" + ]) + file.writelines(_joints) + + _faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)] + file.writelines(_faces) + file.close() + + def _export_mesh(self, vertices: ndarray, faces: ndarray, path: str): + format = path.split('.')[-1] + assert format in ['obj', 'ply'] + if path.endswith('ply'): + if not OPEN3D_EQUIPPED: + raise RuntimeError("open3d is not available") + mesh = o3d.geometry.TriangleMesh() + mesh.vertices = o3d.utility.Vector3dVector(vertices) + mesh.triangles = o3d.utility.Vector3iVector(faces) + self._safe_make_dir(path) + o3d.io.write_triangle_mesh(path, mesh) + return + name = path.removesuffix('.obj') + path = name + ".obj" + self._safe_make_dir(path) + with open(path, 'w') as file: + file.write("o mesh\n") + _vertices = [] + for co in vertices: + _vertices.append(f"v {co[0]} {co[2]} {-co[1]}\n") + file.writelines(_vertices) + _faces = [] + for face in faces: + _faces.append(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n") + file.writelines(_faces) + + def _export_pc(self, vertices: ndarray, path: str, vertex_normals: Union[ndarray, None]=None, normal_size: float=0.01): + if path.endswith('.ply'): + if vertex_normals is not None: + print("normal result will not be displayed in .ply format") + name = path.removesuffix('.ply') + path = name + ".ply" + pc = o3d.geometry.PointCloud() + pc.points = o3d.utility.Vector3dVector(vertices) + # segment fault when numpy >= 2.0 !! use torch environment + self._safe_make_dir(path) + o3d.io.write_point_cloud(path, pc) + return + name = path.removesuffix('.obj') + path = name + ".obj" + self._safe_make_dir(path) + with open(path, 'w') as file: + file.write("o pc\n") + _vertex = [] + for co in vertices: + _vertex.append(f"v {co[0]} {co[2]} {-co[1]}\n") + file.writelines(_vertex) + if vertex_normals is not None: + new_path = path.replace('.obj', '_normal.obj') + nfile = open(new_path, 'w') + nfile.write("o normal\n") + _normal = [] + for i in range(vertices.shape[0]): + co = vertices[i] + x = vertex_normals[i, 0] + y = vertex_normals[i, 1] + z = vertex_normals[i, 2] + _normal.extend([ + f"v {co[0]} {co[2]} {-co[1]}\n", + f"v {co[0]+0.0001} {co[2]} {-co[1]}\n", + f"v {co[0]+x*normal_size} {co[2]+z*normal_size} {-(co[1]+y*normal_size)}\n", + f"f {i*3+1} {i*3+2} {i*3+3}\n", + ]) + nfile.writelines(_normal) + + def _make_armature( + self, + vertices: Union[ndarray, None], + joints: ndarray, + skin: Union[ndarray, None], + parents: List[Union[int, None]], + names: List[str], + faces: Union[ndarray, None]=None, + extrude_size: float=0.03, + group_per_vertex: int=-1, + add_root: bool=False, + do_not_normalize: bool=False, + use_extrude_bone: bool=True, + use_connect_unique_child: bool=True, + extrude_from_parent: bool=True, + tails: Union[ndarray, None]=None, + ): + import bpy # type: ignore + from mathutils import Vector # type: ignore + + # make collection + collection = bpy.data.collections.new('new_collection') + bpy.context.scene.collection.children.link(collection) + + # make mesh + if vertices is not None: + mesh = bpy.data.meshes.new('mesh') + if faces is None: + faces = [] + mesh.from_pydata(vertices, [], faces) + mesh.update() + + # make object from mesh + object = bpy.data.objects.new('character', mesh) + + # add object to scene collection + collection.objects.link(object) + + # deselect mesh + bpy.ops.object.armature_add(enter_editmode=True) + armature = bpy.data.armatures.get('Armature') + edit_bones = armature.edit_bones + + J = joints.shape[0] + if tails is None: + tails = joints.copy() + tails[:, 2] += extrude_size + connects = [False for _ in range(J)] + children = defaultdict(list) + for i in range(1, J): + children[parents[i]].append(i) + if tails is not None: + if use_extrude_bone: + for i in range(J): + if len(children[i]) != 1 and extrude_from_parent and i != 0: + pjoint = joints[parents[i]] + joint = joints[i] + d = joint - pjoint + if np.linalg.norm(d) < 0.000001: + d = np.array([0., 0., 1.]) # in case son.head == parent.head + else: + d = d / np.linalg.norm(d) + tails[i] = joint + d * extrude_size + if use_connect_unique_child: + for i in range(J): + if len(children[i]) == 1: + child = children[i][0] + tails[i] = joints[child] + if parents[i] is not None and len(children[parents[i]]) == 1: + connects[i] = True + + if add_root: + bone_root = edit_bones.get('Bone') + bone_root.name = 'Root' + bone_root.tail = Vector((joints[0, 0], joints[0, 1], joints[0, 2])) + else: + bone_root = edit_bones.get('Bone') + bone_root.name = names[0] + bone_root.head = Vector((joints[0, 0], joints[0, 1], joints[0, 2])) + bone_root.tail = Vector((joints[0, 0], joints[0, 1], joints[0, 2] + extrude_size)) + + def extrude_bone( + edit_bones, + name: str, + parent_name: str, + head: Tuple[float, float, float], + tail: Tuple[float, float, float], + connect: bool + ): + bone = edit_bones.new(name) + bone.head = Vector((head[0], head[1], head[2])) + bone.tail = Vector((tail[0], tail[1], tail[2])) + bone.name = name + parent_bone = edit_bones.get(parent_name) + bone.parent = parent_bone + bone.use_connect = connect + assert not np.isnan(head).any(), f"nan found in head of bone {name}" + assert not np.isnan(tail).any(), f"nan found in tail of bone {name}" + + for i in range(J): + if add_root is False and i==0: + continue + edit_bones = armature.edit_bones + pname = 'Root' if parents[i] is None else names[parents[i]] + extrude_bone(edit_bones, names[i], pname, joints[i], tails[i], connects[i]) + for i in range(J): + bone = edit_bones.get(names[i]) + bone.head = Vector((joints[i, 0], joints[i, 1], joints[i, 2])) + bone.tail = Vector((tails[i, 0], tails[i, 1], tails[i, 2])) + + if vertices is None or skin is None: + return + # must set to object mode to enable parent_set + bpy.ops.object.mode_set(mode='OBJECT') + objects = bpy.data.objects + for o in bpy.context.selected_objects: + o.select_set(False) + ob = objects['character'] + arm = bpy.data.objects['Armature'] + ob.select_set(True) + arm.select_set(True) + bpy.ops.object.parent_set(type='ARMATURE_NAME') + vis = [] + for x in ob.vertex_groups: + vis.append(x.name) + #sparsify + argsorted = np.argsort(-skin, axis=1) + vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted] + if group_per_vertex == -1: + group_per_vertex = vertex_group_reweight.shape[-1] + if not do_not_normalize: + vertex_group_reweight = vertex_group_reweight / vertex_group_reweight[..., :group_per_vertex].sum(axis=1)[...,None] + + for v, w in enumerate(skin): + for ii in range(group_per_vertex): + i = argsorted[v, ii] + if i >= J: + continue + n = names[i] + if n not in vis: + continue + ob.vertex_groups[n].add([v], vertex_group_reweight[v, ii], 'REPLACE') + + def _clean_bpy(self): + import bpy # type: ignore + for c in bpy.data.actions: + bpy.data.actions.remove(c) + for c in bpy.data.armatures: + bpy.data.armatures.remove(c) + for c in bpy.data.cameras: + bpy.data.cameras.remove(c) + for c in bpy.data.collections: + bpy.data.collections.remove(c) + for c in bpy.data.images: + bpy.data.images.remove(c) + for c in bpy.data.materials: + bpy.data.materials.remove(c) + for c in bpy.data.meshes: + bpy.data.meshes.remove(c) + for c in bpy.data.objects: + bpy.data.objects.remove(c) + for c in bpy.data.textures: + bpy.data.textures.remove(c) + + def _export_fbx( + self, + path: str, + vertices: Union[ndarray, None], + joints: ndarray, + skin: Union[ndarray, None], + parents: List[Union[int, None]], + names: List[str], + faces: Union[ndarray, None]=None, + extrude_size: float=0.03, + group_per_vertex: int=-1, + add_root: bool=False, + do_not_normalize: bool=False, + use_extrude_bone: bool=True, + use_connect_unique_child: bool=True, + extrude_from_parent: bool=True, + tails: Union[ndarray, None]=None, + ): + ''' + Requires bpy installed + ''' + import bpy # type: ignore + self._safe_make_dir(path) + self._clean_bpy() + self._make_armature( + vertices=vertices, + joints=joints, + skin=skin, + parents=parents, + names=names, + faces=faces, + extrude_size=extrude_size, + group_per_vertex=group_per_vertex, + add_root=add_root, + do_not_normalize=do_not_normalize, + use_extrude_bone=use_extrude_bone, + use_connect_unique_child=use_connect_unique_child, + extrude_from_parent=extrude_from_parent, + tails=tails, + ) + + # always enable add_leaf_bones to keep leaf bones + bpy.ops.export_scene.fbx(filepath=path, check_existing=False, add_leaf_bones=False) + + def _export_render( + self, + path: str, + vertices: Union[ndarray, None], + faces: Union[ndarray, None], + bones: Union[ndarray, None], + resolution: Tuple[float, float]=[256, 256], + ): + import bpy # type: ignore + import bpy_extras # type: ignore + from mathutils import Vector # type: ignore + + self._safe_make_dir(path) + # normalize into [-1, 1]^3 + # copied from augment + assert (vertices is not None) or (bones is not None) + bounds = [] + if vertices is not None: + bounds.append(vertices) + if bones is not None: + bounds.append(bones[:, :3]) + bounds.append(bones[:, 3:]) + bounds = np.concatenate(bounds, axis=0) + bound_min = bounds.min(axis=0) + bound_max = bounds.max(axis=0) + + trans_vertex = np.eye(4) + + trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex + + # scale into the cube [-1, 1] + scale = np.max((bound_max - bound_min) / 2) + trans_vertex = _scale_to_m(1. / scale) @ trans_vertex + + def _apply(v: ndarray, trans: ndarray) -> ndarray: + return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3] + + if vertices is not None: + vertices = _apply(vertices, trans_vertex) + if bones is not None: + bones[:, :3] = _apply(bones[:, :3], trans_vertex) + bones[:, 3:] = _apply(bones[:, 3:], trans_vertex) + + # bpy api calls + self._clean_bpy() + bpy.context.scene.render.engine = 'BLENDER_WORKBENCH' + bpy.context.scene.render.film_transparent = True + bpy.context.scene.display.shading.background_type = 'VIEWPORT' + + collection = bpy.data.collections.new('new_collection') + bpy.context.scene.collection.children.link(collection) + + if vertices is not None: + mesh_data = bpy.data.meshes.new(name="MeshData") + mesh_obj = bpy.data.objects.new(name="MeshObject", object_data=mesh_data) + collection.objects.link(mesh_obj) + + mesh_data.from_pydata((vertices).tolist(), [], faces.tolist()) + mesh_data.update() + + def look_at(camera, point): + direction = point - camera.location + rot_quat = direction.to_track_quat('-Z', 'Y') + camera.rotation_euler = rot_quat.to_euler() + + bpy.ops.object.camera_add(location=(4, -4, 2.5)) + camera = bpy.context.object + camera.data.angle = np.radians(25.0) + look_at(camera, Vector((0, 0, -0.2))) + bpy.context.scene.camera = camera + + bpy.context.scene.render.resolution_x = resolution[0] + bpy.context.scene.render.resolution_y = resolution[1] + bpy.context.scene.render.image_settings.file_format = 'PNG' + bpy.context.scene.render.filepath = path + + bpy.ops.render.render(write_still=True) + # some AI generated code to draw bones over mesh + if bones is not None: + # TODO: do not save image after rendering + from PIL import Image, ImageDraw + img_pil = Image.open(path).convert("RGBA") + draw = ImageDraw.Draw(img_pil) + + from bpy_extras.image_utils import load_image # type: ignore + bpy.context.scene.use_nodes = True + nodes = bpy.context.scene.node_tree.nodes + # nodes.clear() + + img = load_image(path) + image_node = nodes.new(type='CompositorNodeImage') + image_node.image = img + + for i, bone in enumerate(bones): + head, tail = bone[:3], bone[3:] + head_2d = bpy_extras.object_utils.world_to_camera_view(bpy.context.scene, camera, Vector(head)) + tail_2d = bpy_extras.object_utils.world_to_camera_view(bpy.context.scene, camera, Vector(tail)) + + res_x, res_y = resolution + head_pix = (head_2d.x * res_x, (1 - head_2d.y) * res_y) + tail_pix = (tail_2d.x * res_x, (1 - tail_2d.y) * res_y) + draw.line([head_pix, tail_pix], fill=(255, 0, 0, 255), width=1) + img_pil.save(path) + +def _trans_to_m(v: ndarray): + m = np.eye(4) + m[0:3, 3] = v + return m + +def _scale_to_m(r: ndarray): + m = np.zeros((4, 4)) + m[0, 0] = r + m[1, 1] = r + m[2, 2] = r + m[3, 3] = 1. + return m \ No newline at end of file diff --git a/UniRig/src/data/extract.py b/UniRig/src/data/extract.py new file mode 100644 index 0000000000000000000000000000000000000000..828915d52c6c08fc1b5b72c2c71a465007f61b5f --- /dev/null +++ b/UniRig/src/data/extract.py @@ -0,0 +1,523 @@ +import bpy, os +from collections import defaultdict +from tqdm import tqdm +import numpy as np +from numpy import ndarray +from typing import Dict, Tuple, List, Optional, Union +import trimesh +import fast_simplification +from scipy.spatial import KDTree + +import argparse +import yaml +from box import Box +import os + +from .log import new_entry, add_error, add_warning, new_log, end_log +from .raw_data import RawData + +def load(filepath: str): + old_objs = set(bpy.context.scene.objects) + + if not os.path.exists(filepath): + raise ValueError(f'File {filepath} does not exist !') + + try: + if filepath.endswith(".vrm"): + # enable vrm addon and load vrm model + bpy.ops.preferences.addon_enable(module='vrm') + + bpy.ops.import_scene.vrm( + filepath=filepath, + use_addon_preferences=True, + extract_textures_into_folder=False, + make_new_texture_folder=False, + set_shading_type_to_material_on_import=False, + set_view_transform_to_standard_on_import=True, + set_armature_display_to_wire=True, + set_armature_display_to_show_in_front=True, + set_armature_bone_shape_to_default=True, + disable_bake=True, # customized option for better performance + ) + elif filepath.endswith(".obj"): + bpy.ops.wm.obj_import(filepath=filepath) + elif filepath.endswith(".fbx") or filepath.endswith(".FBX"): + # end bone is removed using remove_dummy_bone + bpy.ops.import_scene.fbx(filepath=filepath, ignore_leaf_bones=False, use_image_search=False) + elif filepath.endswith(".glb") or filepath.endswith(".gltf"): + bpy.ops.import_scene.gltf(filepath=filepath, import_pack_images=False) + elif filepath.endswith(".dae"): + bpy.ops.wm.collada_import(filepath=filepath) + elif filepath.endswith(".blend"): + with bpy.data.libraries.load(filepath) as (data_from, data_to): + data_to.objects = data_from.objects + for obj in data_to.objects: + if obj is not None: + bpy.context.collection.objects.link(obj) + else: + raise ValueError(f"not suported type {filepath}") + except: + raise ValueError(f"failed to load {filepath}") + + armature = [x for x in set(bpy.context.scene.objects)-old_objs if x.type=="ARMATURE"] + if len(armature)==0: + return None + if len(armature)>1: + raise ValueError(f"multiple armatures found") + armature = armature[0] + + armature.select_set(True) + bpy.context.view_layer.objects.active = armature + bpy.ops.object.mode_set(mode='EDIT') + for bone in bpy.data.armatures[0].edit_bones: + bone.roll = 0. # change all roll to 0. to prevent weird behaviour + + bpy.ops.object.mode_set(mode='OBJECT') + armature.select_set(False) + + bpy.ops.object.select_all(action='DESELECT') + return armature + +# remove all data in bpy +def clean_bpy(): + # First try to purge orphan data + try: + bpy.ops.outliner.orphans_purge(do_local_ids=True, do_linked_ids=True, do_recursive=True) + except Exception as e: + print(f"Warning: Could not purge orphans: {e}") + + # Then remove all data by type + data_types = [ + bpy.data.actions, + bpy.data.armatures, + bpy.data.cameras, + bpy.data.collections, + bpy.data.curves, + bpy.data.images, + bpy.data.lights, + bpy.data.materials, + bpy.data.meshes, + bpy.data.objects, + bpy.data.textures, + bpy.data.worlds, + bpy.data.node_groups + ] + + for data_collection in data_types: + try: + for item in data_collection: + try: + data_collection.remove(item) + except Exception as e: + print(f"Warning: Could not remove {item.name} from {data_collection}: {e}") + except Exception as e: + print(f"Warning: Error processing {data_collection}: {e}") + + # Force garbage collection to free memory + import gc + gc.collect() + +def get_arranged_bones(armature): + matrix_world = armature.matrix_world + arranged_bones = [] + root = armature.pose.bones[0] + while root.parent is not None: + root = root.parent + Q = [root] + rot = np.array(matrix_world)[:3, :3] + + # dfs and sort + while len(Q) != 0: + b = Q.pop(0) + arranged_bones.append(b) + children = [] + for cb in b.children: + head = rot @ np.array(b.head) + children.append((cb, head[0], head[1], head[2])) + children = sorted(children, key=lambda x: (x[3], x[1], x[2])) + _c = [x[0] for x in children] + Q = _c + Q + return arranged_bones + +def process_mesh(): + meshes = [] + for v in bpy.data.objects: + if v.type == 'MESH': + meshes.append(v) + + _dict_mesh = {} + for obj in meshes: + m = np.array(obj.matrix_world) + matrix_world_rot = m[:3, :3] + matrix_world_bias = m[:3, 3] + rot = matrix_world_rot + total_vertices = len(obj.data.vertices) + vertex = np.zeros((4, total_vertices)) + vertex_normal = np.zeros((total_vertices, 3)) + obj_verts = obj.data.vertices + faces = [] + normals = [] + + for v in obj_verts: + vertex_normal[v.index] = rot @ np.array(v.normal) # be careful ! + vv = rot @ v.co + vv = np.array(vv) + matrix_world_bias + vertex[0:3, v.index] = vv + vertex[3][v.index] = 1 # affine coordinate + + for polygon in obj.data.polygons: + edges = polygon.edge_keys + nodes = [] + adj = {} + for edge in edges: + if adj.get(edge[0]) is None: + adj[edge[0]] = [] + adj[edge[0]].append(edge[1]) + if adj.get(edge[1]) is None: + adj[edge[1]] = [] + adj[edge[1]].append(edge[0]) + nodes.append(edge[0]) + nodes.append(edge[1]) + normal = polygon.normal + nodes = list(set(sorted(nodes))) + first = nodes[0] + loop = [] + now = first + vis = {} + while True: + loop.append(now) + vis[now] = True + if vis.get(adj[now][0]) is None: + now = adj[now][0] + elif vis.get(adj[now][1]) is None: + now = adj[now][1] + else: + break + for (second, third) in zip(loop[1:], loop[2:]): + faces.append((first + 1, second + 1, third + 1)) # the cursed +1 + normals.append(rot @ normal) # and the cursed normal of BLENDER + + correct_faces = [] + for (i, face) in enumerate(faces): + normal = normals[i] + v0 = face[0] - 1 + v1 = face[1] - 1 + v2 = face[2] - 1 + v = np.cross( + vertex[:3, v1] - vertex[:3, v0], + vertex[:3, v2] - vertex[:3, v0], + ) + if (v*normal).sum() > 0: + correct_faces.append(face) + else: + correct_faces.append((face[0], face[2], face[1])) + if len(correct_faces) > 0: + _dict_mesh[obj.name] = { + 'vertex': vertex, + 'face': correct_faces, + } + + vertex = np.concatenate([_dict_mesh[name]['vertex'] for name in _dict_mesh], axis=1)[:3, :].transpose() + + total_faces = 0 + now_bias = 0 + for name in _dict_mesh: + total_faces += len(_dict_mesh[name]['face']) + faces = np.zeros((total_faces, 3), dtype=np.int64) + tot = 0 + for name in _dict_mesh: + f = np.array(_dict_mesh[name]['face'], dtype=np.int64) + faces[tot:tot+f.shape[0]] = f + now_bias + now_bias += _dict_mesh[name]['vertex'].shape[1] + tot += f.shape[0] + + return vertex, faces + +def process_armature( + armature, + arranged_bones, +) -> Tuple[np.ndarray, np.ndarray]: + matrix_world = armature.matrix_world + index = {} + + for (id, pbone) in enumerate(arranged_bones): + index[pbone.name] = id + + root = armature.pose.bones[0] + while root.parent is not None: + root = root.parent + m = np.array(matrix_world.to_4x4()) + scale_inv = np.linalg.inv(np.diag(matrix_world.to_scale())) + rot = m[:3, :3] + bias = m[:3, 3] + + s = [] + bpy.ops.object.editmode_toggle() + edit_bones = armature.data.edit_bones + + J = len(arranged_bones) + joints = np.zeros((J, 3), dtype=np.float32) + tails = np.zeros((J, 3), dtype=np.float32) + parents = [] + name_to_id = {} + names = [] + matrix_local_stack = np.zeros((J, 4, 4), dtype=np.float32) + for (id, pbone) in enumerate(arranged_bones): + name = pbone.name + names.append(name) + matrix_local = np.array(pbone.bone.matrix_local) + use_inherit_rotation = pbone.bone.use_inherit_rotation + if use_inherit_rotation == False: + add_warning(f"use_inherit_rotation of bone {name} is False !") + head = rot @ matrix_local[0:3, 3] + bias + s.append(head) + edit_bone = edit_bones.get(name) + tail = rot @ np.array(edit_bone.tail) + bias + + name_to_id[name] = id + joints[id] = head + tails[id] = tail + parents.append(None if pbone.parent not in arranged_bones else name_to_id[pbone.parent.name]) + # remove scale part + matrix_local[:, 3:4] = m @ matrix_local[:, 3:4] + matrix_local[:3, :3] = scale_inv @ matrix_local[:3, :3] + matrix_local_stack[id] = matrix_local + bpy.ops.object.editmode_toggle() + + return joints, tails, parents, names, matrix_local_stack + +def save_raw_data( + path: str, + vertices: ndarray, + faces: ndarray, + joints: Union[ndarray, None], + tails: Union[ndarray, None], + parents: Union[List[Union[int, None]], None], + names: Union[List[str], None], + matrix_local: Union[ndarray, None], + target_count: int, +): + mesh = trimesh.Trimesh(vertices=vertices, faces=faces) + vertices = np.array(mesh.vertices, dtype=np.float32) + faces = np.array(mesh.faces, dtype=np.int64) + if faces.shape[0] > target_count: + vertices, faces = fast_simplification.simplify(vertices, faces, target_count=target_count) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces) + + new_vertices = np.array(mesh.vertices, dtype=np.float32) + new_vertex_normals = np.array(mesh.vertex_normals, dtype=np.float32) + new_faces = np.array(mesh.faces, dtype=np.int64) + new_face_normals = np.array(mesh.face_normals, dtype=np.float32) + if joints is not None: + new_joints = np.array(joints, dtype=np.float32) + else: + new_joints = None + raw_data = RawData( + vertices=new_vertices, + vertex_normals=new_vertex_normals, + faces=new_faces, + face_normals=new_face_normals, + joints=new_joints, + tails=tails, + skin=None, + no_skin=None, + parents=parents, + names=names, + matrix_local=matrix_local, + ) + raw_data.check() + raw_data.save(path=path) + +def extract_builtin( + output_folder: str, + target_count: int, + num_runs: int, + id: int, + time: str, + files: List[Union[str, str]], +): + log_path = "./logs" + log_path = os.path.join(log_path, time) + + num_files = len(files) + gap = num_files // num_runs + start = gap * id + end = gap * (id + 1) + if id+1==num_runs: + end = num_files + + files = sorted(files) + if end!=-1: + files = files[:end] + new_log(log_path, f"extract_builtin_{start}_{end}") + tot = 0 + for file in tqdm(files[start:]): + input_file = file[0] + output_dir = file[1] + clean_bpy() + new_entry(input_file) + try: + print(f"Now processing {input_file}...") + + armature = load(input_file) + + print('save to:', output_dir) + os.makedirs(output_dir, exist_ok=True) + + vertices, faces = process_mesh() + if armature is not None: + arranged_bones = get_arranged_bones(armature) + joints, tails, parents, names, matrix_local = process_armature(armature, arranged_bones) + + else: + joints = None + tails = None + parents = None + names = None + matrix_local = None + + save_file = os.path.join(output_dir, 'raw_data.npz') + save_raw_data( + path=save_file, + vertices=vertices, + faces=faces-1, + joints=joints, + tails=tails, + parents=parents, + names=names, + matrix_local=matrix_local, + target_count=target_count, + ) + + tot += 1 + + except ValueError as e: + add_error(str(e)) + print(f"ValueError: {str(e)}") + except RuntimeError as e: + add_error(str(e)) + print(f"RuntimeError: {str(e)}") + except TimeoutError as e: + add_error("time out") + print("TimeoutError: Processing timed out") + except Exception as e: + add_error(f"Unexpected error: {str(e)}") + print(f"Unexpected error: {str(e)}") + end_log() + print(f"{tot} models processed") + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def nullable_string(val): + if not val: + return None + return val + +def get_files( + data_name: str, + input_dataset_dir: str, + output_dataset_dir: str, + inputs: Union[str, None]=None, + require_suffix: List[str]=['obj','fbx','FBX','dae','glb','gltf','vrm'], + force_override: bool=False, + warning: bool=True, +) -> List[Tuple[str, str]]: + + files = [] # (input_file, output_dir) + if inputs is not None: # specified input file(s) + vis = {} + inputs = inputs.split(',') + for file in inputs: + file_name = file.removeprefix("./") + # remove suffix + file_name = '.'.join(file_name.split('.')[:-1]) + output_dir = os.path.join(output_dataset_dir, file_name) + raw_data_npz = os.path.join(output_dir, data_name) + if not force_override and os.path.exists(raw_data_npz): + continue + if warning and output_dir in vis: + print(f"\033[33mWARNING: duplicate output directory: {output_dir}, you need to rename prefix of files to avoid ambiguity\033[0m") + vis[output_dir] = True + files.append((file, output_dir)) + else: + vis = {} + for root, dirs, f in os.walk(input_dataset_dir): + for file in f: + if file.split('.')[-1] in require_suffix: + file_name = file.removeprefix("./") + # remove suffix + file_name = '.'.join(file_name.split('.')[:-1]) + + output_dir = os.path.join(output_dataset_dir, os.path.relpath(root, input_dataset_dir), file_name) + raw_data_npz = os.path.join(output_dir, data_name) + + # Check if all required files exist + if not force_override and os.path.exists(raw_data_npz): + continue + if warning and output_dir in vis: + print(f"\033[33mWARNING: duplicate output directory: {output_dir}, you need to rename prefix of files to avoid ambiguity\033[0m") + vis[output_dir] = True + files.append((os.path.join(root, file), output_dir)) + + return files + +def parse(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True) + parser.add_argument('--require_suffix', type=str, required=True) + parser.add_argument('--faces_target_count', type=int, required=True) + parser.add_argument('--num_runs', type=int, required=True) + parser.add_argument('--force_override', type=str2bool, required=True) + parser.add_argument('--id', type=int, required=True) + parser.add_argument('--time', type=str, required=True) + + parser.add_argument('--input', type=nullable_string, required=False, default=None) + parser.add_argument('--input_dir', type=nullable_string, required=False, default=None) + parser.add_argument('--output_dir', type=nullable_string, required=False, default=None) + return parser.parse_args() + +if __name__ == "__main__": + args = parse() + + config = Box(yaml.safe_load(open(args.config, "r"))) + + num_runs = args.num_runs + id = args.id + timestamp = args.time + require_suffix = args.require_suffix.split(',') + force_override = args.force_override + target_count = args.faces_target_count + + if args.input_dir: + config.input_dataset_dir = args.input_dir + if args.output_dir: + config.output_dataset_dir = args.output_dir + + assert config.input_dataset_dir is not None or args.input is None, 'you cannot specify both input and input_dir' + + files = get_files( + data_name='raw_data.npz', + inputs=args.input, + input_dataset_dir=config.input_dataset_dir, + output_dataset_dir=config.output_dataset_dir, + require_suffix=require_suffix, + force_override=force_override, + warning=True, + ) + + extract_builtin( + output_folder=config.output_dataset_dir, + target_count=target_count, + num_runs=num_runs, + id=id, + time=timestamp, + files=files, + ) \ No newline at end of file diff --git a/UniRig/src/data/log.py b/UniRig/src/data/log.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf3d211728aa5548de7eb240cdd914ed01a593c --- /dev/null +++ b/UniRig/src/data/log.py @@ -0,0 +1,50 @@ +import os +from typing import List + +login_time = '' +log_filepath = '' + +class Entry: + def __init__(self, entry_name): + self.entry = entry_name + self.error = None + self.warning = [] + + def have_error(self): + return self.error != None + + def have_warning(self): + return len(self.warning) != 0 + +logs: List[Entry] = [] + +def new_log(path, log_name): + global login_time, log_filepath + log_filepath = os.path.join(path, f"{log_name}.txt") + os.makedirs(path, exist_ok=True) + with open(log_filepath, 'a') as file: + file.write(f"Log: {log_name}\n") + +def end_log(): + global log_filepath + with open(log_filepath, 'a') as file: + file.write(f"End of file\n") + +def new_entry(entry_name): + global log_filepath + print(f"\033[32mNow processing {entry_name}...\033[0m") + logs.append(Entry(entry_name)) + +def add_error(error): + global log_filepath + print(f"\033[31mError found when processing {logs[-1].entry}: {error}\033[0m") + logs[-1].error = error + with open(log_filepath, 'a') as file: + file.write(f"Entry: {logs[-1].entry}, Error: {error}\n") + +def add_warning(warning): + global log_filepath + print(f"\033[33mWarning found when processing {logs[-1].entry}: {warning}\033[0m") + logs[-1].warning.append(warning) + with open(log_filepath, 'a') as file: + file.write(f"Entry: {logs[-1].entry}, Warning: {warning}\n") \ No newline at end of file diff --git a/UniRig/src/data/order.py b/UniRig/src/data/order.py new file mode 100644 index 0000000000000000000000000000000000000000..a33deeed9ae177d6991e7444f31650e4902c7ab5 --- /dev/null +++ b/UniRig/src/data/order.py @@ -0,0 +1,112 @@ +from typing import Dict, List, Tuple, Union +from collections import defaultdict +from dataclasses import dataclass +import yaml +from box import Box + +from .spec import ConfigSpec + +@dataclass +class OrderConfig(ConfigSpec): + ''' + Config to handle bones re-ordering. + ''' + + # {skeleton_name: path} + skeleton_path: Dict[str, str] + + # {cls: {part_name: [bone_name_1, bone_name_2, ...]}} + parts: Dict[str, Dict[str, List[str]]] + + # {cls: parts of bones to be arranged in [part_name_1, part_name_2, ...]} + parts_order: Dict[str, List[str]] + + @classmethod + def parse(cls, config): + cls.check_keys(config) + skeleton_path = config.skeleton_path + parts = {} + parts_order = {} + for (cls, path) in skeleton_path.items(): + assert cls not in parts, 'cls conflicts' + d = Box(yaml.safe_load(open(path, 'r'))) + parts[cls] = d.parts + parts_order[cls] = d.parts_order + return OrderConfig( + skeleton_path=skeleton_path, + parts=parts, + parts_order=parts_order, + ) + +class Order(): + + # {part_name: [bone_name_1, bone_name_2, ...]} + parts: Dict[str, Dict[str, List[str]]] + + # parts of bones to be arranged in [part_name_1, part_name_2, ...] + parts_order: Dict[str, List[str]] + + def __init__(self, config: OrderConfig): + self.parts = config.parts + self.parts_order = config.parts_order + + def part_exists(self, cls: str, part: str, names: List[str]) -> bool: + ''' + Check if part exists. + ''' + if part not in self.parts[cls]: + return False + for name in self.parts[cls][part]: + if name not in names: + return False + return True + + def make_names(self, cls: Union[str, None], parts: List[Union[str, None]], num_bones: int) -> List[str]: + ''' + Get names for specified cls. + ''' + names = [] + for part in parts: + if part is None: # spring + continue + if cls in self.parts and part in self.parts[cls]: + names.extend(self.parts[cls][part]) + assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones" + for i in range(len(names), num_bones): + names.append(f"bone_{i}") + return names + + def arrange_names(self, cls: str, names: List[str], parents: List[Union[int, None]]) -> Tuple[List[str], Dict[int, Union[str]]]: + ''' + Arrange names according to required parts order. + ''' + if cls not in self.parts_order: + return names, {0: None} # add a spring token + vis = defaultdict(bool) + name_to_id = {name: i for (i, name) in enumerate(names)} + new_names = [] + parts_bias = {} + for part in self.parts_order[cls]: + if self.part_exists(cls=cls, part=part, names=names): + for name in self.parts[cls][part]: + vis[name] = True + flag = False + for name in self.parts[cls][part]: + pid = parents[name_to_id[name]] + if pid is None: + continue + if not vis[names[pid]]: + flag = True + break + if flag: # incorrect parts order and should immediately add a spring token + break + parts_bias[len(new_names)] = part + new_names.extend(self.parts[cls][part]) + parts_bias[len(new_names)] = None # add a spring token + for name in names: + if name not in new_names: + new_names.append(name) + return new_names, parts_bias + +def get_order(config: OrderConfig) -> Order: + return Order(config=config) \ No newline at end of file diff --git a/UniRig/src/data/raw_data.py b/UniRig/src/data/raw_data.py new file mode 100644 index 0000000000000000000000000000000000000000..286342f6f7d4b47aa31f62908ee6d92b90d13131 --- /dev/null +++ b/UniRig/src/data/raw_data.py @@ -0,0 +1,307 @@ +from dataclasses import dataclass +import numpy as np +from numpy import ndarray + +import os +from typing import Union, List, Tuple + +from .exporter import Exporter + +from ..tokenizer.spec import DetokenzeOutput +from .order import Order + +@dataclass(frozen=True) +class RawData(Exporter): + ''' + Dataclass to handle data from processed model files. + ''' + + # vertices of the mesh, shape (N, 3), float32 + vertices: Union[ndarray, None] + + # normals of vertices, shape (N, 3), float32 + vertex_normals: Union[ndarray, None] + + # faces of mesh, shape (F, 3), face id starts from 0 to F-1, int64 + faces: Union[ndarray, None] + + # face normal of mesh, shape (F, 3), float32 + face_normals: Union[ndarray, None] + + # joints of bones, shape (J, 3), float32 + joints: Union[ndarray, None] + + # tails of joints, shape (J, 3), float32 + tails: Union[ndarray, None] + + # skinning of joints, shape (N, J), float32 + skin: Union[ndarray, None] + + # whether the joint has skin, bool + no_skin: Union[ndarray, None] + + # parents of joints, None represents no parent(a root joint) + # make sure parent[k] < k + parents: Union[List[Union[int, None]], None] + + # names of joints + names: Union[List[str], None] + + # local coordinate + matrix_local: Union[ndarray, None] + + # path to data + path: Union[str, None]=None + + # data cls + cls: Union[str, None]=None + + @staticmethod + def load(path: str) -> 'RawData': + data = np.load(path, allow_pickle=True) + d = {name: data[name][()] for name in data} + d['path'] = path + return RawData(**d) + + def save(self, path: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + np.savez(file=path, **self.__dict__) + + @property + def N(self): + ''' + number of vertices + ''' + return self.vertices.shape[0] + + @property + def F(self): + ''' + number of faces + ''' + return self.faces.shape[0] + + @property + def J(self): + ''' + number of joints + ''' + return self.joints.shape[0] + + def check(self): + if self.names is not None and self.joints is not None: + assert len(self.names) == self.J + if self.names is not None and self.parents is not None: + assert len(self.names) == len(self.parents) + if self.parents is not None: + for (i, pid) in enumerate(self.parents): + if i==0: + assert pid is None + else: + assert pid is not None + assert pid < i + + def export_pc(self, path: str, with_normal: bool=True, normal_size=0.01): + ''' + export point cloud + ''' + if with_normal: + self._export_pc(vertices=self.vertices, path=path, vertex_normals=self.vertex_normals, normal_size=normal_size) + else: + self._export_pc(vertices=self.vertices, path=path, vertex_normals=None, normal_size=normal_size) + + def export_mesh(self, path: str): + ''' + export mesh + ''' + self._export_mesh(vertices=self.vertices, faces=self.faces, path=path) + + def export_skeleton(self, path: str): + ''' + export spring + ''' + self._export_skeleton(joints=self.joints, parents=self.parents, path=path) + + def export_skeleton_sequence(self, path: str): + ''' + export spring + ''' + self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path) + + def export_fbx( + self, + path: str, + extrude_size: float=0.03, + group_per_vertex: int=-1, + add_root: bool=False, + do_not_normalize: bool=False, + use_extrude_bone: bool=True, + use_connect_unique_child: bool=True, + extrude_from_parent: bool=True, + use_tail: bool=False, + custom_vertex_group: Union[ndarray, None]=None, + ): + ''' + export the whole model with skining + ''' + self._export_fbx( + path=path, + vertices=self.vertices, + joints=self.joints, + skin=self.skin if custom_vertex_group is None else custom_vertex_group, + parents=self.parents, + names=self.names, + faces=self.faces, + extrude_size=extrude_size, + group_per_vertex=group_per_vertex, + add_root=add_root, + do_not_normalize=do_not_normalize, + use_extrude_bone=use_extrude_bone, + use_connect_unique_child=use_connect_unique_child, + extrude_from_parent=extrude_from_parent, + tails=self.tails if use_tail else None, + ) + + def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256]): + self._export_render( + path=path, + vertices=self.vertices, + faces=self.faces, + bones=np.concatenate([self.joints, self.tails], axis=-1), + resolution=resolution, + ) + +@dataclass(frozen=True) +class RawSkeleton(Exporter): + ''' + Dataclass to handle skeleton from AR. + ''' + # joints of bones, shape (J, 3), float32 + joints: Union[ndarray, None] + + # tails of joints, shape (J, 3), float32 + tails: Union[ndarray, None] + + # whether the joint has skin, bool + no_skin: Union[ndarray, None] + + # parents of joints, None represents no parent(a root joint) + # make sure parent[k] < k + parents: Union[List[Union[int, None]], None] + + # names of joints + names: Union[List[str], None] + + @staticmethod + def load(path: str) -> 'RawSkeleton': + data = np.load(path, allow_pickle=True) + return RawSkeleton(**{name: data[name][()] for name in data}) + + def save(self, path: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + np.savez(file=path, **self.__dict__) + + @staticmethod + def from_detokenize_output(res: DetokenzeOutput, order: Union[Order, None]) -> 'RawSkeleton': + J = len(res.bones) + names = order.make_names(cls=res.cls, parts=res.parts, num_bones=J) + joints = res.joints + p_joints = res.p_joints + parents = [] + for (i, joint) in enumerate(joints): + if i == 0: + parents.append(None) + continue + p_joint = p_joints[i] + dis = 999999 + pid = None + for j in reversed(range(i)): + n_dis = ((joints[j] - p_joint)**2).sum() + if n_dis < dis: + pid = j + dis = n_dis + parents.append(pid) + return RawSkeleton( + joints=joints, + tails=res.tails, + no_skin=res.no_skin, + parents=parents, + names=names, + ) + + def export_skeleton(self, path: str): + ''' + export spring + ''' + self._export_skeleton(joints=self.joints, parents=self.parents, path=path) + + def export_skeleton_sequence(self, path: str): + ''' + export spring + ''' + self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path) + + def export_fbx( + self, + path: str, + extrude_size: float=0.03, + group_per_vertex: int=-1, + add_root: bool=False, + do_not_normalize: bool=False, + use_extrude_bone: bool=True, + use_connect_unique_child: bool=True, + extrude_from_parent: bool=True, + use_tail: bool=False, + ): + ''' + export the whole model with skining + ''' + self._export_fbx( + path=path, + vertices=None, + joints=self.joints, + skin=None, + parents=self.parents, + names=self.names, + faces=None, + extrude_size=extrude_size, + group_per_vertex=group_per_vertex, + add_root=add_root, + do_not_normalize=do_not_normalize, + use_extrude_bone=use_extrude_bone, + use_connect_unique_child=use_connect_unique_child, + extrude_from_parent=extrude_from_parent, + tails=self.tails if use_tail else None, + ) + + def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256]): + self._export_render( + path=path, + vertices=None, + faces=None, + bones=np.concatenate([self.joints, self.tails], axis=-1), + resolution=resolution, + ) + +@dataclass +class RawSkin(Exporter): + ''' + Dataclass to handle skeleton from AR. + ''' + # skin, shape (J, N) + skin: ndarray + + # always sampled, shape (N, 3) + vertices: Union[ndarray, None]=None + + # for future use, shape (J, 3) + joints: Union[ndarray, None]=None + + @staticmethod + def load(path: str) -> 'RawSkin': + data = np.load(path, allow_pickle=True) + return RawSkin(**{name: data[name][()] for name in data}) + + def save(self, path: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + np.savez(file=path, **self.__dict__) \ No newline at end of file diff --git a/UniRig/src/data/sampler.py b/UniRig/src/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..919bb14a514747b45afd16eaa9481f7c4798dccb --- /dev/null +++ b/UniRig/src/data/sampler.py @@ -0,0 +1,210 @@ +from typing import List +from heapq import heappush, heappop, heapify +from dataclasses import dataclass +from abc import ABC, abstractmethod +import numpy as np +from numpy import ndarray + +from typing import Dict, Tuple + +from .asset import Asset +from .spec import ConfigSpec + +@dataclass +class SamplerConfig(ConfigSpec): + ''' + Config to handle bones re-ordering. + ''' + # which sampler to use + method: str + + # how many samples in total + num_samples: int + + # how many vertex samples + vertex_samples: int + + # kwargs + kwargs: Dict[str, Dict] + + @classmethod + def parse(cls, config) -> 'SamplerConfig': + cls.check_keys(config) + return SamplerConfig( + method=config.method, + num_samples=config.get('num_samples', 0), + vertex_samples=config.get('vertex_samples', 0), + kwargs=config.get('kwargs', {}), + ) + +@dataclass +class SamplerResult(): + # sampled vertices + vertices: ndarray + + # sampled normals + normals: ndarray + + # sampled vertex groups + vertex_groups: Dict[str, ndarray] + +class Sampler(ABC): + ''' + Abstract class for samplers. + ''' + + def _sample_barycentric( + self, + vertex_group: ndarray, + faces: ndarray, + face_index: ndarray, + random_lengths: ndarray, + ): + v_origins = vertex_group[faces[face_index, 0]] + v_vectors = vertex_group[faces[face_index, 1:]] + v_vectors -= v_origins[:, np.newaxis, :] + + sample_vector = (v_vectors * random_lengths).sum(axis=1) + v_samples = sample_vector + v_origins + return v_samples + + @abstractmethod + def __init__(self, config: SamplerConfig): + pass + + @abstractmethod + def sample( + self, + asset: Asset, + ) -> SamplerResult: + ''' + Return sampled vertices, sampled normals and vertex groups. + ''' + pass + +class SamplerOrigin(Sampler): + def __init__(self, config: SamplerConfig): + super().__init__(config) + self.num_samples = config.num_samples + self.vertex_samples = config.vertex_samples + + def sample( + self, + asset: Asset, + ) -> SamplerResult: + perm = np.random.permutation(asset.vertices.shape[0]) + if asset.vertices.shape[0] < self.num_samples: + m = self.num_samples - asset.vertices.shape[0] + perm = np.concatenate([perm, np.random.randint(0, asset.vertices.shape[0], (m,))]) + perm = perm[:self.num_samples] + n_v = asset.vertices[perm] + n_n = asset.vertex_normals[perm] + n_vg = {name: v[perm] for name, v in asset.vertex_groups.items()} + return SamplerResult( + vertices=n_v, + normals=n_n, + vertex_groups=n_vg, + ) + +class SamplerMix(Sampler): + def __init__(self, config: SamplerConfig): + super().__init__(config) + self.num_samples = config.num_samples + self.vertex_samples = config.vertex_samples + assert self.num_samples >= self.vertex_samples, 'num_samples should >= vertex_samples' + + @property + def mesh_preserve(self): + return self.num_samples==-1 + + def sample( + self, + asset: Asset, + ) -> SamplerResult: + # 1. sample vertices + num_samples = self.num_samples + perm = np.random.permutation(asset.vertices.shape[0]) + vertex_samples = min(self.vertex_samples, asset.vertices.shape[0]) + num_samples -= vertex_samples + perm = perm[:vertex_samples] + n_vertex = asset.vertices[perm] + n_normal = asset.vertex_normals[perm] + n_v = {name: v[perm] for name, v in asset.vertex_groups.items()} + + # 2. sample surface + perm = np.random.permutation(num_samples) + vertex_samples, face_index, random_lengths = sample_surface( + num_samples=num_samples, + vertices=asset.vertices, + faces=asset.faces, + return_weight=True, + ) + vertex_samples = np.concatenate([n_vertex, vertex_samples], axis=0) + normal_samples = np.concatenate([n_normal, asset.face_normals[face_index]], axis=0) + vertex_group_samples = {} + for n, v in asset.vertex_groups.items(): + g = self._sample_barycentric( + vertex_group=v, + faces=asset.faces, + face_index=face_index, + random_lengths=random_lengths, + ) + vertex_group_samples[n] = np.concatenate([n_v[n], g], axis=0) + return SamplerResult( + vertices=vertex_samples, + normals=normal_samples, + vertex_groups=vertex_group_samples, + ) + +def sample_surface( + num_samples: int, + vertices: ndarray, + faces: ndarray, + return_weight: bool=False, +): + ''' + Randomly pick samples according to face area. + + See sample_surface: https://github.com/mikedh/trimesh/blob/main/trimesh/sample.py + ''' + # get face area + offset_0 = vertices[faces[:, 1]] - vertices[faces[:, 0]] + offset_1 = vertices[faces[:, 2]] - vertices[faces[:, 0]] + face_weight = np.cross(offset_0, offset_1, axis=-1) + face_weight = (face_weight * face_weight).sum(axis=1) + + weight_cum = np.cumsum(face_weight, axis=0) + face_pick = np.random.rand(num_samples) * weight_cum[-1] + face_index = np.searchsorted(weight_cum, face_pick) + + # pull triangles into the form of an origin + 2 vectors + tri_origins = vertices[faces[:, 0]] + tri_vectors = vertices[faces[:, 1:]] + tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3)) + + # pull the vectors for the faces we are going to sample from + tri_origins = tri_origins[face_index] + tri_vectors = tri_vectors[face_index] + + # randomly generate two 0-1 scalar components to multiply edge vectors b + random_lengths = np.random.rand(len(tri_vectors), 2, 1) + + random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0 + random_lengths[random_test] -= 1.0 + random_lengths = np.abs(random_lengths) + + sample_vector = (tri_vectors * random_lengths).sum(axis=1) + vertex_samples = sample_vector + tri_origins + if not return_weight: + return vertex_samples + return vertex_samples, face_index, random_lengths + +def get_sampler(config: SamplerConfig) -> Sampler: + method = config.method + if method=='origin': + sampler = SamplerOrigin(config) + elif method=='mix': + sampler = SamplerMix(config) + else: + raise ValueError(f"sampler method {method} not supported") + return sampler \ No newline at end of file diff --git a/UniRig/src/data/spec.py b/UniRig/src/data/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..9331427e1edd1ef69ad87f9b65a29ca0bb71f9f8 --- /dev/null +++ b/UniRig/src/data/spec.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod +from dataclasses import fields + +class ConfigSpec(ABC): + @classmethod + def check_keys(cls, config): + expect = [field.name for field in fields(cls)] + for key in config.keys(): + if key not in expect: + raise ValueError(f"expect names {expect} in {cls.__name__}, found {key}") + + @classmethod + @abstractmethod + def parse(cls, config): + pass diff --git a/UniRig/src/data/tail.py b/UniRig/src/data/tail.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6307ae63c97efdd515ecf715a130319fff1542 --- /dev/null +++ b/UniRig/src/data/tail.py @@ -0,0 +1,50 @@ +from collections import defaultdict +from dataclasses import dataclass +import numpy as np +from numpy import ndarray + +from typing import Tuple + +from .asset import Asset +from .spec import ConfigSpec + +@dataclass +class TailConfig(ConfigSpec): + ''' + Config to handle tails. + ''' + + # copy joints to tails + copy_joint_to_tail: bool + + # if the joint has only one son, then connect tail to son's joint + connect_tail_to_unique_son: bool + + @classmethod + def parse(cls, config) -> 'TailConfig': + cls.check_keys(config) + return TailConfig( + copy_joint_to_tail=config.copy_joint_to_tail, + connect_tail_to_unique_son=config.connect_tail_to_unique_son, + ) + +class Tail(): + + def __init__(self, config: TailConfig): + self.config = config + + def process_tail(self, asset: Asset): + if self.config.copy_joint_to_tail: + assert asset.tails is None, 'copying joints to existing tails is not permitted, please change copy_joint_to_tail to False in transform config' + asset.tails = asset.joints.copy() + if self.config.connect_tail_to_unique_son and asset.tails is not None: + children = defaultdict(list) + for (id, p) in enumerate(asset.parents): + if p is not None: + children[p].append(id) + for i in range(asset.J): + if len(children[i]) == 1: + asset.tails[i] = asset.joints[children[i][0]] + +def get_tail(config: TailConfig) -> Tail: + return Tail(config=config) \ No newline at end of file diff --git a/UniRig/src/data/transform.py b/UniRig/src/data/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..018124f17e1758d41062200ecaccf02a5fbbb6ae --- /dev/null +++ b/UniRig/src/data/transform.py @@ -0,0 +1,107 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import Union, List, Tuple +from copy import deepcopy + +from .asset import Asset +from .augment import AugmentConfig, Augment, get_augments +from .order import OrderConfig, Order, get_order +from .sampler import SamplerConfig, get_sampler +from .vertex_group import VertexGroupConfig, get_vertex_groups +from .tail import TailConfig, get_tail +from .spec import ConfigSpec + +@dataclass +class TransformConfig(ConfigSpec): + + tail_config: Union[TailConfig, None]=None, + + order_config: Union[OrderConfig, None]=None, + + vertex_group_config: Union[VertexGroupConfig, None]=None, + + augment_config: Union[AugmentConfig, None]=None, + + sampler_config: Union[SamplerConfig, None]=None, + + @classmethod + def parse(cls, config) -> 'TransformConfig': + cls.check_keys(config) + tail_config = config.get('tail_config', None) + order_config = config.get('order_config', None) + vertex_group_config = config.get('vertex_group_config', None) + augment_config = config.get('augment_config', None) + sampler_config = config.get('sampler_config', None) + + if tail_config is not None: + tail_config = TailConfig.parse(config=tail_config) + if order_config is not None: + order_config = OrderConfig.parse(config=order_config) + if vertex_group_config is not None: + vertex_group_config = VertexGroupConfig.parse(config=vertex_group_config) + if augment_config is not None: + augment_config = AugmentConfig.parse(config=augment_config) + if sampler_config is not None: + sampler_config = SamplerConfig.parse(config=sampler_config) + + return TransformConfig( + tail_config=tail_config, + order_config=order_config, + vertex_group_config=vertex_group_config, + augment_config=augment_config, + sampler_config=sampler_config, + ) + +def transform_asset( + asset: Asset, + transform_config: TransformConfig, +) -> Tuple[List[Augment], List[Augment]]: + assert isinstance(transform_config, TransformConfig), f"found {type(transform_config)}" + # 1. try processing tails + # TODO: use a better method + if transform_config.tail_config is not None: + tail = get_tail(config=transform_config.tail_config) + tail.process_tail(asset=asset) + + # 2. arrange bones + if transform_config.order_config is not None: + order = get_order(config=transform_config.order_config) + asset.set_order(order=order) + + # 3. collapse must perform first + if transform_config.augment_config: + first_augments, second_augments = get_augments(config=transform_config.augment_config) + else: + first_augments = [] + second_augments = [] + + kwargs = {} + for augment in first_augments: + augment.transform(asset=asset, **kwargs) + + # 4. get vertex groups + if transform_config.vertex_group_config is not None: + vertex_groups = get_vertex_groups(config=transform_config.vertex_group_config) + d = {} + for v in vertex_groups: + d.update(v.get_vertex_group(asset=asset)) + asset.vertex_groups = d + else: + asset.vertex_groups = {} + + # 5. regular augments + for augment in second_augments: + augment.transform(asset=asset, **kwargs) + + # 6. sample + if transform_config.sampler_config is not None: + sampler = get_sampler(config=transform_config.sampler_config) + res = sampler.sample(asset=asset) + asset.sampled_vertices = res.vertices + asset.sampled_normals = res.normals + asset.sampled_vertex_groups = res.vertex_groups + else: + asset.sampled_vertices = asset.vertices.copy() + asset.sampled_normals = asset.vertex_normals.copy() + asset.sampled_vertex_groups = deepcopy(asset.vertex_groups) + return first_augments, second_augments \ No newline at end of file diff --git a/UniRig/src/data/utils.py b/UniRig/src/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aec48872ce4332d02017f4ffaf3ee8e7aad6348e --- /dev/null +++ b/UniRig/src/data/utils.py @@ -0,0 +1,258 @@ +import torch +import numpy as np +from numpy import ndarray +from torch import Tensor, FloatTensor +from typing import Tuple, Union + +from scipy.spatial.transform import Rotation as R +from scipy.sparse import csc_matrix +import numpy as np + +def quaternion_to_matrix(x, use_4x4=True) -> FloatTensor: + """ + Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#quaternion_to_matrix + """ + if not isinstance(x, Tensor): + quaternions = torch.tensor(x, dtype=torch.float32) + else: + quaternions = x + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + device = quaternions.device + + if use_4x4: + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32), + torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32), + torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32), + torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32), + torch.ones(quaternions.shape[:-1], device=device, dtype=torch.float32), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (4, 4)) + else: + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + +def axis_angle_to_quaternion(axis_angle: FloatTensor) -> FloatTensor: + """ + Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#axis_angle_to_quaternion + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + +def axis_angle_to_matrix(axis_angle: Union[FloatTensor, ndarray]) -> Union[FloatTensor, ndarray]: + """ + Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#axis_angle_to_matrix + """ + if isinstance(axis_angle, FloatTensor): + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + else: + res = np.pad(R.from_rotvec(axis_angle).as_matrix(), ((0, 0), (0, 1), (0, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0))) + assert res.ndim == 3 + res[:, -1, -1] = 1 + return res + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[ + torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + return standardize_quaternion(out) + +def linear_blend_skinning( + vertex: Union[FloatTensor, ndarray], + matrix_local: Union[FloatTensor, ndarray], + matrix: Union[FloatTensor, ndarray], + skin: Union[FloatTensor, ndarray], + pad: int=0, + value: float=0., +) -> Union[FloatTensor, ndarray]: + ''' + Args: + vertex: (B, N, 4-pad) or (N, 4-pad) + matrix_local: (B, J, 4, 4) or (J, 4, 4) + matrix: (B, J, 4, 4) or (J, 4, 4) + skin: (B, N, J) or (N, J), value of pseudo bones should be 0 + Returns: + (B, N, 3) or (N, 3) + ''' + assert vertex.shape[-1] + pad == 4 + if isinstance(vertex, Tensor): + dims = vertex.dim() + elif isinstance(vertex, ndarray): + dims = vertex.ndim + else: + raise NotImplementedError() + if dims == 3: # Case: (B, N, 3+pad) + assert isinstance(vertex, Tensor) + J = matrix_local.shape[1] + # (B, J, 3+pad, N) + offset = ( + matrix_local.inverse() @ + torch.nn.functional.pad(vertex, (0, pad, 0, 0, 0, 0), value=value).unsqueeze(1).transpose(2, 3).repeat(1, J, 1, 1) + ) + # (B, J, 4, N) + per_bone_matrix = matrix @ offset + # (B, J, 4, N) + weighted_per_bone_matrix = skin.transpose(1, 2).unsqueeze(2) * per_bone_matrix + # (B, 3, N) + g = weighted_per_bone_matrix.sum(dim=1) + # (B, 3, N) + final = g[:, 0:3, :] / (skin.transpose(1, 2).sum(dim=1) + 1e-8).unsqueeze(1) + return final.permute(0, 2, 1) + + elif dims == 2: # Case: (N, 3+pad) + if isinstance(vertex, Tensor): + J = matrix_local.shape[0] + offset = ( + matrix_local.inverse() @ + torch.nn.functional.pad(vertex, (0, pad, 0, 0), value=value).unsqueeze(0).transpose(1, 2).repeat(J, 1, 1) + ) + per_bone_matrix = matrix @ offset + weighted_per_bone_matrix = skin.transpose(0, 1).unsqueeze(1) * per_bone_matrix + g = weighted_per_bone_matrix.sum(dim=0) + final = g[0:3, :] / (skin.transpose(0, 1).sum(dim=0) + 1e-8).unsqueeze(0) + return final.permute(1, 0) # Output shape (N, 3) + else: + J = matrix_local.shape[0] + N = vertex.shape[0] + # (4, N) + padded = np.pad(vertex, ((0, 0), (0, pad)), 'constant', constant_values=(0, value)).T + # (J, 4, 4) + trans = matrix @ np.linalg.inv(matrix_local) + weighted_per_bone_matrix = [] + # (J, N) + mask = (skin > 0).T + for i in range(J): + offset = np.zeros((4, N), dtype=np.float32) + offset[:, mask[i]] = (trans[i] @ padded[:, mask[i]]) * skin.T[i, mask[i]] + weighted_per_bone_matrix.append(offset) + weighted_per_bone_matrix = np.stack(weighted_per_bone_matrix) + g = np.sum(weighted_per_bone_matrix, axis=0) + final = g[:3, :] / (np.sum(skin, axis=1) + 1e-8) + return final.T + else: + assert 0, f'unsupported shape: {vertex.shape}' diff --git a/UniRig/src/data/vertex_group.py b/UniRig/src/data/vertex_group.py new file mode 100644 index 0000000000000000000000000000000000000000..645e065712cd825452c359f201893ad28656d916 --- /dev/null +++ b/UniRig/src/data/vertex_group.py @@ -0,0 +1,577 @@ +import platform +import os +if platform.system() == "Linux": + os.environ['PYOPENGL_PLATFORM'] = 'egl' + +from typing import Dict, List, Tuple +from dataclasses import dataclass +from collections import defaultdict +from abc import ABC, abstractmethod +import numpy as np +from numpy import ndarray + +from scipy.spatial import cKDTree +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import shortest_path, connected_components + +from .asset import Asset +from .spec import ConfigSpec + +@dataclass +class VertexGroupConfig(ConfigSpec): + ''' + Config to sample vertex group. + ''' + + # names + names: List[str] + + # kwargs + kwargs: Dict[str, Dict] + + @classmethod + def parse(cls, config) -> 'VertexGroupConfig': + cls.check_keys(config) + return VertexGroupConfig( + names=config.get('names', []), + kwargs=config.get('kwargs', {}), + ) + +class VertexGroup(ABC): + + @abstractmethod + def __init__(self, **kwargs): + pass + + @abstractmethod + def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]: + pass + +class VertexGroupSkin(VertexGroup): + ''' + Capture skin. + ''' + + def __init__(self, **kwargs): + pass + + def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]: + return { + 'skin': asset.skin / (asset.skin.sum(axis=-1, keepdims=True) + 1e-6), + } + +class VertexGroupGeodesicDistance(VertexGroup): + ''' + Calculate geodesic distance. + ''' + def __init__(self, **kwargs): + self.deterministic = kwargs.get('deterministic', False) + self.soft_mask = kwargs.get('soft_mask', False) + + def _prepare( + self, + joints: ndarray, # (J, 3) + edges: List[Tuple[int, int]], + ) -> Tuple[ndarray, ndarray]: + J = joints.shape[0] + dis_matrix = np.ones((J, J)) * 100.0 + step_matrix = np.ones((J, J)) * 100.0 + def dis(x: ndarray, y: ndarray): + return np.linalg.norm(x-y) + for i in range(J): + dis_matrix[i, i] = 0. + step_matrix[i, i] = 0. + for edge in edges: + dis_matrix[edge[0], edge[1]] = dis(joints[edge[0]], joints[edge[1]]) + dis_matrix[edge[1], edge[0]] = dis(joints[edge[0]], joints[edge[1]]) + step_matrix[edge[0], edge[1]] = 1 + step_matrix[edge[1], edge[0]] = 1 + # floyd + for k in range(J): + dis_matrix = np.minimum(dis_matrix, dis_matrix[:, k][:, np.newaxis] + dis_matrix[k, :][np.newaxis, :]) + step_matrix = np.minimum(step_matrix, step_matrix[:, k][:, np.newaxis] + step_matrix[k, :][np.newaxis, :]) + return dis_matrix, step_matrix + + def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]: + children = defaultdict(list) + edges = [] + for (id, p) in enumerate(asset.parents): + if p is not None: + edges.append((id, p)) + children[p].append(id) + child = [] + tails = asset.tails.copy() + for id in range(asset.J): + if len(children[id]) == 1: + child.append(children[id][0]) + else: + child.append(id) + if self.deterministic: + tails[id] = asset.joints[id] + child = np.array(child) + dis_matrix, step_matrix = self._prepare( + joints=asset.joints, + edges=edges, + ) + geo_dis, geo_mask = get_geodesic_distance( + vertices=asset.vertices, + joints=asset.joints, + tails=tails, + dis_matrix=dis_matrix, + step_matrix=step_matrix, + child=child, + soft_mask=self.soft_mask, + ) + return { + 'geodesic_distance': geo_dis, + 'geodesic_mask': geo_mask, + } + +class VertexGroupVoxelSkin(VertexGroup): + ''' + Capture voxel skin. + ''' + + def __init__(self, **kwargs): + self.grid = kwargs.get('grid', 64) + self.alpha = kwargs.get('alpha', 0.5) + self.link_dis = kwargs.get('link_dis', 0.00001) + self.grid_query = kwargs.get('grid_query', 27) + self.vertex_query = kwargs.get('vertex_query', 27) + self.grid_weight = kwargs.get('grid_weight', 3.0) + self.mode = kwargs.get('mode', 'square') + + def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]: + + # normalize into [-1, 1] first + min_vals = np.min(asset.vertices, axis=0) + max_vals = np.max(asset.vertices, axis=0) + + center = (min_vals + max_vals) / 2 + + scale = np.max(max_vals - min_vals) / 2 + + normalized_vertices = (asset.vertices - center) / scale + normalized_joints = (asset.joints - center) / scale + + grid_indices, grid_coords = voxelization( + vertices=normalized_vertices, + faces=asset.faces, + grid=self.grid, + ) + skin = voxel_skin( + grid=self.grid, + grid_coords=grid_coords, + joints=normalized_joints, + vertices=normalized_vertices, + faces=asset.faces, + alpha=self.alpha, + link_dis=self.link_dis, + grid_query=self.grid_query, + vertex_query=self.vertex_query, + grid_weight=self.grid_weight, + mode=self.mode, + ) + skin = np.nan_to_num(skin, nan=0., posinf=0., neginf=0.) + return { + 'voxel_skin': skin, + } + +class VertexGroupMeshPartDistance(VertexGroup): + def __init__(self, **kwargs): + self.part_dim = kwargs['part_dim'] + self.dis_dim = kwargs['dis_dim'] + + def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]: + tot, vertex_labels, face_labels = find_connected_components(asset.vertices, asset.faces) + # (N, dis_dim) + part_distances = compute_distances_in_components(asset.vertices, asset.faces, vertex_labels, tot, self.dis_dim) + # (tot, part_dim) + part_vectors = generate_spread_vectors(tot, self.part_dim) + # (N, part_dim) + part_vectors = np.zeros((asset.vertices.shape[0], self.part_dim)) + for i in range(tot): + part_vectors[labels == i] = part_vectors[i] + return { + 'num_parts': tot, + 'part_vectors': part_vectors, + 'part_distances': part_distances, + } + +# TODO: move this into a new file +class VertexGroupMeshParts(VertexGroup): + def __init__(self, **kwargs): + pass + + def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]: + tot, vertex_labels, face_labels = find_connected_components(asset.vertices, asset.faces) + asset.meta['num_parts'] = tot + asset.meta['vertex_labels'] = vertex_labels + asset.meta['face_labels'] = face_labels + return {} + +def get_geodesic_distance( + vertices: ndarray, # (N, 3) + joints: ndarray, # (J, 3) + tails: ndarray, # (J, 3) + dis_matrix: ndarray, # (J, J) + step_matrix: ndarray, # (J, J) + child: ndarray, + eps: float=1e-4, + soft_mask: bool=False, +) -> Tuple[ndarray, ndarray]: + # (J, 3) + offset = tails - joints + inv = (1./(offset * offset + eps).sum(axis=-1))[np.newaxis, ...] + # head + g0 = tails[np.newaxis, ...] - vertices[:, np.newaxis, :] + c0 = (g0 * offset[np.newaxis, ...]).sum(axis=-1) * inv + # tail + g1 = vertices[:, np.newaxis, :] - joints[np.newaxis, ...] + c1 = (g1 * offset[np.newaxis, ...]).sum(axis=-1) * inv + # (N, J) + scale0 = (np.clip(c0, 0., 1.) + eps) / (np.clip(c0, 0., 1.) + np.clip(c1, 0., 1.) + eps * 2) + scale1 = -scale0 + 1 + # (N, J, 3) + nearest = scale0[..., np.newaxis] * joints[np.newaxis, ...] + scale1[..., np.newaxis] * tails[np.newaxis, ...] + # (N, J) + dis = np.linalg.norm(vertices[:, np.newaxis, :] - nearest, axis=-1) + # (N) + index = np.argmin(dis, axis=1) + # (N) + r = np.arange(dis.shape[0]) + # (N, J) + res = ( + dis_matrix[index] * scale0[r[:, np.newaxis], index[:, np.newaxis]] + + dis_matrix[child[index]] * scale1[r[:, np.newaxis], index[:, np.newaxis]] + ) + if soft_mask: + mask = (1.0 - ( + step_matrix[index] * scale0[r[:, np.newaxis], index[:, np.newaxis]] + + step_matrix[child[index]] * scale1[r[:, np.newaxis], index[:, np.newaxis]] + )).clip(0., 1.).astype(np.float32) + else: + mask = (( + step_matrix[index] * scale0[r[:, np.newaxis], index[:, np.newaxis]] + + step_matrix[child[index]] * scale1[r[:, np.newaxis], index[:, np.newaxis]] + ) <= 1.).astype(np.float32) + + # normalize geo dis + row_min = np.min(res, axis=0, keepdims=True) + row_max = np.max(res, axis=0, keepdims=True) + res = (res - row_min) / (row_max - row_min) + res = np.nan_to_num(res, nan=0., posinf=0., neginf=0.) + return res, mask + +def get_vertex_groups(config: VertexGroupConfig) -> List[VertexGroup]: + vertex_groups = [] + MAP = { + 'geodesic_distance': VertexGroupGeodesicDistance, + 'skin': VertexGroupSkin, + 'voxel_skin': VertexGroupVoxelSkin, + 'mesh_part_distance': VertexGroupMeshPartDistance, + 'mesh_parts': VertexGroupMeshParts, + } + for name in config.names: + assert name in MAP, f"expect: [{','.join(MAP.keys())}], found: {name}" + vertex_groups.append(MAP[name](**config.kwargs.get(name, {}))) + return vertex_groups + +def voxelization( + vertices: ndarray, + faces: ndarray, + grid: int=256, + scale: float=1.0, +): + import pyrender + znear = 0.05 + zfar = 4.0 + eye_dis = 2.0 # distance from eye to origin + r_faces = np.stack([faces[:, 0], faces[:, 2], faces[:, 1]], axis=-1) + # get zbuffers + mesh = pyrender.Mesh( + primitives=[ + pyrender.Primitive( + positions=vertices, + indices=np.concatenate([faces, r_faces]), # double sided + mode=pyrender.GLTF.TRIANGLES, + ) + ] + ) + scene = pyrender.Scene(bg_color=[0, 0, 0, 0]) + scene.add(mesh) + + camera = pyrender.OrthographicCamera(xmag=scale, ymag=scale, znear=znear, zfar=zfar) + camera_poses = {} + # coordinate: + # see https://pyrender.readthedocs.io/en/latest/examples/cameras.html + camera_poses['+z'] = np.array([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, eye_dis], + [0, 0, 0, 1], + ], dtype=np.float32) # look at +z (bottom to top) + camera_poses['-z'] = np.array([ + [-1, 0, 0, 0], + [ 0, 1, 0, 0], + [ 0, 0,-1, -eye_dis], + [ 0, 0, 0, 1], + ], dtype=np.float32) # look at -z (top to bottom) + camera_poses['+y'] = np.array([ + [1, 0, 0, 0], + [0, 0,-1, -eye_dis], + [0, 1, 0, 0], + [0, 0, 0, 1], + ], dtype=np.float32) # look at +y (because model is looking at -y)(front to back) + camera_poses['-y'] = np.array([ + [1, 0, 0, 0], + [0, 0, 1, eye_dis], + [0,-1, 0, 0], + [0, 0, 0, 1], + ], dtype=np.float32) # look at -y (back to front) + camera_poses['+x'] = np.array([ + [0, 0,-1, -eye_dis], + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 1], + ], dtype=np.float32) # look at +x (left to right) + camera_poses['-x'] = np.array([ + [ 0, 0, 1, eye_dis], + [ 0, 1, 0, 0], + [-1, 0, 0, 0], + [ 0, 0, 0, 1], + ], dtype=np.float32) # look at -x (righy to left) + for name, pose in camera_poses.items(): + scene.add(camera, name=name, pose=pose) + camera_nodes = [node for node in scene.get_nodes() if isinstance(node, pyrender.Node) and node.camera is not None] + renderer = pyrender.OffscreenRenderer(viewport_width=grid, viewport_height=grid) + + i, j, k = np.indices((grid, grid, grid)) + grid_indices = np.stack((i.ravel(), j.ravel(), k.ravel()), axis=1, dtype=np.int64) + grid_coords = np.stack((i.ravel(), j.ravel(), grid-1-k.ravel()), axis=1, dtype=np.float32) * 2 / grid - 1.0 + 1.0 / grid # every position is in the middle of the grid + depths = {} + for cam_node in camera_nodes: + # a = time.time() + scene.main_camera_node = cam_node + name = cam_node.name + proj_depth = renderer.render(scene, flags=pyrender.constants.RenderFlags.DEPTH_ONLY | pyrender.constants.RenderFlags.OFFSCREEN) + proj_depth[proj_depth Tuple[int, ndarray]: + ''' + Find connected components of a mesh. + + Returns: + int: number of connected components + ndarray: labels of connected components + ''' + N = vertices.shape[0] + edges = [] + for face in faces: + v0, v1, v2 = face + edges.append([v0, v1]) + edges.append([v1, v2]) + edges.append([v2, v0]) + + edges = np.array(edges) + row = edges[:, 0] + col = edges[:, 1] + data = np.ones(len(edges), dtype=int) + adj_matrix = csr_matrix((data, (row, col)), shape=(N, N)) + adj_matrix = adj_matrix + adj_matrix.T + + tot, vertex_labels = connected_components(adj_matrix, directed=False, return_labels=True) + face_labels = vertex_labels[faces[:, 0]] + return tot, vertex_labels, face_labels + +def compute_distances_in_components(vertices: ndarray, faces: ndarray, vertex_labels: ndarray, tot: int, k: int) -> ndarray: + N = vertices.shape[0] + edges = [] + weights = [] + for face in faces: + v0, v1, v2 = face + w01 = np.linalg.norm(vertices[v0] - vertices[v1]) + w12 = np.linalg.norm(vertices[v1] - vertices[v2]) + w20 = np.linalg.norm(vertices[v2] - vertices[v0]) + edges.extend([[v0, v1], [v1, v2], [v2, v0]]) + weights.extend([w01, w12, w20]) + + edges = np.array(edges) + weights = np.array(weights) + row = edges[:, 0] + col = edges[:, 1] + adj_matrix = csr_matrix((weights, (row, col)), shape=(N, N)) + adj_matrix = adj_matrix + adj_matrix.T + + distance_matrix = np.full((N, k), np.inf) # (N, k) + + for component_id in range(tot): + component_mask = (vertex_labels == component_id) + component_vertices_idx = np.where(component_mask)[0] + n_component = len(component_vertices_idx) + + if n_component == 0: + continue + + if n_component >= k: + sampled_indices = np.random.permutation(n_component)[:k] + else: + sampled_indices = np.concatenate([ + np.random.permutation(n_component), + np.random.randint(0, n_component, k - n_component) + ]) + sampled_vertices = component_vertices_idx[sampled_indices] + + dist_matrix = shortest_path(adj_matrix, indices=sampled_vertices, directed=False) + dist_matrix = dist_matrix[:, component_mask].T + # normalize into [0, 1] + max_value = dist_matrix.max() + min_value = dist_matrix.min() + if max_value < min_value + 1e-6: + dist_matrix[...] = 0. + else: + dist_matrix = (dist_matrix - min_value) / (max_value - min_value) + + distance_matrix[component_mask, :] = dist_matrix + + return distance_matrix + +def generate_spread_vectors(tot: int, dim: int, iterations: int=100, lr: float=1.0) -> ndarray: + if tot <= 0: + return None + + vectors = np.random.randn(tot, dim) + vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True) + vectors = np.nan_to_num(vectors, nan=1.0, posinf=1.0, neginf=1.0) + + for _ in range(iterations): + diff = vectors[np.newaxis, :, :] - vectors[:, np.newaxis, :] + norm_sq = np.sum(diff ** 2, axis=2) + weight = 1. / (norm_sq + 1.) + vectors += np.sum(diff * weight[:, :, np.newaxis] * lr, axis=1) + vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True) + + return vectors diff --git a/UniRig/src/inference/download.py b/UniRig/src/inference/download.py new file mode 100644 index 0000000000000000000000000000000000000000..febd2bbe62edad62de66f09686605574623fb623 --- /dev/null +++ b/UniRig/src/inference/download.py @@ -0,0 +1,19 @@ +from huggingface_hub import hf_hub_download + +def download(ckpt_name: str) -> str: + MAP = { + 'experiments/skeleton/articulation-xl_quantization_256/model.ckpt': 'skeleton/articulation-xl_quantization_256/model.ckpt', + 'experiments/skin/articulation-xl/model.ckpt': 'skin/articulation-xl/model.ckpt', + } + + try: + if ckpt_name not in MAP: + print(f"not found: {ckpt_name}") + return ckpt_name + return hf_hub_download( + repo_id='VAST-AI/UniRig', + filename=MAP[ckpt_name], + ) + except Exception as e: + print(f"Failed to download {ckpt_name}: {e}") + return ckpt_name \ No newline at end of file diff --git a/UniRig/src/inference/get_list.py b/UniRig/src/inference/get_list.py new file mode 100644 index 0000000000000000000000000000000000000000..3286e91a6ae6da60085e4326e1313ef6bd09c37d --- /dev/null +++ b/UniRig/src/inference/get_list.py @@ -0,0 +1,25 @@ +import os +import argparse +from tqdm import tqdm +from box import Box +import yaml + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str) + args = parser.parse_args() + + config = Box(yaml.safe_load(open(args.config, "r"))) + + dataset = config.output_dataset_dir + + paths = [] + for root, dirs, files in tqdm(os.walk(dataset)): + for file in files: + if file == 'raw_data.npz': + paths.append(os.path.relpath(root, dataset)) + tot = len(paths) + os.makedirs(dataset, exist_ok=True) + f = open(os.path.join(dataset, f"inference_datalist.txt"), 'w') + f.writelines('\n'.join(paths)) + f.close() \ No newline at end of file diff --git a/UniRig/src/inference/merge.py b/UniRig/src/inference/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..0b3613cc29e4e246563e9408cc6ce2d1fed614a5 --- /dev/null +++ b/UniRig/src/inference/merge.py @@ -0,0 +1,527 @@ +''' +inject the result in res.npz into model.vrm and exports as res_textured.vrm +''' +import argparse +import yaml +import os +import numpy as np +from numpy import ndarray + +from typing import Tuple, Union, List + +import argparse +from tqdm import tqdm +from box import Box + +from scipy.spatial import cKDTree + +import open3d as o3d +import itertools + +import bpy +from mathutils import Vector + +from ..data.raw_data import RawData, RawSkin +from ..data.extract import process_mesh, process_armature, get_arranged_bones + +def parser(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str) + parser.add_argument('--num_runs', type=int) + parser.add_argument('--id', type=int) + return parser.parse_args() + +def clean_bpy(): + for c in bpy.data.actions: + bpy.data.actions.remove(c) + for c in bpy.data.armatures: + bpy.data.armatures.remove(c) + for c in bpy.data.cameras: + bpy.data.cameras.remove(c) + for c in bpy.data.collections: + bpy.data.collections.remove(c) + for c in bpy.data.images: + bpy.data.images.remove(c) + for c in bpy.data.materials: + bpy.data.materials.remove(c) + for c in bpy.data.meshes: + bpy.data.meshes.remove(c) + for c in bpy.data.objects: + bpy.data.objects.remove(c) + for c in bpy.data.textures: + bpy.data.textures.remove(c) + +def load(filepath: str, return_armature: bool=False): + if return_armature: + old_objs = set(bpy.context.scene.objects) + + if not os.path.exists(filepath): + raise ValueError(f'File {filepath} does not exist !') + try: + if filepath.endswith(".vrm"): + # enable vrm addon and load vrm model + bpy.ops.preferences.addon_enable(module='vrm') + + bpy.ops.import_scene.vrm( + filepath=filepath, + use_addon_preferences=True, + extract_textures_into_folder=False, + make_new_texture_folder=False, + set_shading_type_to_material_on_import=False, + set_view_transform_to_standard_on_import=True, + set_armature_display_to_wire=True, + set_armature_display_to_show_in_front=True, + set_armature_bone_shape_to_default=True, + disable_bake=True, # customized option for better performance + ) + elif filepath.endswith(".fbx") or filepath.endswith(".FBX"): + bpy.ops.import_scene.fbx(filepath=filepath, ignore_leaf_bones=False, use_image_search=False) + elif filepath.endswith(".glb") or filepath.endswith(".gltf"): + bpy.ops.import_scene.gltf(filepath=filepath, import_pack_images=False) + elif filepath.endswith(".dae"): + bpy.ops.wm.collada_import(filepath=filepath) + elif filepath.endswith(".blend"): + with bpy.data.libraries.load(filepath) as (data_from, data_to): + data_to.objects = data_from.objects + for obj in data_to.objects: + if obj is not None: + bpy.context.collection.objects.link(obj) + else: + raise ValueError(f"not suported type {filepath}") + except: + raise ValueError(f"failed to load {filepath}") + if return_armature: + armature = [x for x in set(bpy.context.scene.objects)-old_objs if x.type=="ARMATURE"] + if len(armature)==0: + return None + if len(armature)>1: + raise ValueError(f"multiple armatures found") + armature = armature[0] + + armature.select_set(True) + bpy.context.view_layer.objects.active = armature + bpy.ops.object.mode_set(mode='EDIT') + for bone in bpy.data.armatures[0].edit_bones: + bone.roll = 0. # change all roll to 0. to prevent weird behaviour + + bpy.ops.object.mode_set(mode='OBJECT') + armature.select_set(False) + + bpy.ops.object.select_all(action='DESELECT') + return armature + +def get_skin(arranged_bones): + meshes = [] + for v in bpy.data.objects: + if v.type == 'MESH': + meshes.append(v) + index = {} + for (id, pbone) in enumerate(arranged_bones): + index[pbone.name] = id + _dict_skin = {} + total_bones = len(arranged_bones) + for obj in meshes: + total_vertices = len(obj.data.vertices) + skin_weight = np.zeros((total_vertices, total_bones)) + obj_group_names = [g.name for g in obj.vertex_groups] + obj_verts = obj.data.vertices + for bone in arranged_bones: + if bone.name not in obj_group_names: + continue + + gidx = obj.vertex_groups[bone.name].index + bone_verts = [v for v in obj_verts if gidx in [g.group for g in v.groups]] + for v in bone_verts: + which = [id for id in range(len(v.groups)) if v.groups[id].group==gidx] + w = v.groups[which[0]].weight + skin_weight[v.index, index[bone.name]] = w + _dict_skin[obj.name] = { + 'skin': skin_weight, + } + + skin = np.concatenate([ + _dict_skin[d]['skin'] for d in _dict_skin + ], axis=0) + return skin + +def axis(a: np.ndarray): + b = np.concatenate([-a[:, 0:1], -a[:, 1:2], a[:, 2:3]], axis=1) + return b + +def get_correct_orientation_kdtree(a: np.ndarray, b: np.ndarray, bones: np.ndarray, num: int=16384) -> np.ndarray: + ''' + a: sampled_vertiecs + b: mesh_vertices + ''' + min_loss = float('inf') + best_transformed = a.copy() + axis_permutations = list(itertools.permutations([0, 1, 2])) + sign_combinations = [(x, y, z) for x in [1, -1] + for y in [1, -1] + for z in [1, -1]] + _bones = bones.copy() + for perm in axis_permutations: + permuted_a = a[np.random.permutation(a.shape[0])[:num]][:, perm] + for signs in sign_combinations: + transformed = permuted_a * np.array(signs) + tree = cKDTree(transformed) + distances, indices = tree.query(b) + current_loss = distances.mean() + if current_loss < min_loss: # prevent from mirroring + min_loss = current_loss + best_transformed = a[:, perm] * np.array(signs) + bones[:, :3] = _bones[:, :3][:, perm] * np.array(signs) + bones[:, 3:] = _bones[:, 3:][:, perm] * np.array(signs) + + return best_transformed, bones + +def denormalize_vertices(mesh_vertices: ndarray, vertices: ndarray, bones: ndarray) -> np.ndarray: + min_vals = np.min(mesh_vertices, axis=0) + max_vals = np.max(mesh_vertices, axis=0) + center = (min_vals + max_vals) / 2 + scale = np.max(max_vals - min_vals) / 2 + denormalized_vertices = vertices * scale + center + denormalized_bones = bones * scale + denormalized_bones[:, :3] += center + denormalized_bones[:, 3:] += center + + return denormalized_vertices, denormalized_bones + +def make_armature( + vertices: ndarray, + bones: ndarray, # (joint, tail) + parents: list[Union[int, None]], + names: list[str], + skin: ndarray, + group_per_vertex: int=4, + add_root: bool=False, + is_vrm: bool=False, +): + context = bpy.context + + mesh_vertices = [] + for ob in bpy.data.objects: + if ob.type != 'MESH': + continue + m = np.array(ob.matrix_world) + matrix_world_rot = m[:3, :3] + matrix_world_bias = m[:3, 3] + for v in ob.data.vertices: + mesh_vertices.append(matrix_world_rot @ np.array(v.co) + matrix_world_bias) + + mesh_vertices = np.stack(mesh_vertices) + vertices, bones = denormalize_vertices(mesh_vertices, vertices, bones) + + bpy.ops.object.add(type="ARMATURE", location=(0, 0, 0)) + armature = context.object + if hasattr(armature.data, 'vrm_addon_extension'): + armature.data.vrm_addon_extension.spec_version = "1.0" + humanoid = armature.data.vrm_addon_extension.vrm1.humanoid + is_vrm = True + bpy.ops.object.mode_set(mode="EDIT") + edit_bones = armature.data.edit_bones + if add_root: + bone_root = edit_bones.new('Root') + bone_root.name = 'Root' + bone_root.head = (0., 0., 0.) + bone_root.tail = (bones[0, 0], bones[0, 1], bones[0, 2]) + + J = len(names) + def extrude_bone( + name: Union[None, str], + parent_name: Union[None, str], + head: Tuple[float, float, float], + tail: Tuple[float, float, float], + ): + bone = edit_bones.new(name) + bone.head = (head[0], head[1], head[2]) + bone.tail = (tail[0], tail[1], tail[2]) + bone.name = name + if parent_name is None: + return + parent_bone = edit_bones.get(parent_name) + bone.parent = parent_bone + bone.use_connect = False # always False currently + + vertices, bones = get_correct_orientation_kdtree(vertices, mesh_vertices, bones) + for i in range(J): + if add_root: + pname = 'Root' if parents[i] is None else names[parents[i]] + else: + pname = None if parents[i] is None else names[parents[i]] + extrude_bone(names[i], pname, bones[i, :3], bones[i, 3:]) + + # must set to object mode to enable parent_set + bpy.ops.object.mode_set(mode='OBJECT') + objects = bpy.data.objects + for o in bpy.context.selected_objects: + o.select_set(False) + + argsorted = np.argsort(-skin, axis=1) + vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted] + vertex_group_reweight = vertex_group_reweight / vertex_group_reweight[..., :group_per_vertex].sum(axis=1)[...,None] + vertex_group_reweight = np.nan_to_num(vertex_group_reweight) + tree = cKDTree(vertices) + for ob in objects: + if ob.type != 'MESH': + continue + ob.select_set(True) + armature.select_set(True) + bpy.ops.object.parent_set(type='ARMATURE_NAME') + vis = [] + for x in ob.vertex_groups: + vis.append(x.name) + + n_vertices = [] + m = np.array(ob.matrix_world) + matrix_world_rot = m[:3, :3] + matrix_world_bias = m[:3, 3] + for v in ob.data.vertices: + n_vertices.append(matrix_world_rot @ np.array(v.co) + matrix_world_bias) + n_vertices = np.stack(n_vertices) + + _, index = tree.query(n_vertices) + + for v, co in enumerate(tqdm(n_vertices)): + for ii in range(group_per_vertex): + i = argsorted[index[v], ii] + if i >= len(names): + continue + n = names[i] + if n not in ob.vertex_groups: + continue + + ob.vertex_groups[n].add([v], vertex_group_reweight[index[v], ii], 'REPLACE') + armature.select_set(False) + ob.select_set(False) + + # set vrm bones link + if is_vrm: + armature.data.vrm_addon_extension.spec_version = "1.0" + humanoid.human_bones.hips.node.bone_name = "J_Bip_C_Hips" + humanoid.human_bones.spine.node.bone_name = "J_Bip_C_Spine" + + humanoid.human_bones.chest.node.bone_name = "J_Bip_C_Chest" + humanoid.human_bones.neck.node.bone_name = "J_Bip_C_Neck" + humanoid.human_bones.head.node.bone_name = "J_Bip_C_Head" + humanoid.human_bones.left_upper_leg.node.bone_name = "J_Bip_L_UpperLeg" + humanoid.human_bones.left_lower_leg.node.bone_name = "J_Bip_L_LowerLeg" + humanoid.human_bones.left_foot.node.bone_name = "J_Bip_L_Foot" + humanoid.human_bones.right_upper_leg.node.bone_name = "J_Bip_R_UpperLeg" + humanoid.human_bones.right_lower_leg.node.bone_name = "J_Bip_R_LowerLeg" + humanoid.human_bones.right_foot.node.bone_name = "J_Bip_R_Foot" + humanoid.human_bones.left_upper_arm.node.bone_name = "J_Bip_L_UpperArm" + humanoid.human_bones.left_lower_arm.node.bone_name = "J_Bip_L_LowerArm" + humanoid.human_bones.left_hand.node.bone_name = "J_Bip_L_Hand" + humanoid.human_bones.right_upper_arm.node.bone_name = "J_Bip_R_UpperArm" + humanoid.human_bones.right_lower_arm.node.bone_name = "J_Bip_R_LowerArm" + humanoid.human_bones.right_hand.node.bone_name = "J_Bip_R_Hand" + + bpy.ops.vrm.assign_vrm1_humanoid_human_bones_automatically(armature_name="Armature") + +def merge( + path: str, + output_path: str, + vertices: ndarray, + joints: ndarray, + skin: ndarray, + parents: List[Union[None, int]], + names: List[str], + tails: ndarray, + add_root: bool=False, + is_vrm: bool=False, +): + ''' + Merge skin and bone into original file. + ''' + clean_bpy() + try: + load(path) + except Exception as e: + print(f"Failed to load {path}: {e}") + return + for c in bpy.data.armatures: + bpy.data.armatures.remove(c) + + bones = np.concatenate([joints, tails], axis=1) + # if the result is weired, orientation may be wrong + make_armature( + vertices=vertices, + bones=bones, + parents=parents, + names=names, + skin=skin, + group_per_vertex=4, + add_root=add_root, + is_vrm=is_vrm, + ) + + dirpath = os.path.dirname(output_path) + if dirpath != '': + os.makedirs(dirpath, exist_ok=True) + try: + if is_vrm: + bpy.ops.export_scene.vrm(filepath=output_path) + elif output_path.endswith(".fbx") or output_path.endswith(".FBX"): + bpy.ops.export_scene.fbx(filepath=output_path, add_leaf_bones=True) + elif output_path.endswith(".glb") or output_path.endswith(".gltf"): + bpy.ops.export_scene.gltf(filepath=output_path) + elif output_path.endswith(".dae"): + bpy.ops.wm.collada_export(filepath=output_path) + elif output_path.endswith(".blend"): + with bpy.data.libraries.load(output_path) as (data_from, data_to): + data_to.objects = data_from.objects + else: + raise ValueError(f"not suported type {output_path}") + except: + raise ValueError(f"failed to export {output_path}") + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def nullable_string(val): + if not val: + return None + return val + +def parse(): + parser = argparse.ArgumentParser() + parser.add_argument('--require_suffix', type=str, required=True) + parser.add_argument('--num_runs', type=int, required=True) + parser.add_argument('--id', type=int, required=True) + parser.add_argument('--data_config', type=str, required=False) + parser.add_argument('--skeleton_config', type=str, required=False) + parser.add_argument('--skin_config', type=str, required=False) + parser.add_argument('--merge_dir', type=str, required=False) + parser.add_argument('--merge_name', type=str, required=False) + parser.add_argument('--add_root', type=str2bool, required=False, default=False) + parser.add_argument('--source', type=nullable_string, required=False, default=None) + parser.add_argument('--target', type=nullable_string, required=False, default=None) + parser.add_argument('--output', type=nullable_string, required=False, default=None) + return parser.parse_args() + +def transfer(source: str, target: str, output: str, add_root: bool=False): + try: + armature = load(filepath=source, return_armature=True) + assert armature is not None + except Exception as e: + print(f"failed to load {source}") + return + vertices, faces = process_mesh() + arranged_bones = get_arranged_bones(armature) + skin = get_skin(arranged_bones) + joints, tails, parents, names, matrix_local = process_armature(armature, arranged_bones) + merge( + path=target, + output_path=output, + vertices=vertices, + joints=joints, + skin=skin, + parents=parents, + names=names, + tails=tails, + add_root=add_root, + ) + +if __name__ == "__main__": + args = parse() + + if args.source is not None or args.target is not None: + assert args.source is not None and args.target is not None + transfer(args.source, args.target, args.output, args.add_root) + exit() + + data_config = Box(yaml.safe_load(open(args.data_config, "r"))) + skeleton_config = Box(yaml.safe_load(open(args.skeleton_config, "r"))) + skin_config = Box(yaml.safe_load(open(args.skin_config, "r"))) + + num_runs = args.num_runs + id = args.id + require_suffix = args.require_suffix.split(',') + merge_dir = args.merge_dir + merge_name = args.merge_name + add_root = args.add_root + + input_dataset_dir = data_config.input_dataset_dir + dataset_name = data_config.output_dataset_dir + + skin_output_dataset_dir = skin_config.writer.output_dir + skin_name = skin_config.writer.export_npz + + skeleton_output_dataset_dir = skeleton_config.writer.output_dir + skeleton_name = skeleton_config.writer.export_npz + + def make_path(output_dataset_dir, dataset_name, root, file_name): + if output_dataset_dir is None: + return os.path.join( + dataset_name, + os.path.relpath(root, input_dataset_dir), + file_name, + ) + return os.path.join( + output_dataset_dir, + dataset_name, + os.path.relpath(root, input_dataset_dir), + file_name, + ) + + files = [] + for root, dirs, f in os.walk(input_dataset_dir): + for file in f: + if file.split('.')[-1] in require_suffix: + file_name = file.removeprefix("./") + suffix = file.split('.')[-1] + # remove suffix + file_name = '.'.join(file_name.split('.')[:-1]) + + skin_path = make_path(skin_output_dataset_dir, dataset_name, root, os.path.join(file_name, skin_name+'.npz')) + skeleton_path = make_path(skeleton_output_dataset_dir, dataset_name, root, os.path.join(file_name, skeleton_name+'.npz')) + merge_path = make_path(merge_dir, dataset_name, root, os.path.join(file_name, merge_name+"."+suffix)) + + # check if inference result exists + if os.path.exists(skin_path) and os.path.exists(skeleton_path): + files.append((os.path.join(root, file), skin_path, skeleton_path, merge_path)) + + num_files = len(files) + print("num_files", num_files) + gap = num_files // num_runs + start = gap * id + end = gap * (id + 1) + if id+1==num_runs: + end = num_files + + files = sorted(files) + if end!=-1: + files = files[:end] + tot = 0 + for file in tqdm(files[start:]): + origin_file = file[0] + skin_path = file[1] + skeleton_path = file[2] + merge_file = file[3] + + raw_skin = RawSkin.load(path=skin_path) + raw_data = RawData.load(path=skeleton_path) + + try: + merge( + path=origin_file, + output_path=merge_file, + vertices=raw_skin.vertices, + joints=raw_skin.joints, + skin=raw_skin.skin, + parents=raw_data.parents, + names=raw_data.names, + tails=raw_data.tails, + add_root=add_root, + is_vrm=(raw_data.cls=='vroid'), + ) + except Exception as e: + print(f"failed to merge {origin_file}: {e}") \ No newline at end of file diff --git a/UniRig/src/model/__init__.py b/UniRig/src/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/UniRig/src/model/michelangelo/LICENSE b/UniRig/src/model/michelangelo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e72bfddabc15be5718a7cc061ac10e47741d8219 --- /dev/null +++ b/UniRig/src/model/michelangelo/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. \ No newline at end of file diff --git a/UniRig/src/model/michelangelo/__init__.py b/UniRig/src/model/michelangelo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16d324c6fbb109889666d97ec8aaf96dffbb802b --- /dev/null +++ b/UniRig/src/model/michelangelo/__init__.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# +# This file is part of UniRig. +# +# This file is derived from https://github.com/NeuralCarver/Michelangelo +# +# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors +# Copyright (c) 2025 VAST-AI-Research and contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . \ No newline at end of file diff --git a/UniRig/src/model/michelangelo/get_model.py b/UniRig/src/model/michelangelo/get_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5b4c5197b2d438c32a2b487ea9b87303bc2dd0b6 --- /dev/null +++ b/UniRig/src/model/michelangelo/get_model.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# +# This file is part of UniRig. +# +# This file is derived from https://github.com/NeuralCarver/Michelangelo +# +# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors +# Copyright (c) 2025 VAST-AI-Research and contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import torch + +from .models.tsal.sal_perceiver import AlignedShapeLatentPerceiver, ShapeAsLatentPerceiverEncoder + +def get_encoder( + pretrained_path: str=None, + freeze_decoder: bool=False, + **kwargs +) -> AlignedShapeLatentPerceiver: + model = AlignedShapeLatentPerceiver(**kwargs) + if pretrained_path is not None: + state_dict = torch.load(pretrained_path, weights_only=True) + model.load_state_dict(state_dict) + if freeze_decoder: + model.geo_decoder.requires_grad_(False) + model.encoder.query.requires_grad_(False) + model.pre_kl.requires_grad_(False) + model.post_kl.requires_grad_(False) + model.transformer.requires_grad_(False) + return model + +def get_encoder_simplified( + pretrained_path: str=None, + **kwargs +) -> ShapeAsLatentPerceiverEncoder: + model = ShapeAsLatentPerceiverEncoder(**kwargs) + if pretrained_path is not None: + state_dict = torch.load(pretrained_path, weights_only=True) + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/UniRig/src/model/michelangelo/models/__init__.py b/UniRig/src/model/michelangelo/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16d324c6fbb109889666d97ec8aaf96dffbb802b --- /dev/null +++ b/UniRig/src/model/michelangelo/models/__init__.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# +# This file is part of UniRig. +# +# This file is derived from https://github.com/NeuralCarver/Michelangelo +# +# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors +# Copyright (c) 2025 VAST-AI-Research and contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . \ No newline at end of file diff --git a/UniRig/src/model/michelangelo/models/modules/__init__.py b/UniRig/src/model/michelangelo/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..deb22f87c60542124ec634ddcecc36c9699e5207 --- /dev/null +++ b/UniRig/src/model/michelangelo/models/modules/__init__.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +# +# This file is part of UniRig. +# +# This file is derived from https://github.com/NeuralCarver/Michelangelo +# +# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors +# Copyright (c) 2025 VAST-AI-Research and contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from .checkpoint import checkpoint diff --git a/UniRig/src/model/michelangelo/models/modules/checkpoint.py b/UniRig/src/model/michelangelo/models/modules/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5d51f05e5c7b2818fb8cd0e207bfa2f260415f --- /dev/null +++ b/UniRig/src/model/michelangelo/models/modules/checkpoint.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# +# This file is part of UniRig. +# +# This file is derived from https://github.com/NeuralCarver/Michelangelo +# +# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors +# Copyright (c) 2025 VAST-AI-Research and contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 +""" + +import torch +from typing import Callable, Iterable, Sequence, Union +from packaging import version + + +def checkpoint( + func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], + inputs: Sequence[torch.Tensor], + params: Iterable[torch.Tensor], + flag: bool, + use_deepspeed: bool = False +): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + :param use_deepspeed: if True, use deepspeed + """ + if flag: + if use_deepspeed: + import deepspeed + return deepspeed.checkpointing.checkpoint(func, *inputs) + + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def _get_fwd_decorator(): + if version.parse(torch.__version__) >= version.parse('2.5.0'): + return torch.amp.custom_fwd(device_type='cuda') + else: + return torch.cuda.amp.custom_fwd() + + @staticmethod + def _get_bwd_decorator(): + if version.parse(torch.__version__) >= version.parse('2.5.0'): + return torch.amp.custom_bwd(device_type='cuda') + else: + def custom_bwd(bwd): + return torch.cuda.amp.custom_bwd(bwd=bwd) + return custom_bwd + + @staticmethod + @_get_fwd_decorator() + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + @_get_bwd_decorator() + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/UniRig/src/model/michelangelo/models/modules/embedder.py b/UniRig/src/model/michelangelo/models/modules/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..6fd01d55c79cb5bf3e357f2ac5bb45d465884195 --- /dev/null +++ b/UniRig/src/model/michelangelo/models/modules/embedder.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- +# +# This file is part of UniRig. +# +# This file is derived from https://github.com/NeuralCarver/Michelangelo +# +# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors +# Copyright (c) 2025 VAST-AI-Research and contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import numpy as np +import torch +import torch.nn as nn +import math + +VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class LearnedFourierEmbedder(nn.Module): + """ following @crowsonkb "s lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, in_channels, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + per_channel_dim = half_dim // in_channels + self.weights = nn.Parameter(torch.randn(per_channel_dim)) + + def forward(self, x): + """ + + Args: + x (torch.FloatTensor): [..., c] + + Returns: + x (torch.FloatTensor): [..., d] + """ + + # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] + freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) + fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) + return fouriered + + +class TriplaneLearnedFourierEmbedder(nn.Module): + def __init__(self, in_channels, dim): + super().__init__() + + self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + + self.out_dim = in_channels + dim + + def forward(self, x): + + yz_embed = self.yz_plane_embedder(x) + xz_embed = self.xz_plane_embedder(x) + xy_embed = self.xy_plane_embedder(x) + + embed = yz_embed + xz_embed + xy_embed + + return embed + + +def sequential_pos_embed(num_len, embed_dim): + assert embed_dim % 2 == 0 + + pos = torch.arange(num_len, dtype=torch.float32) + omega = torch.arange(embed_dim // 2, dtype=torch.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + + return embeddings + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].to(timesteps.dtype) * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, + num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, + log2_hashmap_size=19, desired_resolution=None): + if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): + return nn.Identity(), input_dim + + elif embed_type == "fourier": + embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, + logspace=True, include_input=True) + return embedder_obj, embedder_obj.out_dim + + elif embed_type == "hashgrid": + raise NotImplementedError + + elif embed_type == "sphere_harmonic": + raise NotImplementedError + + else: + raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") diff --git a/UniRig/src/model/michelangelo/models/modules/transformer_blocks.py b/UniRig/src/model/michelangelo/models/modules/transformer_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..4df460c7dac0ae68dd675b5504e0bfeeee2e6026 --- /dev/null +++ b/UniRig/src/model/michelangelo/models/modules/transformer_blocks.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# +# This file is part of UniRig. +# +# This file is derived from https://github.com/NeuralCarver/Michelangelo +# +# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors +# Copyright (c) 2025 VAST-AI-Research and contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional +import os + +from .checkpoint import checkpoint + +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + +def flash_attention(q, k, v): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + out = F.scaled_dot_product_attention(q, k, v) + out = out.transpose(1, 2) + # print("use flash atten 2") + + return out + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool, + flash: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash) + init_linear(self.c_qkv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), False) + x = self.c_proj(x) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + self.flash = flash + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + if self.flash: + out = flash_attention(q, k, v) + out = out.reshape(out.shape[0], out.shape[1], -1) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float = 1.0, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) + + def _forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + def forward(self, x: torch.Tensor): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool = True, + flash: bool = False, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadCrossAttention( + device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash + ) + init_linear(self.c_q, init_scale) + init_linear(self.c_kv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), False) + x = self.c_proj(x) + return x + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, + flash: bool = False, n_data: Optional[int] = None): + + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_data = n_data + self.flash = flash + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + if self.flash: + out = flash_attention(q, k, v) + out = out.reshape(out.shape[0], out.shape[1], -1) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_data: Optional[int] = None, + width: int, + heads: int, + data_width: Optional[int] = None, + mlp_width_scale: int = 4, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, hidden_width_scale=mlp_width_scale, init_scale=init_scale) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class MLP(nn.Module): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + width: int, + hidden_width_scale: int = 4, + init_scale: float): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * hidden_width_scale, device=device, dtype=dtype) + self.c_proj = nn.Linear(width * hidden_width_scale, width, device=device, dtype=dtype) + self.gelu = nn.GELU() + init_linear(self.c_fc, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class Transformer(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x diff --git a/UniRig/src/model/michelangelo/models/tsal/__init__.py b/UniRig/src/model/michelangelo/models/tsal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d3779a628a7c29f7c0177a252a3b41ef3d6c82c --- /dev/null +++ b/UniRig/src/model/michelangelo/models/tsal/__init__.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# +# This file is part of UniRig. +# +# This file is derived from https://github.com/NeuralCarver/Michelangelo +# +# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors +# Copyright (c) 2025 VAST-AI-Research and contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . diff --git a/UniRig/src/model/michelangelo/models/tsal/sal_perceiver.py b/UniRig/src/model/michelangelo/models/tsal/sal_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..3300fd07c5207cd2fae713942ec154e5d4ff9325 --- /dev/null +++ b/UniRig/src/model/michelangelo/models/tsal/sal_perceiver.py @@ -0,0 +1,621 @@ +# -*- coding: utf-8 -*- +# +# This file is part of UniRig. +# +# This file is derived from https://github.com/NeuralCarver/Michelangelo +# +# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors +# Copyright (c) 2025 VAST-AI-Research and contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import torch +import torch.nn as nn +from typing import Optional, Union +from einops import repeat +import math +from torch_cluster import fps +import random +import time +import numpy as np + +from ..modules import checkpoint +from ..modules.embedder import FourierEmbedder +from ..modules.transformer_blocks import ( + ResidualCrossAttentionBlock, + Transformer +) + +from .tsal_base import ShapeAsLatentModule + + +class CrossAttentionEncoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + fourier_embedder: FourierEmbedder, + point_feats: int, + width: int, + heads: int, + layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False, + query_method: bool = False, + use_full_input: bool = True, + token_num: int = 256, + no_query: bool=False): + + super().__init__() + + self.query_method = query_method + self.token_num = token_num + self.use_full_input = use_full_input + + self.use_checkpoint = use_checkpoint + self.num_latents = num_latents + + if no_query: + self.query = None + else: + self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) + + self.fourier_embedder = fourier_embedder + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) + self.cross_attn = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + + self.self_attn = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=False + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) + else: + self.ln_post = None + + def _forward(self, pc, feats): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + + """ + if self.query_method: + token_num = self.num_latents + bs = pc.shape[0] + data = self.fourier_embedder(pc) + if feats is not None: + data = torch.cat([data, feats], dim=-1) + data = self.input_proj(data) + + query = repeat(self.query, "m c -> b m c", b=bs) + + latents = self.cross_attn(query, data) + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + pre_pc = None + else: + + if isinstance(self.token_num, int): + token_num = self.token_num + else: + token_num = random.choice(self.token_num) + + if self.training: + rng = np.random.default_rng() + else: + rng = np.random.default_rng(seed=0) + ind = rng.choice(pc.shape[1], token_num * 4, replace=token_num * 4 > pc.shape[1]) + + pre_pc = pc[:,ind,:] + pre_feats = feats[:,ind,:] + + + B, N, D = pre_pc.shape + C = pre_feats.shape[-1] + ###### fps + pos = pre_pc.view(B*N, D) + pos_feats = pre_feats.view(B*N, C) + batch = torch.arange(B).to(pc.device) + batch = torch.repeat_interleave(batch, N) + + idx = fps(pos, batch, ratio=1. / 4, random_start=self.training) + + sampled_pc = pos[idx] + sampled_pc = sampled_pc.view(B, -1, 3) + + sampled_feats = pos_feats[idx] + sampled_feats = sampled_feats.view(B, -1, C) + + ###### + if self.use_full_input: + data = self.fourier_embedder(pc) + else: + data = self.fourier_embedder(pre_pc) + + if feats is not None: + if not self.use_full_input: + feats = pre_feats + data = torch.cat([data, feats], dim=-1) + data = self.input_proj(data) + + sampled_data = self.fourier_embedder(sampled_pc) + if feats is not None: + sampled_data = torch.cat([sampled_data, sampled_feats], dim=-1) + sampled_data = self.input_proj(sampled_data) + + latents = self.cross_attn(sampled_data, data) + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + pre_pc = torch.cat([pre_pc, pre_feats], dim=-1) + + return latents, pc, token_num, pre_pc + + def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + dict + """ + + return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) + + +class CrossAttentionDecoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False, + mlp_width_scale: int = 4, + supervision_type: str = 'occupancy'): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.fourier_embedder = fourier_embedder + self.supervision_type = supervision_type + + self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) + + self.cross_attn_decoder = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + n_data=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + mlp_width_scale=mlp_width_scale, + ) + + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) + if self.supervision_type == 'occupancy-sdf': + self.output_proj_sdf = nn.Linear(width, out_channels, device=device, dtype=dtype) + + + + def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + if next(self.query_proj.parameters()).dtype == torch.float16: + queries = queries.half() + latents = latents.half() + # print(f"queries: {queries.dtype}, {queries.device}") + # print(f"latents: {latents.dtype}, {latents.device}"z) + queries = self.query_proj(self.fourier_embedder(queries)) + x = self.cross_attn_decoder(queries, latents) + x = self.ln_post(x) + x_1 = self.output_proj(x) + if self.supervision_type == 'occupancy-sdf': + x_2 = self.output_proj_sdf(x) + return x_1, x_2 + else: + return x_1 + + def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) + + +class ShapeAsLatentPerceiver(ShapeAsLatentModule): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + decoder_width: Optional[int] = None, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False, + supervision_type: str = 'occupancy', + query_method: bool = False, + token_num: int = 256, + grad_type: str = "numerical", + grad_interval: float = 0.005, + use_full_input: bool = True, + freeze_encoder: bool = False, + decoder_mlp_width_scale: int = 4, + residual_kl: bool = False, + ): + + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.num_latents = num_latents + assert grad_type in ["numerical", "analytical"] + self.grad_type = grad_type + self.grad_interval = grad_interval + self.supervision_type = supervision_type + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + init_scale = init_scale * math.sqrt(1.0 / width) + self.encoder = CrossAttentionEncoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + num_latents=num_latents, + point_feats=point_feats, + width=width, + heads=heads, + layers=num_encoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint, + query_method=query_method, + use_full_input=use_full_input, + token_num=token_num + ) + + self.embed_dim = embed_dim + self.residual_kl = residual_kl + if decoder_width is None: + decoder_width = width + if embed_dim > 0: + # VAE embed + self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) + self.post_kl = nn.Linear(embed_dim, decoder_width, device=device, dtype=dtype) + self.latent_shape = (num_latents, embed_dim) + if self.residual_kl: + assert self.post_kl.out_features % self.post_kl.in_features == 0 + assert self.pre_kl.in_features % self.pre_kl.out_features == 0 + else: + self.latent_shape = (num_latents, width) + + self.transformer = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=decoder_width, + layers=num_decoder_layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + + # geometry decoder + self.geo_decoder = CrossAttentionDecoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + out_channels=1, + num_latents=num_latents, + width=decoder_width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint, + supervision_type=supervision_type, + mlp_width_scale=decoder_mlp_width_scale + ) + + if freeze_encoder: + for p in self.encoder.parameters(): + p.requires_grad = False + for p in self.pre_kl.parameters(): + p.requires_grad = False + print("freeze encoder and pre kl") + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + volume_queries (torch.FloatTensor): [B, P, 3] + sample_posterior (bool): + + Returns: + logits (torch.FloatTensor): [B, P] + center_pos (torch.FloatTensor): [B, M, 3] + posterior (DiagonalGaussianDistribution or None). + + """ + + latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) + + latents = self.decode(latents) + logits = self.query_geometry(volume_queries, latents) + + return logits, center_pos, posterior + + +class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[str], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + decoder_width: Optional[int] = None, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False, + supervision_type: str = 'occupancy', + grad_type: str = "numerical", + grad_interval: float = 0.005, + query_method: bool = False, + use_full_input: bool = True, + token_num: int = 256, + freeze_encoder: bool = False, + decoder_mlp_width_scale: int = 4, + residual_kl: bool = False, + ): + + MAP_DTYPE = { + 'float32': torch.float32, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, + } + if dtype is not None: + dtype = MAP_DTYPE[dtype] + super().__init__( + device=device, + dtype=dtype, + num_latents=1 + num_latents, + point_feats=point_feats, + embed_dim=embed_dim, + num_freqs=num_freqs, + include_pi=include_pi, + width=width, + decoder_width=decoder_width, + heads=heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint, + supervision_type=supervision_type, + grad_type=grad_type, + grad_interval=grad_interval, + query_method=query_method, + token_num=token_num, + use_full_input=use_full_input, + freeze_encoder=freeze_encoder, + decoder_mlp_width_scale=decoder_mlp_width_scale, + residual_kl=residual_kl, + ) + + self.width = width + + def encode(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sample_posterior: bool = True, + only_shape: bool=False): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, c] + sample_posterior (bool): + + Returns: + shape_embed (torch.FloatTensor) + kl_embed (torch.FloatTensor): + posterior (DiagonalGaussianDistribution or None): + """ + + shape_embed, latents, token_num, pre_pc = self.encode_latents(pc, feats) + if only_shape: + return shape_embed + kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) + + return shape_embed, kl_embed, posterior, token_num, pre_pc + + def encode_latents(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None): + + x, _, token_num, pre_pc = self.encoder(pc, feats) + + shape_embed = x[:, 0] + # latents = x[:, 1:] + # use all tokens + latents = x + + return shape_embed, latents, token_num, pre_pc + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + raise NotImplementedError() + +##################################################### +# a simplified verstion of perceiver encoder +##################################################### + +class ShapeAsLatentPerceiverEncoder(ShapeAsLatentModule): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[Union[torch.dtype, str]], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False, + supervision_type: str = 'occupancy', + query_method: bool = False, + token_num: int = 256, + grad_type: str = "numerical", + grad_interval: float = 0.005, + use_full_input: bool = True, + freeze_encoder: bool = False, + residual_kl: bool = False, + ): + + super().__init__() + + + MAP_DTYPE = { + 'float32': torch.float32, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, + } + + if dtype is not None and isinstance(dtype, str): + dtype = MAP_DTYPE[dtype] + + self.use_checkpoint = use_checkpoint + + self.num_latents = num_latents + assert grad_type in ["numerical", "analytical"] + self.grad_type = grad_type + self.grad_interval = grad_interval + self.supervision_type = supervision_type + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + init_scale = init_scale * math.sqrt(1.0 / width) + self.encoder = CrossAttentionEncoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + num_latents=num_latents, + point_feats=point_feats, + width=width, + heads=heads, + layers=num_encoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint, + query_method=query_method, + use_full_input=use_full_input, + token_num=token_num, + no_query=True, + ) + + self.embed_dim = embed_dim + self.residual_kl = residual_kl + if freeze_encoder: + for p in self.encoder.parameters(): + p.requires_grad = False + print("freeze encoder") + self.width = width + + def encode_latents(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None): + + x, _, token_num, pre_pc = self.encoder(pc, feats) + + shape_embed = x[:, 0] + latents = x + + return shape_embed, latents, token_num, pre_pc + + def forward(self): + raise NotImplementedError() \ No newline at end of file diff --git a/UniRig/src/model/michelangelo/models/tsal/tsal_base.py b/UniRig/src/model/michelangelo/models/tsal/tsal_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c785fe23df80a59ead9f165b7f209fe9721b95d5 --- /dev/null +++ b/UniRig/src/model/michelangelo/models/tsal/tsal_base.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# +# This file is part of UniRig. +# +# This file is derived from https://github.com/NeuralCarver/Michelangelo +# +# Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors +# Copyright (c) 2025 VAST-AI-Research and contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import torch.nn as nn +from typing import Tuple, List, Optional +import lightning.pytorch as pl + + +class Point2MeshOutput(object): + def __init__(self): + self.mesh_v = None + self.mesh_f = None + self.center = None + self.pc = None + + +class Latent2MeshOutput(object): + + def __init__(self): + self.mesh_v = None + self.mesh_f = None + + +class AlignedMeshOutput(object): + + def __init__(self): + self.mesh_v = None + self.mesh_f = None + self.surface = None + self.image = None + self.text: Optional[str] = None + self.shape_text_similarity: Optional[float] = None + self.shape_image_similarity: Optional[float] = None + + +class ShapeAsLatentPLModule(pl.LightningModule): + latent_shape: Tuple[int] + + def encode(self, surface, *args, **kwargs): + raise NotImplementedError + + def decode(self, z_q, *args, **kwargs): + raise NotImplementedError + + def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: + raise NotImplementedError + + def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: + raise NotImplementedError + + +class ShapeAsLatentModule(nn.Module): + latent_shape: Tuple[int, int] + + def __init__(self, *args, **kwargs): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def query_geometry(self, *args, **kwargs): + raise NotImplementedError + + +class AlignedShapeAsLatentPLModule(pl.LightningModule): + latent_shape: Tuple[int] + + def set_shape_model_only(self): + raise NotImplementedError + + def encode(self, surface, *args, **kwargs): + raise NotImplementedError + + def decode(self, z_q, *args, **kwargs): + raise NotImplementedError + + def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: + raise NotImplementedError + + def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: + raise NotImplementedError + + +class AlignedShapeAsLatentModule(nn.Module): + shape_model: ShapeAsLatentModule + latent_shape: Tuple[int, int] + + def __init__(self, *args, **kwargs): + super().__init__() + + def set_shape_model_only(self): + raise NotImplementedError + + def encode_image_embed(self, *args, **kwargs): + raise NotImplementedError + + def encode_text_embed(self, *args, **kwargs): + raise NotImplementedError + + def encode_shape_embed(self, *args, **kwargs): + raise NotImplementedError + + +class TexturedShapeAsLatentModule(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def query_geometry(self, *args, **kwargs): + raise NotImplementedError + + def query_color(self, *args, **kwargs): + raise NotImplementedError diff --git a/UniRig/src/model/parse.py b/UniRig/src/model/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..be065d388725f5876d6c4ace8a47fe1aec379bb1 --- /dev/null +++ b/UniRig/src/model/parse.py @@ -0,0 +1,14 @@ +from .unirig_ar import UniRigAR +from .unirig_skin import UniRigSkin + +from .spec import ModelSpec + +def get_model(**kwargs) -> ModelSpec: + MAP = { + 'unirig_ar': UniRigAR, + 'unirig_skin': UniRigSkin, + } + __target__ = kwargs['__target__'] + del kwargs['__target__'] + assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}" + return MAP[__target__](**kwargs) \ No newline at end of file diff --git a/UniRig/src/model/parse_encoder.py b/UniRig/src/model/parse_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8d51f0f63ced6bc924ccf1ece871e6905d04da8b --- /dev/null +++ b/UniRig/src/model/parse_encoder.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass + +from .michelangelo.get_model import get_encoder as get_encoder_michelangelo +from .michelangelo.get_model import AlignedShapeLatentPerceiver +from .michelangelo.get_model import get_encoder_simplified as get_encoder_michelangelo_encoder +from .michelangelo.get_model import ShapeAsLatentPerceiverEncoder +from .pointcept.models.PTv3Object import get_encoder as get_encoder_ptv3obj +from .pointcept.models.PTv3Object import PointTransformerV3Object + +@dataclass(frozen=True) +class _MAP_MESH_ENCODER: + ptv3obj = PointTransformerV3Object + michelangelo = AlignedShapeLatentPerceiver + michelangelo_encoder = ShapeAsLatentPerceiverEncoder + +MAP_MESH_ENCODER = _MAP_MESH_ENCODER() + + +def get_mesh_encoder(**kwargs): + MAP = { + 'ptv3obj': get_encoder_ptv3obj, + 'michelangelo': get_encoder_michelangelo, + 'michelangelo_encoder': get_encoder_michelangelo_encoder, + } + __target__ = kwargs['__target__'] + del kwargs['__target__'] + assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}" + return MAP[__target__](**kwargs) \ No newline at end of file diff --git a/UniRig/src/model/pointcept/LICENSE b/UniRig/src/model/pointcept/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9293bb11a97a3da9232d22db6cc3005e8df219a7 --- /dev/null +++ b/UniRig/src/model/pointcept/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Pointcept + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/UniRig/src/model/pointcept/README.md b/UniRig/src/model/pointcept/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7585d78c428e3a0b9e593dde0c2b1a958c545d9d --- /dev/null +++ b/UniRig/src/model/pointcept/README.md @@ -0,0 +1 @@ +original repo: https://github.com/Pointcept/SAMPart3D/tree/main \ No newline at end of file diff --git a/UniRig/src/model/pointcept/__init__.py b/UniRig/src/model/pointcept/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/UniRig/src/model/pointcept/datasets/__init__.py b/UniRig/src/model/pointcept/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6d08bd5d43f12670958df4687791b5b073d815 --- /dev/null +++ b/UniRig/src/model/pointcept/datasets/__init__.py @@ -0,0 +1,4 @@ +from .builder import build_dataset +from .utils import point_collate_fn, collate_fn + +from .dataset_render_16views import SAMPart3DDataset16Views \ No newline at end of file diff --git a/UniRig/src/model/pointcept/datasets/builder.py b/UniRig/src/model/pointcept/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..5b01c6defba0a8afc7f835184744dc7f6356fb82 --- /dev/null +++ b/UniRig/src/model/pointcept/datasets/builder.py @@ -0,0 +1,8 @@ +from pointcept.utils.registry import Registry + +DATASETS = Registry("datasets") + + +def build_dataset(cfg): + """Build datasets.""" + return DATASETS.build(cfg) diff --git a/UniRig/src/model/pointcept/datasets/dataset_render_16views.py b/UniRig/src/model/pointcept/datasets/dataset_render_16views.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed6e7ff1c929b4ad4884a76b7b89177e321afc6 --- /dev/null +++ b/UniRig/src/model/pointcept/datasets/dataset_render_16views.py @@ -0,0 +1,438 @@ +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" +from os.path import join +import glob +import numpy as np +import torch +import trimesh +import json +import cv2 +import pointops +from copy import deepcopy +from torch.utils.data import Dataset +from collections.abc import Sequence +from transformers import pipeline, SamModel +from PIL import Image + +from pointcept.utils.logger import get_root_logger +from pointcept.utils.cache import shared_dict +from .builder import DATASETS +from .transform import Compose, TRANSFORMS +from .sampart3d_util import * + + +@DATASETS.register_module() +class SAMPart3DDataset16Views(Dataset): + + def __init__( + self, + split="train", + data_root="data/scannet", + mesh_root="", + mesh_path_mapping=None, + oid="", + label="", + sample_num=15000, + pixels_per_image=256, + batch_size=90, + transform=None, + loop=1, + extent_scale=10.0 + ): + super(SAMPart3DDataset16Views, self).__init__() + + data_root = os.path.join(data_root, str(oid)) + mesh_path = os.path.join(mesh_root, f"{oid}.glb") + self.data_root = data_root + self.split = split + self.pixels_per_image = pixels_per_image + self.batch_size = batch_size + self.device = 'cuda' + self.logger = get_root_logger() + + self.extent_scale = extent_scale + + self.meta_data = json.load(open(os.path.join(data_root, "meta.json"))) + + # Load mesh and sample pointclouds + self.mesh_path = mesh_path + transform = Compose(transform) + self.load_mesh(mesh_path, transform, sample_num) + + # Prepare SAM masks and depth mapping + if self.split == "train": + + self.prepare_meta_data() + + self.loop = loop + self.data_list = self.get_data_list() + self.logger.info( + "Totally {} x {} samples in {} set.".format( + len(self.data_list), self.loop, split + ) + ) + + def sample_pixel(self, masks, image_height=512, image_width=512): + masks = masks.to(self.device) + indices_batch = torch.zeros((self.batch_size*self.pixels_per_image, 3), device=self.device) + random_imgs = torch.randint(0, len(masks), (self.batch_size,), device=self.device) + for i in range(self.batch_size): + # Find the indices of the valid points in the mask + valid_indices = torch.nonzero(masks[random_imgs[i]], as_tuple=False) + # if len(valid_indices) == 0: + # continue + # Randomly sample from the valid indices + if len(valid_indices) >= self.pixels_per_image: + indices = valid_indices[torch.randint(0, len(valid_indices), (self.pixels_per_image,))] + else: + # Repeat the indices to fill up to pixels_per_image + repeat_times = self.pixels_per_image // len(valid_indices) + 1 + indices = valid_indices.repeat(repeat_times, 1)[:self.pixels_per_image] + + indices_batch[i * self.pixels_per_image : (i + 1) * self.pixels_per_image, 0] = random_imgs[i] + indices_batch[i * self.pixels_per_image : (i + 1) * self.pixels_per_image, 1:] = indices + + return indices_batch + + + def load_mesh(self, mesh_path, transform, sample_num=15000, pcd_path=None): + mesh = trimesh.load(mesh_path) + if isinstance(mesh, trimesh.Scene): + mesh = mesh.dump(concatenate=True) + coord, face_index, color = sample_surface(mesh, count=sample_num, sample_color=True) + color = color[..., :3] + face_normals = mesh.face_normals + normal = face_normals[face_index] + # self.mesh_scale, self.mesh_center_offset = cal_scale(mesh_path) + mesh_scale = self.meta_data["scaling_factor"] + mesh_center_offset = self.meta_data["mesh_offset"] + + object_org_coord = coord.copy() + rotation_matrix = np.array([ + [1, 0, 0], + [0, 0, 1], + [0, -1, 0]]) + object_org_coord = np.dot(object_org_coord, rotation_matrix) + object_org_coord = object_org_coord * mesh_scale + mesh_center_offset + + offset = torch.tensor(coord.shape[0]) + obj = dict(coord=coord, normal=normal, color=color, offset=offset, origin_coord=object_org_coord, face_index=face_index) + obj = transform(obj) + self.object_org_coord = obj["origin_coord"].clone() + self.face_index = obj["face_index"].clone().numpy() + self.pcd_inverse = obj["inverse"].clone().numpy() + # print("object_org_coord", torch.unique(self.object_org_coord, return_counts=True)) + del obj["origin_coord"], obj["face_index"], obj["inverse"] + self.object = obj + + + + def prepare_meta_data(self, data_path=None): + SAM_model = pipeline("mask-generation", model="facebook/sam-vit-huge", device=self.device) + pixel_level_keys_list = [] + scale_list = [] + group_cdf_list = [] + depth_valid_list = [] + mapping_list = [] + mapping_valid_list = [] + object_org_coord = self.object_org_coord.to(self.device).contiguous().float() + obj_offset = torch.tensor(object_org_coord.shape[0]).to(self.device) + + camera_angle_x = self.meta_data['camera_angle_x'] + for i, c2w_opengl in enumerate(self.meta_data["transforms"]): + # print(frame['index']) + c2w_opengl = np.array(c2w_opengl) + self.logger.info(f"Processing frame_{i}") + rgb_path = join(self.data_root, f"render_{i:04d}.webp") + img = np.array(Image.open(rgb_path)) + if img.shape[-1] == 4: + mask_img = img[..., 3] == 0 + img[mask_img] = [255, 255, 255, 255] + img = img[..., :3] + img = Image.fromarray(img.astype('uint8')) + + # Calculate mapping + depth_path = join(self.data_root, f"depth_{i:04d}.exr") + depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) + depth = depth[..., 0] + depth_valid = torch.tensor(depth < 65500.0) + org_points = gen_pcd(depth, c2w_opengl, camera_angle_x) + org_points = torch.from_numpy(org_points) + points_tensor = org_points.to(self.device).contiguous().float() + offset = torch.tensor(points_tensor.shape[0]).to(self.device) + indices, distances = pointops.knn_query(1, object_org_coord, obj_offset, points_tensor, offset) + mapping = torch.zeros((depth.shape[0], depth.shape[1]), dtype=torch.int) - 1 + + # Create a mask where distances are less than 0.03 + mask_dis = distances[..., 0] < 0.03 + indices[~mask_dis] = -1 + mapping[depth_valid] = indices.cpu().flatten() + mapping_valid = mapping != -1 + + # Calculate groups + try: + masks = SAM_model(img, points_per_side=32, pred_iou_thresh=0.9, stability_score_thresh=0.9) + masks = masks['masks'] + masks = sorted(masks, key=lambda x: x.sum()) + except: + masks = [] + + # mask filter + masks_filtered = [] + img_valid = ~mask_img + for mask in masks: + valid_ratio = mask[img_valid].sum() / img_valid.sum() + invalid_ratio = mask[mask_img].sum() / mask_img.sum() + if valid_ratio == 0 or invalid_ratio > 0.1: + continue + else: + masks_filtered.append(mask) + pixel_level_keys, scale, mask_cdf = self._calculate_3d_groups(torch.from_numpy(depth), mapping_valid, masks_filtered, points_tensor[mask_dis]) + + pixel_level_keys_list.append(pixel_level_keys) + scale_list.append(scale) + group_cdf_list.append(mask_cdf) + depth_valid_list.append(depth_valid) + mapping_list.append(mapping) + mapping_valid_list.append(mapping_valid) + + self.pixel_level_keys = torch.nested.nested_tensor( + pixel_level_keys_list + ) + self.scale_3d_statistics = torch.cat(scale_list) + self.scale_3d = torch.nested.nested_tensor(scale_list) + self.group_cdf = torch.nested.nested_tensor(group_cdf_list) + self.depth_valid = torch.stack(depth_valid_list) + self.mapping = torch.stack(mapping_list) + self.mapping_valid = torch.stack(mapping_valid_list) + + def _calculate_3d_groups( + self, + depth: torch.Tensor, + valid: torch.Tensor, + masks: torch.Tensor, + point: torch.Tensor, + max_scale: float = 2.0, + ): + """ + Calculate the set of groups and their 3D scale for each pixel, and the cdf. + Returns: + - pixel_level_keys: [H, W, max_masks] + - scale: [num_masks, 1] + - mask_cdf: [H, W, max_masks] + max_masks is the maximum number of masks that was assigned to a pixel in the image, + padded with -1s. mask_cdf does *not* include the -1s. + Refer to the main paper for more details. + """ + image_shape = depth.shape[:2] + depth_valid = valid + point = point.to(self.device) + + def helper_return_no_masks(): + # Fail gracefully when no masks are found. + # Create dummy data (all -1s), which will be ignored later. + # See: `get_loss_dict_group` in `garfield_model.py` + pixel_level_keys = torch.full( + (image_shape[0], image_shape[1], 1), -1, dtype=torch.int + ) + scale = torch.Tensor([0.0]).view(-1, 1) + mask_cdf = torch.full( + (image_shape[0], image_shape[1], 1), 1, dtype=torch.float + ) + return (pixel_level_keys, scale, mask_cdf) + + + # If no masks are found, return dummy data. + if len(masks) == 0: + return helper_return_no_masks() + + sam_mask = [] + scale = [] + + # For all 2D groups, + # 1) Denoise the masks (through eroding) + all_masks = torch.stack( + # [torch.from_numpy(_["segmentation"]).to(self.device) for _ in masks] + [torch.from_numpy(_).to(self.device) for _ in masks] + ) + # erode all masks using 3x3 kernel + # ignore erode + eroded_masks = torch.conv2d( + all_masks.unsqueeze(1).float(), + torch.full((3, 3), 1.0).view(1, 1, 3, 3).to("cuda"), + padding=1, + ) + eroded_masks = (eroded_masks >= 5).squeeze(1) # (num_masks, H, W) + + # 2) Calculate 3D scale + # Don't include groups with scale > max_scale (likely to be too noisy to be useful) + for i in range(len(masks)): + curr_mask_org = eroded_masks[i] + curr_mask = curr_mask_org[depth_valid] + curr_points = point[curr_mask] + extent = (curr_points.std(dim=0) * self.extent_scale).norm() + if extent.item() < max_scale: + sam_mask.append(curr_mask_org) + scale.append(extent.item()) + + # If no masks are found, after postprocessing, return dummy data. + if len(sam_mask) == 0: + return helper_return_no_masks() + + sam_mask = torch.stack(sam_mask) # (num_masks, H, W) + scale = torch.Tensor(scale).view(-1, 1).to(self.device) # (num_masks, 1) + + # Calculate "pixel level keys", which is a 2D array of shape (H, W, max_masks) + # Each pixel has a list of group indices that it belongs to, in order of increasing scale. + pixel_level_keys = self.create_pixel_mask_array( + sam_mask + ).long() # (H, W, max_masks) + depth_invalid = ~depth_valid + pixel_level_keys[depth_invalid, :] = -1 + + # Calculate group sampling CDF, to bias sampling towards smaller groups + # Be careful to not include -1s in the CDF (padding, or unlabeled pixels) + # Inversely proportional to log of mask size. + mask_inds, counts = torch.unique(pixel_level_keys, return_counts=True) + counts[0] = 0 # don't include -1 + probs = counts / counts.sum() # [-1, 0, ...] + + pixel_shape = pixel_level_keys.shape + if (pixel_level_keys.max()+2) != probs.shape[0]: + pixel_level_keys_new = pixel_level_keys.reshape(-1) + unique_values, inverse_indices = torch.unique(pixel_level_keys_new, return_inverse=True) + pixel_level_keys_new = inverse_indices.reshape(-1) + else: + pixel_level_keys_new = pixel_level_keys.reshape(-1) + 1 + + mask_probs = torch.gather(probs, 0, pixel_level_keys.reshape(-1) + 1).view( + pixel_shape + ) + mask_log_probs = torch.log(mask_probs) + never_masked = mask_log_probs.isinf() + mask_log_probs[never_masked] = 0.0 + mask_log_probs = mask_log_probs / ( + mask_log_probs.sum(dim=-1, keepdim=True) + 1e-6 + ) + mask_cdf = torch.cumsum(mask_log_probs, dim=-1) + mask_cdf[never_masked] = 1.0 + + return (pixel_level_keys.cpu(), scale.cpu(), mask_cdf.cpu()) + + @staticmethod + def create_pixel_mask_array(masks: torch.Tensor): + """ + Create per-pixel data structure for grouping supervision. + pixel_mask_array[x, y] = [m1, m2, ...] means that pixel (x, y) belongs to masks m1, m2, ... + where Area(m1) < Area(m2) < ... (sorted by area). + """ + max_masks = masks.sum(dim=0).max().item() + # print(max_masks) + image_shape = masks.shape[1:] + pixel_mask_array = torch.full( + (max_masks, image_shape[0], image_shape[1]), -1, dtype=torch.int + ).to(masks.device) + + for m, mask in enumerate(masks): + mask_clone = mask.clone() + for i in range(max_masks): + free = pixel_mask_array[i] == -1 + masked_area = mask_clone == 1 + right_index = free & masked_area + if len(pixel_mask_array[i][right_index]) != 0: + pixel_mask_array[i][right_index] = m + mask_clone[right_index] = 0 + pixel_mask_array = pixel_mask_array.permute(1, 2, 0) + + return pixel_mask_array + + def get_data_list(self): + data_list = glob.glob(os.path.join(self.data_root, "*.exr")) + return data_list + + def get_data(self, idx): + indices = self.sample_pixel(self.mapping_valid, 512, 512).long().detach().cpu() + npximg = self.pixels_per_image + img_ind = indices[:, 0] + x_ind = indices[:, 1] + y_ind = indices[:, 2] + + # sampled_imgs = img_ind[::npximg] + mask_id = torch.zeros((indices.shape[0],), device=self.device) + scale = torch.zeros((indices.shape[0],), device=self.device) + mapping = torch.zeros((indices.shape[0],), device=self.device) + + random_vec_sampling = (torch.rand((1,)) * torch.ones((npximg,))).view(-1, 1) + random_vec_densify = (torch.rand((1,)) * torch.ones((npximg,))).view(-1, 1) + + for i in range(0, indices.shape[0], npximg): + img_idx = img_ind[i] + + # calculate mapping + mapping[i : i + npximg] = self.mapping[img_idx][x_ind[i : i + npximg], y_ind[i : i + npximg]] + + # Use `random_vec` to choose a group for each pixel. + per_pixel_index = self.pixel_level_keys[img_idx][ + x_ind[i : i + npximg], y_ind[i : i + npximg] + ] + random_index = torch.sum( + random_vec_sampling.view(-1, 1) + > self.group_cdf[img_idx][x_ind[i : i + npximg], y_ind[i : i + npximg]], + dim=-1, + ) + + # `per_pixel_index` encodes the list of groups that each pixel belongs to. + # If there's only one group, then `per_pixel_index` is a 1D tensor + # -- this will mess up the future `gather` operations. + if per_pixel_index.shape[-1] == 1: + per_pixel_mask = per_pixel_index.squeeze() + else: + per_pixel_mask = torch.gather( + per_pixel_index, 1, random_index.unsqueeze(-1) + ).squeeze() + per_pixel_mask_ = torch.gather( + per_pixel_index, + 1, + torch.max(random_index.unsqueeze(-1) - 1, torch.Tensor([0]).int()), + ).squeeze() + + mask_id[i : i + npximg] = per_pixel_mask.to(self.device) + + # interval scale supervision + curr_scale = self.scale_3d[img_idx][per_pixel_mask] + curr_scale[random_index == 0] = ( + self.scale_3d[img_idx][per_pixel_mask][random_index == 0] + * random_vec_densify[random_index == 0] + ) + for j in range(1, self.group_cdf[img_idx].shape[-1]): + if (random_index == j).sum() == 0: + continue + curr_scale[random_index == j] = ( + self.scale_3d[img_idx][per_pixel_mask_][random_index == j] + + ( + self.scale_3d[img_idx][per_pixel_mask][random_index == j] + - self.scale_3d[img_idx][per_pixel_mask_][random_index == j] + ) + * random_vec_densify[random_index == j] + ) + scale[i : i + npximg] = curr_scale.squeeze().to(self.device) + + batch = dict() + batch["mask_id"] = mask_id + batch["scale"] = scale + batch["nPxImg"] = npximg + batch["obj"] = self.object + batch["mapping"] = mapping.long() + return batch + + def val_data(self): + return dict(obj=self.object) + + def get_data_name(self, idx): + return os.path.basename(self.data_list[idx % len(self.data_list)]).split(".")[0] + + def __getitem__(self, idx): + return self.get_data(idx % len(self.data_list)) + + def __len__(self): + return len(self.data_list) * self.loop diff --git a/UniRig/src/model/pointcept/datasets/sampart3d_util.py b/UniRig/src/model/pointcept/datasets/sampart3d_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0e11a17ca59027a3750c1ccaa2c9f07a3f73fcab --- /dev/null +++ b/UniRig/src/model/pointcept/datasets/sampart3d_util.py @@ -0,0 +1,152 @@ +import numpy as np +import trimesh +import os +import json +import math +import open3d as o3d +import torch + + +def sample_surface(mesh, count, face_weight=None, sample_color=False, seed=147): + + if face_weight is None: + # len(mesh.faces) float, array of the areas + # of each face of the mesh + face_weight = mesh.area_faces + + # cumulative sum of weights (len(mesh.faces)) + weight_cum = np.cumsum(face_weight) + + # seed the random number generator as requested + random = np.random.default_rng(seed).random + + # last value of cumulative sum is total summed weight/area + face_pick = random(count) * weight_cum[-1] + # get the index of the selected faces + face_index = np.searchsorted(weight_cum, face_pick) + + # pull triangles into the form of an origin + 2 vectors + tri_origins = mesh.vertices[mesh.faces[:, 0]] + tri_vectors = mesh.vertices[mesh.faces[:, 1:]].copy() + tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3)) + + # pull the vectors for the faces we are going to sample from + tri_origins = tri_origins[face_index] + tri_vectors = tri_vectors[face_index] + + if sample_color and hasattr(mesh.visual, "uv"): + uv_origins = mesh.visual.uv[mesh.faces[:, 0]] + uv_vectors = mesh.visual.uv[mesh.faces[:, 1:]].copy() + uv_origins_tile = np.tile(uv_origins, (1, 2)).reshape((-1, 2, 2)) + uv_vectors -= uv_origins_tile + uv_origins = uv_origins[face_index] + uv_vectors = uv_vectors[face_index] + + # randomly generate two 0-1 scalar components to multiply edge vectors b + random_lengths = random((len(tri_vectors), 2, 1)) + + # points will be distributed on a quadrilateral if we use 2 0-1 samples + # if the two scalar components sum less than 1.0 the point will be + # inside the triangle, so we find vectors longer than 1.0 and + # transform them to be inside the triangle + random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0 + random_lengths[random_test] -= 1.0 + random_lengths = np.abs(random_lengths) + + # multiply triangle edge vectors by the random lengths and sum + sample_vector = (tri_vectors * random_lengths).sum(axis=1) + + # finally, offset by the origin to generate + # (n,3) points in space on the triangle + samples = sample_vector + tri_origins + + if sample_color: + if hasattr(mesh.visual, "uv"): + sample_uv_vector = (uv_vectors * random_lengths).sum(axis=1) + uv_samples = sample_uv_vector + uv_origins + try: + texture = mesh.visual.material.baseColorTexture + except: + texture = mesh.visual.material.image + colors = trimesh.visual.color.uv_to_interpolated_color(uv_samples, texture) + else: + colors = mesh.visual.face_colors[face_index] + + return samples, face_index, colors + + return samples, face_index + + +def get_ray_directions(W, H, fx, fy, cx, cy, use_pixel_centers=True): + pixel_center = 0.5 if use_pixel_centers else 0 + i, j = np.meshgrid( + np.arange(W, dtype=np.float32) + pixel_center, + np.arange(H, dtype=np.float32) + pixel_center, + indexing="xy", + ) + directions = np.stack( + [(i - cx) / fx, -(j - cy) / fy, -np.ones_like(i)], -1 + ) + + return directions + + +def gen_pcd(depth, c2w_opengl, camera_angle_x): + + h, w = depth.shape + + depth_valid = depth < 65500.0 + depth = depth[depth_valid] + focal = ( + 0.5 * w / math.tan(0.5 * camera_angle_x) + ) # scaled focal length + ray_directions = get_ray_directions(w, h, focal, focal, w // 2, h // 2) + points_c = ray_directions[depth_valid] * depth[:, None] + points_c_homo = np.concatenate( + [points_c, np.ones_like(points_c[..., :1])], axis=-1 + ) + org_points = (points_c_homo @ c2w_opengl.T)[..., :3] + + return org_points + + +def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + coord = np.array(coord) + if color is not None: + color = np.array(color) + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(coord) + pcd.colors = o3d.utility.Vector3dVector(np.ones_like(coord) if color is None else color) + o3d.io.write_point_cloud(file_path, pcd) + if logger is not None: + logger.info(f"Save Point Cloud to: {file_path}") + + +def vis_pcd_feat(coord, point_feat, save_path): + class TorchPCA(object): + + def __init__(self, n_components): + self.n_components = n_components + + def fit(self, X): + self.mean_ = X.mean(dim=0) + unbiased = X - self.mean_.unsqueeze(0) + U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4) + self.components_ = V.T + self.singular_values_ = S + return self + + def transform(self, X): + t0 = X - self.mean_.unsqueeze(0) + projected = t0 @ self.components_.T + return projected + + fit_pca = TorchPCA(n_components=3).fit(point_feat) + x_red = fit_pca.transform(point_feat) + if isinstance(x_red, np.ndarray): + x_red = torch.from_numpy(x_red) + x_red -= x_red.min(dim=0, keepdim=True).values + x_red /= x_red.max(dim=0, keepdim=True).values + + save_point_cloud(coord.detach().cpu(), x_red.detach().cpu(), save_path) \ No newline at end of file diff --git a/UniRig/src/model/pointcept/datasets/transform.py b/UniRig/src/model/pointcept/datasets/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb11b80fc28ad592e453cb69df23669f413aaa4 --- /dev/null +++ b/UniRig/src/model/pointcept/datasets/transform.py @@ -0,0 +1,1143 @@ +import random +import numbers +import scipy +import scipy.ndimage +import scipy.interpolate +import scipy.stats +import numpy as np +import torch +import copy +from collections.abc import Sequence, Mapping + +from pointcept.utils.registry import Registry + +TRANSFORMS = Registry("transforms") + + +@TRANSFORMS.register_module() +class Collect(object): + def __init__(self, keys, offset_keys_dict=None, **kwargs): + """ + e.g. Collect(keys=[coord], feat_keys=[coord, color]) + """ + if offset_keys_dict is None: + offset_keys_dict = dict(offset="coord") + self.keys = keys + self.offset_keys = offset_keys_dict + self.kwargs = kwargs + + def __call__(self, data_dict): + data = dict() + if isinstance(self.keys, str): + self.keys = [self.keys] + for key in self.keys: + data[key] = data_dict[key] + for key, value in self.offset_keys.items(): + data[key] = torch.tensor([data_dict[value].shape[0]]) + for name, keys in self.kwargs.items(): + name = name.replace("_keys", "") + assert isinstance(keys, Sequence) + data[name] = torch.cat([data_dict[key].float() for key in keys], dim=1) + return data + + +@TRANSFORMS.register_module() +class Copy(object): + def __init__(self, keys_dict=None): + if keys_dict is None: + keys_dict = dict(coord="origin_coord", segment="origin_segment") + self.keys_dict = keys_dict + + def __call__(self, data_dict): + for key, value in self.keys_dict.items(): + if isinstance(data_dict[key], np.ndarray): + data_dict[value] = data_dict[key].copy() + elif isinstance(data_dict[key], torch.Tensor): + data_dict[value] = data_dict[key].clone().detach() + else: + data_dict[value] = copy.deepcopy(data_dict[key]) + return data_dict + + +@TRANSFORMS.register_module() +class ToTensor(object): + def __call__(self, data): + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, str): + # note that str is also a kind of sequence, judgement should before sequence + return data + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, bool): + return torch.from_numpy(data) + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.integer): + return torch.from_numpy(data).long() + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating): + return torch.from_numpy(data).float() + elif isinstance(data, Mapping): + result = {sub_key: self(item) for sub_key, item in data.items()} + return result + elif isinstance(data, Sequence): + result = [self(item) for item in data] + return result + else: + raise TypeError(f"type {type(data)} cannot be converted to tensor.") + + +@TRANSFORMS.register_module() +class Add(object): + def __init__(self, keys_dict=None): + if keys_dict is None: + keys_dict = dict() + self.keys_dict = keys_dict + + def __call__(self, data_dict): + for key, value in self.keys_dict.items(): + data_dict[key] = value + return data_dict + + +@TRANSFORMS.register_module() +class NormalizeColor(object): + def __call__(self, data_dict): + if "color" in data_dict.keys(): + data_dict["color"] = data_dict["color"] / 127.5 - 1 + return data_dict + + +@TRANSFORMS.register_module() +class NormalizeCoord(object): + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + # modified from pointnet2 + centroid = np.mean(data_dict["coord"], axis=0) + data_dict["coord"] -= centroid + m = np.max(np.sqrt(np.sum(data_dict["coord"] ** 2, axis=1))) + data_dict["coord"] = data_dict["coord"] / m + return data_dict + + +@TRANSFORMS.register_module() +class PositiveShift(object): + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + coord_min = np.min(data_dict["coord"], 0) + data_dict["coord"] -= coord_min + return data_dict + + +@TRANSFORMS.register_module() +class CenterShift(object): + def __init__(self, apply_z=True): + self.apply_z = apply_z + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + x_min, y_min, z_min = data_dict["coord"].min(axis=0) + x_max, y_max, _ = data_dict["coord"].max(axis=0) + if self.apply_z: + shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, z_min] + else: + shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, 0] + data_dict["coord"] -= shift + return data_dict + + +@TRANSFORMS.register_module() +class RandomShift(object): + def __init__(self, shift=((-0.2, 0.2), (-0.2, 0.2), (0, 0))): + self.shift = shift + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + shift_x = np.random.uniform(self.shift[0][0], self.shift[0][1]) + shift_y = np.random.uniform(self.shift[1][0], self.shift[1][1]) + shift_z = np.random.uniform(self.shift[2][0], self.shift[2][1]) + data_dict["coord"] += [shift_x, shift_y, shift_z] + return data_dict + + +@TRANSFORMS.register_module() +class PointClip(object): + def __init__(self, point_cloud_range=(-80, -80, -3, 80, 80, 1)): + self.point_cloud_range = point_cloud_range + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + data_dict["coord"] = np.clip( + data_dict["coord"], + a_min=self.point_cloud_range[:3], + a_max=self.point_cloud_range[3:], + ) + return data_dict + + +@TRANSFORMS.register_module() +class RandomDropout(object): + def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5): + """ + upright_axis: axis index among x,y,z, i.e. 2 for z + """ + self.dropout_ratio = dropout_ratio + self.dropout_application_ratio = dropout_application_ratio + + def __call__(self, data_dict): + if random.random() < self.dropout_application_ratio: + n = len(data_dict["coord"]) + idx = np.random.choice(n, int(n * (1 - self.dropout_ratio)), replace=False) + if "sampled_index" in data_dict: + # for ScanNet data efficient, we need to make sure labeled point is sampled. + idx = np.unique(np.append(idx, data_dict["sampled_index"])) + mask = np.zeros_like(data_dict["segment"]).astype(bool) + mask[data_dict["sampled_index"]] = True + data_dict["sampled_index"] = np.where(mask[idx])[0] + if "coord" in data_dict.keys(): + data_dict["coord"] = data_dict["coord"][idx] + if "color" in data_dict.keys(): + data_dict["color"] = data_dict["color"][idx] + if "normal" in data_dict.keys(): + data_dict["normal"] = data_dict["normal"][idx] + if "strength" in data_dict.keys(): + data_dict["strength"] = data_dict["strength"][idx] + if "segment" in data_dict.keys(): + data_dict["segment"] = data_dict["segment"][idx] + if "instance" in data_dict.keys(): + data_dict["instance"] = data_dict["instance"][idx] + return data_dict + + +@TRANSFORMS.register_module() +class RandomRotate(object): + def __init__(self, angle=None, center=None, axis="z", always_apply=False, p=0.5): + self.angle = [-1, 1] if angle is None else angle + self.axis = axis + self.always_apply = always_apply + self.p = p if not self.always_apply else 1 + self.center = center + + def __call__(self, data_dict): + if random.random() > self.p: + return data_dict + angle = np.random.uniform(self.angle[0], self.angle[1]) * np.pi + rot_cos, rot_sin = np.cos(angle), np.sin(angle) + if self.axis == "x": + rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]]) + elif self.axis == "y": + rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]]) + elif self.axis == "z": + rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]) + else: + raise NotImplementedError + if "coord" in data_dict.keys(): + if self.center is None: + x_min, y_min, z_min = data_dict["coord"].min(axis=0) + x_max, y_max, z_max = data_dict["coord"].max(axis=0) + center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2] + else: + center = self.center + data_dict["coord"] -= center + data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t)) + data_dict["coord"] += center + if "normal" in data_dict.keys(): + data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t)) + return data_dict + + +@TRANSFORMS.register_module() +class RandomRotateTargetAngle(object): + def __init__( + self, angle=(1 / 2, 1, 3 / 2), center=None, axis="z", always_apply=False, p=0.75 + ): + self.angle = angle + self.axis = axis + self.always_apply = always_apply + self.p = p if not self.always_apply else 1 + self.center = center + + def __call__(self, data_dict): + if random.random() > self.p: + return data_dict + angle = np.random.choice(self.angle) * np.pi + rot_cos, rot_sin = np.cos(angle), np.sin(angle) + if self.axis == "x": + rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]]) + elif self.axis == "y": + rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]]) + elif self.axis == "z": + rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]) + else: + raise NotImplementedError + if "coord" in data_dict.keys(): + if self.center is None: + x_min, y_min, z_min = data_dict["coord"].min(axis=0) + x_max, y_max, z_max = data_dict["coord"].max(axis=0) + center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2] + else: + center = self.center + data_dict["coord"] -= center + data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t)) + data_dict["coord"] += center + if "normal" in data_dict.keys(): + data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t)) + return data_dict + + +@TRANSFORMS.register_module() +class RandomScale(object): + def __init__(self, scale=None, anisotropic=False): + self.scale = scale if scale is not None else [0.95, 1.05] + self.anisotropic = anisotropic + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + scale = np.random.uniform( + self.scale[0], self.scale[1], 3 if self.anisotropic else 1 + ) + data_dict["coord"] *= scale + return data_dict + + +@TRANSFORMS.register_module() +class RandomFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, data_dict): + if np.random.rand() < self.p: + if "coord" in data_dict.keys(): + data_dict["coord"][:, 0] = -data_dict["coord"][:, 0] + if "normal" in data_dict.keys(): + data_dict["normal"][:, 0] = -data_dict["normal"][:, 0] + if np.random.rand() < self.p: + if "coord" in data_dict.keys(): + data_dict["coord"][:, 1] = -data_dict["coord"][:, 1] + if "normal" in data_dict.keys(): + data_dict["normal"][:, 1] = -data_dict["normal"][:, 1] + return data_dict + + +@TRANSFORMS.register_module() +class RandomJitter(object): + def __init__(self, sigma=0.01, clip=0.05): + assert clip > 0 + self.sigma = sigma + self.clip = clip + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + jitter = np.clip( + self.sigma * np.random.randn(data_dict["coord"].shape[0], 3), + -self.clip, + self.clip, + ) + data_dict["coord"] += jitter + return data_dict + + +@TRANSFORMS.register_module() +class ClipGaussianJitter(object): + def __init__(self, scalar=0.02, store_jitter=False): + self.scalar = scalar + self.mean = np.mean(3) + self.cov = np.identity(3) + self.quantile = 1.96 + self.store_jitter = store_jitter + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + jitter = np.random.multivariate_normal( + self.mean, self.cov, data_dict["coord"].shape[0] + ) + jitter = self.scalar * np.clip(jitter / 1.96, -1, 1) + data_dict["coord"] += jitter + if self.store_jitter: + data_dict["jitter"] = jitter + return data_dict + + +@TRANSFORMS.register_module() +class ChromaticAutoContrast(object): + def __init__(self, p=0.2, blend_factor=None): + self.p = p + self.blend_factor = blend_factor + + def __call__(self, data_dict): + if "color" in data_dict.keys() and np.random.rand() < self.p: + lo = np.min(data_dict["color"], 0, keepdims=True) + hi = np.max(data_dict["color"], 0, keepdims=True) + scale = 255 / (hi - lo) + contrast_feat = (data_dict["color"][:, :3] - lo) * scale + blend_factor = ( + np.random.rand() if self.blend_factor is None else self.blend_factor + ) + data_dict["color"][:, :3] = (1 - blend_factor) * data_dict["color"][ + :, :3 + ] + blend_factor * contrast_feat + return data_dict + + +@TRANSFORMS.register_module() +class ChromaticTranslation(object): + def __init__(self, p=0.95, ratio=0.05): + self.p = p + self.ratio = ratio + + def __call__(self, data_dict): + if "color" in data_dict.keys() and np.random.rand() < self.p: + tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.ratio + data_dict["color"][:, :3] = np.clip(tr + data_dict["color"][:, :3], 0, 255) + return data_dict + + +@TRANSFORMS.register_module() +class ChromaticJitter(object): + def __init__(self, p=0.95, std=0.005): + self.p = p + self.std = std + + def __call__(self, data_dict): + if "color" in data_dict.keys() and np.random.rand() < self.p: + noise = np.random.randn(data_dict["color"].shape[0], 3) + noise *= self.std * 255 + data_dict["color"][:, :3] = np.clip( + noise + data_dict["color"][:, :3], 0, 255 + ) + return data_dict + + +@TRANSFORMS.register_module() +class RandomColorGrayScale(object): + def __init__(self, p): + self.p = p + + @staticmethod + def rgb_to_grayscale(color, num_output_channels=1): + if color.shape[-1] < 3: + raise TypeError( + "Input color should have at least 3 dimensions, but found {}".format( + color.shape[-1] + ) + ) + + if num_output_channels not in (1, 3): + raise ValueError("num_output_channels should be either 1 or 3") + + r, g, b = color[..., 0], color[..., 1], color[..., 2] + gray = (0.2989 * r + 0.587 * g + 0.114 * b).astype(color.dtype) + gray = np.expand_dims(gray, axis=-1) + + if num_output_channels == 3: + gray = np.broadcast_to(gray, color.shape) + + return gray + + def __call__(self, data_dict): + if np.random.rand() < self.p: + data_dict["color"] = self.rgb_to_grayscale(data_dict["color"], 3) + return data_dict + + +@TRANSFORMS.register_module() +class RandomColorJitter(object): + """ + Random Color Jitter for 3D point cloud (refer torchvision) + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.95): + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input( + hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False + ) + self.p = p + + @staticmethod + def _check_input( + value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True + ): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + "If {} is a single number, it must be non negative.".format(name) + ) + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError( + "{} should be a single number or a list/tuple with length 2.".format( + name + ) + ) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def blend(color1, color2, ratio): + ratio = float(ratio) + bound = 255.0 + return ( + (ratio * color1 + (1.0 - ratio) * color2) + .clip(0, bound) + .astype(color1.dtype) + ) + + @staticmethod + def rgb2hsv(rgb): + r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] + maxc = np.max(rgb, axis=-1) + minc = np.min(rgb, axis=-1) + eqc = maxc == minc + cr = maxc - minc + s = cr / (np.ones_like(maxc) * eqc + maxc * (1 - eqc)) + cr_divisor = np.ones_like(maxc) * eqc + cr * (1 - eqc) + rc = (maxc - r) / cr_divisor + gc = (maxc - g) / cr_divisor + bc = (maxc - b) / cr_divisor + + hr = (maxc == r) * (bc - gc) + hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) + hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) + h = hr + hg + hb + h = (h / 6.0 + 1.0) % 1.0 + return np.stack((h, s, maxc), axis=-1) + + @staticmethod + def hsv2rgb(hsv): + h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] + i = np.floor(h * 6.0) + f = (h * 6.0) - i + i = i.astype(np.int32) + + p = np.clip((v * (1.0 - s)), 0.0, 1.0) + q = np.clip((v * (1.0 - s * f)), 0.0, 1.0) + t = np.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) + i = i % 6 + mask = np.expand_dims(i, axis=-1) == np.arange(6) + + a1 = np.stack((v, q, p, p, t, v), axis=-1) + a2 = np.stack((t, v, v, q, p, p), axis=-1) + a3 = np.stack((p, p, t, v, v, q), axis=-1) + a4 = np.stack((a1, a2, a3), axis=-1) + + return np.einsum("...na, ...nab -> ...nb", mask.astype(hsv.dtype), a4) + + def adjust_brightness(self, color, brightness_factor): + if brightness_factor < 0: + raise ValueError( + "brightness_factor ({}) is not non-negative.".format(brightness_factor) + ) + + return self.blend(color, np.zeros_like(color), brightness_factor) + + def adjust_contrast(self, color, contrast_factor): + if contrast_factor < 0: + raise ValueError( + "contrast_factor ({}) is not non-negative.".format(contrast_factor) + ) + mean = np.mean(RandomColorGrayScale.rgb_to_grayscale(color)) + return self.blend(color, mean, contrast_factor) + + def adjust_saturation(self, color, saturation_factor): + if saturation_factor < 0: + raise ValueError( + "saturation_factor ({}) is not non-negative.".format(saturation_factor) + ) + gray = RandomColorGrayScale.rgb_to_grayscale(color) + return self.blend(color, gray, saturation_factor) + + def adjust_hue(self, color, hue_factor): + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError( + "hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor) + ) + orig_dtype = color.dtype + hsv = self.rgb2hsv(color / 255.0) + h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] + h = (h + hue_factor) % 1.0 + hsv = np.stack((h, s, v), axis=-1) + color_hue_adj = (self.hsv2rgb(hsv) * 255.0).astype(orig_dtype) + return color_hue_adj + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + fn_idx = torch.randperm(4) + b = ( + None + if brightness is None + else np.random.uniform(brightness[0], brightness[1]) + ) + c = None if contrast is None else np.random.uniform(contrast[0], contrast[1]) + s = ( + None + if saturation is None + else np.random.uniform(saturation[0], saturation[1]) + ) + h = None if hue is None else np.random.uniform(hue[0], hue[1]) + return fn_idx, b, c, s, h + + def __call__(self, data_dict): + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + + for fn_id in fn_idx: + if ( + fn_id == 0 + and brightness_factor is not None + and np.random.rand() < self.p + ): + data_dict["color"] = self.adjust_brightness( + data_dict["color"], brightness_factor + ) + elif ( + fn_id == 1 and contrast_factor is not None and np.random.rand() < self.p + ): + data_dict["color"] = self.adjust_contrast( + data_dict["color"], contrast_factor + ) + elif ( + fn_id == 2 + and saturation_factor is not None + and np.random.rand() < self.p + ): + data_dict["color"] = self.adjust_saturation( + data_dict["color"], saturation_factor + ) + elif fn_id == 3 and hue_factor is not None and np.random.rand() < self.p: + data_dict["color"] = self.adjust_hue(data_dict["color"], hue_factor) + return data_dict + + +@TRANSFORMS.register_module() +class HueSaturationTranslation(object): + @staticmethod + def rgb_to_hsv(rgb): + # Translated from source of colorsys.rgb_to_hsv + # r,g,b should be a numpy arrays with values between 0 and 255 + # rgb_to_hsv returns an array of floats between 0.0 and 1.0. + rgb = rgb.astype("float") + hsv = np.zeros_like(rgb) + # in case an RGBA array was passed, just copy the A channel + hsv[..., 3:] = rgb[..., 3:] + r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] + maxc = np.max(rgb[..., :3], axis=-1) + minc = np.min(rgb[..., :3], axis=-1) + hsv[..., 2] = maxc + mask = maxc != minc + hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask] + rc = np.zeros_like(r) + gc = np.zeros_like(g) + bc = np.zeros_like(b) + rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask] + gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask] + bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask] + hsv[..., 0] = np.select( + [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc + ) + hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0 + return hsv + + @staticmethod + def hsv_to_rgb(hsv): + # Translated from source of colorsys.hsv_to_rgb + # h,s should be a numpy arrays with values between 0.0 and 1.0 + # v should be a numpy array with values between 0.0 and 255.0 + # hsv_to_rgb returns an array of uints between 0 and 255. + rgb = np.empty_like(hsv) + rgb[..., 3:] = hsv[..., 3:] + h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] + i = (h * 6.0).astype("uint8") + f = (h * 6.0) - i + p = v * (1.0 - s) + q = v * (1.0 - s * f) + t = v * (1.0 - s * (1.0 - f)) + i = i % 6 + conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5] + rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v) + rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t) + rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p) + return rgb.astype("uint8") + + def __init__(self, hue_max=0.5, saturation_max=0.2): + self.hue_max = hue_max + self.saturation_max = saturation_max + + def __call__(self, data_dict): + if "color" in data_dict.keys(): + # Assume color[:, :3] is rgb + hsv = HueSaturationTranslation.rgb_to_hsv(data_dict["color"][:, :3]) + hue_val = (np.random.rand() - 0.5) * 2 * self.hue_max + sat_ratio = 1 + (np.random.rand() - 0.5) * 2 * self.saturation_max + hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1) + hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1) + data_dict["color"][:, :3] = np.clip( + HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255 + ) + return data_dict + + +@TRANSFORMS.register_module() +class RandomColorDrop(object): + def __init__(self, p=0.2, color_augment=0.0): + self.p = p + self.color_augment = color_augment + + def __call__(self, data_dict): + if "color" in data_dict.keys() and np.random.rand() < self.p: + data_dict["color"] *= self.color_augment + return data_dict + + def __repr__(self): + return "RandomColorDrop(color_augment: {}, p: {})".format( + self.color_augment, self.p + ) + + +@TRANSFORMS.register_module() +class ElasticDistortion(object): + def __init__(self, distortion_params=None): + self.distortion_params = ( + [[0.2, 0.4], [0.8, 1.6]] if distortion_params is None else distortion_params + ) + + @staticmethod + def elastic_distortion(coords, granularity, magnitude): + """ + Apply elastic distortion on sparse coordinate space. + pointcloud: numpy array of (number of points, at least 3 spatial dims) + granularity: size of the noise grid (in same scale[m/cm] as the voxel grid) + magnitude: noise multiplier + """ + blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3 + blury = np.ones((1, 3, 1, 1)).astype("float32") / 3 + blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3 + coords_min = coords.min(0) + + # Create Gaussian noise tensor of the size given by granularity. + noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3 + noise = np.random.randn(*noise_dim, 3).astype(np.float32) + + # Smoothing. + for _ in range(2): + noise = scipy.ndimage.filters.convolve( + noise, blurx, mode="constant", cval=0 + ) + noise = scipy.ndimage.filters.convolve( + noise, blury, mode="constant", cval=0 + ) + noise = scipy.ndimage.filters.convolve( + noise, blurz, mode="constant", cval=0 + ) + + # Trilinear interpolate noise filters for each spatial dimensions. + ax = [ + np.linspace(d_min, d_max, d) + for d_min, d_max, d in zip( + coords_min - granularity, + coords_min + granularity * (noise_dim - 2), + noise_dim, + ) + ] + interp = scipy.interpolate.RegularGridInterpolator( + ax, noise, bounds_error=False, fill_value=0 + ) + coords += interp(coords) * magnitude + return coords + + def __call__(self, data_dict): + if "coord" in data_dict.keys() and self.distortion_params is not None: + if random.random() < 0.95: + for granularity, magnitude in self.distortion_params: + data_dict["coord"] = self.elastic_distortion( + data_dict["coord"], granularity, magnitude + ) + return data_dict + + +@TRANSFORMS.register_module() +class GridSample(object): + def __init__( + self, + grid_size=0.05, + hash_type="fnv", + mode="train", + keys=("coord", "color", "normal", "segment"), + return_inverse=False, + return_grid_coord=False, + return_min_coord=False, + return_displacement=False, + project_displacement=False, + return_idx=False, + ): + self.grid_size = grid_size + self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec + assert mode in ["train", "test"] + self.mode = mode + self.keys = keys + self.return_inverse = return_inverse + self.return_grid_coord = return_grid_coord + self.return_min_coord = return_min_coord + self.return_displacement = return_displacement + self.project_displacement = project_displacement + self.retrun_idx = return_idx + + def __call__(self, data_dict): + assert "coord" in data_dict.keys() + scaled_coord = data_dict["coord"] / np.array(self.grid_size) + grid_coord = np.floor(scaled_coord).astype(int) + min_coord = grid_coord.min(0) + grid_coord -= min_coord + scaled_coord -= min_coord + min_coord = min_coord * np.array(self.grid_size) + key = self.hash(grid_coord) + idx_sort = np.argsort(key) + key_sort = key[idx_sort] + _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True) + if self.mode == "train": # train mode + idx_select = ( + np.cumsum(np.insert(count, 0, 0)[0:-1]) + + np.random.randint(0, count.max(), count.size) % count + ) + idx_unique = idx_sort[idx_select] + if "sampled_index" in data_dict: + # for ScanNet data efficient, we need to make sure labeled point is sampled. + idx_unique = np.unique( + np.append(idx_unique, data_dict["sampled_index"]) + ) + mask = np.zeros_like(data_dict["segment"]).astype(bool) + mask[data_dict["sampled_index"]] = True + data_dict["sampled_index"] = np.where(mask[idx_unique])[0] + if self.return_inverse: + data_dict["inverse"] = np.zeros_like(inverse) + data_dict["inverse"][idx_sort] = inverse + if self.return_grid_coord: + data_dict["grid_coord"] = grid_coord[idx_unique] + if self.return_min_coord: + data_dict["min_coord"] = min_coord.reshape([1, 3]) + if self.return_displacement: + displacement = ( + scaled_coord - grid_coord - 0.5 + ) # [0, 1] -> [-0.5, 0.5] displacement to center + if self.project_displacement: + displacement = np.sum( + displacement * data_dict["normal"], axis=-1, keepdims=True + ) + data_dict["displacement"] = displacement[idx_unique] + for key in self.keys: + data_dict[key] = data_dict[key][idx_unique] + if self.retrun_idx: + data_dict["idx_unique"] = idx_unique + return data_dict + + elif self.mode == "test": # test mode + data_part_list = [] + for i in range(count.max()): + idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count + idx_part = idx_sort[idx_select] + data_part = dict(index=idx_part) + if self.return_inverse: + data_dict["inverse"] = np.zeros_like(inverse) + data_dict["inverse"][idx_sort] = inverse + if self.return_grid_coord: + data_part["grid_coord"] = grid_coord[idx_part] + if self.return_min_coord: + data_part["min_coord"] = min_coord.reshape([1, 3]) + if self.return_displacement: + displacement = ( + scaled_coord - grid_coord - 0.5 + ) # [0, 1] -> [-0.5, 0.5] displacement to center + if self.project_displacement: + displacement = np.sum( + displacement * data_dict["normal"], axis=-1, keepdims=True + ) + data_dict["displacement"] = displacement[idx_part] + for key in data_dict.keys(): + if key in self.keys: + data_part[key] = data_dict[key][idx_part] + else: + data_part[key] = data_dict[key] + data_part_list.append(data_part) + return data_part_list + else: + raise NotImplementedError + + @staticmethod + def ravel_hash_vec(arr): + """ + Ravel the coordinates after subtracting the min coordinates. + """ + assert arr.ndim == 2 + arr = arr.copy() + arr -= arr.min(0) + arr = arr.astype(np.uint64, copy=False) + arr_max = arr.max(0).astype(np.uint64) + 1 + + keys = np.zeros(arr.shape[0], dtype=np.uint64) + # Fortran style indexing + for j in range(arr.shape[1] - 1): + keys += arr[:, j] + keys *= arr_max[j + 1] + keys += arr[:, -1] + return keys + + @staticmethod + def fnv_hash_vec(arr): + """ + FNV64-1A + """ + assert arr.ndim == 2 + # Floor first for negative coordinates + arr = arr.copy() + arr = arr.astype(np.uint64, copy=False) + hashed_arr = np.uint64(14695981039346656037) * np.ones( + arr.shape[0], dtype=np.uint64 + ) + for j in range(arr.shape[1]): + hashed_arr *= np.uint64(1099511628211) + hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j]) + return hashed_arr + + +@TRANSFORMS.register_module() +class SphereCrop(object): + def __init__(self, point_max=80000, sample_rate=None, mode="random"): + self.point_max = point_max + self.sample_rate = sample_rate + assert mode in ["random", "center", "all"] + self.mode = mode + + def __call__(self, data_dict): + point_max = ( + int(self.sample_rate * data_dict["coord"].shape[0]) + if self.sample_rate is not None + else self.point_max + ) + + assert "coord" in data_dict.keys() + if self.mode == "all": + # TODO: Optimize + if "index" not in data_dict.keys(): + data_dict["index"] = np.arange(data_dict["coord"].shape[0]) + data_part_list = [] + # coord_list, color_list, dist2_list, idx_list, offset_list = [], [], [], [], [] + if data_dict["coord"].shape[0] > point_max: + coord_p, idx_uni = np.random.rand( + data_dict["coord"].shape[0] + ) * 1e-3, np.array([]) + while idx_uni.size != data_dict["index"].shape[0]: + init_idx = np.argmin(coord_p) + dist2 = np.sum( + np.power(data_dict["coord"] - data_dict["coord"][init_idx], 2), + 1, + ) + idx_crop = np.argsort(dist2)[:point_max] + + data_crop_dict = dict() + if "coord" in data_dict.keys(): + data_crop_dict["coord"] = data_dict["coord"][idx_crop] + if "grid_coord" in data_dict.keys(): + data_crop_dict["grid_coord"] = data_dict["grid_coord"][idx_crop] + if "normal" in data_dict.keys(): + data_crop_dict["normal"] = data_dict["normal"][idx_crop] + if "color" in data_dict.keys(): + data_crop_dict["color"] = data_dict["color"][idx_crop] + if "displacement" in data_dict.keys(): + data_crop_dict["displacement"] = data_dict["displacement"][ + idx_crop + ] + if "strength" in data_dict.keys(): + data_crop_dict["strength"] = data_dict["strength"][idx_crop] + data_crop_dict["weight"] = dist2[idx_crop] + data_crop_dict["index"] = data_dict["index"][idx_crop] + data_part_list.append(data_crop_dict) + + delta = np.square( + 1 - data_crop_dict["weight"] / np.max(data_crop_dict["weight"]) + ) + coord_p[idx_crop] += delta + idx_uni = np.unique( + np.concatenate((idx_uni, data_crop_dict["index"])) + ) + else: + data_crop_dict = data_dict.copy() + data_crop_dict["weight"] = np.zeros(data_dict["coord"].shape[0]) + data_crop_dict["index"] = data_dict["index"] + data_part_list.append(data_crop_dict) + return data_part_list + # mode is "random" or "center" + elif data_dict["coord"].shape[0] > point_max: + if self.mode == "random": + center = data_dict["coord"][ + np.random.randint(data_dict["coord"].shape[0]) + ] + elif self.mode == "center": + center = data_dict["coord"][data_dict["coord"].shape[0] // 2] + else: + raise NotImplementedError + idx_crop = np.argsort(np.sum(np.square(data_dict["coord"] - center), 1))[ + :point_max + ] + if "coord" in data_dict.keys(): + data_dict["coord"] = data_dict["coord"][idx_crop] + if "origin_coord" in data_dict.keys(): + data_dict["origin_coord"] = data_dict["origin_coord"][idx_crop] + if "grid_coord" in data_dict.keys(): + data_dict["grid_coord"] = data_dict["grid_coord"][idx_crop] + if "color" in data_dict.keys(): + data_dict["color"] = data_dict["color"][idx_crop] + if "normal" in data_dict.keys(): + data_dict["normal"] = data_dict["normal"][idx_crop] + if "segment" in data_dict.keys(): + data_dict["segment"] = data_dict["segment"][idx_crop] + if "instance" in data_dict.keys(): + data_dict["instance"] = data_dict["instance"][idx_crop] + if "displacement" in data_dict.keys(): + data_dict["displacement"] = data_dict["displacement"][idx_crop] + if "strength" in data_dict.keys(): + data_dict["strength"] = data_dict["strength"][idx_crop] + return data_dict + + +@TRANSFORMS.register_module() +class ShufflePoint(object): + def __call__(self, data_dict): + assert "coord" in data_dict.keys() + shuffle_index = np.arange(data_dict["coord"].shape[0]) + np.random.shuffle(shuffle_index) + if "coord" in data_dict.keys(): + data_dict["coord"] = data_dict["coord"][shuffle_index] + if "grid_coord" in data_dict.keys(): + data_dict["grid_coord"] = data_dict["grid_coord"][shuffle_index] + if "displacement" in data_dict.keys(): + data_dict["displacement"] = data_dict["displacement"][shuffle_index] + if "color" in data_dict.keys(): + data_dict["color"] = data_dict["color"][shuffle_index] + if "normal" in data_dict.keys(): + data_dict["normal"] = data_dict["normal"][shuffle_index] + if "segment" in data_dict.keys(): + data_dict["segment"] = data_dict["segment"][shuffle_index] + if "instance" in data_dict.keys(): + data_dict["instance"] = data_dict["instance"][shuffle_index] + return data_dict + + +@TRANSFORMS.register_module() +class CropBoundary(object): + def __call__(self, data_dict): + assert "segment" in data_dict + segment = data_dict["segment"].flatten() + mask = (segment != 0) * (segment != 1) + if "coord" in data_dict.keys(): + data_dict["coord"] = data_dict["coord"][mask] + if "grid_coord" in data_dict.keys(): + data_dict["grid_coord"] = data_dict["grid_coord"][mask] + if "color" in data_dict.keys(): + data_dict["color"] = data_dict["color"][mask] + if "normal" in data_dict.keys(): + data_dict["normal"] = data_dict["normal"][mask] + if "segment" in data_dict.keys(): + data_dict["segment"] = data_dict["segment"][mask] + if "instance" in data_dict.keys(): + data_dict["instance"] = data_dict["instance"][mask] + return data_dict + + +@TRANSFORMS.register_module() +class ContrastiveViewsGenerator(object): + def __init__( + self, + view_keys=("coord", "color", "normal", "origin_coord"), + view_trans_cfg=None, + ): + self.view_keys = view_keys + self.view_trans = Compose(view_trans_cfg) + + def __call__(self, data_dict): + view1_dict = dict() + view2_dict = dict() + for key in self.view_keys: + view1_dict[key] = data_dict[key].copy() + view2_dict[key] = data_dict[key].copy() + view1_dict = self.view_trans(view1_dict) + view2_dict = self.view_trans(view2_dict) + for key, value in view1_dict.items(): + data_dict["view1_" + key] = value + for key, value in view2_dict.items(): + data_dict["view2_" + key] = value + return data_dict + + +@TRANSFORMS.register_module() +class InstanceParser(object): + def __init__(self, segment_ignore_index=(-1, 0, 1), instance_ignore_index=-1): + self.segment_ignore_index = segment_ignore_index + self.instance_ignore_index = instance_ignore_index + + def __call__(self, data_dict): + coord = data_dict["coord"] + segment = data_dict["segment"] + instance = data_dict["instance"] + mask = ~np.in1d(segment, self.segment_ignore_index) + # mapping ignored instance to ignore index + instance[~mask] = self.instance_ignore_index + # reorder left instance + unique, inverse = np.unique(instance[mask], return_inverse=True) + instance_num = len(unique) + instance[mask] = inverse + # init instance information + centroid = np.ones((coord.shape[0], 3)) * self.instance_ignore_index + bbox = np.ones((instance_num, 8)) * self.instance_ignore_index + vacancy = [ + index for index in self.segment_ignore_index if index >= 0 + ] # vacate class index + + for instance_id in range(instance_num): + mask_ = instance == instance_id + coord_ = coord[mask_] + bbox_min = coord_.min(0) + bbox_max = coord_.max(0) + bbox_centroid = coord_.mean(0) + bbox_center = (bbox_max + bbox_min) / 2 + bbox_size = bbox_max - bbox_min + bbox_theta = np.zeros(1, dtype=coord_.dtype) + bbox_class = np.array([segment[mask_][0]], dtype=coord_.dtype) + # shift class index to fill vacate class index caused by segment ignore index + bbox_class -= np.greater(bbox_class, vacancy).sum() + + centroid[mask_] = bbox_centroid + bbox[instance_id] = np.concatenate( + [bbox_center, bbox_size, bbox_theta, bbox_class] + ) # 3 + 3 + 1 + 1 = 8 + data_dict["instance"] = instance + data_dict["instance_centroid"] = centroid + data_dict["bbox"] = bbox + return data_dict + + +class Compose(object): + def __init__(self, cfg=None): + self.cfg = cfg if cfg is not None else [] + self.transforms = [] + for t_cfg in self.cfg: + self.transforms.append(TRANSFORMS.build(t_cfg)) + + def __call__(self, data_dict): + for t in self.transforms: + data_dict = t(data_dict) + return data_dict diff --git a/UniRig/src/model/pointcept/datasets/utils.py b/UniRig/src/model/pointcept/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d24e17bf619306adf14f6746530829502b8e546 --- /dev/null +++ b/UniRig/src/model/pointcept/datasets/utils.py @@ -0,0 +1,58 @@ +import random +from collections.abc import Mapping, Sequence +import numpy as np +import torch +# import trimesh +from torch.utils.data.dataloader import default_collate + + +def collate_fn_dummy(batch): + return batch + + +def collate_fn(batch): + """ + collate function for point cloud which support dict and list, + 'coord' is necessary to determine 'offset' + """ + if not isinstance(batch, Sequence): + raise TypeError(f"{batch.dtype} is not supported.") + + if isinstance(batch[0], torch.Tensor): + return torch.cat(list(batch)) + elif isinstance(batch[0], str): + # str is also a kind of Sequence, judgement should before Sequence + return list(batch) + elif isinstance(batch[0], Sequence): + for data in batch: + data.append(torch.tensor([data[0].shape[0]])) + batch = [collate_fn(samples) for samples in zip(*batch)] + batch[-1] = torch.cumsum(batch[-1], dim=0).int() + return batch + elif isinstance(batch[0], Mapping): + batch = {key: collate_fn([d[key] for d in batch]) for key in batch[0]} + for key in batch.keys(): + if "offset" in key: + batch[key] = torch.cumsum(batch[key], dim=0) + return batch + else: + return default_collate(batch) + + +def point_collate_fn(batch, mix_prob=0): + assert isinstance( + batch[0], Mapping + ) # currently, only support input_dict, rather than input_list + batch = collate_fn(batch) + if "offset" in batch.keys(): + # Mix3d (https://arxiv.org/pdf/2110.02210.pdf) + if random.random() < mix_prob: + batch["offset"] = torch.cat( + [batch["offset"][1:-1:2], batch["offset"][-1].unsqueeze(0)], dim=0 + ) + return batch + + +def gaussian_kernel(dist2: np.array, a: float = 1, c: float = 5): + return a * np.exp(-dist2 / (2 * c**2)) + diff --git a/UniRig/src/model/pointcept/engines/__init__.py b/UniRig/src/model/pointcept/engines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/UniRig/src/model/pointcept/engines/defaults.py b/UniRig/src/model/pointcept/engines/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..832bbcac47851c7ee033b63f3374022dea5e49dc --- /dev/null +++ b/UniRig/src/model/pointcept/engines/defaults.py @@ -0,0 +1,143 @@ +import os +import sys +import argparse +import multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel + + +import pointcept.utils.comm as comm +from pointcept.utils.env import get_random_seed, set_seed +from pointcept.utils.config import Config, DictAction + + +def create_ddp_model(model, *, fp16_compression=False, **kwargs): + """ + Create a DistributedDataParallel model if there are >1 processes. + Args: + model: a torch.nn.Module + fp16_compression: add fp16 compression hooks to the ddp object. + See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook + kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. + """ + if comm.get_world_size() == 1: + return model + # kwargs['find_unused_parameters'] = True + if "device_ids" not in kwargs: + kwargs["device_ids"] = [comm.get_local_rank()] + if "output_device" not in kwargs: + kwargs["output_device"] = [comm.get_local_rank()] + ddp = DistributedDataParallel(model, **kwargs) + if fp16_compression: + from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks + + ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) + return ddp + + +def worker_init_fn(worker_id, num_workers, rank, seed): + """Worker init func for dataloader. + + The seed of each worker equals to num_worker * rank + worker_id + user_seed + + Args: + worker_id (int): Worker id. + num_workers (int): Number of workers. + rank (int): The rank of current process. + seed (int): The random seed to use. + """ + + worker_seed = num_workers * rank + worker_id + seed + set_seed(worker_seed) + + +def default_argument_parser(epilog=None): + parser = argparse.ArgumentParser( + epilog=epilog + or f""" + Examples: + Run on single machine: + $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml + Change some config options: + $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 + Run on multiple machines: + (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags] + (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags] + """, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--config-file", default="", metavar="FILE", help="path to config file" + ) + parser.add_argument( + "--num-gpus", type=int, default=1, help="number of gpus *per machine*" + ) + parser.add_argument( + "--num-machines", type=int, default=1, help="total number of machines" + ) + parser.add_argument( + "--machine-rank", + type=int, + default=0, + help="the rank of this machine (unique per machine)", + ) + # PyTorch still may leave orphan processes in multi-gpu training. + # Therefore we use a deterministic way to obtain port, + # so that users are aware of orphan processes by seeing the port occupied. + # port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 + parser.add_argument( + "--dist-url", + # default="tcp://127.0.0.1:{}".format(port), + default="auto", + help="initialization URL for pytorch distributed backend. See " + "https://pytorch.org/docs/stable/distributed.html for details.", + ) + parser.add_argument( + "--options", nargs="+", action=DictAction, help="custom options" + ) + return parser + + +def default_config_parser(file_path, options): + # config name protocol: dataset_name/model_name-exp_name + if os.path.isfile(file_path): + cfg = Config.fromfile(file_path) + else: + sep = file_path.find("-") + cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :])) + + if options is not None: + cfg.merge_from_dict(options) + + if cfg.seed is None: + cfg.seed = get_random_seed() + + cfg.data.train.loop = cfg.epoch // cfg.eval_epoch + + os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True) + if not cfg.resume: + cfg.dump(os.path.join(cfg.save_path, "config.py")) + return cfg + + +def default_setup(cfg): + # scalar by world size + world_size = comm.get_world_size() + cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count() + cfg.num_worker_per_gpu = cfg.num_worker // world_size + assert cfg.batch_size % world_size == 0 + assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0 + assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0 + cfg.batch_size_per_gpu = cfg.batch_size // world_size + cfg.batch_size_val_per_gpu = ( + cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1 + ) + cfg.batch_size_test_per_gpu = ( + cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1 + ) + # update data loop + assert cfg.epoch % cfg.eval_epoch == 0 + # settle random seed + rank = comm.get_rank() + seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank + set_seed(seed) + return cfg diff --git a/UniRig/src/model/pointcept/engines/eval.py b/UniRig/src/model/pointcept/engines/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ebea23c91df213d90d0113987bd13597f08230 --- /dev/null +++ b/UniRig/src/model/pointcept/engines/eval.py @@ -0,0 +1,407 @@ +import os +import sys +import weakref +import torch +torch.multiprocessing.set_start_method('spawn') +import torch.nn as nn +import torch.utils.data +from functools import partial + +if sys.version_info >= (3, 10): + from collections.abc import Iterator +else: + from collections import Iterator +from tensorboardX import SummaryWriter + +from .defaults import create_ddp_model, worker_init_fn +from .hooks import HookBase, build_hooks +import pointcept.utils.comm as comm +from pointcept.datasets import build_dataset, point_collate_fn, collate_fn +from pointcept.models import build_model +from pointcept.utils.logger import get_root_logger +from pointcept.utils.optimizer import build_optimizer +from pointcept.utils.scheduler import build_scheduler +from pointcept.utils.events import EventStorage +from pointcept.utils.registry import Registry + +from sklearn.preprocessing import QuantileTransformer +from pointcept.utils.timer import Timer + +TRAINERS = Registry("trainers") +from cuml.cluster.hdbscan import HDBSCAN +# from sklearn.cluster import HDBSCAN +import open3d as o3d +import matplotlib.colors as mcolors +import numpy as np +from collections import OrderedDict +import trimesh +import pointops + +class TrainerBase: + def __init__(self) -> None: + self.hooks = [] + self.epoch = 0 + self.start_epoch = 0 + self.max_epoch = 0 + self.max_iter = 0 + self.comm_info = dict() + self.data_iterator: Iterator = enumerate([]) + self.storage: EventStorage + self.writer: SummaryWriter + self._iter_timer = Timer() + + def register_hooks(self, hooks) -> None: + hooks = build_hooks(hooks) + for h in hooks: + assert isinstance(h, HookBase) + # To avoid circular reference, hooks and trainer cannot own each other. + # This normally does not matter, but will cause memory leak if the + # involved objects contain __del__: + # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/ + h.trainer = weakref.proxy(self) + self.hooks.extend(hooks) + + def train(self): + with EventStorage() as self.storage: + # => before train + self.before_train() + for self.epoch in range(self.start_epoch, self.max_epoch): + # => before epoch + self.before_epoch() + # => run_epoch + for ( + self.comm_info["iter"], + self.comm_info["input_dict"], + ) in self.data_iterator: + # => before_step + self.before_step() + # => run_step + self.run_step() + # => after_step + self.after_step() + # => after epoch + self.after_epoch() + # => after train + self.after_train() + + def before_train(self): + for h in self.hooks: + h.before_train() + + def before_epoch(self): + for h in self.hooks: + h.before_epoch() + + def before_step(self): + for h in self.hooks: + h.before_step() + + def run_step(self): + raise NotImplementedError + + def after_step(self): + for h in self.hooks: + h.after_step() + + def after_epoch(self): + for h in self.hooks: + h.after_epoch() + self.storage.reset_histories() + + def after_train(self): + # Sync GPU before running train hooks + comm.synchronize() + for h in self.hooks: + h.after_train() + if comm.is_main_process(): + self.writer.close() + + +@TRAINERS.register_module("DefaultTrainer") +class Trainer(TrainerBase): + def __init__(self, cfg): + super(Trainer, self).__init__() + self.epoch = 0 + self.start_epoch = 0 + self.max_epoch = cfg.eval_epoch + self.best_metric_value = -torch.inf + self.logger = get_root_logger( + log_file=os.path.join(cfg.save_path, "train.log"), + # file_mode="a" if cfg.resume else "w", + file_mode="a", + ) + self.logger.info("=> Loading config ...") + self.cfg = cfg + self.logger.info(f"Save path: {cfg.save_path}") + self.logger.info(f"Config:\n{cfg.pretty_text}") + self.logger.info("=> Building model ...") + self.model = self.build_model() + self.logger.info("=> Building val dataset & dataloader ...") + self.train_loader = self.build_train_loader() + self.logger.info("=> Building hooks ...") + self.register_hooks(self.cfg.hooks) + + # !!! + self.val_scales_list = self.cfg.val_scales_list + self.mesh_voting = self.cfg.mesh_voting + self.backbone_weight_path = self.cfg.backbone_weight_path + + + def eval(self): + # val_data = build_dataset(self.cfg.data.val) + self.logger.info("=> Loading checkpoint & weight ...") + if self.backbone_weight_path != None: + self.logger.info("=> Loading checkpoint of pretrained backbone") + if os.path.isfile(self.backbone_weight_path): + checkpoint = torch.load( + self.backbone_weight_path, + map_location=lambda storage, loc: storage.cuda(), + ) + weight = OrderedDict() + for key, value in checkpoint["state_dict"].items(): + if not key.startswith("module."): + if comm.get_world_size() > 1: + key = "module." + key # xxx.xxx -> module.xxx.xxx + # Now all keys contain "module." no matter DDP or not. + # if self.keywords in key: + # key = key.replace(self.keywords, self.replacement) + if comm.get_world_size() == 1: + key = key[7:] # module.xxx.xxx -> xxx.xxx + # if key.startswith("backbone."): + # key = key[9:] # backbone.xxx.xxx -> xxx.xxx + key = "backbone." + key # xxx.xxx -> backbone.xxx.xxx + weight[key] = value + load_state_info = self.model.load_state_dict(weight, strict=False) + self.logger.info(f"Missing keys: {load_state_info[0]}") + else: + self.logger.info(f"No weight found at: {self.backbone_weight_path}") + + if self.cfg.weight and os.path.isfile(self.cfg.weight): + checkpoint = torch.load( + self.cfg.weight, + map_location=lambda storage, loc: storage.cuda(), + ) + load_state_info = self.model.load_state_dict(checkpoint["state_dict"], strict=False) + self.logger.info(f"Missing keys: {load_state_info[0]}") + scale_statistics = checkpoint["state_dict"]["scale_statistics"] + self.model.quantile_transformer = self._get_quantile_func(scale_statistics) + else: + self.logger.info(f"No weight found at: {self.cfg.weight}") + self.cfg.weight = "last" + + self.model.eval() + save_root = os.path.join(self.cfg.save_path, "vis_pcd", os.path.splitext(os.path.basename(self.cfg.weight))[0]) + os.makedirs(save_root, exist_ok=True) + group_save_root = os.path.join(self.cfg.save_path, "results", os.path.splitext(os.path.basename(self.cfg.weight))[0]) + os.makedirs(group_save_root, exist_ok=True) + + hex_colors = list(mcolors.CSS4_COLORS.values()) + rgb_colors = np.array([mcolors.to_rgb(color) for color in hex_colors if color not in ['#000000', '#FFFFFF']]) + def relative_luminance(color): + return 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2] + rgb_colors = [color for color in rgb_colors if (relative_luminance(color) > 0.4 and relative_luminance(color) < 0.8)] + np.random.shuffle(rgb_colors) + input_dict = self.train_loader.val_data() + + pcd_inverse = self.train_loader.pcd_inverse + if self.mesh_voting: + mesh = trimesh.load(self.train_loader.mesh_path) + if isinstance(mesh, trimesh.Scene): + mesh = mesh.dump(concatenate=True) + mesh.visual = trimesh.visual.ColorVisuals(mesh=mesh) + + for scale in self.val_scales_list: + input_dict["scale"] = scale + instance_feat = self.model(input_dict).cpu().detach().numpy() + + clusterer = HDBSCAN( + cluster_selection_epsilon=0.1, + min_samples=30, + min_cluster_size=30, + allow_single_cluster=False, + ).fit(instance_feat) + + labels = clusterer.labels_ + invalid_label_mask = labels == -1 + if invalid_label_mask.sum() > 0: + if invalid_label_mask.sum() == len(invalid_label_mask): + labels = np.zeros_like(labels) + else: + coord = input_dict["obj"]["coord"].cuda().contiguous().float() + valid_coord = coord[~invalid_label_mask] + valid_offset = torch.tensor(valid_coord.shape[0]).cuda() + invalid_coord = coord[invalid_label_mask] + invalid_offset = torch.tensor(invalid_coord.shape[0]).cuda() + indices, distances = pointops.knn_query(1, valid_coord, valid_offset, invalid_coord, invalid_offset) + indices = indices[:, 0].cpu().numpy() + labels[invalid_label_mask] = labels[~invalid_label_mask][indices] + + + # np.save(os.path.join(group_save_root, f"{str(scale)}.npy"), labels) + save_path = os.path.join(save_root, f"{str(scale)}.ply") + coord = input_dict["obj"]["coord"].cpu().numpy() + random_color = [] + for i in range(max(labels) + 1): + random_color.append(rgb_colors[i % len(rgb_colors)]) + random_color.append(np.array([0, 0, 0])) + color = [random_color[i] for i in labels] + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(coord) + pcd.colors = o3d.utility.Vector3dVector(color) + o3d.io.write_point_cloud(save_path, pcd) + + labels = labels[pcd_inverse] + + # print(len(clusterer.labels_)) + self.logger.info(f"scale_{scale} has {max(labels)+1} groups") + if self.mesh_voting: + face_index = self.train_loader.face_index + face_index = face_index[pcd_inverse] + + # Compute votes for each face using NumPy's bincount function + # labels = clusterer.labels_ + num_faces = len(mesh.faces) + num_labels = max(labels) + 1 + votes = np.zeros((num_faces, num_labels), dtype=np.int32) + np.add.at(votes, (face_index, labels), 1) + + # Find the label with most votes for each face using NumPy's argmax function + max_votes_labels = np.argmax(votes, axis=1) + # Set the label to -1 for faces that have no corresponding points + max_votes_labels[np.all(votes == 0, axis=1)] = -1 + + valid_mask = max_votes_labels != -1 + face_centroids = mesh.triangles_center + coord = torch.tensor(face_centroids).cuda().contiguous().float() + valid_coord = coord[valid_mask] + valid_offset = torch.tensor(valid_coord.shape[0]).cuda() + invalid_coord = coord[~valid_mask] + invalid_offset = torch.tensor(invalid_coord.shape[0]).cuda() + indices, distances = pointops.knn_query(1, valid_coord, valid_offset, invalid_coord, invalid_offset) + # # the first column is the point itself + # indices = indices[:, 1].cpu().numpy() + indices = indices[:, 0].cpu().numpy() + mesh_group = max_votes_labels.copy() + mesh_group[~valid_mask] = mesh_group[valid_mask][indices] + + np.save(os.path.join(group_save_root, f"mesh_{str(scale)}.npy"), mesh_group) + + # Assign color to each face based on the label with most votes + for face, label in enumerate(mesh_group): + color = (random_color[label] * 255).astype(np.uint8) + color_with_alpha = np.append(color, 255) # Add alpha value + mesh.visual.face_colors[face] = color_with_alpha + + # Save the new mesh + mesh_save_path = os.path.join(save_root, f"mesh_{str(scale)}.ply") + mesh.export(mesh_save_path) + + + def _get_quantile_func(self, scales: torch.Tensor, distribution="normal"): + """ + Use 3D scale statistics to normalize scales -- use quantile transformer. + """ + scales = scales.flatten() + max_grouping_scale = 2 + scales = scales[(scales > 0) & (scales < max_grouping_scale)] + + scales = scales.detach().cpu().numpy() + + # Calculate quantile transformer + quantile_transformer = QuantileTransformer(output_distribution=distribution) + quantile_transformer = quantile_transformer.fit(scales.reshape(-1, 1)) + + def quantile_transformer_func(scales): + # This function acts as a wrapper for QuantileTransformer. + # QuantileTransformer expects a numpy array, while we have a torch tensor. + return torch.Tensor( + quantile_transformer.transform(scales.cpu().numpy()) + ).to(scales.device) + + return quantile_transformer_func + + def run_step(self): + input_dict = self.comm_info["input_dict"] + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp): + output_dict = self.model(input_dict) + loss = output_dict["loss"] + self.optimizer.zero_grad() + if self.cfg.enable_amp: + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + + # When enable amp, optimizer.step call are skipped if the loss scaling factor is too large. + # Fix torch warning scheduler step before optimizer step. + scaler = self.scaler.get_scale() + self.scaler.update() + if scaler <= self.scaler.get_scale(): + self.scheduler.step() + else: + loss.backward() + self.optimizer.step() + self.scheduler.step() + if self.cfg.empty_cache: + torch.cuda.empty_cache() + self.comm_info["model_output_dict"] = output_dict + + def build_model(self): + model = build_model(self.cfg.model) + if self.cfg.sync_bn: + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + # logger.info(f"Model: \n{self.model}") + self.logger.info(f"Num params: {n_parameters}") + model = create_ddp_model( + model.cuda(), + broadcast_buffers=False, + find_unused_parameters=self.cfg.find_unused_parameters, + ) + return model + + def build_writer(self): + writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None + self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}") + return writer + + def build_train_loader(self): + self.cfg.data.train.split = "val" + self.cfg.data.train.oid = self.cfg.oid + self.cfg.data.train.label = self.cfg.label + train_data = build_dataset(self.cfg.data.train) + return train_data + + def build_val_loader(self): + val_loader = None + if self.cfg.evaluate: + val_data = build_dataset(self.cfg.data.val) + if comm.get_world_size() > 1: + val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) + else: + val_sampler = None + val_loader = torch.utils.data.DataLoader( + val_data, + batch_size=self.cfg.batch_size_val_per_gpu, + shuffle=False, + num_workers=self.cfg.num_worker_per_gpu, + pin_memory=True, + sampler=val_sampler, + collate_fn=collate_fn, + ) + return val_loader + + def build_optimizer(self): + return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts) + + def build_scheduler(self): + assert hasattr(self, "optimizer") + assert hasattr(self, "train_loader") + # self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch + self.cfg.scheduler.total_steps = self.max_epoch + return build_scheduler(self.cfg.scheduler, self.optimizer) + + def build_scaler(self): + scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None + return scaler diff --git a/UniRig/src/model/pointcept/engines/hooks/__init__.py b/UniRig/src/model/pointcept/engines/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8434169b30ac1563d5f60fa8fd89c61979c20568 --- /dev/null +++ b/UniRig/src/model/pointcept/engines/hooks/__init__.py @@ -0,0 +1,6 @@ +from .default import HookBase +from .misc import * +from .evaluator import * +# from .partseg import * + +from .builder import build_hooks diff --git a/UniRig/src/model/pointcept/engines/hooks/builder.py b/UniRig/src/model/pointcept/engines/hooks/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..834258b8f03c5451569a11cd3eaf299f9234fae8 --- /dev/null +++ b/UniRig/src/model/pointcept/engines/hooks/builder.py @@ -0,0 +1,11 @@ +from pointcept.utils.registry import Registry + + +HOOKS = Registry("hooks") + + +def build_hooks(cfg): + hooks = [] + for hook_cfg in cfg: + hooks.append(HOOKS.build(hook_cfg)) + return hooks diff --git a/UniRig/src/model/pointcept/engines/hooks/default.py b/UniRig/src/model/pointcept/engines/hooks/default.py new file mode 100644 index 0000000000000000000000000000000000000000..0ccd1f0fd3bd34215274c65b92e6a0755166d313 --- /dev/null +++ b/UniRig/src/model/pointcept/engines/hooks/default.py @@ -0,0 +1,24 @@ +class HookBase: + """ + Base class for hooks that can be registered with :class:`TrainerBase`. + """ + + trainer = None # A weak reference to the trainer object. + + def before_train(self): + pass + + def before_epoch(self): + pass + + def before_step(self): + pass + + def after_step(self): + pass + + def after_epoch(self): + pass + + def after_train(self): + pass diff --git a/UniRig/src/model/pointcept/engines/hooks/evaluator.py b/UniRig/src/model/pointcept/engines/hooks/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b06ef6c553ae291e1a87c4a31ed56859a1a83c --- /dev/null +++ b/UniRig/src/model/pointcept/engines/hooks/evaluator.py @@ -0,0 +1,574 @@ +import numpy as np +import torch +import torch.distributed as dist +import pointops +from uuid import uuid4 + +import pointcept.utils.comm as comm +from pointcept.utils.misc import intersection_and_union_gpu + +from .default import HookBase +from .builder import HOOKS + + +@HOOKS.register_module() +class ClsEvaluator(HookBase): + def after_epoch(self): + if self.trainer.cfg.evaluate: + self.eval() + + def eval(self): + self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") + self.trainer.model.eval() + for i, input_dict in enumerate(self.trainer.val_loader): + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.no_grad(): + output_dict = self.trainer.model(input_dict) + output = output_dict["cls_logits"] + loss = output_dict["loss"] + pred = output.max(1)[1] + label = input_dict["category"] + intersection, union, target = intersection_and_union_gpu( + pred, + label, + self.trainer.cfg.data.num_classes, + self.trainer.cfg.data.ignore_index, + ) + if comm.get_world_size() > 1: + dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce( + target + ) + intersection, union, target = ( + intersection.cpu().numpy(), + union.cpu().numpy(), + target.cpu().numpy(), + ) + # Here there is no need to sync since sync happened in dist.all_reduce + self.trainer.storage.put_scalar("val_intersection", intersection) + self.trainer.storage.put_scalar("val_union", union) + self.trainer.storage.put_scalar("val_target", target) + self.trainer.storage.put_scalar("val_loss", loss.item()) + self.trainer.logger.info( + "Test: [{iter}/{max_iter}] " + "Loss {loss:.4f} ".format( + iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item() + ) + ) + loss_avg = self.trainer.storage.history("val_loss").avg + intersection = self.trainer.storage.history("val_intersection").total + union = self.trainer.storage.history("val_union").total + target = self.trainer.storage.history("val_target").total + iou_class = intersection / (union + 1e-10) + acc_class = intersection / (target + 1e-10) + m_iou = np.mean(iou_class) + m_acc = np.mean(acc_class) + all_acc = sum(intersection) / (sum(target) + 1e-10) + self.trainer.logger.info( + "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format( + m_iou, m_acc, all_acc + ) + ) + for i in range(self.trainer.cfg.data.num_classes): + self.trainer.logger.info( + "Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( + idx=i, + name=self.trainer.cfg.data.names[i], + iou=iou_class[i], + accuracy=acc_class[i], + ) + ) + current_epoch = self.trainer.epoch + 1 + if self.trainer.writer is not None: + self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch) + self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch) + self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch) + self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch) + self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") + self.trainer.comm_info["current_metric_value"] = all_acc # save for saver + self.trainer.comm_info["current_metric_name"] = "allAcc" # save for saver + + def after_train(self): + self.trainer.logger.info( + "Best {}: {:.4f}".format("allAcc", self.trainer.best_metric_value) + ) + + +@HOOKS.register_module() +class SemSegEvaluator(HookBase): + def after_epoch(self): + if self.trainer.cfg.evaluate: + self.eval() + + def eval(self): + self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") + self.trainer.model.eval() + for i, input_dict in enumerate(self.trainer.val_loader): + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.no_grad(): + output_dict = self.trainer.model(input_dict) + output = output_dict["seg_logits"] + loss = output_dict["loss"] + pred = output.max(1)[1] + segment = input_dict["segment"] + if "origin_coord" in input_dict.keys(): + idx, _ = pointops.knn_query( + 1, + input_dict["coord"].float(), + input_dict["offset"].int(), + input_dict["origin_coord"].float(), + input_dict["origin_offset"].int(), + ) + pred = pred[idx.flatten().long()] + segment = input_dict["origin_segment"] + intersection, union, target = intersection_and_union_gpu( + pred, + segment, + self.trainer.cfg.data.num_classes, + self.trainer.cfg.data.ignore_index, + ) + if comm.get_world_size() > 1: + dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce( + target + ) + intersection, union, target = ( + intersection.cpu().numpy(), + union.cpu().numpy(), + target.cpu().numpy(), + ) + # Here there is no need to sync since sync happened in dist.all_reduce + self.trainer.storage.put_scalar("val_intersection", intersection) + self.trainer.storage.put_scalar("val_union", union) + self.trainer.storage.put_scalar("val_target", target) + self.trainer.storage.put_scalar("val_loss", loss.item()) + info = "Test: [{iter}/{max_iter}] ".format( + iter=i + 1, max_iter=len(self.trainer.val_loader) + ) + if "origin_coord" in input_dict.keys(): + info = "Interp. " + info + self.trainer.logger.info( + info + + "Loss {loss:.4f} ".format( + iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item() + ) + ) + loss_avg = self.trainer.storage.history("val_loss").avg + intersection = self.trainer.storage.history("val_intersection").total + union = self.trainer.storage.history("val_union").total + target = self.trainer.storage.history("val_target").total + iou_class = intersection / (union + 1e-10) + acc_class = intersection / (target + 1e-10) + m_iou = np.mean(iou_class) + m_acc = np.mean(acc_class) + all_acc = sum(intersection) / (sum(target) + 1e-10) + self.trainer.logger.info( + "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format( + m_iou, m_acc, all_acc + ) + ) + for i in range(self.trainer.cfg.data.num_classes): + self.trainer.logger.info( + "Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( + idx=i, + name=self.trainer.cfg.data.names[i], + iou=iou_class[i], + accuracy=acc_class[i], + ) + ) + current_epoch = self.trainer.epoch + 1 + if self.trainer.writer is not None: + self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch) + self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch) + self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch) + self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch) + self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") + self.trainer.comm_info["current_metric_value"] = m_iou # save for saver + self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver + + def after_train(self): + self.trainer.logger.info( + "Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value) + ) + + +@HOOKS.register_module() +class InsSegEvaluator(HookBase): + def __init__(self, segment_ignore_index=(-1,), instance_ignore_index=-1): + self.segment_ignore_index = segment_ignore_index + self.instance_ignore_index = instance_ignore_index + + self.valid_class_names = None # update in before train + self.overlaps = np.append(np.arange(0.5, 0.95, 0.05), 0.25) + self.min_region_sizes = 100 + self.distance_threshes = float("inf") + self.distance_confs = -float("inf") + + def before_train(self): + self.valid_class_names = [ + self.trainer.cfg.data.names[i] + for i in range(self.trainer.cfg.data.num_classes) + if i not in self.segment_ignore_index + ] + + def after_epoch(self): + if self.trainer.cfg.evaluate: + self.eval() + + def associate_instances(self, pred, segment, instance): + segment = segment.cpu().numpy() + instance = instance.cpu().numpy() + void_mask = np.in1d(segment, self.segment_ignore_index) + + assert ( + pred["pred_classes"].shape[0] + == pred["pred_scores"].shape[0] + == pred["pred_masks"].shape[0] + ) + assert pred["pred_masks"].shape[1] == segment.shape[0] == instance.shape[0] + # get gt instances + gt_instances = dict() + for i in range(self.trainer.cfg.data.num_classes): + if i not in self.segment_ignore_index: + gt_instances[self.trainer.cfg.data.names[i]] = [] + instance_ids, idx, counts = np.unique( + instance, return_index=True, return_counts=True + ) + segment_ids = segment[idx] + for i in range(len(instance_ids)): + if instance_ids[i] == self.instance_ignore_index: + continue + if segment_ids[i] in self.segment_ignore_index: + continue + gt_inst = dict() + gt_inst["instance_id"] = instance_ids[i] + gt_inst["segment_id"] = segment_ids[i] + gt_inst["dist_conf"] = 0.0 + gt_inst["med_dist"] = -1.0 + gt_inst["vert_count"] = counts[i] + gt_inst["matched_pred"] = [] + gt_instances[self.trainer.cfg.data.names[segment_ids[i]]].append(gt_inst) + + # get pred instances and associate with gt + pred_instances = dict() + for i in range(self.trainer.cfg.data.num_classes): + if i not in self.segment_ignore_index: + pred_instances[self.trainer.cfg.data.names[i]] = [] + instance_id = 0 + for i in range(len(pred["pred_classes"])): + if pred["pred_classes"][i] in self.segment_ignore_index: + continue + pred_inst = dict() + pred_inst["uuid"] = uuid4() + pred_inst["instance_id"] = instance_id + pred_inst["segment_id"] = pred["pred_classes"][i] + pred_inst["confidence"] = pred["pred_scores"][i] + pred_inst["mask"] = np.not_equal(pred["pred_masks"][i], 0) + pred_inst["vert_count"] = np.count_nonzero(pred_inst["mask"]) + pred_inst["void_intersection"] = np.count_nonzero( + np.logical_and(void_mask, pred_inst["mask"]) + ) + if pred_inst["vert_count"] < self.min_region_sizes: + continue # skip if empty + segment_name = self.trainer.cfg.data.names[pred_inst["segment_id"]] + matched_gt = [] + for gt_idx, gt_inst in enumerate(gt_instances[segment_name]): + intersection = np.count_nonzero( + np.logical_and( + instance == gt_inst["instance_id"], pred_inst["mask"] + ) + ) + if intersection > 0: + gt_inst_ = gt_inst.copy() + pred_inst_ = pred_inst.copy() + gt_inst_["intersection"] = intersection + pred_inst_["intersection"] = intersection + matched_gt.append(gt_inst_) + gt_inst["matched_pred"].append(pred_inst_) + pred_inst["matched_gt"] = matched_gt + pred_instances[segment_name].append(pred_inst) + instance_id += 1 + return gt_instances, pred_instances + + def evaluate_matches(self, scenes): + overlaps = self.overlaps + min_region_sizes = [self.min_region_sizes] + dist_threshes = [self.distance_threshes] + dist_confs = [self.distance_confs] + + # results: class x overlap + ap_table = np.zeros( + (len(dist_threshes), len(self.valid_class_names), len(overlaps)), float + ) + for di, (min_region_size, distance_thresh, distance_conf) in enumerate( + zip(min_region_sizes, dist_threshes, dist_confs) + ): + for oi, overlap_th in enumerate(overlaps): + pred_visited = {} + for scene in scenes: + for _ in scene["pred"]: + for label_name in self.valid_class_names: + for p in scene["pred"][label_name]: + if "uuid" in p: + pred_visited[p["uuid"]] = False + for li, label_name in enumerate(self.valid_class_names): + y_true = np.empty(0) + y_score = np.empty(0) + hard_false_negatives = 0 + has_gt = False + has_pred = False + for scene in scenes: + pred_instances = scene["pred"][label_name] + gt_instances = scene["gt"][label_name] + # filter groups in ground truth + gt_instances = [ + gt + for gt in gt_instances + if gt["vert_count"] >= min_region_size + and gt["med_dist"] <= distance_thresh + and gt["dist_conf"] >= distance_conf + ] + if gt_instances: + has_gt = True + if pred_instances: + has_pred = True + + cur_true = np.ones(len(gt_instances)) + cur_score = np.ones(len(gt_instances)) * (-float("inf")) + cur_match = np.zeros(len(gt_instances), dtype=bool) + # collect matches + for gti, gt in enumerate(gt_instances): + found_match = False + for pred in gt["matched_pred"]: + # greedy assignments + if pred_visited[pred["uuid"]]: + continue + overlap = float(pred["intersection"]) / ( + gt["vert_count"] + + pred["vert_count"] + - pred["intersection"] + ) + if overlap > overlap_th: + confidence = pred["confidence"] + # if already have a prediction for this gt, + # the prediction with the lower score is automatically a false positive + if cur_match[gti]: + max_score = max(cur_score[gti], confidence) + min_score = min(cur_score[gti], confidence) + cur_score[gti] = max_score + # append false positive + cur_true = np.append(cur_true, 0) + cur_score = np.append(cur_score, min_score) + cur_match = np.append(cur_match, True) + # otherwise set score + else: + found_match = True + cur_match[gti] = True + cur_score[gti] = confidence + pred_visited[pred["uuid"]] = True + if not found_match: + hard_false_negatives += 1 + # remove non-matched ground truth instances + cur_true = cur_true[cur_match] + cur_score = cur_score[cur_match] + + # collect non-matched predictions as false positive + for pred in pred_instances: + found_gt = False + for gt in pred["matched_gt"]: + overlap = float(gt["intersection"]) / ( + gt["vert_count"] + + pred["vert_count"] + - gt["intersection"] + ) + if overlap > overlap_th: + found_gt = True + break + if not found_gt: + num_ignore = pred["void_intersection"] + for gt in pred["matched_gt"]: + if gt["segment_id"] in self.segment_ignore_index: + num_ignore += gt["intersection"] + # small ground truth instances + if ( + gt["vert_count"] < min_region_size + or gt["med_dist"] > distance_thresh + or gt["dist_conf"] < distance_conf + ): + num_ignore += gt["intersection"] + proportion_ignore = ( + float(num_ignore) / pred["vert_count"] + ) + # if not ignored append false positive + if proportion_ignore <= overlap_th: + cur_true = np.append(cur_true, 0) + confidence = pred["confidence"] + cur_score = np.append(cur_score, confidence) + + # append to overall results + y_true = np.append(y_true, cur_true) + y_score = np.append(y_score, cur_score) + + # compute average precision + if has_gt and has_pred: + # compute precision recall curve first + + # sorting and cumsum + score_arg_sort = np.argsort(y_score) + y_score_sorted = y_score[score_arg_sort] + y_true_sorted = y_true[score_arg_sort] + y_true_sorted_cumsum = np.cumsum(y_true_sorted) + + # unique thresholds + (thresholds, unique_indices) = np.unique( + y_score_sorted, return_index=True + ) + num_prec_recall = len(unique_indices) + 1 + + # prepare precision recall + num_examples = len(y_score_sorted) + # https://github.com/ScanNet/ScanNet/pull/26 + # all predictions are non-matched but also all of them are ignored and not counted as FP + # y_true_sorted_cumsum is empty + # num_true_examples = y_true_sorted_cumsum[-1] + num_true_examples = ( + y_true_sorted_cumsum[-1] + if len(y_true_sorted_cumsum) > 0 + else 0 + ) + precision = np.zeros(num_prec_recall) + recall = np.zeros(num_prec_recall) + + # deal with the first point + y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0) + # deal with remaining + for idx_res, idx_scores in enumerate(unique_indices): + cumsum = y_true_sorted_cumsum[idx_scores - 1] + tp = num_true_examples - cumsum + fp = num_examples - idx_scores - tp + fn = cumsum + hard_false_negatives + p = float(tp) / (tp + fp) + r = float(tp) / (tp + fn) + precision[idx_res] = p + recall[idx_res] = r + + # first point in curve is artificial + precision[-1] = 1.0 + recall[-1] = 0.0 + + # compute average of precision-recall curve + recall_for_conv = np.copy(recall) + recall_for_conv = np.append(recall_for_conv[0], recall_for_conv) + recall_for_conv = np.append(recall_for_conv, 0.0) + + stepWidths = np.convolve( + recall_for_conv, [-0.5, 0, 0.5], "valid" + ) + # integrate is now simply a dot product + ap_current = np.dot(precision, stepWidths) + + elif has_gt: + ap_current = 0.0 + else: + ap_current = float("nan") + ap_table[di, li, oi] = ap_current + d_inf = 0 + o50 = np.where(np.isclose(self.overlaps, 0.5)) + o25 = np.where(np.isclose(self.overlaps, 0.25)) + oAllBut25 = np.where(np.logical_not(np.isclose(self.overlaps, 0.25))) + ap_scores = dict() + ap_scores["all_ap"] = np.nanmean(ap_table[d_inf, :, oAllBut25]) + ap_scores["all_ap_50%"] = np.nanmean(ap_table[d_inf, :, o50]) + ap_scores["all_ap_25%"] = np.nanmean(ap_table[d_inf, :, o25]) + ap_scores["classes"] = {} + for li, label_name in enumerate(self.valid_class_names): + ap_scores["classes"][label_name] = {} + ap_scores["classes"][label_name]["ap"] = np.average( + ap_table[d_inf, li, oAllBut25] + ) + ap_scores["classes"][label_name]["ap50%"] = np.average( + ap_table[d_inf, li, o50] + ) + ap_scores["classes"][label_name]["ap25%"] = np.average( + ap_table[d_inf, li, o25] + ) + return ap_scores + + def eval(self): + self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") + self.trainer.model.eval() + scenes = [] + for i, input_dict in enumerate(self.trainer.val_loader): + assert ( + len(input_dict["offset"]) == 1 + ) # currently only support bs 1 for each GPU + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.no_grad(): + output_dict = self.trainer.model(input_dict) + + loss = output_dict["loss"] + + segment = input_dict["segment"] + instance = input_dict["instance"] + # map to origin + if "origin_coord" in input_dict.keys(): + idx, _ = pointops.knn_query( + 1, + input_dict["coord"].float(), + input_dict["offset"].int(), + input_dict["origin_coord"].float(), + input_dict["origin_offset"].int(), + ) + idx = idx.cpu().flatten().long() + output_dict["pred_masks"] = output_dict["pred_masks"][:, idx] + segment = input_dict["origin_segment"] + instance = input_dict["origin_instance"] + + gt_instances, pred_instance = self.associate_instances( + output_dict, segment, instance + ) + scenes.append(dict(gt=gt_instances, pred=pred_instance)) + + self.trainer.storage.put_scalar("val_loss", loss.item()) + self.trainer.logger.info( + "Test: [{iter}/{max_iter}] " + "Loss {loss:.4f} ".format( + iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item() + ) + ) + + loss_avg = self.trainer.storage.history("val_loss").avg + comm.synchronize() + scenes_sync = comm.gather(scenes, dst=0) + scenes = [scene for scenes_ in scenes_sync for scene in scenes_] + ap_scores = self.evaluate_matches(scenes) + all_ap = ap_scores["all_ap"] + all_ap_50 = ap_scores["all_ap_50%"] + all_ap_25 = ap_scores["all_ap_25%"] + self.trainer.logger.info( + "Val result: mAP/AP50/AP25 {:.4f}/{:.4f}/{:.4f}.".format( + all_ap, all_ap_50, all_ap_25 + ) + ) + for i, label_name in enumerate(self.valid_class_names): + ap = ap_scores["classes"][label_name]["ap"] + ap_50 = ap_scores["classes"][label_name]["ap50%"] + ap_25 = ap_scores["classes"][label_name]["ap25%"] + self.trainer.logger.info( + "Class_{idx}-{name} Result: AP/AP50/AP25 {AP:.4f}/{AP50:.4f}/{AP25:.4f}".format( + idx=i, name=label_name, AP=ap, AP50=ap_50, AP25=ap_25 + ) + ) + current_epoch = self.trainer.epoch + 1 + if self.trainer.writer is not None: + self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch) + self.trainer.writer.add_scalar("val/mAP", all_ap, current_epoch) + self.trainer.writer.add_scalar("val/AP50", all_ap_50, current_epoch) + self.trainer.writer.add_scalar("val/AP25", all_ap_25, current_epoch) + self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") + self.trainer.comm_info["current_metric_value"] = all_ap_50 # save for saver + self.trainer.comm_info["current_metric_name"] = "AP50" # save for saver diff --git a/UniRig/src/model/pointcept/engines/hooks/misc.py b/UniRig/src/model/pointcept/engines/hooks/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..995e3646bc24abbfd2af40678f239306d5da6181 --- /dev/null +++ b/UniRig/src/model/pointcept/engines/hooks/misc.py @@ -0,0 +1,429 @@ +import sys +import glob +import os +import shutil +import time +import torch +import torch.utils.data +from collections import OrderedDict + +if sys.version_info >= (3, 10): + from collections.abc import Sequence +else: + from collections import Sequence +from pointcept.utils.timer import Timer +from pointcept.utils.comm import is_main_process, synchronize, get_world_size +from pointcept.utils.cache import shared_dict + +import pointcept.utils.comm as comm +# from pointcept.engines.test import TESTERS + +from .default import HookBase +from .builder import HOOKS + + +@HOOKS.register_module() +class IterationTimer(HookBase): + def __init__(self, warmup_iter=1): + self._warmup_iter = warmup_iter + self._start_time = time.perf_counter() + self._iter_timer = Timer() + self._remain_iter = 0 + + def before_train(self): + self._start_time = time.perf_counter() + self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader) + + def before_epoch(self): + self._iter_timer.reset() + + def before_step(self): + data_time = self._iter_timer.seconds() + self.trainer.storage.put_scalar("data_time", data_time) + + def after_step(self): + batch_time = self._iter_timer.seconds() + self._iter_timer.reset() + self.trainer.storage.put_scalar("batch_time", batch_time) + self._remain_iter -= 1 + remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg + t_m, t_s = divmod(remain_time, 60) + t_h, t_m = divmod(t_m, 60) + remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s)) + if "iter_info" in self.trainer.comm_info.keys(): + info = ( + "Data {data_time_val:.3f} ({data_time_avg:.3f}) " + "Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) " + "Remain {remain_time} ".format( + data_time_val=self.trainer.storage.history("data_time").val, + data_time_avg=self.trainer.storage.history("data_time").avg, + batch_time_val=self.trainer.storage.history("batch_time").val, + batch_time_avg=self.trainer.storage.history("batch_time").avg, + remain_time=remain_time, + ) + ) + self.trainer.comm_info["iter_info"] += info + if self.trainer.comm_info["iter"] <= self._warmup_iter: + self.trainer.storage.history("data_time").reset() + self.trainer.storage.history("batch_time").reset() + + +@HOOKS.register_module() +class InformationWriter(HookBase): + def __init__(self): + self.curr_iter = 0 + self.model_output_keys = [] + + def before_train(self): + self.trainer.comm_info["iter_info"] = "" + self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader) + + def before_step(self): + self.curr_iter += 1 + # MSC pretrain do not have offset information. Comment the code for support MSC + # info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \ + # "Scan {batch_size} ({points_num}) ".format( + # epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch, + # iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader), + # batch_size=len(self.trainer.comm_info["input_dict"]["offset"]), + # points_num=self.trainer.comm_info["input_dict"]["offset"][-1] + # ) + info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format( + epoch=self.trainer.epoch + 1, + max_epoch=self.trainer.max_epoch, + iter=self.trainer.comm_info["iter"] + 1, + max_iter=len(self.trainer.train_loader), + ) + self.trainer.comm_info["iter_info"] += info + + def after_step(self): + if "model_output_dict" in self.trainer.comm_info.keys(): + model_output_dict = self.trainer.comm_info["model_output_dict"] + self.model_output_keys = model_output_dict.keys() + for key in self.model_output_keys: + self.trainer.storage.put_scalar(key, model_output_dict[key].item()) + + for key in self.model_output_keys: + self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format( + key=key, value=self.trainer.storage.history(key).val + ) + lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"] + self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr) + self.trainer.logger.info(self.trainer.comm_info["iter_info"]) + self.trainer.comm_info["iter_info"] = "" # reset iter info + if self.trainer.writer is not None: + self.trainer.writer.add_scalar("lr", lr, self.curr_iter) + for key in self.model_output_keys: + self.trainer.writer.add_scalar( + "train_batch/" + key, + self.trainer.storage.history(key).val, + self.curr_iter, + ) + + def after_epoch(self): + epoch_info = "Train result: " + for key in self.model_output_keys: + epoch_info += "{key}: {value:.4f} ".format( + key=key, value=self.trainer.storage.history(key).avg + ) + self.trainer.logger.info(epoch_info) + if self.trainer.writer is not None: + for key in self.model_output_keys: + self.trainer.writer.add_scalar( + "train/" + key, + self.trainer.storage.history(key).avg, + self.trainer.epoch + 1, + ) + + +@HOOKS.register_module() +class CheckpointSaver(HookBase): + def __init__(self, save_freq=None): + self.save_freq = save_freq # None or int, None indicate only save model last + + def after_epoch(self): + if is_main_process(): + is_best = False + if self.trainer.cfg.evaluate: + current_metric_value = self.trainer.comm_info["current_metric_value"] + current_metric_name = self.trainer.comm_info["current_metric_name"] + if current_metric_value > self.trainer.best_metric_value: + self.trainer.best_metric_value = current_metric_value + is_best = True + self.trainer.logger.info( + "Best validation {} updated to: {:.4f}".format( + current_metric_name, current_metric_value + ) + ) + self.trainer.logger.info( + "Currently Best {}: {:.4f}".format( + current_metric_name, self.trainer.best_metric_value + ) + ) + + filename = os.path.join( + self.trainer.cfg.save_path, "model", "model_last.pth" + ) + self.trainer.logger.info("Saving checkpoint to: " + filename) + torch.save( + { + "epoch": self.trainer.epoch + 1, + "state_dict": self.trainer.model.state_dict(), + "optimizer": self.trainer.optimizer.state_dict(), + "scheduler": self.trainer.scheduler.state_dict(), + "scaler": self.trainer.scaler.state_dict() + if self.trainer.cfg.enable_amp + else None, + "best_metric_value": self.trainer.best_metric_value, + }, + filename + ".tmp", + ) + os.replace(filename + ".tmp", filename) + if is_best: + shutil.copyfile( + filename, + os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"), + ) + if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0: + shutil.copyfile( + filename, + os.path.join( + self.trainer.cfg.save_path, + "model", + f"epoch_{self.trainer.epoch + 1}.pth", + ), + ) + + +@HOOKS.register_module() +class CheckpointLoader(HookBase): + def __init__(self, keywords="", replacement=None, strict=False): + self.keywords = keywords + self.replacement = replacement if replacement is not None else keywords + self.strict = strict + + def before_train(self): + self.trainer.logger.info("=> Loading checkpoint & weight ...") + if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight): + self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}") + checkpoint = torch.load( + self.trainer.cfg.weight, + map_location=lambda storage, loc: storage.cuda(), + ) + self.trainer.logger.info( + f"Loading layer weights with keyword: {self.keywords}, " + f"replace keyword with: {self.replacement}" + ) + weight = OrderedDict() + for key, value in checkpoint["state_dict"].items(): + if not key.startswith("module."): + if comm.get_world_size() > 1: + key = "module." + key # xxx.xxx -> module.xxx.xxx + # Now all keys contain "module." no matter DDP or not. + if self.keywords in key: + key = key.replace(self.keywords, self.replacement) + if comm.get_world_size() == 1: + key = key[7:] # module.xxx.xxx -> xxx.xxx + weight[key] = value + load_state_info = self.trainer.model.load_state_dict( + weight, strict=self.strict + ) + self.trainer.logger.info(f"Missing keys: {load_state_info[0]}") + if self.trainer.cfg.resume: + self.trainer.logger.info( + f"Resuming train at eval epoch: {checkpoint['epoch']}" + ) + self.trainer.start_epoch = checkpoint["epoch"] + self.trainer.best_metric_value = checkpoint["best_metric_value"] + self.trainer.optimizer.load_state_dict(checkpoint["optimizer"]) + self.trainer.scheduler.load_state_dict(checkpoint["scheduler"]) + if self.trainer.cfg.enable_amp: + self.trainer.scaler.load_state_dict(checkpoint["scaler"]) + else: + self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}") + + +@HOOKS.register_module() +class DataCacheOperator(HookBase): + def __init__(self, data_root, split): + self.data_root = data_root + self.split = split + self.data_list = self.get_data_list() + + def get_data_list(self): + if isinstance(self.split, str): + data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth")) + elif isinstance(self.split, Sequence): + data_list = [] + for split in self.split: + data_list += glob.glob(os.path.join(self.data_root, split, "*.pth")) + else: + raise NotImplementedError + return data_list + + def get_cache_name(self, data_path): + data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0] + return "pointcept" + data_name.replace(os.path.sep, "-") + + def before_train(self): + self.trainer.logger.info( + f"=> Caching dataset: {self.data_root}, split: {self.split} ..." + ) + if is_main_process(): + for data_path in self.data_list: + cache_name = self.get_cache_name(data_path) + data = torch.load(data_path) + shared_dict(cache_name, data) + synchronize() + + +@HOOKS.register_module() +class RuntimeProfiler(HookBase): + def __init__( + self, + forward=True, + backward=True, + interrupt=False, + warm_up=2, + sort_by="cuda_time_total", + row_limit=30, + ): + self.forward = forward + self.backward = backward + self.interrupt = interrupt + self.warm_up = warm_up + self.sort_by = sort_by + self.row_limit = row_limit + + def before_train(self): + self.trainer.logger.info("Profiling runtime ...") + from torch.profiler import profile, record_function, ProfilerActivity + + for i, input_dict in enumerate(self.trainer.train_loader): + if i == self.warm_up + 1: + break + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + if self.forward: + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as forward_prof: + with record_function("model_inference"): + output_dict = self.trainer.model(input_dict) + else: + output_dict = self.trainer.model(input_dict) + loss = output_dict["loss"] + if self.backward: + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as backward_prof: + with record_function("model_inference"): + loss.backward() + self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]") + if self.forward: + self.trainer.logger.info( + "Forward profile: \n" + + str( + forward_prof.key_averages().table( + sort_by=self.sort_by, row_limit=self.row_limit + ) + ) + ) + forward_prof.export_chrome_trace( + os.path.join(self.trainer.cfg.save_path, "forward_trace.json") + ) + + if self.backward: + self.trainer.logger.info( + "Backward profile: \n" + + str( + backward_prof.key_averages().table( + sort_by=self.sort_by, row_limit=self.row_limit + ) + ) + ) + backward_prof.export_chrome_trace( + os.path.join(self.trainer.cfg.save_path, "backward_trace.json") + ) + if self.interrupt: + sys.exit(0) + + +@HOOKS.register_module() +class RuntimeProfilerV2(HookBase): + def __init__( + self, + interrupt=False, + wait=1, + warmup=1, + active=10, + repeat=1, + sort_by="cuda_time_total", + row_limit=30, + ): + self.interrupt = interrupt + self.wait = wait + self.warmup = warmup + self.active = active + self.repeat = repeat + self.sort_by = sort_by + self.row_limit = row_limit + + def before_train(self): + self.trainer.logger.info("Profiling runtime ...") + from torch.profiler import ( + profile, + record_function, + ProfilerActivity, + schedule, + tensorboard_trace_handler, + ) + + prof = profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule( + wait=self.wait, + warmup=self.warmup, + active=self.active, + repeat=self.repeat, + ), + on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + prof.start() + for i, input_dict in enumerate(self.trainer.train_loader): + if i >= (self.wait + self.warmup + self.active) * self.repeat: + break + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with record_function("model_forward"): + output_dict = self.trainer.model(input_dict) + loss = output_dict["loss"] + with record_function("model_backward"): + loss.backward() + prof.step() + self.trainer.logger.info( + f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]" + ) + self.trainer.logger.info( + "Profile: \n" + + str( + prof.key_averages().table( + sort_by=self.sort_by, row_limit=self.row_limit + ) + ) + ) + prof.stop() + + if self.interrupt: + sys.exit(0) diff --git a/UniRig/src/model/pointcept/engines/launch.py b/UniRig/src/model/pointcept/engines/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..e2592dabbbc1a398e03ab73e3ace2de9f51fb3a9 --- /dev/null +++ b/UniRig/src/model/pointcept/engines/launch.py @@ -0,0 +1,128 @@ +import os +import logging +from datetime import timedelta +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from pointcept.utils import comm + +__all__ = ["DEFAULT_TIMEOUT", "launch"] + +DEFAULT_TIMEOUT = timedelta(minutes=60) + + +def _find_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def launch( + main_func, + num_gpus_per_machine, + num_machines=1, + machine_rank=0, + dist_url=None, + cfg=(), + timeout=DEFAULT_TIMEOUT, +): + """ + Launch multi-gpu or distributed training. + This function must be called on all machines involved in the training. + It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine. + Args: + main_func: a function that will be called by `main_func(*args)` + num_gpus_per_machine (int): number of GPUs per machine + num_machines (int): the total number of machines + machine_rank (int): the rank of this machine + dist_url (str): url to connect to for distributed jobs, including protocol + e.g. "tcp://127.0.0.1:8686". + Can be set to "auto" to automatically select a free port on localhost + timeout (timedelta): timeout of the distributed workers + args (tuple): arguments passed to main_func + """ + world_size = num_machines * num_gpus_per_machine + if world_size > 1: + if dist_url == "auto": + assert ( + num_machines == 1 + ), "dist_url=auto not supported in multi-machine jobs." + port = _find_free_port() + dist_url = f"tcp://127.0.0.1:{port}" + if num_machines > 1 and dist_url.startswith("file://"): + logger = logging.getLogger(__name__) + logger.warning( + "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" + ) + + mp.spawn( + _distributed_worker, + nprocs=num_gpus_per_machine, + args=( + main_func, + world_size, + num_gpus_per_machine, + machine_rank, + dist_url, + cfg, + timeout, + ), + daemon=False, + ) + else: + main_func(*cfg) + + +def _distributed_worker( + local_rank, + main_func, + world_size, + num_gpus_per_machine, + machine_rank, + dist_url, + cfg, + timeout=DEFAULT_TIMEOUT, +): + assert ( + torch.cuda.is_available() + ), "cuda is not available. Please check your installation." + global_rank = machine_rank * num_gpus_per_machine + local_rank + try: + dist.init_process_group( + backend="NCCL", + init_method=dist_url, + world_size=world_size, + rank=global_rank, + timeout=timeout, + ) + except Exception as e: + logger = logging.getLogger(__name__) + logger.error("Process group URL: {}".format(dist_url)) + raise e + + # Setup the local process group (which contains ranks within the same machine) + assert comm._LOCAL_PROCESS_GROUP is None + num_machines = world_size // num_gpus_per_machine + for i in range(num_machines): + ranks_on_i = list( + range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) + ) + pg = dist.new_group(ranks_on_i) + if i == machine_rank: + comm._LOCAL_PROCESS_GROUP = pg + + assert num_gpus_per_machine <= torch.cuda.device_count() + torch.cuda.set_device(local_rank) + + # synchronize is needed here to prevent a possible timeout after calling init_process_group + # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 + comm.synchronize() + + main_func(*cfg) diff --git a/UniRig/src/model/pointcept/engines/train.py b/UniRig/src/model/pointcept/engines/train.py new file mode 100644 index 0000000000000000000000000000000000000000..66964b7774eafa64714806b3988abaa775de403f --- /dev/null +++ b/UniRig/src/model/pointcept/engines/train.py @@ -0,0 +1,492 @@ +import os +import sys +import weakref +import torch +torch.multiprocessing.set_start_method('spawn') +import torch.nn as nn +import torch.utils.data +from functools import partial + +if sys.version_info >= (3, 10): + from collections.abc import Iterator +else: + from collections import Iterator +from tensorboardX import SummaryWriter + +from .defaults import create_ddp_model, worker_init_fn +from .hooks import HookBase, build_hooks +import pointcept.utils.comm as comm +from pointcept.datasets import build_dataset, point_collate_fn, collate_fn +from pointcept.models import build_model +from pointcept.utils.logger import get_root_logger +from pointcept.utils.optimizer import build_optimizer +from pointcept.utils.scheduler import build_scheduler +from pointcept.utils.events import EventStorage +from pointcept.utils.registry import Registry + +from sklearn.preprocessing import QuantileTransformer +from pointcept.utils.timer import Timer + +TRAINERS = Registry("trainers") +from cuml.cluster.hdbscan import HDBSCAN +# from sklearn.cluster import HDBSCAN +import open3d as o3d +import matplotlib.colors as mcolors +import numpy as np +from collections import OrderedDict +import trimesh +import pointops + +class TrainerBase: + def __init__(self) -> None: + self.hooks = [] + self.epoch = 0 + self.start_epoch = 0 + self.max_epoch = 0 + self.max_iter = 0 + self.comm_info = dict() + self.data_iterator: Iterator = enumerate([]) + self.storage: EventStorage + self.writer: SummaryWriter + self._iter_timer = Timer() + + def register_hooks(self, hooks) -> None: + hooks = build_hooks(hooks) + for h in hooks: + assert isinstance(h, HookBase) + # To avoid circular reference, hooks and trainer cannot own each other. + # This normally does not matter, but will cause memory leak if the + # involved objects contain __del__: + # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/ + h.trainer = weakref.proxy(self) + self.hooks.extend(hooks) + + def train(self): + with EventStorage() as self.storage: + # => before train + self.before_train() + for self.epoch in range(self.start_epoch, self.max_epoch): + # => before epoch + self.before_epoch() + # => run_epoch + for ( + self.comm_info["iter"], + self.comm_info["input_dict"], + ) in self.data_iterator: + # => before_step + self.before_step() + # => run_step + self.run_step() + # => after_step + self.after_step() + # => after epoch + self.after_epoch() + # => after train + self.after_train() + + def before_train(self): + for h in self.hooks: + h.before_train() + + def before_epoch(self): + for h in self.hooks: + h.before_epoch() + + def before_step(self): + for h in self.hooks: + h.before_step() + + def run_step(self): + raise NotImplementedError + + def after_step(self): + for h in self.hooks: + h.after_step() + + def after_epoch(self): + for h in self.hooks: + h.after_epoch() + self.storage.reset_histories() + + def after_train(self): + # Sync GPU before running train hooks + comm.synchronize() + for h in self.hooks: + h.after_train() + if comm.is_main_process(): + self.writer.close() + + +@TRAINERS.register_module("DefaultTrainer") +class Trainer(TrainerBase): + def __init__(self, cfg): + super(Trainer, self).__init__() + self.epoch = 0 + self.start_epoch = 0 + self.max_epoch = cfg.eval_epoch + self.best_metric_value = -torch.inf + self.logger = get_root_logger( + log_file=os.path.join(cfg.save_path, "train.log"), + # file_mode="a" if cfg.resume else "w", + file_mode="a", + ) + self.logger.info("=> Loading config ...") + self.cfg = cfg + self.logger.info(f"Save path: {cfg.save_path}") + self.logger.info(f"Config:\n{cfg.pretty_text}") + self.logger.info("=> Building model ...") + self.model = self.build_model() + self.logger.info("=> Building writer ...") + self.writer = self.build_writer() + self.logger.info("=> Building train dataset & dataloader ...") + self.train_loader = self.build_train_loader() + # self.logger.info("=> Building val dataset & dataloader ...") + # self.val_loader = self.build_val_loader() + self.logger.info("=> Building optimize, scheduler, scaler(amp) ...") + self.optimizer = self.build_optimizer() + self.scheduler = self.build_scheduler() + self.scaler = self.build_scaler() + self.logger.info("=> Building hooks ...") + self.register_hooks(self.cfg.hooks) + + # !!! + self.model.scale_statistics = nn.Parameter(self.train_loader.scale_3d_statistics) + self.model.scale_statistics.requires_grad = False + self.model.quantile_transformer = self._get_quantile_func(self.train_loader.scale_3d_statistics) + # print(id(self.model)) + # self.val_scales_list = [0.0, 0.5, 1.0, 1.5, 2.0] + self.val_scales_list = self.cfg.val_scales_list + self.mesh_voting = self.cfg.mesh_voting + self.backbone_weight_path = self.cfg.backbone_weight_path + + def train(self): + with EventStorage() as self.storage: + # => before train + # self.before_train() + self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>") + # data_time = self._iter_timer.seconds() + # self.trainer.storage.put_scalar("data_time", data_time) + # !!! load checkpoint + if self.backbone_weight_path != None: + self.logger.info("=> Loading checkpoint & weight ...") + if os.path.isfile(self.backbone_weight_path): + checkpoint = torch.load( + self.backbone_weight_path, + map_location=lambda storage, loc: storage.cuda(), + ) + weight = OrderedDict() + for key, value in checkpoint["state_dict"].items(): + if not key.startswith("module."): + if comm.get_world_size() > 1: + key = "module." + key # xxx.xxx -> module.xxx.xxx + # Now all keys contain "module." no matter DDP or not. + # if self.keywords in key: + # key = key.replace(self.keywords, self.replacement) + if comm.get_world_size() == 1: + key = key[7:] # module.xxx.xxx -> xxx.xxx + # if key.startswith("backbone."): + # key = key[9:] # backbone.xxx.xxx -> xxx.xxx + key = "backbone." + key # xxx.xxx -> backbone.xxx.xxx + weight[key] = value + load_state_info = self.model.load_state_dict(weight, strict=False) + self.logger.info(f"Missing keys: {load_state_info[0]}") + else: + self.logger.info(f"No weight found at: {self.backbone_weight_path}") + + for self.epoch in range(self.start_epoch, self.max_epoch): + self.model.train() + loss_dict = self.model(self.train_loader.get_data(0)) + loss = loss_dict["instance_loss"] + + # !!! writer + lr = self.optimizer.state_dict()["param_groups"][0]["lr"] + self.writer.add_scalar("lr", lr, self.epoch) + for key in loss_dict.keys(): + self.writer.add_scalar( + "train/" + key, + loss_dict[key].item(), + self.epoch, + ) + if self.epoch % 10 == 0: + self.logger.info( + f"iter: {self.epoch}, total_loss: {loss.item()}, loss_1: {loss_dict['instance_loss_1'].item()}, loss_2: {loss_dict['instance_loss_2'].item()}, loss_3: {loss_dict['instance_loss_3'].item()}" + ) + + # !!! optimizer + self.optimizer.zero_grad() + if self.cfg.enable_amp: + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + + # When enable amp, optimizer.step call are skipped if the loss scaling factor is too large. + # Fix torch warning scheduler step before optimizer step. + scaler = self.scaler.get_scale() + self.scaler.update() + if scaler <= self.scaler.get_scale(): + self.scheduler.step() + else: + loss.backward() + self.optimizer.step() + self.scheduler.step() + + # !!! save checkpoint + if (self.epoch + 1) % 5000 == 0: + filename = os.path.join(self.cfg.save_path, "model", f"{str(self.epoch + 1)}.pth") + self.logger.info("Saving checkpoint to: " + filename) + torch.save( + { + "epoch": self.epoch + 1, + "state_dict": self.model.state_dict(), + }, + filename + ".tmp", + ) + os.replace(filename + ".tmp", filename) + self.eval() + + def eval(self): + # val_data = build_dataset(self.cfg.data.val) + self.logger.info("=> Loading checkpoint & weight ...") + if self.cfg.weight and os.path.isfile(self.cfg.weight): + checkpoint = torch.load( + self.cfg.weight, + map_location=lambda storage, loc: storage.cuda(), + ) + load_state_info = self.model.load_state_dict(checkpoint["state_dict"]) + self.logger.info(f"Missing keys: {load_state_info[0]}") + else: + self.logger.info(f"No weight found at: {self.cfg.weight}") + self.cfg.weight = "last" + + self.model.eval() + save_root = os.path.join(self.cfg.save_path, "vis_pcd", os.path.splitext(os.path.basename(self.cfg.weight))[0]) + os.makedirs(save_root, exist_ok=True) + group_save_root = os.path.join(self.cfg.save_path, "results", os.path.splitext(os.path.basename(self.cfg.weight))[0]) + os.makedirs(group_save_root, exist_ok=True) + + hex_colors = list(mcolors.CSS4_COLORS.values()) + rgb_colors = np.array([mcolors.to_rgb(color) for color in hex_colors if color not in ['#000000', '#FFFFFF']]) + def relative_luminance(color): + return 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2] + rgb_colors = [color for color in rgb_colors if (relative_luminance(color) > 0.4 and relative_luminance(color) < 0.8)] + np.random.shuffle(rgb_colors) + input_dict = self.train_loader.val_data() + + pcd_inverse = self.train_loader.pcd_inverse + if self.mesh_voting: + mesh = trimesh.load(self.train_loader.mesh_path) + if isinstance(mesh, trimesh.Scene): + mesh = mesh.dump(concatenate=True) + mesh.visual = trimesh.visual.ColorVisuals(mesh=mesh) + + for scale in self.val_scales_list: + input_dict["scale"] = scale + instance_feat = self.model(input_dict).cpu().detach().numpy() + + clusterer = HDBSCAN( + cluster_selection_epsilon=0.1, + min_samples=30, + min_cluster_size=30, + allow_single_cluster=False, + ).fit(instance_feat) + + labels = clusterer.labels_ + invalid_label_mask = labels == -1 + if invalid_label_mask.sum() > 0: + if invalid_label_mask.sum() == len(invalid_label_mask): + labels = np.zeros_like(labels) + else: + coord = input_dict["obj"]["coord"].cuda().contiguous().float() + valid_coord = coord[~invalid_label_mask] + valid_offset = torch.tensor(valid_coord.shape[0]).cuda() + invalid_coord = coord[invalid_label_mask] + invalid_offset = torch.tensor(invalid_coord.shape[0]).cuda() + indices, distances = pointops.knn_query(1, valid_coord, valid_offset, invalid_coord, invalid_offset) + indices = indices[:, 0].cpu().numpy() + labels[invalid_label_mask] = labels[~invalid_label_mask][indices] + + + # np.save(os.path.join(group_save_root, f"{str(scale)}.npy"), labels) + save_path = os.path.join(save_root, f"{str(scale)}.ply") + coord = input_dict["obj"]["coord"].cpu().numpy() + random_color = [] + for i in range(max(labels) + 1): + random_color.append(rgb_colors[i % len(rgb_colors)]) + random_color.append(np.array([0, 0, 0])) + color = [random_color[i] for i in labels] + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(coord) + pcd.colors = o3d.utility.Vector3dVector(color) + o3d.io.write_point_cloud(save_path, pcd) + + labels = labels[pcd_inverse] + + # print(len(clusterer.labels_)) + self.logger.info(f"scale_{scale} has {max(labels)+1} groups") + # print(min(clusterer.labels_)) + if self.mesh_voting: + face_index = self.train_loader.face_index + face_index = face_index[pcd_inverse] + + # Compute votes for each face using NumPy's bincount function + # labels = clusterer.labels_ + num_faces = len(mesh.faces) + num_labels = max(labels) + 1 + votes = np.zeros((num_faces, num_labels), dtype=np.int32) + np.add.at(votes, (face_index, labels), 1) + + # Find the label with most votes for each face using NumPy's argmax function + max_votes_labels = np.argmax(votes, axis=1) + # Set the label to -1 for faces that have no corresponding points + max_votes_labels[np.all(votes == 0, axis=1)] = -1 + + valid_mask = max_votes_labels != -1 + face_centroids = mesh.triangles_center + coord = torch.tensor(face_centroids).cuda().contiguous().float() + valid_coord = coord[valid_mask] + valid_offset = torch.tensor(valid_coord.shape[0]).cuda() + invalid_coord = coord[~valid_mask] + invalid_offset = torch.tensor(invalid_coord.shape[0]).cuda() + indices, distances = pointops.knn_query(1, valid_coord, valid_offset, invalid_coord, invalid_offset) + # # the first column is the point itself + # indices = indices[:, 1].cpu().numpy() + indices = indices[:, 0].cpu().numpy() + mesh_group = max_votes_labels.copy() + mesh_group[~valid_mask] = mesh_group[valid_mask][indices] + + np.save(os.path.join(group_save_root, f"mesh_{str(scale)}.npy"), mesh_group) + + # Assign color to each face based on the label with most votes + for face, label in enumerate(mesh_group): + color = (random_color[label] * 255).astype(np.uint8) + color_with_alpha = np.append(color, 255) # Add alpha value + mesh.visual.face_colors[face] = color_with_alpha + + # Save the new mesh + mesh_save_path = os.path.join(save_root, f"mesh_{str(scale)}.ply") + mesh.export(mesh_save_path) + + + def _get_quantile_func(self, scales: torch.Tensor, distribution="normal"): + """ + Use 3D scale statistics to normalize scales -- use quantile transformer. + """ + scales = scales.flatten() + max_grouping_scale = 2 + scales = scales[(scales > 0) & (scales < max_grouping_scale)] + + scales = scales.detach().cpu().numpy() + + # Calculate quantile transformer + quantile_transformer = QuantileTransformer(output_distribution=distribution) + quantile_transformer = quantile_transformer.fit(scales.reshape(-1, 1)) + + def quantile_transformer_func(scales): + # This function acts as a wrapper for QuantileTransformer. + # QuantileTransformer expects a numpy array, while we have a torch tensor. + return torch.Tensor( + quantile_transformer.transform(scales.cpu().numpy()) + ).to(scales.device) + + return quantile_transformer_func + + def run_step(self): + input_dict = self.comm_info["input_dict"] + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp): + output_dict = self.model(input_dict) + loss = output_dict["loss"] + self.optimizer.zero_grad() + if self.cfg.enable_amp: + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + + # When enable amp, optimizer.step call are skipped if the loss scaling factor is too large. + # Fix torch warning scheduler step before optimizer step. + scaler = self.scaler.get_scale() + self.scaler.update() + if scaler <= self.scaler.get_scale(): + self.scheduler.step() + else: + loss.backward() + self.optimizer.step() + self.scheduler.step() + if self.cfg.empty_cache: + torch.cuda.empty_cache() + self.comm_info["model_output_dict"] = output_dict + + def build_model(self): + model = build_model(self.cfg.model) + if self.cfg.sync_bn: + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + # logger.info(f"Model: \n{self.model}") + self.logger.info(f"Num params: {n_parameters}") + model = create_ddp_model( + model.cuda(), + broadcast_buffers=False, + find_unused_parameters=self.cfg.find_unused_parameters, + ) + return model + + def build_writer(self): + writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None + self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}") + return writer + + def build_train_loader(self): + self.cfg.data.train.oid = self.cfg.oid + self.cfg.data.train.label = self.cfg.label + train_data = build_dataset(self.cfg.data.train) + return train_data + + def build_val_loader(self): + val_loader = None + if self.cfg.evaluate: + val_data = build_dataset(self.cfg.data.val) + if comm.get_world_size() > 1: + val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) + else: + val_sampler = None + val_loader = torch.utils.data.DataLoader( + val_data, + batch_size=self.cfg.batch_size_val_per_gpu, + shuffle=False, + num_workers=self.cfg.num_worker_per_gpu, + pin_memory=True, + sampler=val_sampler, + collate_fn=collate_fn, + ) + return val_loader + + def build_optimizer(self): + return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts) + + def build_scheduler(self): + assert hasattr(self, "optimizer") + assert hasattr(self, "train_loader") + # self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch + self.cfg.scheduler.total_steps = self.max_epoch + return build_scheduler(self.cfg.scheduler, self.optimizer) + + def build_scaler(self): + scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None + return scaler + + +@TRAINERS.register_module("MultiDatasetTrainer") +class MultiDatasetTrainer(Trainer): + def build_train_loader(self): + from pointcept.datasets import MultiDatasetDataloader + + train_data = build_dataset(self.cfg.data.train) + train_loader = MultiDatasetDataloader( + train_data, + self.cfg.batch_size_per_gpu, + self.cfg.num_worker_per_gpu, + self.cfg.mix_prob, + self.cfg.seed, + ) + self.comm_info["iter_per_epoch"] = len(train_loader) + return train_loader diff --git a/UniRig/src/model/pointcept/models/PTv3Object.py b/UniRig/src/model/pointcept/models/PTv3Object.py new file mode 100644 index 0000000000000000000000000000000000000000..5006353c34bb95b1c266e420d960a9c288b2ade4 --- /dev/null +++ b/UniRig/src/model/pointcept/models/PTv3Object.py @@ -0,0 +1,664 @@ +from functools import partial +from addict import Dict +import math +import torch +import torch.nn as nn +import spconv.pytorch as spconv +import torch_scatter +from timm.models.layers import DropPath +from typing import Union +from einops import rearrange + +try: + import flash_attn +except ImportError: + flash_attn = None + +from .utils.misc import offset2bincount +from .utils.structure import Point +from .modules import PointModule, PointSequential + + +class RPE(torch.nn.Module): + def __init__(self, patch_size, num_heads): + super().__init__() + self.patch_size = patch_size + self.num_heads = num_heads + self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) + self.rpe_num = 2 * self.pos_bnd + 1 + self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) + torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) + + def forward(self, coord): + idx = ( + coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd + + self.pos_bnd # relative position to positive index + + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride + ) + out = self.rpe_table.index_select(0, idx.reshape(-1)) + out = out.view(idx.shape + (-1,)).sum(3) + out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) + return out + +class QueryKeyNorm(nn.Module): + def __init__(self, channels, num_heads): + super(QueryKeyNorm, self).__init__() + self.num_heads = num_heads + self.norm = nn.LayerNorm(channels // num_heads, elementwise_affine=False) + + def forward(self, qkv): + H = self.num_heads + #qkv = qkv.reshape(-1, 3, H, qkv.shape[1] // H).permute(1, 0, 2, 3) + qkv = rearrange(qkv, 'N (S H Ch) -> S N H Ch', H=H, S=3) + q, k, v = qkv.unbind(dim=0) + # q, k, v: [N, H, C // H] + q_norm = self.norm(q) + k_norm = self.norm(k) + + # qkv_norm: [3, N, H, C // H] + qkv_norm = torch.stack([q_norm, k_norm, v]) + qkv_norm = rearrange(qkv_norm, 'S N H Ch -> N (S H Ch)') + return qkv_norm + +class SerializedAttention(PointModule): + def __init__( + self, + channels, + num_heads, + patch_size, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + order_index=0, + enable_rpe=False, + enable_flash=True, + upcast_attention=True, + upcast_softmax=True, + enable_qknorm=False, + ): + super().__init__() + assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}" + self.channels = channels + self.num_heads = num_heads + self.scale = qk_scale or (channels // num_heads) ** -0.5 + self.order_index = order_index + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.enable_rpe = enable_rpe + self.enable_flash = enable_flash + self.enable_qknorm = enable_qknorm + if enable_qknorm: + self.qknorm = QueryKeyNorm(channels, num_heads) + else: + print("WARNING: enable_qknorm is False in PTv3Object and training may be fragile") + if enable_flash: + assert ( + enable_rpe is False + ), "Set enable_rpe to False when enable Flash Attention" + assert ( + upcast_attention is False + ), "Set upcast_attention to False when enable Flash Attention" + assert ( + upcast_softmax is False + ), "Set upcast_softmax to False when enable Flash Attention" + assert flash_attn is not None, "Make sure flash_attn is installed." + self.patch_size = patch_size + self.attn_drop = attn_drop + else: + # when disable flash attention, we still don't want to use mask + # consequently, patch size will auto set to the + # min number of patch_size_max and number of points + self.patch_size_max = patch_size + self.patch_size = 0 + self.attn_drop = torch.nn.Dropout(attn_drop) + + self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) + self.proj = torch.nn.Linear(channels, channels) + self.proj_drop = torch.nn.Dropout(proj_drop) + self.softmax = torch.nn.Softmax(dim=-1) + self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None + + @torch.no_grad() + def get_rel_pos(self, point, order): + K = self.patch_size + rel_pos_key = f"rel_pos_{self.order_index}" + if rel_pos_key not in point.keys(): + grid_coord = point.grid_coord[order] + grid_coord = grid_coord.reshape(-1, K, 3) + point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) + return point[rel_pos_key] + + @torch.no_grad() + def get_padding_and_inverse(self, point): + pad_key = "pad" + unpad_key = "unpad" + cu_seqlens_key = "cu_seqlens_key" + if ( + pad_key not in point.keys() + or unpad_key not in point.keys() + or cu_seqlens_key not in point.keys() + ): + offset = point.offset + bincount = offset2bincount(offset) + bincount_pad = ( + torch.div( + bincount + self.patch_size - 1, + self.patch_size, + rounding_mode="trunc", + ) + * self.patch_size + ) + # only pad point when num of points larger than patch_size + mask_pad = bincount > self.patch_size + bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad + _offset = nn.functional.pad(offset, (1, 0)) + _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) + pad = torch.arange(_offset_pad[-1], device=offset.device) + unpad = torch.arange(_offset[-1], device=offset.device) + cu_seqlens = [] + for i in range(len(offset)): + unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] + if bincount[i] != bincount_pad[i]: + pad[ + _offset_pad[i + 1] + - self.patch_size + + (bincount[i] % self.patch_size) : _offset_pad[i + 1] + ] = pad[ + _offset_pad[i + 1] + - 2 * self.patch_size + + (bincount[i] % self.patch_size) : _offset_pad[i + 1] + - self.patch_size + ] + pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] + cu_seqlens.append( + torch.arange( + _offset_pad[i], + _offset_pad[i + 1], + step=self.patch_size, + dtype=torch.int32, + device=offset.device, + ) + ) + point[pad_key] = pad + point[unpad_key] = unpad + point[cu_seqlens_key] = nn.functional.pad( + torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] + ) + return point[pad_key], point[unpad_key], point[cu_seqlens_key] + + def forward(self, point): + if not self.enable_flash: + self.patch_size = min( + offset2bincount(point.offset).min().tolist(), self.patch_size_max + ) + + H = self.num_heads + K = self.patch_size + C = self.channels + + pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) + + order = point.serialized_order[self.order_index][pad] + inverse = unpad[point.serialized_inverse[self.order_index]] + + # padding and reshape feat and batch for serialized point patch + qkv = self.qkv(point.feat)[order] + if self.enable_qknorm: + qkv = self.qknorm(qkv) + + if not self.enable_flash: + # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') + q, k, v = ( + qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) + ) + # attn + if self.upcast_attention: + q = q.float() + k = k.float() + attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) + if self.enable_rpe: + attn = attn + self.rpe(self.get_rel_pos(point, order)) + if self.upcast_softmax: + attn = attn.float() + attn = self.softmax(attn) + attn = self.attn_drop(attn).to(qkv.dtype) + feat = (attn @ v).transpose(1, 2).reshape(-1, C) + else: + feat = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv.half().reshape(-1, 3, H, C // H), + cu_seqlens, + max_seqlen=self.patch_size, + dropout_p=self.attn_drop if self.training else 0, + softmax_scale=self.scale, + ).reshape(-1, C) + feat = feat.to(qkv.dtype) + feat = feat[inverse] + + # ffn + feat = self.proj(feat) + feat = self.proj_drop(feat) + point.feat = feat + return point + + +class MLP(nn.Module): + def __init__( + self, + in_channels, + hidden_channels=None, + out_channels=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_channels, out_channels) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Block(PointModule): + def __init__( + self, + channels, + num_heads, + patch_size=48, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + pre_norm=True, + order_index=0, + cpe_indice_key=None, + enable_rpe=False, + enable_flash=True, + upcast_attention=True, + upcast_softmax=True, + enable_qknorm=False, + ): + super().__init__() + self.channels = channels + self.pre_norm = pre_norm + + self.cpe = PointSequential( + spconv.SubMConv3d( + channels, + channels, + kernel_size=3, + bias=True, + indice_key=cpe_indice_key, + ), + nn.Linear(channels, channels), + norm_layer(channels), + ) + + self.norm1 = PointSequential(norm_layer(channels)) + self.attn = SerializedAttention( + channels=channels, + patch_size=patch_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + order_index=order_index, + enable_rpe=enable_rpe, + enable_flash=enable_flash, + upcast_attention=upcast_attention, + upcast_softmax=upcast_softmax, + enable_qknorm=enable_qknorm, + ) + self.norm2 = PointSequential(norm_layer(channels)) + self.mlp = PointSequential( + MLP( + in_channels=channels, + hidden_channels=int(channels * mlp_ratio), + out_channels=channels, + act_layer=act_layer, + drop=proj_drop, + ) + ) + self.drop_path = PointSequential( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + + def forward(self, point: Point): + shortcut = point.feat + point = self.cpe(point) + point.feat = shortcut + point.feat + shortcut = point.feat + if self.pre_norm: + point = self.norm1(point) + point = self.drop_path(self.attn(point)) + point.feat = shortcut + point.feat + if not self.pre_norm: + point = self.norm1(point) + + shortcut = point.feat + if self.pre_norm: + point = self.norm2(point) + point = self.drop_path(self.mlp(point)) + point.feat = shortcut + point.feat + if not self.pre_norm: + point = self.norm2(point) + # point.sparse_conv_feat.replace_feature(point.feat) + point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) + return point + + +class SerializedPooling(PointModule): + def __init__( + self, + in_channels, + out_channels, + stride=2, + norm_layer=None, + act_layer=None, + reduce="max", + shuffle_orders=True, + traceable=True, # record parent and cluster + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8 + # TODO: add support to grid pool (any stride) + self.stride = stride + assert reduce in ["sum", "mean", "min", "max"] + self.reduce = reduce + self.shuffle_orders = shuffle_orders + self.traceable = traceable + + self.proj = nn.Linear(in_channels, out_channels) + if norm_layer is not None: + self.norm = PointSequential(norm_layer(out_channels)) + if act_layer is not None: + self.act = PointSequential(act_layer()) + + def forward(self, point: Point): + pooling_depth = (math.ceil(self.stride) - 1).bit_length() + if pooling_depth > point.serialized_depth: + pooling_depth = 0 + assert { + "serialized_code", + "serialized_order", + "serialized_inverse", + "serialized_depth", + }.issubset( + point.keys() + ), "Run point.serialization() point cloud before SerializedPooling" + + code = point.serialized_code >> pooling_depth * 3 + code_, cluster, counts = torch.unique( + code[0], + sorted=True, + return_inverse=True, + return_counts=True, + ) + # indices of point sorted by cluster, for torch_scatter.segment_csr + _, indices = torch.sort(cluster) + # index pointer for sorted point, for torch_scatter.segment_csr + idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) + # head_indices of each cluster, for reduce attr e.g. code, batch + head_indices = indices[idx_ptr[:-1]] + # generate down code, order, inverse + code = code[:, head_indices] + order = torch.argsort(code) + inverse = torch.zeros_like(order).scatter_( + dim=1, + index=order, + src=torch.arange(0, code.shape[1], device=order.device).repeat( + code.shape[0], 1 + ), + ) + + if self.shuffle_orders: + perm = torch.randperm(code.shape[0]) + code = code[perm] + order = order[perm] + inverse = inverse[perm] + + # collect information + point_dict = Dict( + feat=torch_scatter.segment_csr( + self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce + ), + coord=torch_scatter.segment_csr( + point.coord[indices], idx_ptr, reduce="mean" + ), + grid_coord=point.grid_coord[head_indices] >> pooling_depth, + serialized_code=code, + serialized_order=order, + serialized_inverse=inverse, + serialized_depth=point.serialized_depth - pooling_depth, + batch=point.batch[head_indices], + ) + + if "condition" in point.keys(): + point_dict["condition"] = point.condition + if "context" in point.keys(): + point_dict["context"] = point.context + + if self.traceable: + point_dict["pooling_inverse"] = cluster + point_dict["pooling_parent"] = point + point = Point(point_dict) + if self.norm is not None: + point = self.norm(point) + if self.act is not None: + point = self.act(point) + point.sparsify() + return point + + +class SerializedUnpooling(PointModule): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + norm_layer=None, + act_layer=None, + traceable=False, # record parent and cluster + ): + super().__init__() + self.proj = PointSequential(nn.Linear(in_channels, out_channels)) + self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels)) + + if norm_layer is not None: + self.proj.add(norm_layer(out_channels)) + self.proj_skip.add(norm_layer(out_channels)) + + if act_layer is not None: + self.proj.add(act_layer()) + self.proj_skip.add(act_layer()) + + self.traceable = traceable + + def forward(self, point): + assert "pooling_parent" in point.keys() + assert "pooling_inverse" in point.keys() + parent = point.pop("pooling_parent") + inverse = point.pop("pooling_inverse") + point = self.proj(point) + parent = self.proj_skip(parent) + parent.feat = parent.feat + point.feat[inverse] + + if self.traceable: + parent["unpooling_parent"] = point + return parent + + +class Embedding(PointModule): + def __init__( + self, + in_channels, + embed_channels, + norm_layer=None, + act_layer=None, + res_linear=False, + ): + super().__init__() + self.in_channels = in_channels + self.embed_channels = embed_channels + + # TODO: check remove spconv + self.stem = PointSequential( + conv=spconv.SubMConv3d( + in_channels, + embed_channels, + kernel_size=5, + padding=1, + bias=False, + indice_key="stem", + ) + ) + if norm_layer is not None: + self.stem.add(norm_layer(embed_channels), name="norm") + if act_layer is not None: + self.stem.add(act_layer(), name="act") + + if res_linear: + self.res_linear = nn.Linear(in_channels, embed_channels) + else: + self.res_linear = None + + def forward(self, point: Point): + if self.res_linear: + res_feature = self.res_linear(point.feat) + point = self.stem(point) + if self.res_linear: + point.feat = point.feat + res_feature + point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) + return point + + +class PointTransformerV3Object(PointModule): + def __init__( + self, + in_channels=9, + order=("z", "z-trans", "hilbert", "hilbert-trans"), + stride=(), + enc_depths=(3, 3, 3, 6, 16), + enc_channels=(32, 64, 128, 256, 384), + enc_num_head=(2, 4, 8, 16, 24), + enc_patch_size=(1024, 1024, 1024, 1024, 1024), + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + pre_norm=True, + shuffle_orders=True, + enable_rpe=False, + enable_flash=True, + upcast_attention=False, + upcast_softmax=False, + cls_mode=False, + enable_qknorm=False, + layer_norm=False, + res_linear=True, + ): + super().__init__() + self.num_stages = len(enc_depths) + self.order = [order] if isinstance(order, str) else order + self.cls_mode = cls_mode + self.shuffle_orders = shuffle_orders + + # norm layers + if layer_norm: + bn_layer = partial(nn.LayerNorm) + else: + print("WARNING: use BatchNorm in ptv3obj !!!") + bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) + ln_layer = nn.LayerNorm + # activation layers + act_layer = nn.GELU + + self.embedding = Embedding( + in_channels=in_channels, + embed_channels=enc_channels[0], + norm_layer=bn_layer, + act_layer=act_layer, + res_linear=res_linear, + ) + + # encoder + enc_drop_path = [ + x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) + ] + self.enc = PointSequential() + for s in range(self.num_stages): + enc_drop_path_ = enc_drop_path[ + sum(enc_depths[:s]) : sum(enc_depths[: s + 1]) + ] + enc = PointSequential() + if s > 0: + enc.add(nn.Linear(enc_channels[s - 1], enc_channels[s])) + + for i in range(enc_depths[s]): + enc.add( + Block( + channels=enc_channels[s], + num_heads=enc_num_head[s], + patch_size=enc_patch_size[s], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + drop_path=enc_drop_path_[i], + norm_layer=ln_layer, + act_layer=act_layer, + pre_norm=pre_norm, + order_index=i % len(self.order), + cpe_indice_key=f"stage{s}", + enable_rpe=enable_rpe, + enable_flash=enable_flash, + upcast_attention=upcast_attention, + upcast_softmax=upcast_softmax, + enable_qknorm=enable_qknorm, + ), + name=f"block{i}", + ) + if len(enc) != 0: + self.enc.add(module=enc, name=f"enc{s}") + + + def forward(self, data_dict, min_coord=None): + point = Point(data_dict) + point.serialization(order=self.order, shuffle_orders=self.shuffle_orders, min_coord=min_coord) + point.sparsify() + point = self.embedding(point) + point = self.enc(point) + return point + +def get_encoder(pretrained_path: Union[str, None]=None, freeze_encoder: bool=False, **kwargs) -> PointTransformerV3Object: + point_encoder = PointTransformerV3Object(**kwargs) + if pretrained_path is not None: + checkpoint = torch.load(pretrained_path) + state_dict = checkpoint["state_dict"] + state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} + point_encoder.load_state_dict(state_dict, strict=False) + if freeze_encoder is True: + for name, param in point_encoder.named_parameters(): + if 'res_linear' not in name and 'qknorm' not in name: + param.requires_grad = False + return point_encoder \ No newline at end of file diff --git a/UniRig/src/model/pointcept/models/__init__.py b/UniRig/src/model/pointcept/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7027e601f513c17e2e5c36a6491e928dbfe4b1a --- /dev/null +++ b/UniRig/src/model/pointcept/models/__init__.py @@ -0,0 +1 @@ +from .PTv3Object import PointTransformerV3Object diff --git a/UniRig/src/model/pointcept/models/modules.py b/UniRig/src/model/pointcept/models/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb77634cf5f54759ee301d6cffb966ca0076396 --- /dev/null +++ b/UniRig/src/model/pointcept/models/modules.py @@ -0,0 +1,83 @@ +import sys +import torch.nn as nn +import spconv.pytorch as spconv +from collections import OrderedDict +from .utils.structure import Point + + +class PointModule(nn.Module): + r"""PointModule + placeholder, all module subclass from this will take Point in PointSequential. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class PointSequential(PointModule): + r"""A sequential container. + Modules will be added to it in the order they are passed in the constructor. + Alternatively, an ordered dict of modules can also be passed in. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + for name, module in kwargs.items(): + if sys.version_info < (3, 6): + raise ValueError("kwargs only supported in py36+") + if name in self._modules: + raise ValueError("name exists.") + self.add_module(name, module) + + def __getitem__(self, idx): + if not (-len(self) <= idx < len(self)): + raise IndexError("index {} is out of range".format(idx)) + if idx < 0: + idx += len(self) + it = iter(self._modules.values()) + for i in range(idx): + next(it) + return next(it) + + def __len__(self): + return len(self._modules) + + def add(self, module, name=None): + if name is None: + name = str(len(self._modules)) + if name in self._modules: + raise KeyError("name exists") + self.add_module(name, module) + + def forward(self, input): + for k, module in self._modules.items(): + # Point module + if isinstance(module, PointModule): + input = module(input) + # Spconv module + elif spconv.modules.is_spconv_module(module): + if isinstance(input, Point): + input.sparse_conv_feat = module(input.sparse_conv_feat) + input.feat = input.sparse_conv_feat.features + else: + input = module(input) + # PyTorch module + else: + if isinstance(input, Point): + input.feat = module(input.feat) + if "sparse_conv_feat" in input.keys(): + input.sparse_conv_feat = input.sparse_conv_feat.replace_feature( + input.feat + ) + elif isinstance(input, spconv.SparseConvTensor): + if input.indices.shape[0] != 0: + input = input.replace_feature(module(input.features)) + else: + input = module(input) + return input diff --git a/UniRig/src/model/pointcept/models/utils/__init__.py b/UniRig/src/model/pointcept/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66e6bc0f62993abb3625a9598f54e7775aeb0008 --- /dev/null +++ b/UniRig/src/model/pointcept/models/utils/__init__.py @@ -0,0 +1,4 @@ +from .misc import offset2batch, offset2bincount, batch2offset, off_diagonal +from .checkpoint import checkpoint +from .serialization import encode, decode +from .structure import Point diff --git a/UniRig/src/model/pointcept/models/utils/checkpoint.py b/UniRig/src/model/pointcept/models/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..09b6d307ccf5b6d79135eec9274c9c8a8c29432e --- /dev/null +++ b/UniRig/src/model/pointcept/models/utils/checkpoint.py @@ -0,0 +1,50 @@ +import torch + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) diff --git a/UniRig/src/model/pointcept/models/utils/misc.py b/UniRig/src/model/pointcept/models/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..fb1b282f480f23cba30644edf35477a9358d0e11 --- /dev/null +++ b/UniRig/src/model/pointcept/models/utils/misc.py @@ -0,0 +1,28 @@ +import torch + + +@torch.inference_mode() +def offset2bincount(offset): + return torch.diff( + offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long) + ) + + +@torch.inference_mode() +def offset2batch(offset): + bincount = offset2bincount(offset) + return torch.arange( + len(bincount), device=offset.device, dtype=torch.long + ).repeat_interleave(bincount) + + +@torch.inference_mode() +def batch2offset(batch): + return torch.cumsum(batch.bincount(), dim=0).long() + + +def off_diagonal(x): + # return a flattened view of the off-diagonal elements of a square matrix + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() diff --git a/UniRig/src/model/pointcept/models/utils/serialization/__init__.py b/UniRig/src/model/pointcept/models/utils/serialization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..058c5e1001c76d9c7014bf0bbb824eec4f54f476 --- /dev/null +++ b/UniRig/src/model/pointcept/models/utils/serialization/__init__.py @@ -0,0 +1,8 @@ +from .default import ( + encode, + decode, + z_order_encode, + z_order_decode, + hilbert_encode, + hilbert_decode, +) diff --git a/UniRig/src/model/pointcept/models/utils/serialization/default.py b/UniRig/src/model/pointcept/models/utils/serialization/default.py new file mode 100644 index 0000000000000000000000000000000000000000..15898b55625fc0e1125db9b713e900892f04176c --- /dev/null +++ b/UniRig/src/model/pointcept/models/utils/serialization/default.py @@ -0,0 +1,59 @@ +import torch +from .z_order import xyz2key as z_order_encode_ +from .z_order import key2xyz as z_order_decode_ +from .hilbert import encode as hilbert_encode_ +from .hilbert import decode as hilbert_decode_ + + +@torch.inference_mode() +def encode(grid_coord, batch=None, depth=16, order="z"): + assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} + if order == "z": + code = z_order_encode(grid_coord, depth=depth) + elif order == "z-trans": + code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) + elif order == "hilbert": + code = hilbert_encode(grid_coord, depth=depth) + elif order == "hilbert-trans": + code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) + else: + raise NotImplementedError + if batch is not None: + batch = batch.long() + code = batch << depth * 3 | code + return code + + +@torch.inference_mode() +def decode(code, depth=16, order="z"): + assert order in {"z", "hilbert"} + batch = code >> depth * 3 + code = code & ((1 << depth * 3) - 1) + if order == "z": + grid_coord = z_order_decode(code, depth=depth) + elif order == "hilbert": + grid_coord = hilbert_decode(code, depth=depth) + else: + raise NotImplementedError + return grid_coord, batch + + +def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): + x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() + # we block the support to batch, maintain batched code in Point class + code = z_order_encode_(x, y, z, b=None, depth=depth) + return code + + +def z_order_decode(code: torch.Tensor, depth): + x, y, z = z_order_decode_(code, depth=depth) + grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3) + return grid_coord + + +def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): + return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth) + + +def hilbert_decode(code: torch.Tensor, depth: int = 16): + return hilbert_decode_(code, num_dims=3, num_bits=depth) diff --git a/UniRig/src/model/pointcept/models/utils/serialization/hilbert.py b/UniRig/src/model/pointcept/models/utils/serialization/hilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..79356577e12179a35341af9b430a1ba8e238b832 --- /dev/null +++ b/UniRig/src/model/pointcept/models/utils/serialization/hilbert.py @@ -0,0 +1,295 @@ +import torch + + +def right_shift(binary, k=1, axis=-1): + """Right shift an array of binary values. + + Parameters: + ----------- + binary: An ndarray of binary values. + + k: The number of bits to shift. Default 1. + + axis: The axis along which to shift. Default -1. + + Returns: + -------- + Returns an ndarray with zero prepended and the ends truncated, along + whatever axis was specified.""" + + # If we're shifting the whole thing, just return zeros. + if binary.shape[axis] <= k: + return torch.zeros_like(binary) + + # Determine the padding pattern. + # padding = [(0,0)] * len(binary.shape) + # padding[axis] = (k,0) + + # Determine the slicing pattern to eliminate just the last one. + slicing = [slice(None)] * len(binary.shape) + slicing[axis] = slice(None, -k) + shifted = torch.nn.functional.pad( + binary[tuple(slicing)], (k, 0), mode="constant", value=0 + ) + + return shifted + + +def binary2gray(binary, axis=-1): + """Convert an array of binary values into Gray codes. + + This uses the classic X ^ (X >> 1) trick to compute the Gray code. + + Parameters: + ----------- + binary: An ndarray of binary values. + + axis: The axis along which to compute the gray code. Default=-1. + + Returns: + -------- + Returns an ndarray of Gray codes. + """ + shifted = right_shift(binary, axis=axis) + + # Do the X ^ (X >> 1) trick. + gray = torch.logical_xor(binary, shifted) + + return gray + + +def gray2binary(gray, axis=-1): + """Convert an array of Gray codes back into binary values. + + Parameters: + ----------- + gray: An ndarray of gray codes. + + axis: The axis along which to perform Gray decoding. Default=-1. + + Returns: + -------- + Returns an ndarray of binary values. + """ + + # Loop the log2(bits) number of times necessary, with shift and xor. + shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) + while shift > 0: + gray = torch.logical_xor(gray, right_shift(gray, shift)) + shift = torch.div(shift, 2, rounding_mode="floor") + return gray + + +def encode(locs, num_dims, num_bits): + """Decode an array of locations in a hypercube into a Hilbert integer. + + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + + Params: + ------- + locs - An ndarray of locations in a hypercube of num_dims dimensions, in + which each dimension runs from 0 to 2**num_bits-1. The shape can + be arbitrary, as long as the last dimension of the same has size + num_dims. + + num_dims - The dimensionality of the hypercube. Integer. + + num_bits - The number of bits for each dimension. Integer. + + Returns: + -------- + The output is an ndarray of uint64 integers with the same shape as the + input, excluding the last dimension, which needs to be num_dims. + """ + + # Keep around the original shape for later. + orig_shape = locs.shape + bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) + bitpack_mask_rev = bitpack_mask.flip(-1) + + if orig_shape[-1] != num_dims: + raise ValueError( + """ + The shape of locs was surprising in that the last dimension was of size + %d, but num_dims=%d. These need to be equal. + """ + % (orig_shape[-1], num_dims) + ) + + if num_dims * num_bits > 63: + raise ValueError( + """ + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a int64. Are you sure you need that many points on your Hilbert + curve? + """ + % (num_dims, num_bits, num_dims * num_bits) + ) + + # Treat the location integers as 64-bit unsigned and then split them up into + # a sequence of uint8s. Preserve the association by dimension. + locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) + + # Now turn these into bits and truncate to num_bits. + gray = ( + locs_uint8.unsqueeze(-1) + .bitwise_and(bitpack_mask_rev) + .ne(0) + .byte() + .flatten(-2, -1)[..., -num_bits:] + ) + + # Run the decoding process the other way. + # Iterate forwards through the bits. + for bit in range(0, num_bits): + # Iterate forwards through the dimensions. + for dim in range(0, num_dims): + # Identify which ones have this bit active. + mask = gray[:, dim, bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:, 0, bit + 1 :] = torch.logical_xor( + gray[:, 0, bit + 1 :], mask[:, None] + ) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = torch.logical_and( + torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), + torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), + ) + gray[:, dim, bit + 1 :] = torch.logical_xor( + gray[:, dim, bit + 1 :], to_flip + ) + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) + + # Now flatten out. + gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims)) + + # Convert Gray back to binary. + hh_bin = gray2binary(gray) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits * num_dims + padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) + + # Convert binary values into uint8s. + hh_uint8 = ( + (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask) + .sum(2) + .squeeze() + .type(torch.uint8) + ) + + # Convert uint8s into uint64s. + hh_uint64 = hh_uint8.view(torch.int64).squeeze() + + return hh_uint64 + + +def decode(hilberts, num_dims, num_bits): + """Decode an array of Hilbert integers into locations in a hypercube. + + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + + Params: + ------- + hilberts - An ndarray of Hilbert integers. Must be an integer dtype and + cannot have fewer bits than num_dims * num_bits. + + num_dims - The dimensionality of the hypercube. Integer. + + num_bits - The number of bits for each dimension. Integer. + + Returns: + -------- + The output is an ndarray of unsigned integers with the same shape as hilberts + but with an additional dimension of size num_dims. + """ + + if num_dims * num_bits > 64: + raise ValueError( + """ + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a uint64. Are you sure you need that many points on your Hilbert + curve? + """ + % (num_dims, num_bits) + ) + + # Handle the case where we got handed a naked integer. + hilberts = torch.atleast_1d(hilberts) + + # Keep around the shape for later. + orig_shape = hilberts.shape + bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device) + bitpack_mask_rev = bitpack_mask.flip(-1) + + # Treat each of the hilberts as a s equence of eight uint8. + # This treats all of the inputs as uint64 and makes things uniform. + hh_uint8 = ( + hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1) + ) + + # Turn these lists of uints into lists of bits and then truncate to the size + # we actually need for using Skilling's procedure. + hh_bits = ( + hh_uint8.unsqueeze(-1) + .bitwise_and(bitpack_mask_rev) + .ne(0) + .byte() + .flatten(-2, -1)[:, -num_dims * num_bits :] + ) + + # Take the sequence of bits and Gray-code it. + gray = binary2gray(hh_bits) + + # There has got to be a better way to do this. + # I could index them differently, but the eventual packbits likes it this way. + gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2) + + # Iterate backwards through the bits. + for bit in range(num_bits - 1, -1, -1): + # Iterate backwards through the dimensions. + for dim in range(num_dims - 1, -1, -1): + # Identify which ones have this bit active. + mask = gray[:, dim, bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:, 0, bit + 1 :] = torch.logical_xor( + gray[:, 0, bit + 1 :], mask[:, None] + ) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = torch.logical_and( + torch.logical_not(mask[:, None]), + torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), + ) + gray[:, dim, bit + 1 :] = torch.logical_xor( + gray[:, dim, bit + 1 :], to_flip + ) + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits + padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0) + + # Now chop these up into blocks of 8. + locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8)) + + # Take those blocks and turn them unto uint8s. + # from IPython import embed; embed() + locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8) + + # Finally, treat these as uint64s. + flat_locs = locs_uint8.view(torch.int64) + + # Return them in the expected shape. + return flat_locs.reshape((*orig_shape, num_dims)) diff --git a/UniRig/src/model/pointcept/models/utils/serialization/z_order.py b/UniRig/src/model/pointcept/models/utils/serialization/z_order.py new file mode 100644 index 0000000000000000000000000000000000000000..b0237c87c48c32d07fe96d9ecc72d91e1c464099 --- /dev/null +++ b/UniRig/src/model/pointcept/models/utils/serialization/z_order.py @@ -0,0 +1,119 @@ +import torch +from typing import Optional, Union + + +class KeyLUT: + def __init__(self): + r256 = torch.arange(256, dtype=torch.int64) + r512 = torch.arange(512, dtype=torch.int64) + zero = torch.zeros(256, dtype=torch.int64) + device = torch.device("cpu") + + self._encode = { + device: ( + self.xyz2key(r256, zero, zero, 8), + self.xyz2key(zero, r256, zero, 8), + self.xyz2key(zero, zero, r256, 8), + ) + } + self._decode = {device: self.key2xyz(r512, 9)} + + def encode_lut(self, device=torch.device("cpu")): + if device not in self._encode: + cpu = torch.device("cpu") + self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) + return self._encode[device] + + def decode_lut(self, device=torch.device("cpu")): + if device not in self._decode: + cpu = torch.device("cpu") + self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) + return self._decode[device] + + def xyz2key(self, x, y, z, depth): + key = torch.zeros_like(x) + for i in range(depth): + mask = 1 << i + key = ( + key + | ((x & mask) << (2 * i + 2)) + | ((y & mask) << (2 * i + 1)) + | ((z & mask) << (2 * i + 0)) + ) + return key + + def key2xyz(self, key, depth): + x = torch.zeros_like(key) + y = torch.zeros_like(key) + z = torch.zeros_like(key) + for i in range(depth): + x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) + y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) + z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) + return x, y, z + + +_key_lut = KeyLUT() + + +def xyz2key( + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + b: Optional[Union[torch.Tensor, int]] = None, + depth: int = 16, +): + r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys + based on pre-computed look up tables. The speed of this function is much + faster than the method based on for-loop. + + Args: + x (torch.Tensor): The x coordinate. + y (torch.Tensor): The y coordinate. + z (torch.Tensor): The z coordinate. + b (torch.Tensor or int): The batch index of the coordinates, and should be + smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of + :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. + depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). + """ + + EX, EY, EZ = _key_lut.encode_lut(x.device) + x, y, z = x.long(), y.long(), z.long() + + mask = 255 if depth > 8 else (1 << depth) - 1 + key = EX[x & mask] | EY[y & mask] | EZ[z & mask] + if depth > 8: + mask = (1 << (depth - 8)) - 1 + key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] + key = key16 << 24 | key + + if b is not None: + b = b.long() + key = b << 48 | key + + return key + + +def key2xyz(key: torch.Tensor, depth: int = 16): + r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates + and the batch index based on pre-computed look up tables. + + Args: + key (torch.Tensor): The shuffled key. + depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). + """ + + DX, DY, DZ = _key_lut.decode_lut(key.device) + x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) + + b = key >> 48 + key = key & ((1 << 48) - 1) + + n = (depth + 2) // 3 + for i in range(n): + k = key >> (i * 9) & 511 + x = x | (DX[k] << (i * 3)) + y = y | (DY[k] << (i * 3)) + z = z | (DZ[k] << (i * 3)) + + return x, y, z, b diff --git a/UniRig/src/model/pointcept/models/utils/structure.py b/UniRig/src/model/pointcept/models/utils/structure.py new file mode 100644 index 0000000000000000000000000000000000000000..afef3d7dbe2c3118e67d8eaa4a7e2e5198cb6a5e --- /dev/null +++ b/UniRig/src/model/pointcept/models/utils/structure.py @@ -0,0 +1,210 @@ +import torch +import spconv.pytorch as spconv + +try: + import ocnn +except ImportError: + ocnn = None +from addict import Dict + +from .serialization import encode, decode +from ..utils import offset2batch, batch2offset +import torch_scatter + +class Point(Dict): + """ + Point Structure of Pointcept + + A Point (point cloud) in Pointcept is a dictionary that contains various properties of + a batched point cloud. The property with the following names have a specific definition + as follows: + + - "coord": original coordinate of point cloud; + - "grid_coord": grid coordinate for specific grid size (related to GridSampling); + Point also support the following optional attributes: + - "offset": if not exist, initialized as batch size is 1; + - "batch": if not exist, initialized as batch size is 1; + - "feat": feature of point cloud, default input of model; + - "grid_size": Grid size of point cloud (related to GridSampling); + (related to Serialization) + - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range; + - "serialized_code": a list of serialization codes; + - "serialized_order": a list of serialization order determined by code; + - "serialized_inverse": a list of inverse mapping determined by code; + (related to Sparsify: SpConv) + - "sparse_shape": Sparse shape for Sparse Conv Tensor; + - "sparse_conv_feat": SparseConvTensor init with information provide by Point; + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # If one of "offset" or "batch" do not exist, generate by the existing one + if "batch" not in self.keys() and "offset" in self.keys(): + self["batch"] = offset2batch(self.offset) + elif "offset" not in self.keys() and "batch" in self.keys(): + self["offset"] = batch2offset(self.batch) + + def serialization(self, order="z", depth=None, shuffle_orders=False, min_coord=None): + """ + Point Cloud Serialization + + relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] + """ + assert "batch" in self.keys() + # if "grid_coord" not in self.keys(): + # # if you don't want to operate GridSampling in data augmentation, + # # please add the following augmentation into your pipline: + # # dict(type="Copy", keys_dict={"grid_size": 0.01}), + # # (adjust `grid_size` to what your want) + # assert {"grid_size", "coord"}.issubset(self.keys()) + # self["grid_coord"] = torch.div( + # self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" + # ).int() + if "grid_coord" not in self.keys(): + # if you don't want to operate GridSampling in data augmentation, + # please add the following augmentation into your pipline: + # dict(type="Copy", keys_dict={"grid_size": 0.01}), + # (adjust `grid_size` to what your want) + assert {"grid_size", "coord"}.issubset(self.keys()) + idx_ptr = torch.nn.functional.pad(self.offset, (1, 0), value=0) + if min_coord is None: + min_coord = torch_scatter.segment_csr(self.coord, idx_ptr, reduce="min") + self["grid_coord"] = torch.div( + self.coord - min_coord[self.batch], + self.grid_size, + rounding_mode="trunc", + ).int() + + # print(self.grid_coord.max()) + # print(int(self.grid_coord.max()).bit_length()) + + if depth is None: + # Adaptive measure the depth of serialization cube (length = 2 ^ depth) + depth = int(self.grid_coord.max()).bit_length() + self["serialized_depth"] = depth + # Maximum bit length for serialization code is 63 (int64) + assert depth * 3 + len(self.offset).bit_length() <= 63 + # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position. + # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3 + # cube with a grid size of 0.01 meter. We consider it is enough for the current stage. + # We can unlock the limitation by optimizing the z-order encoding function if necessary. + assert depth <= 16 + + # The serialization codes are arranged as following structures: + # [Order1 ([n]), + # Order2 ([n]), + # ... + # OrderN ([n])] (k, n) + code = [ + encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order + ] + code = torch.stack(code) + order = torch.argsort(code) + inverse = torch.zeros_like(order).scatter_( + dim=1, + index=order, + src=torch.arange(0, code.shape[1], device=order.device).repeat( + code.shape[0], 1 + ), + ) + + if shuffle_orders: + perm = torch.randperm(code.shape[0]) + code = code[perm] + order = order[perm] + inverse = inverse[perm] + + self["serialized_code"] = code + self["serialized_order"] = order + self["serialized_inverse"] = inverse + + def sparsify(self, pad=96): + """ + Point Cloud Serialization + + Point cloud is sparse, here we use "sparsify" to specifically refer to + preparing "spconv.SparseConvTensor" for SpConv. + + relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] + + pad: padding sparse for sparse shape. + """ + assert {"feat", "batch"}.issubset(self.keys()) + # if "grid_coord" not in self.keys(): + # # if you don't want to operate GridSampling in data augmentation, + # # please add the following augmentation into your pipline: + # # dict(type="Copy", keys_dict={"grid_size": 0.01}), + # # (adjust `grid_size` to what your want) + # assert {"grid_size", "coord"}.issubset(self.keys()) + # self["grid_coord"] = torch.div( + # self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" + # ).int() + if "grid_coord" not in self.keys(): + # if you don't want to operate GridSampling in data augmentation, + # please add the following augmentation into your pipline: + # dict(type="Copy", keys_dict={"grid_size": 0.01}), + # (adjust `grid_size` to what your want) + assert {"grid_size", "coord"}.issubset(self.keys()) + idx_ptr = torch.nn.functional.pad(self.offset, (1, 0), value=0) + min_coord = torch_scatter.segment_csr(self.coord, idx_ptr, reduce="min") + self["grid_coord"] = torch.div( + self.coord - min_coord[self.batch], + self.grid_size, + rounding_mode="trunc", + ).int() + if "sparse_shape" in self.keys(): + sparse_shape = self.sparse_shape + else: + sparse_shape = torch.add( + torch.max(self.grid_coord, dim=0).values, pad + ).tolist() + sparse_conv_feat = spconv.SparseConvTensor( + features=self.feat, + indices=torch.cat( + [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1 + ).contiguous(), + spatial_shape=sparse_shape, + batch_size=self.batch[-1].tolist() + 1, + ) + self["sparse_shape"] = sparse_shape + self["sparse_conv_feat"] = sparse_conv_feat + + def octreetization(self, depth=None, full_depth=None): + """ + Point Cloud Octreelization + + Generate octree with OCNN + relay on ["grid_coord", "batch", "feat"] + """ + assert ( + ocnn is not None + ), "Please follow https://github.com/octree-nn/ocnn-pytorch install ocnn." + assert {"grid_coord", "feat", "batch"}.issubset(self.keys()) + # add 1 to make grid space support shift order + if depth is None: + if "depth" in self.keys(): + depth = self.depth + else: + depth = int(self.grid_coord.max() + 1).bit_length() + if full_depth is None: + full_depth = 2 + self["depth"] = depth + assert depth <= 16 # maximum in ocnn + + # [0, 2**depth] -> [0, 2] -> [-1, 1] + coord = self.grid_coord / 2 ** (self.depth - 1) - 1.0 + point = ocnn.octree.Points( + points=coord, + features=self.feat, + batch_id=self.batch.unsqueeze(-1), + batch_size=self.batch[-1] + 1, + ) + octree = ocnn.octree.Octree( + depth=depth, + full_depth=full_depth, + batch_size=self.batch[-1] + 1, + device=coord.device, + ) + octree.build_octree(point) + octree.construct_all_neigh() + self["octree"] = octree diff --git a/UniRig/src/model/pointcept/utils/__init__.py b/UniRig/src/model/pointcept/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/UniRig/src/model/pointcept/utils/cache.py b/UniRig/src/model/pointcept/utils/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..a32f06c644b835925d165bca7060f8b185cf9c54 --- /dev/null +++ b/UniRig/src/model/pointcept/utils/cache.py @@ -0,0 +1,56 @@ +""" +Data Cache Utils + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os +# import SharedArray + +try: + from multiprocessing.shared_memory import ShareableList +except ImportError: + import warnings + + warnings.warn("Please update python version >= 3.8 to enable shared_memory") +import numpy as np + + +def shared_array(name, var=None): + if var is not None: + # check exist + if os.path.exists(f"/dev/shm/{name}"): + return SharedArray.attach(f"shm://{name}") + # create shared_array + data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype) + data[...] = var[...] + data.flags.writeable = False + else: + data = SharedArray.attach(f"shm://{name}").copy() + return data + + +def shared_dict(name, var=None): + name = str(name) + assert "." not in name # '.' is used as sep flag + data = {} + if var is not None: + assert isinstance(var, dict) + keys = var.keys() + # current version only cache np.array + keys_valid = [] + for key in keys: + if isinstance(var[key], np.ndarray): + keys_valid.append(key) + keys = keys_valid + + ShareableList(sequence=keys, name=name + ".keys") + for key in keys: + if isinstance(var[key], np.ndarray): + data[key] = shared_array(name=f"{name}.{key}", var=var[key]) + else: + keys = list(ShareableList(name=name + ".keys")) + for key in keys: + data[key] = shared_array(name=f"{name}.{key}") + return data diff --git a/UniRig/src/model/pointcept/utils/comm.py b/UniRig/src/model/pointcept/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..69e29e7c690fe0500d3d9a84b6a8749e2f4f4655 --- /dev/null +++ b/UniRig/src/model/pointcept/utils/comm.py @@ -0,0 +1,198 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +Modified from detectron2(https://github.com/facebookresearch/detectron2) + +Copyright (c) Xiaoyang Wu (xiaoyang.wu@connect.hku.hk). All Rights Reserved. +Please cite our work if you use any part of the code. +""" + +import functools +import numpy as np +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert ( + _LOCAL_PROCESS_GROUP is not None + ), "Local process group is not created! Please use launch() to spawn processes!" + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + if dist.get_backend() == dist.Backend.NCCL: + # This argument is needed to avoid warnings. + # It's valid only for NCCL backend. + dist.barrier(device_ids=[torch.cuda.current_device()]) + else: + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = ( + _get_global_gloo_group() + ) # use CPU group by default, to reduce GPU RAM usage. + world_size = dist.get_world_size(group) + if world_size == 1: + return [data] + + output = [None for _ in range(world_size)] + dist.all_gather_object(output, data, group=group) + return output + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + world_size = dist.get_world_size(group=group) + if world_size == 1: + return [data] + rank = dist.get_rank(group=group) + + if rank == dst: + output = [None for _ in range(world_size)] + dist.gather_object(data, output, dst=dst, group=group) + return output + else: + dist.gather_object(data, None, dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2**31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/UniRig/src/model/pointcept/utils/config.py b/UniRig/src/model/pointcept/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..0388caf896a88189ff8497cca5ac71b1bd37034f --- /dev/null +++ b/UniRig/src/model/pointcept/utils/config.py @@ -0,0 +1,695 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import copy +import os +import os.path as osp +import platform +import shutil +import sys +import tempfile +import uuid +import warnings +from argparse import Action, ArgumentParser +from collections import abc +from importlib import import_module + +from addict import Dict +from yapf.yapflib.yapf_api import FormatCode + +from .misc import import_modules_from_strings +from .path import check_file_exist + +if platform.system() == "Windows": + import regex as re +else: + import re + +BASE_KEY = "_base_" +DELETE_KEY = "_delete_" +DEPRECATION_KEY = "_deprecation_" +RESERVED_KEYS = ["filename", "text", "pretty_text"] + + +class ConfigDict(Dict): + def __missing__(self, name): + raise KeyError(name) + + def __getattr__(self, name): + try: + value = super(ConfigDict, self).__getattr__(name) + except KeyError: + ex = AttributeError( + f"'{self.__class__.__name__}' object has no " f"attribute '{name}'" + ) + except Exception as e: + ex = e + else: + return value + raise ex + + +def add_args(parser, cfg, prefix=""): + for k, v in cfg.items(): + if isinstance(v, str): + parser.add_argument("--" + prefix + k) + elif isinstance(v, int): + parser.add_argument("--" + prefix + k, type=int) + elif isinstance(v, float): + parser.add_argument("--" + prefix + k, type=float) + elif isinstance(v, bool): + parser.add_argument("--" + prefix + k, action="store_true") + elif isinstance(v, dict): + add_args(parser, v, prefix + k + ".") + elif isinstance(v, abc.Iterable): + parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+") + else: + print(f"cannot parse key {prefix + k} of type {type(v)}") + return parser + + +class Config: + """A facility for config and config files. + + It supports common file formats as configs: python/json/yaml. The interface + is the same as a dict object and also allows access config values as + attributes. + + Example: + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> cfg.a + 1 + >>> cfg.b + {'b1': [0, 1]} + >>> cfg.b.b1 + [0, 1] + >>> cfg = Config.fromfile('tests/data/config/a.py') + >>> cfg.filename + "/home/kchen/projects/mmcv/tests/data/config/a.py" + >>> cfg.item4 + 'test' + >>> cfg + "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " + "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" + """ + + @staticmethod + def _validate_py_syntax(filename): + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError( + "There are syntax errors in config " f"file {filename}: {e}" + ) + + @staticmethod + def _substitute_predefined_vars(filename, temp_config_name): + file_dirname = osp.dirname(filename) + file_basename = osp.basename(filename) + file_basename_no_extension = osp.splitext(file_basename)[0] + file_extname = osp.splitext(filename)[1] + support_templates = dict( + fileDirname=file_dirname, + fileBasename=file_basename, + fileBasenameNoExtension=file_basename_no_extension, + fileExtname=file_extname, + ) + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + for key, value in support_templates.items(): + regexp = r"\{\{\s*" + str(key) + r"\s*\}\}" + value = value.replace("\\", "/") + config_file = re.sub(regexp, value, config_file) + with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: + tmp_config_file.write(config_file) + + @staticmethod + def _pre_substitute_base_vars(filename, temp_config_name): + """Substitute base variable placehoders to string, so that parsing + would work.""" + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + base_var_dict = {} + regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}" + base_vars = set(re.findall(regexp, config_file)) + for base_var in base_vars: + randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}" + base_var_dict[randstr] = base_var + regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}" + config_file = re.sub(regexp, f'"{randstr}"', config_file) + with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: + tmp_config_file.write(config_file) + return base_var_dict + + @staticmethod + def _substitute_base_vars(cfg, base_var_dict, base_cfg): + """Substitute variable strings to their actual values.""" + cfg = copy.deepcopy(cfg) + + if isinstance(cfg, dict): + for k, v in cfg.items(): + if isinstance(v, str) and v in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[v].split("."): + new_v = new_v[new_k] + cfg[k] = new_v + elif isinstance(v, (list, tuple, dict)): + cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg) + elif isinstance(cfg, tuple): + cfg = tuple( + Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg + ) + elif isinstance(cfg, list): + cfg = [ + Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg + ] + elif isinstance(cfg, str) and cfg in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[cfg].split("."): + new_v = new_v[new_k] + cfg = new_v + + return cfg + + @staticmethod + def _file2dict(filename, use_predefined_variables=True): + filename = osp.abspath(osp.expanduser(filename)) + check_file_exist(filename) + fileExtname = osp.splitext(filename)[1] + if fileExtname not in [".py", ".json", ".yaml", ".yml"]: + raise IOError("Only py/yml/yaml/json type are supported now!") + + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile( + dir=temp_config_dir, suffix=fileExtname + ) + if platform.system() == "Windows": + temp_config_file.close() + temp_config_name = osp.basename(temp_config_file.name) + # Substitute predefined variables + if use_predefined_variables: + Config._substitute_predefined_vars(filename, temp_config_file.name) + else: + shutil.copyfile(filename, temp_config_file.name) + # Substitute base variables from placeholders to strings + base_var_dict = Config._pre_substitute_base_vars( + temp_config_file.name, temp_config_file.name + ) + + if filename.endswith(".py"): + temp_module_name = osp.splitext(temp_config_name)[0] + sys.path.insert(0, temp_config_dir) + Config._validate_py_syntax(filename) + mod = import_module(temp_module_name) + sys.path.pop(0) + cfg_dict = { + name: value + for name, value in mod.__dict__.items() + if not name.startswith("__") + } + # delete imported module + del sys.modules[temp_module_name] + elif filename.endswith((".yml", ".yaml", ".json")): + raise NotImplementedError + # close temp file + temp_config_file.close() + + # check deprecation information + if DEPRECATION_KEY in cfg_dict: + deprecation_info = cfg_dict.pop(DEPRECATION_KEY) + warning_msg = ( + f"The config file {filename} will be deprecated " "in the future." + ) + if "expected" in deprecation_info: + warning_msg += f' Please use {deprecation_info["expected"]} ' "instead." + if "reference" in deprecation_info: + warning_msg += ( + " More information can be found at " + f'{deprecation_info["reference"]}' + ) + warnings.warn(warning_msg) + + cfg_text = filename + "\n" + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + cfg_text += f.read() + + if BASE_KEY in cfg_dict: + cfg_dir = osp.dirname(filename) + base_filename = cfg_dict.pop(BASE_KEY) + base_filename = ( + base_filename if isinstance(base_filename, list) else [base_filename] + ) + + cfg_dict_list = list() + cfg_text_list = list() + for f in base_filename: + _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) + cfg_dict_list.append(_cfg_dict) + cfg_text_list.append(_cfg_text) + + base_cfg_dict = dict() + for c in cfg_dict_list: + duplicate_keys = base_cfg_dict.keys() & c.keys() + if len(duplicate_keys) > 0: + raise KeyError( + "Duplicate key is not allowed among bases. " + f"Duplicate keys: {duplicate_keys}" + ) + base_cfg_dict.update(c) + + # Substitute base variables from strings to their actual values + cfg_dict = Config._substitute_base_vars( + cfg_dict, base_var_dict, base_cfg_dict + ) + + base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) + cfg_dict = base_cfg_dict + + # merge cfg_text + cfg_text_list.append(cfg_text) + cfg_text = "\n".join(cfg_text_list) + + return cfg_dict, cfg_text + + @staticmethod + def _merge_a_into_b(a, b, allow_list_keys=False): + """merge dict ``a`` into dict ``b`` (non-inplace). + + Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid + in-place modifications. + + Args: + a (dict): The source dict to be merged into ``b``. + b (dict): The origin dict to be fetch keys from ``a``. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in source ``a`` and will replace the element of the + corresponding index in b if b is a list. Default: False. + + Returns: + dict: The modified dict of ``b`` using ``a``. + + Examples: + # Normally merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # Delete b first and merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # b is a list + >>> Config._merge_a_into_b( + ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) + [{'a': 2}, {'b': 2}] + """ + b = b.copy() + for k, v in a.items(): + if allow_list_keys and k.isdigit() and isinstance(b, list): + k = int(k) + if len(b) <= k: + raise KeyError(f"Index {k} exceeds the length of list {b}") + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): + allowed_types = (dict, list) if allow_list_keys else dict + if not isinstance(b[k], allowed_types): + raise TypeError( + f"{k}={v} in child config cannot inherit from base " + f"because {k} is a dict in the child config but is of " + f"type {type(b[k])} in base config. You may set " + f"`{DELETE_KEY}=True` to ignore the base config" + ) + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + else: + b[k] = v + return b + + @staticmethod + def fromfile(filename, use_predefined_variables=True, import_custom_modules=True): + cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables) + if import_custom_modules and cfg_dict.get("custom_imports", None): + import_modules_from_strings(**cfg_dict["custom_imports"]) + return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + + @staticmethod + def fromstring(cfg_str, file_format): + """Generate config from config str. + + Args: + cfg_str (str): Config str. + file_format (str): Config file format corresponding to the + config str. Only py/yml/yaml/json type are supported now! + + Returns: + obj:`Config`: Config obj. + """ + if file_format not in [".py", ".json", ".yaml", ".yml"]: + raise IOError("Only py/yml/yaml/json type are supported now!") + if file_format != ".py" and "dict(" in cfg_str: + # check if users specify a wrong suffix for python + warnings.warn('Please check "file_format", the file format may be .py') + with tempfile.NamedTemporaryFile( + "w", encoding="utf-8", suffix=file_format, delete=False + ) as temp_file: + temp_file.write(cfg_str) + # on windows, previous implementation cause error + # see PR 1077 for details + cfg = Config.fromfile(temp_file.name) + os.remove(temp_file.name) + return cfg + + @staticmethod + def auto_argparser(description=None): + """Generate argparser from config file automatically (experimental)""" + partial_parser = ArgumentParser(description=description) + partial_parser.add_argument("config", help="config file path") + cfg_file = partial_parser.parse_known_args()[0].config + cfg = Config.fromfile(cfg_file) + parser = ArgumentParser(description=description) + parser.add_argument("config", help="config file path") + add_args(parser, cfg) + return parser, cfg + + def __init__(self, cfg_dict=None, cfg_text=None, filename=None): + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}") + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f"{key} is reserved for config file") + + super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict)) + super(Config, self).__setattr__("_filename", filename) + if cfg_text: + text = cfg_text + elif filename: + with open(filename, "r") as f: + text = f.read() + else: + text = "" + super(Config, self).__setattr__("_text", text) + + @property + def filename(self): + return self._filename + + @property + def text(self): + return self._text + + @property + def pretty_text(self): + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split("\n") + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = f"'{v}'" + else: + v_str = str(v) + + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: {v_str}" + else: + attr_str = f"{str(k)}={v_str}" + attr_str = _indent(attr_str, indent) + + return attr_str + + def _format_list(k, v, use_mapping=False): + # check if all items in the list are dict + if all(isinstance(_, dict) for _ in v): + v_str = "[\n" + v_str += "\n".join( + f"dict({_indent(_format_dict(v_), indent)})," for v_ in v + ).rstrip(",") + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: {v_str}" + else: + attr_str = f"{str(k)}={v_str}" + attr_str = _indent(attr_str, indent) + "]" + else: + attr_str = _format_basic_types(k, v, use_mapping) + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= not str(key_name).isidentifier() + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = "" + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += "{" + for idx, (k, v) in enumerate(input_dict.items()): + is_last = idx >= len(input_dict) - 1 + end = "" if outest_level or is_last else "," + if isinstance(v, dict): + v_str = "\n" + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: dict({v_str}" + else: + attr_str = f"{str(k)}=dict({v_str}" + attr_str = _indent(attr_str, indent) + ")" + end + elif isinstance(v, list): + attr_str = _format_list(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += "\n".join(s) + if use_mapping: + r += "}" + return r + + cfg_dict = self._cfg_dict.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + # copied from setup.cfg + yapf_style = dict( + based_on_style="pep8", + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True, + ) + # text, _ = FormatCode(text, style_config=yapf_style, verify=True) + text, _ = FormatCode(text, style_config=yapf_style) + + return text + + def __repr__(self): + return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}" + + def __len__(self): + return len(self._cfg_dict) + + def __getattr__(self, name): + return getattr(self._cfg_dict, name) + + def __getitem__(self, name): + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self): + return iter(self._cfg_dict) + + def __getstate__(self): + return (self._cfg_dict, self._filename, self._text) + + def __setstate__(self, state): + _cfg_dict, _filename, _text = state + super(Config, self).__setattr__("_cfg_dict", _cfg_dict) + super(Config, self).__setattr__("_filename", _filename) + super(Config, self).__setattr__("_text", _text) + + def dump(self, file=None): + cfg_dict = super(Config, self).__getattribute__("_cfg_dict").to_dict() + if self.filename.endswith(".py"): + if file is None: + return self.pretty_text + else: + with open(file, "w", encoding="utf-8") as f: + f.write(self.pretty_text) + else: + import mmcv + + if file is None: + file_format = self.filename.split(".")[-1] + return mmcv.dump(cfg_dict, file_format=file_format) + else: + mmcv.dump(cfg_dict, file) + + def merge_from_dict(self, options, allow_list_keys=True): + """Merge list into cfg_dict. + + Merge the dict parsed by MultipleKVAction into this cfg. + + Examples: + >>> options = {'models.backbone.depth': 50, + ... 'models.backbone.with_cp':True} + >>> cfg = Config(dict(models=dict(backbone=dict(type='ResNet')))) + >>> cfg.merge_from_dict(options) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict( + ... models=dict(backbone=dict(depth=50, with_cp=True))) + + # Merge list element + >>> cfg = Config(dict(pipeline=[ + ... dict(type='LoadImage'), dict(type='LoadAnnotations')])) + >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) + >>> cfg.merge_from_dict(options, allow_list_keys=True) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict(pipeline=[ + ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')]) + + Args: + options (dict): dict of configs to merge from. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in ``options`` and will replace the element of the + corresponding index in the config if the config is a list. + Default: True. + """ + option_cfg_dict = {} + for full_key, v in options.items(): + d = option_cfg_dict + key_list = full_key.split(".") + for subkey in key_list[:-1]: + d.setdefault(subkey, ConfigDict()) + d = d[subkey] + subkey = key_list[-1] + d[subkey] = v + + cfg_dict = super(Config, self).__getattribute__("_cfg_dict") + super(Config, self).__setattr__( + "_cfg_dict", + Config._merge_a_into_b( + option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys + ), + ) + + +class DictAction(Action): + """ + argparse action to split an argument into KEY=VALUE form + on the first = and append to a dictionary. List options can + be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit + brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build + list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' + """ + + @staticmethod + def _parse_int_float_bool(val): + try: + return int(val) + except ValueError: + pass + try: + return float(val) + except ValueError: + pass + if val.lower() in ["true", "false"]: + return True if val.lower() == "true" else False + return val + + @staticmethod + def _parse_iterable(val): + """Parse iterable values in the string. + + All elements inside '()' or '[]' are treated as iterable values. + + Args: + val (str): Value string. + + Returns: + list | tuple: The expanded list or tuple from the string. + + Examples: + >>> DictAction._parse_iterable('1,2,3') + [1, 2, 3] + >>> DictAction._parse_iterable('[a, b, c]') + ['a', 'b', 'c'] + >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') + [(1, 2, 3), ['a', 'b'], 'c'] + """ + + def find_next_comma(string): + """Find the position of next comma in the string. + + If no ',' is found in the string, return the string length. All + chars inside '()' and '[]' are treated as one element and thus ',' + inside these brackets are ignored. + """ + assert (string.count("(") == string.count(")")) and ( + string.count("[") == string.count("]") + ), f"Imbalanced brackets exist in {string}" + end = len(string) + for idx, char in enumerate(string): + pre = string[:idx] + # The string before this ',' is balanced + if ( + (char == ",") + and (pre.count("(") == pre.count(")")) + and (pre.count("[") == pre.count("]")) + ): + end = idx + break + return end + + # Strip ' and " characters and replace whitespace. + val = val.strip("'\"").replace(" ", "") + is_tuple = False + if val.startswith("(") and val.endswith(")"): + is_tuple = True + val = val[1:-1] + elif val.startswith("[") and val.endswith("]"): + val = val[1:-1] + elif "," not in val: + # val is a single value + return DictAction._parse_int_float_bool(val) + + values = [] + while len(val) > 0: + comma_idx = find_next_comma(val) + element = DictAction._parse_iterable(val[:comma_idx]) + values.append(element) + val = val[comma_idx + 1 :] + if is_tuple: + values = tuple(values) + return values + + def __call__(self, parser, namespace, values, option_string=None): + options = {} + for kv in values: + key, val = kv.split("=", maxsplit=1) + options[key] = self._parse_iterable(val) + setattr(namespace, self.dest, options) diff --git a/UniRig/src/model/pointcept/utils/env.py b/UniRig/src/model/pointcept/utils/env.py new file mode 100644 index 0000000000000000000000000000000000000000..653f007dde5c4a7564e732da88dd47e7d37adf97 --- /dev/null +++ b/UniRig/src/model/pointcept/utils/env.py @@ -0,0 +1,36 @@ +""" +Environment Utils + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os +import random +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +from datetime import datetime + + +def get_random_seed(): + seed = ( + os.getpid() + + int(datetime.now().strftime("%S%f")) + + int.from_bytes(os.urandom(2), "big") + ) + return seed + + +def set_seed(seed=None): + if seed is None: + seed = get_random_seed() + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + cudnn.benchmark = False + cudnn.deterministic = True + os.environ["PYTHONHASHSEED"] = str(seed) diff --git a/UniRig/src/model/pointcept/utils/events.py b/UniRig/src/model/pointcept/utils/events.py new file mode 100644 index 0000000000000000000000000000000000000000..d2db40895bce751e2d69482578d460f187cb8d0e --- /dev/null +++ b/UniRig/src/model/pointcept/utils/events.py @@ -0,0 +1,590 @@ +""" +Events Utils + +Modified from Detectron2 + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + + +import datetime +import json +import logging +import os +import time +import torch +import numpy as np + +from typing import List, Optional, Tuple +from collections import defaultdict +from contextlib import contextmanager + +__all__ = [ + "get_event_storage", + "JSONWriter", + "TensorboardXWriter", + "CommonMetricPrinter", + "EventStorage", +] + +_CURRENT_STORAGE_STACK = [] + + +def get_event_storage(): + """ + Returns: + The :class:`EventStorage` object that's currently being used. + Throws an error if no :class:`EventStorage` is currently enabled. + """ + assert len( + _CURRENT_STORAGE_STACK + ), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!" + return _CURRENT_STORAGE_STACK[-1] + + +class EventWriter: + """ + Base class for writers that obtain events from :class:`EventStorage` and process them. + """ + + def write(self): + raise NotImplementedError + + def close(self): + pass + + +class JSONWriter(EventWriter): + """ + Write scalars to a json file. + It saves scalars as one json per line (instead of a big json) for easy parsing. + Examples parsing such a json file: + :: + $ cat metrics.json | jq -s '.[0:2]' + [ + { + "data_time": 0.008433341979980469, + "iteration": 19, + "loss": 1.9228371381759644, + "loss_box_reg": 0.050025828182697296, + "loss_classifier": 0.5316952466964722, + "loss_mask": 0.7236229181289673, + "loss_rpn_box": 0.0856662318110466, + "loss_rpn_cls": 0.48198649287223816, + "lr": 0.007173333333333333, + "time": 0.25401854515075684 + }, + { + "data_time": 0.007216215133666992, + "iteration": 39, + "loss": 1.282649278640747, + "loss_box_reg": 0.06222952902317047, + "loss_classifier": 0.30682939291000366, + "loss_mask": 0.6970193982124329, + "loss_rpn_box": 0.038663312792778015, + "loss_rpn_cls": 0.1471673548221588, + "lr": 0.007706666666666667, + "time": 0.2490077018737793 + } + ] + $ cat metrics.json | jq '.loss_mask' + 0.7126231789588928 + 0.689423680305481 + 0.6776131987571716 + ... + """ + + def __init__(self, json_file, window_size=20): + """ + Args: + json_file (str): path to the json file. New data will be appended if the file exists. + window_size (int): the window size of median smoothing for the scalars whose + `smoothing_hint` are True. + """ + self._file_handle = open(json_file, "a") + self._window_size = window_size + self._last_write = -1 + + def write(self): + storage = get_event_storage() + to_save = defaultdict(dict) + + for k, (v, iter) in storage.latest_with_smoothing_hint( + self._window_size + ).items(): + # keep scalars that have not been written + if iter <= self._last_write: + continue + to_save[iter][k] = v + if len(to_save): + all_iters = sorted(to_save.keys()) + self._last_write = max(all_iters) + + for itr, scalars_per_iter in to_save.items(): + scalars_per_iter["iteration"] = itr + self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n") + self._file_handle.flush() + try: + os.fsync(self._file_handle.fileno()) + except AttributeError: + pass + + def close(self): + self._file_handle.close() + + +class TensorboardXWriter(EventWriter): + """ + Write all scalars to a tensorboard file. + """ + + def __init__(self, log_dir: str, window_size: int = 20, **kwargs): + """ + Args: + log_dir (str): the directory to save the output events + window_size (int): the scalars will be median-smoothed by this window size + kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` + """ + self._window_size = window_size + from torch.utils.tensorboard import SummaryWriter + + self._writer = SummaryWriter(log_dir, **kwargs) + self._last_write = -1 + + def write(self): + storage = get_event_storage() + new_last_write = self._last_write + for k, (v, iter) in storage.latest_with_smoothing_hint( + self._window_size + ).items(): + if iter > self._last_write: + self._writer.add_scalar(k, v, iter) + new_last_write = max(new_last_write, iter) + self._last_write = new_last_write + + # storage.put_{image,histogram} is only meant to be used by + # tensorboard writer. So we access its internal fields directly from here. + if len(storage._vis_data) >= 1: + for img_name, img, step_num in storage._vis_data: + self._writer.add_image(img_name, img, step_num) + # Storage stores all image data and rely on this writer to clear them. + # As a result it assumes only one writer will use its image data. + # An alternative design is to let storage store limited recent + # data (e.g. only the most recent image) that all writers can access. + # In that case a writer may not see all image data if its period is long. + storage.clear_images() + + if len(storage._histograms) >= 1: + for params in storage._histograms: + self._writer.add_histogram_raw(**params) + storage.clear_histograms() + + def close(self): + if hasattr(self, "_writer"): # doesn't exist when the code fails at import + self._writer.close() + + +class CommonMetricPrinter(EventWriter): + """ + Print **common** metrics to the terminal, including + iteration time, ETA, memory, all losses, and the learning rate. + It also applies smoothing using a window of 20 elements. + It's meant to print common metrics in common ways. + To print something in more customized ways, please implement a similar printer by yourself. + """ + + def __init__(self, max_iter: Optional[int] = None, window_size: int = 20): + """ + Args: + max_iter: the maximum number of iterations to train. + Used to compute ETA. If not given, ETA will not be printed. + window_size (int): the losses will be median-smoothed by this window size + """ + self.logger = logging.getLogger(__name__) + self._max_iter = max_iter + self._window_size = window_size + self._last_write = ( + None # (step, time) of last call to write(). Used to compute ETA + ) + + def _get_eta(self, storage) -> Optional[str]: + if self._max_iter is None: + return "" + iteration = storage.iter + try: + eta_seconds = storage.history("time").median(1000) * ( + self._max_iter - iteration - 1 + ) + storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False) + return str(datetime.timedelta(seconds=int(eta_seconds))) + except KeyError: + # estimate eta on our own - more noisy + eta_string = None + if self._last_write is not None: + estimate_iter_time = (time.perf_counter() - self._last_write[1]) / ( + iteration - self._last_write[0] + ) + eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + self._last_write = (iteration, time.perf_counter()) + return eta_string + + def write(self): + storage = get_event_storage() + iteration = storage.iter + if iteration == self._max_iter: + # This hook only reports training progress (loss, ETA, etc) but not other data, + # therefore do not write anything after training succeeds, even if this method + # is called. + return + + try: + data_time = storage.history("data_time").avg(20) + except KeyError: + # they may not exist in the first few iterations (due to warmup) + # or when SimpleTrainer is not used + data_time = None + try: + iter_time = storage.history("time").global_avg() + except KeyError: + iter_time = None + try: + lr = "{:.5g}".format(storage.history("lr").latest()) + except KeyError: + lr = "N/A" + + eta_string = self._get_eta(storage) + + if torch.cuda.is_available(): + max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 + else: + max_mem_mb = None + + # NOTE: max_mem is parsed by grep in "dev/parse_results.sh" + self.logger.info( + " {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format( + eta=f"eta: {eta_string} " if eta_string else "", + iter=iteration, + losses=" ".join( + [ + "{}: {:.4g}".format(k, v.median(self._window_size)) + for k, v in storage.histories().items() + if "loss" in k + ] + ), + time="time: {:.4f} ".format(iter_time) + if iter_time is not None + else "", + data_time="data_time: {:.4f} ".format(data_time) + if data_time is not None + else "", + lr=lr, + memory="max_mem: {:.0f}M".format(max_mem_mb) + if max_mem_mb is not None + else "", + ) + ) + + +class EventStorage: + """ + The user-facing class that provides metric storage functionalities. + In the future we may add support for storing / logging other types of data if needed. + """ + + def __init__(self, start_iter=0): + """ + Args: + start_iter (int): the iteration number to start with + """ + self._history = defaultdict(AverageMeter) + self._smoothing_hints = {} + self._latest_scalars = {} + self._iter = start_iter + self._current_prefix = "" + self._vis_data = [] + self._histograms = [] + + # def put_image(self, img_name, img_tensor): + # """ + # Add an `img_tensor` associated with `img_name`, to be shown on + # tensorboard. + # Args: + # img_name (str): The name of the image to put into tensorboard. + # img_tensor (torch.Tensor or numpy.array): An `uint8` or `float` + # Tensor of shape `[channel, height, width]` where `channel` is + # 3. The image format should be RGB. The elements in img_tensor + # can either have values in [0, 1] (float32) or [0, 255] (uint8). + # The `img_tensor` will be visualized in tensorboard. + # """ + # self._vis_data.append((img_name, img_tensor, self._iter)) + + def put_scalar(self, name, value, n=1, smoothing_hint=False): + """ + Add a scalar `value` to the `HistoryBuffer` associated with `name`. + Args: + smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be + smoothed when logged. The hint will be accessible through + :meth:`EventStorage.smoothing_hints`. A writer may ignore the hint + and apply custom smoothing rule. + It defaults to True because most scalars we save need to be smoothed to + provide any useful signal. + """ + name = self._current_prefix + name + history = self._history[name] + history.update(value, n) + self._latest_scalars[name] = (value, self._iter) + + existing_hint = self._smoothing_hints.get(name) + if existing_hint is not None: + assert ( + existing_hint == smoothing_hint + ), "Scalar {} was put with a different smoothing_hint!".format(name) + else: + self._smoothing_hints[name] = smoothing_hint + + # def put_scalars(self, *, smoothing_hint=True, **kwargs): + # """ + # Put multiple scalars from keyword arguments. + # Examples: + # storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True) + # """ + # for k, v in kwargs.items(): + # self.put_scalar(k, v, smoothing_hint=smoothing_hint) + # + # def put_histogram(self, hist_name, hist_tensor, bins=1000): + # """ + # Create a histogram from a tensor. + # Args: + # hist_name (str): The name of the histogram to put into tensorboard. + # hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted + # into a histogram. + # bins (int): Number of histogram bins. + # """ + # ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item() + # + # # Create a histogram with PyTorch + # hist_counts = torch.histc(hist_tensor, bins=bins) + # hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32) + # + # # Parameter for the add_histogram_raw function of SummaryWriter + # hist_params = dict( + # tag=hist_name, + # min=ht_min, + # max=ht_max, + # num=len(hist_tensor), + # sum=float(hist_tensor.sum()), + # sum_squares=float(torch.sum(hist_tensor**2)), + # bucket_limits=hist_edges[1:].tolist(), + # bucket_counts=hist_counts.tolist(), + # global_step=self._iter, + # ) + # self._histograms.append(hist_params) + + def history(self, name): + """ + Returns: + AverageMeter: the history for name + """ + ret = self._history.get(name, None) + if ret is None: + raise KeyError("No history metric available for {}!".format(name)) + return ret + + def histories(self): + """ + Returns: + dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars + """ + return self._history + + def latest(self): + """ + Returns: + dict[str -> (float, int)]: mapping from the name of each scalar to the most + recent value and the iteration number its added. + """ + return self._latest_scalars + + def latest_with_smoothing_hint(self, window_size=20): + """ + Similar to :meth:`latest`, but the returned values + are either the un-smoothed original latest value, + or a median of the given window_size, + depend on whether the smoothing_hint is True. + This provides a default behavior that other writers can use. + """ + result = {} + for k, (v, itr) in self._latest_scalars.items(): + result[k] = ( + self._history[k].median(window_size) if self._smoothing_hints[k] else v, + itr, + ) + return result + + def smoothing_hints(self): + """ + Returns: + dict[name -> bool]: the user-provided hint on whether the scalar + is noisy and needs smoothing. + """ + return self._smoothing_hints + + def step(self): + """ + User should either: (1) Call this function to increment storage.iter when needed. Or + (2) Set `storage.iter` to the correct iteration number before each iteration. + The storage will then be able to associate the new data with an iteration number. + """ + self._iter += 1 + + @property + def iter(self): + """ + Returns: + int: The current iteration number. When used together with a trainer, + this is ensured to be the same as trainer.iter. + """ + return self._iter + + @iter.setter + def iter(self, val): + self._iter = int(val) + + @property + def iteration(self): + # for backward compatibility + return self._iter + + def __enter__(self): + _CURRENT_STORAGE_STACK.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert _CURRENT_STORAGE_STACK[-1] == self + _CURRENT_STORAGE_STACK.pop() + + @contextmanager + def name_scope(self, name): + """ + Yields: + A context within which all the events added to this storage + will be prefixed by the name scope. + """ + old_prefix = self._current_prefix + self._current_prefix = name.rstrip("/") + "/" + yield + self._current_prefix = old_prefix + + def clear_images(self): + """ + Delete all the stored images for visualization. This should be called + after images are written to tensorboard. + """ + self._vis_data = [] + + def clear_histograms(self): + """ + Delete all the stored histograms for visualization. + This should be called after histograms are written to tensorboard. + """ + self._histograms = [] + + def reset_history(self, name): + ret = self._history.get(name, None) + if ret is None: + raise KeyError("No history metric available for {}!".format(name)) + ret.reset() + + def reset_histories(self): + for name in self._history.keys(): + self._history[name].reset() + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self): + self.val = 0 + self.avg = 0 + self.total = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.total = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.total += val * n + self.count += n + self.avg = self.total / self.count + + +class HistoryBuffer: + """ + Track a series of scalar values and provide access to smoothed values over a + window or the global average of the series. + """ + + def __init__(self, max_length: int = 1000000) -> None: + """ + Args: + max_length: maximal number of values that can be stored in the + buffer. When the capacity of the buffer is exhausted, old + values will be removed. + """ + self._max_length: int = max_length + self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs + self._count: int = 0 + self._global_avg: float = 0 + + def update(self, value: float, iteration: Optional[float] = None) -> None: + """ + Add a new scalar value produced at certain iteration. If the length + of the buffer exceeds self._max_length, the oldest element will be + removed from the buffer. + """ + if iteration is None: + iteration = self._count + if len(self._data) == self._max_length: + self._data.pop(0) + self._data.append((value, iteration)) + + self._count += 1 + self._global_avg += (value - self._global_avg) / self._count + + def latest(self) -> float: + """ + Return the latest scalar value added to the buffer. + """ + return self._data[-1][0] + + def median(self, window_size: int) -> float: + """ + Return the median of the latest `window_size` values in the buffer. + """ + return np.median([x[0] for x in self._data[-window_size:]]) + + def avg(self, window_size: int) -> float: + """ + Return the mean of the latest `window_size` values in the buffer. + """ + return np.mean([x[0] for x in self._data[-window_size:]]) + + def global_avg(self) -> float: + """ + Return the mean of all the elements in the buffer. Note that this + includes those getting removed due to limited buffer storage. + """ + return self._global_avg + + def values(self) -> List[Tuple[float, float]]: + """ + Returns: + list[(number, iteration)]: content of the current buffer. + """ + return self._data diff --git a/UniRig/src/model/pointcept/utils/logger.py b/UniRig/src/model/pointcept/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ddaf2c5a765c9f1325737c3cbc73e1169f13cdd4 --- /dev/null +++ b/UniRig/src/model/pointcept/utils/logger.py @@ -0,0 +1,172 @@ +""" +Logger Utils + +Modified from mmcv + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import logging +import torch +import torch.distributed as dist + +from termcolor import colored + +logger_initialized = {} +root_status = 0 + + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="a", color=False): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified and the process rank is 0, a FileHandler + will also be added. + + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + file_mode (str): The file mode used in opening log file. + Defaults to 'a'. + color (bool): Colorful log output. Defaults to True + + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + logger.propagate = False + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # only rank 0 will add a FileHandler + if rank == 0 and log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + plain_formatter = logging.Formatter( + "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" + ) + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + ) + else: + formatter = plain_formatter + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + if rank == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + + logger_initialized[name] = True + + return logger + + +def print_log(msg, logger=None, level=logging.INFO): + """Print a log message. + + Args: + msg (str): The message to be logged. + logger (logging.Logger | str | None): The logger to be used. + Some special loggers are: + - "silent": no message will be printed. + - other str: the logger obtained with `get_root_logger(logger)`. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger == "silent": + pass + elif isinstance(logger, str): + _logger = get_logger(logger) + _logger.log(level, msg) + else: + raise TypeError( + "logger should be either a logging.Logger object, str, " + f'"silent" or None, but got {type(logger)}' + ) + + +def get_root_logger(log_file=None, log_level=logging.INFO, file_mode="a"): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name. + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + file_mode (str): File Mode of logger. (w or a) + + Returns: + logging.Logger: The root logger. + """ + logger = get_logger( + name="pointcept", log_file=log_file, log_level=log_level, file_mode=file_mode + ) + return logger + + +def _log_api_usage(identifier: str): + """ + Internal function used to log the usage of different detectron2 components + inside facebook's infra. + """ + torch._C._log_api_usage_once("pointcept." + identifier) diff --git a/UniRig/src/model/pointcept/utils/misc.py b/UniRig/src/model/pointcept/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..d890275ac26e1ac23c88e33d00065050a2230bac --- /dev/null +++ b/UniRig/src/model/pointcept/utils/misc.py @@ -0,0 +1,159 @@ +""" +Misc + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os +import warnings +from collections import abc +import numpy as np +import torch +from importlib import import_module + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def intersection_and_union(output, target, K, ignore_index=-1): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert output.ndim in [1, 2, 3] + assert output.shape == target.shape + output = output.reshape(output.size).copy() + target = target.reshape(target.size) + output[np.where(target == ignore_index)[0]] = ignore_index + intersection = output[np.where(output == target)[0]] + area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1)) + area_output, _ = np.histogram(output, bins=np.arange(K + 1)) + area_target, _ = np.histogram(target, bins=np.arange(K + 1)) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + + +def intersection_and_union_gpu(output, target, k, ignore_index=-1): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert output.dim() in [1, 2, 3] + assert output.shape == target.shape + output = output.view(-1) + target = target.view(-1) + output[target == ignore_index] = ignore_index + intersection = output[output == target] + area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1) + area_output = torch.histc(output, bins=k, min=0, max=k - 1) + area_target = torch.histc(target, bins=k, min=0, max=k - 1) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + + +def make_dirs(dir_name): + if not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + +def find_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def is_seq_of(seq, expected_type, seq_type=None): + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def is_str(x): + """Whether the input is an string instance. + + Note: This method is deprecated since python 2 is no longer supported. + """ + return isinstance(x, str) + + +def import_modules_from_strings(imports, allow_failed_imports=False): + """Import modules from the given list of strings. + + Args: + imports (list | str | None): The given module names to be imported. + allow_failed_imports (bool): If True, the failed imports will return + None. Otherwise, an ImportError is raise. Default: False. + + Returns: + list[module] | module | None: The imported modules. + + Examples: + >>> osp, sys = import_modules_from_strings( + ... ['os.path', 'sys']) + >>> import os.path as osp_ + >>> import sys as sys_ + >>> assert osp == osp_ + >>> assert sys == sys_ + """ + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError(f"custom_imports must be a list but got type {type(imports)}") + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.") + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + warnings.warn(f"{imp} failed to import and is ignored.", UserWarning) + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported diff --git a/UniRig/src/model/pointcept/utils/optimizer.py b/UniRig/src/model/pointcept/utils/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2a3fe00a7e7ac3751926568d6112a964e13930 --- /dev/null +++ b/UniRig/src/model/pointcept/utils/optimizer.py @@ -0,0 +1,67 @@ +""" +Optimizer + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import torch +from pointcept.utils.logger import get_root_logger +from pointcept.utils.registry import Registry + +OPTIMIZERS = Registry("optimizers") + + +OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD") +OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam") +OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW") + + +def build_optimizer(cfg, model, param_dicts=None): + if param_dicts is None: + # cfg.params = model.parameters() + cfg.params = [dict(names=[], params=[], lr=cfg.lr)] + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + cfg.params[0]["names"].append(n) + cfg.params[0]["params"].append(p) + else: + cfg.params = [dict(names=[], params=[], lr=cfg.lr)] + for i in range(len(param_dicts)): + param_group = dict(names=[], params=[]) + if "lr" in param_dicts[i].keys(): + param_group["lr"] = param_dicts[i].lr + if "momentum" in param_dicts[i].keys(): + param_group["momentum"] = param_dicts[i].momentum + if "weight_decay" in param_dicts[i].keys(): + param_group["weight_decay"] = param_dicts[i].weight_decay + cfg.params.append(param_group) + + for n, p in model.named_parameters(): + # !!! requires_grad is a must + if not p.requires_grad: + continue + flag = False + for i in range(len(param_dicts)): + if param_dicts[i].keyword in n: + cfg.params[i + 1]["names"].append(n) + cfg.params[i + 1]["params"].append(p) + flag = True + break + if not flag: + cfg.params[0]["names"].append(n) + cfg.params[0]["params"].append(p) + + logger = get_root_logger() + + for i in range(len(cfg.params)): + param_names = cfg.params[i].pop("names") + message = "" + for key in cfg.params[i].keys(): + if key != "params": + message += f" {key}: {cfg.params[i][key]};" + logger.info(f"Params Group {i+1} -{message} Params: {param_names}.") + # print(111) + # exit(0) + return OPTIMIZERS.build(cfg=cfg) diff --git a/UniRig/src/model/pointcept/utils/path.py b/UniRig/src/model/pointcept/utils/path.py new file mode 100644 index 0000000000000000000000000000000000000000..ce98fa5fd0dfbf6e1d61e833ecc35fea4ab2782b --- /dev/null +++ b/UniRig/src/model/pointcept/utils/path.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from pathlib import Path + +from .misc import is_str + + +def is_filepath(x): + return is_str(x) or isinstance(x, Path) + + +def fopen(filepath, *args, **kwargs): + if is_str(filepath): + return open(filepath, *args, **kwargs) + elif isinstance(filepath, Path): + return filepath.open(*args, **kwargs) + raise ValueError("`filepath` should be a string or a Path") + + +def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): + if not osp.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == "": + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def symlink(src, dst, overwrite=True, **kwargs): + if os.path.lexists(dst) and overwrite: + os.remove(dst) + os.symlink(src, dst, **kwargs) + + +def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): + """Scan a directory to find the interested files. + + Args: + dir_path (str | obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + case_sensitive (bool, optional) : If set to False, ignore the case of + suffix. Default: True. + + Returns: + A generator for all the interested files with relative paths. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + if suffix is not None and not case_sensitive: + suffix = ( + suffix.lower() + if isinstance(suffix, str) + else tuple(item.lower() for item in suffix) + ) + + root = dir_path + + def _scandir(dir_path, suffix, recursive, case_sensitive): + for entry in os.scandir(dir_path): + if not entry.name.startswith(".") and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + _rel_path = rel_path if case_sensitive else rel_path.lower() + if suffix is None or _rel_path.endswith(suffix): + yield rel_path + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir(entry.path, suffix, recursive, case_sensitive) + + return _scandir(dir_path, suffix, recursive, case_sensitive) + + +def find_vcs_root(path, markers=(".git",)): + """Finds the root directory (including itself) of specified markers. + + Args: + path (str): Path of directory or file. + markers (list[str], optional): List of file or directory names. + + Returns: + The directory contained one of the markers or None if not found. + """ + if osp.isfile(path): + path = osp.dirname(path) + + prev, cur = None, osp.abspath(osp.expanduser(path)) + while cur != prev: + if any(osp.exists(osp.join(cur, marker)) for marker in markers): + return cur + prev, cur = cur, osp.split(cur)[0] + return None diff --git a/UniRig/src/model/pointcept/utils/registry.py b/UniRig/src/model/pointcept/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac308a87d38ff61da14d6b4d5c73b4c68c15a58 --- /dev/null +++ b/UniRig/src/model/pointcept/utils/registry.py @@ -0,0 +1,316 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import warnings +from functools import partial + +from .misc import is_seq_of + + +def build_from_cfg(cfg, registry, default_args=None): + """Build a module from configs dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f"cfg must be a dict, but got {type(cfg)}") + if "type" not in cfg: + if default_args is None or "type" not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f"but got {cfg}\n{default_args}" + ) + if not isinstance(registry, Registry): + raise TypeError( + "registry must be an mmcv.Registry object, " f"but got {type(registry)}" + ) + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError( + "default_args must be a dict or None, " f"but got {type(default_args)}" + ) + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop("type") + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError(f"{obj_type} is not in the {registry.name} registry") + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}") + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f"{obj_cls.__name__}: {e}") + + +class Registry: + """A registry to map strings to classes. + + Registered object could be built from registry. + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(type='ResNet')) + + Please refer to + https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for + advanced usage. + + Args: + name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = ( + self.__class__.__name__ + f"(name={self._name}, " + f"items={self._module_dict})" + ) + return format_str + + @staticmethod + def infer_scope(): + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Example: + # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``mmdet``. + + + Returns: + scope (str): The inferred scope name. + """ + # inspect.stack() trace where this function is called, the index-2 + # indicates the frame where `infer_scope()` is called + filename = inspect.getmodule(inspect.stack()[2][0]).__name__ + split_filename = filename.split(".") + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + + The first scope will be split from key. + + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + + Return: + scope (str, None): The first scope. + key (str): The remaining key. + """ + split_index = key.find(".") + if split_index != -1: + return key[:split_index], key[split_index + 1 :] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + + Args: + key (str): The class name in string format. + + Returns: + class: The corresponding class. + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + + Example: + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> @mmdet_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(type='mmdet.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert ( + registry.scope not in self.children + ), f"scope {registry.scope} exists in {self.name} registry" + self.children[registry.scope] = registry + + def _register_module(self, module_class, module_name=None, force=False): + if not inspect.isclass(module_class): + raise TypeError("module must be a class, " f"but got {type(module_class)}") + + if module_name is None: + module_name = module_class.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f"{name} is already registered " f"in {self.name}") + self._module_dict[name] = module_class + + def deprecated_register_module(self, cls=None, force=False): + warnings.warn( + "The old API of register_module(module, force=False) " + "is deprecated and will be removed, please use the new API " + "register_module(name=None, force=False, module=None) instead." + ) + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + + Example: + >>> backbones = Registry('backbone') + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> class ResNet: + >>> pass + >>> backbones.register_module(ResNet) + + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + "name must be either of None, an instance of str or a sequence" + f" of str, but got {type(name)}" + ) + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module(module_class=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(cls): + self._register_module(module_class=cls, module_name=name, force=force) + return cls + + return _register diff --git a/UniRig/src/model/pointcept/utils/scheduler.py b/UniRig/src/model/pointcept/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..3e2e29fdde2e2668c023af36afdb89e73fb9ce53 --- /dev/null +++ b/UniRig/src/model/pointcept/utils/scheduler.py @@ -0,0 +1,147 @@ +""" +Scheduler + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import torch.optim.lr_scheduler as lr_scheduler +from .registry import Registry + +SCHEDULERS = Registry("schedulers") + + +@SCHEDULERS.register_module() +class MultiStepLR(lr_scheduler.MultiStepLR): + def __init__( + self, + optimizer, + milestones, + total_steps, + gamma=0.1, + last_epoch=-1, + verbose=False, + ): + super().__init__( + optimizer=optimizer, + milestones=[rate * total_steps for rate in milestones], + gamma=gamma, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class MultiStepWithWarmupLR(lr_scheduler.LambdaLR): + def __init__( + self, + optimizer, + milestones, + total_steps, + gamma=0.1, + warmup_rate=0.05, + warmup_scale=1e-6, + last_epoch=-1, + verbose=False, + ): + milestones = [rate * total_steps for rate in milestones] + + def multi_step_with_warmup(s): + factor = 1.0 + for i in range(len(milestones)): + if s < milestones[i]: + break + factor *= gamma + + if s <= warmup_rate * total_steps: + warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * ( + 1 - warmup_scale + ) + else: + warmup_coefficient = 1.0 + return warmup_coefficient * factor + + super().__init__( + optimizer=optimizer, + lr_lambda=multi_step_with_warmup, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class PolyLR(lr_scheduler.LambdaLR): + def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class ExpLR(lr_scheduler.LambdaLR): + def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + lr_lambda=lambda s: gamma ** (s / total_steps), + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR): + def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + T_max=total_steps, + eta_min=eta_min, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class OneCycleLR(lr_scheduler.OneCycleLR): + r""" + torch.optim.lr_scheduler.OneCycleLR, Block total_steps + """ + + def __init__( + self, + optimizer, + max_lr, + total_steps=None, + pct_start=0.3, + anneal_strategy="cos", + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose=False, + ): + super().__init__( + optimizer=optimizer, + max_lr=max_lr, + total_steps=total_steps, + pct_start=pct_start, + anneal_strategy=anneal_strategy, + cycle_momentum=cycle_momentum, + base_momentum=base_momentum, + max_momentum=max_momentum, + div_factor=div_factor, + final_div_factor=final_div_factor, + three_phase=three_phase, + last_epoch=last_epoch, + verbose=verbose, + ) + + +def build_scheduler(cfg, optimizer): + cfg.optimizer = optimizer + return SCHEDULERS.build(cfg=cfg) diff --git a/UniRig/src/model/pointcept/utils/timer.py b/UniRig/src/model/pointcept/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..3de4a16e33c43fe61ea3088f82216fd62eb6e959 --- /dev/null +++ b/UniRig/src/model/pointcept/utils/timer.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# -*- coding: utf-8 -*- + +from time import perf_counter +from typing import Optional + + +class Timer: + """ + A timer which computes the time elapsed since the start/reset of the timer. + """ + + def __init__(self) -> None: + self.reset() + + def reset(self) -> None: + """ + Reset the timer. + """ + self._start = perf_counter() + self._paused: Optional[float] = None + self._total_paused = 0 + self._count_start = 1 + + def pause(self) -> None: + """ + Pause the timer. + """ + if self._paused is not None: + raise ValueError("Trying to pause a Timer that is already paused!") + self._paused = perf_counter() + + def is_paused(self) -> bool: + """ + Returns: + bool: whether the timer is currently paused + """ + return self._paused is not None + + def resume(self) -> None: + """ + Resume the timer. + """ + if self._paused is None: + raise ValueError("Trying to resume a Timer that is not paused!") + # pyre-fixme[58]: `-` is not supported for operand types `float` and + # `Optional[float]`. + self._total_paused += perf_counter() - self._paused + self._paused = None + self._count_start += 1 + + def seconds(self) -> float: + """ + Returns: + (float): the total number of seconds since the start/reset of the + timer, excluding the time when the timer is paused. + """ + if self._paused is not None: + end_time: float = self._paused # type: ignore + else: + end_time = perf_counter() + return end_time - self._start - self._total_paused + + def avg_seconds(self) -> float: + """ + Returns: + (float): the average number of seconds between every start/reset and + pause. + """ + return self.seconds() / self._count_start diff --git a/UniRig/src/model/pointcept/utils/visualization.py b/UniRig/src/model/pointcept/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..7a010dd8289f60119d1bfbccdff65edb908e24f6 --- /dev/null +++ b/UniRig/src/model/pointcept/utils/visualization.py @@ -0,0 +1,89 @@ +""" +Visualization Utils + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os +import open3d as o3d +import numpy as np +import torch + + +def to_numpy(x): + if isinstance(x, torch.Tensor): + x = x.clone().detach().cpu().numpy() + assert isinstance(x, np.ndarray) + return x + + +def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + coord = to_numpy(coord) + if color is not None: + color = to_numpy(color) + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(coord) + pcd.colors = o3d.utility.Vector3dVector( + np.ones_like(coord) if color is None else color + ) + o3d.io.write_point_cloud(file_path, pcd) + if logger is not None: + logger.info(f"Save Point Cloud to: {file_path}") + + +def save_bounding_boxes( + bboxes_corners, color=(1.0, 0.0, 0.0), file_path="bbox.ply", logger=None +): + bboxes_corners = to_numpy(bboxes_corners) + # point list + points = bboxes_corners.reshape(-1, 3) + # line list + box_lines = np.array( + [ + [0, 1], + [1, 2], + [2, 3], + [3, 0], + [4, 5], + [5, 6], + [6, 7], + [7, 0], + [0, 4], + [1, 5], + [2, 6], + [3, 7], + ] + ) + lines = [] + for i, _ in enumerate(bboxes_corners): + lines.append(box_lines + i * 8) + lines = np.concatenate(lines) + # color list + color = np.array([color for _ in range(len(lines))]) + # generate line set + line_set = o3d.geometry.LineSet() + line_set.points = o3d.utility.Vector3dVector(points) + line_set.lines = o3d.utility.Vector2iVector(lines) + line_set.colors = o3d.utility.Vector3dVector(color) + o3d.io.write_line_set(file_path, line_set) + + if logger is not None: + logger.info(f"Save Boxes to: {file_path}") + + +def save_lines( + points, lines, color=(1.0, 0.0, 0.0), file_path="lines.ply", logger=None +): + points = to_numpy(points) + lines = to_numpy(lines) + colors = np.array([color for _ in range(len(lines))]) + line_set = o3d.geometry.LineSet() + line_set.points = o3d.utility.Vector3dVector(points) + line_set.lines = o3d.utility.Vector2iVector(lines) + line_set.colors = o3d.utility.Vector3dVector(colors) + o3d.io.write_line_set(file_path, line_set) + + if logger is not None: + logger.info(f"Save Lines to: {file_path}") diff --git a/UniRig/src/model/spec.py b/UniRig/src/model/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..a0201e9f779c802dfbe9ccb957f3e3925aca15cc --- /dev/null +++ b/UniRig/src/model/spec.py @@ -0,0 +1,134 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +import numpy as np +from numpy import ndarray +from typing import Dict, Union, List, final +import lightning.pytorch as pl + +from ..data.asset import Asset +from ..data.augment import Augment + +@dataclass +class ModelInput(): + # tokens for ar input + tokens: Union[ndarray, None]=None + + # pad token + pad: Union[int, None]=None + + # vertices(usually sampled), (N, 3) + vertices: Union[ndarray, None]=None + + # normals(usually sampled), (N, 3) + normals: Union[ndarray, None]=None + + # joints + joints: Union[ndarray, None]=None + + # tails + tails: Union[ndarray, None]=None + + # assets for debug usage + asset: Union[Asset, None]=None + + # augments asset used + augments: Union[Augment, None]=None + +class ModelSpec(pl.LightningModule, ABC): + + @abstractmethod + def __init__(self): + super().__init__() + + @final + def _process_fn(self, batch: List[ModelInput]) -> List[Dict]: + ''' + Returns + cls: List[str] + + path: List[str] + + data_name: List[str] + + joints: shape (B, J, 3), J==max_bones + + tails: shape (B, J, 3) + + parents: shape (B, J), -1 represents no parent(should always appear at 0-th position) + + num_bones: shape (B), the true number of bones + + skin: shape (B, J), padding value==0. + + vertices: (B, N, 3) + + normals: (B, N, 3) + + matrix_local: (B, J, 4, 4), current matrix_local + + pose_matrix: (B, J, 4, 4), for motion loss calculation + ''' + n_batch = self.process_fn(batch) + BAN = ['cls', 'path', 'data_name', 'joints', 'tails', 'parents', 'num_bones', 'vertices', + 'normals', 'matrix_local', 'pose_matrix', 'num_points', 'origin_vertices', + 'origin_vertex_normals', 'origin_face_normals', 'num_faces', 'faces'] + # skin should be in vertex group + max_bones = 0 + max_points = 0 + max_faces = 0 + for b in batch: + if b.joints is not None: + max_bones = max(max_bones, b.asset.J) + max_faces = max(max_faces, b.asset.F) + max_points = max(max_points, b.asset.N) + self._augments = [] + self._assets = [] + for (id, b) in enumerate(batch): + for ban in BAN: + assert ban not in n_batch[id], f"cannot override `{ban}` in process_fn" + n_batch[id]['cls'] = b.asset.cls + n_batch[id]['path'] = b.asset.path + n_batch[id]['data_name'] = b.asset.data_name + if b.asset.joints is not None: + n_batch[id]['joints'] = np.pad(b.asset.joints, ((0, max_bones-b.asset.J), (0, 0)), mode='constant', constant_values=0.) + n_batch[id]['num_bones'] = b.asset.J + if b.asset.tails is not None: + n_batch[id]['tails'] = np.pad(b.asset.tails, ((0, max_bones-b.asset.J), (0, 0)), mode='constant', constant_values=0.) + if b.asset.parents is not None: + parents = b.asset.parents.copy() # cannot put None into dict + parents[0] = -1 + parents = np.pad(parents, (0, max_bones-b.asset.J), 'constant', constant_values=-1) + n_batch[id]['parents'] = parents + if b.asset.matrix_local is not None: + J = b.asset.J + matrix_local = np.pad(b.asset.matrix_local, ((0, max_bones-J), (0, 0), (0, 0)), 'constant', constant_values=0.) + # set identity to prevent singular matrix in lbs + matrix_local[J:, 0, 0] = 1. + matrix_local[J:, 1, 1] = 1. + matrix_local[J:, 2, 2] = 1. + matrix_local[J:, 3, 3] = 1. + n_batch[id]['matrix_local'] = matrix_local + if b.asset.pose_matrix is not None: + J = b.asset.J + pose_matrix = np.pad(b.asset.pose_matrix, ((0, max_bones-J), (0, 0), (0, 0)), 'constant', constant_values=0.) + pose_matrix[J:, 0, 0] = 1. + pose_matrix[J:, 1, 1] = 1. + pose_matrix[J:, 2, 2] = 1. + pose_matrix[J:, 3, 3] = 1. + n_batch[id]['pose_matrix'] = pose_matrix + n_batch[id]['vertices'] = b.vertices + n_batch[id]['normals'] = b.normals + n_batch[id]['num_points'] = b.asset.N + n_batch[id]['origin_vertices'] = np.pad(b.asset.vertices, ((0, max_points-b.asset.N), (0, 0))) + n_batch[id]['origin_vertex_normals'] = np.pad(b.asset.vertex_normals, ((0, max_points-b.asset.N), (0, 0))) + n_batch[id]['num_faces'] = b.asset.F + n_batch[id]['origin_faces'] = np.pad(b.asset.faces, ((0, max_faces-b.asset.F), (0, 0))) + n_batch[id]['origin_face_normals'] = np.pad(b.asset.face_normals, ((0, max_faces-b.asset.F), (0, 0))) + return n_batch + + @abstractmethod + def process_fn(self, batch: List[ModelInput]) -> Dict: + ''' + Fetch data from dataloader and turn it into Tensor objects. + ''' + pass \ No newline at end of file diff --git a/UniRig/src/model/unirig_ar.py b/UniRig/src/model/unirig_ar.py new file mode 100644 index 0000000000000000000000000000000000000000..260f25923c3c10f310ea0fc061b938d24f9da233 --- /dev/null +++ b/UniRig/src/model/unirig_ar.py @@ -0,0 +1,179 @@ +import torch +from torch import nn, FloatTensor, LongTensor +import numpy as np +from torch.nn.functional import pad +from typing import Dict, List, Union +from transformers import AutoModelForCausalLM, AutoConfig + +from .spec import ModelSpec, ModelInput +from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder + +from ..tokenizer.spec import TokenizerSpec, DetokenzeOutput +from copy import deepcopy + +class UniRigAR(ModelSpec): + + def process_fn(self, batch: List[ModelInput]) -> List[Dict]: + if batch[0].joints is None: # predict + return [{} for _ in range(len(batch))] + max_length = 0 + for b in batch: + max_length = max(max_length, b.tokens.shape[0]) + res = [{ + 'input_ids': np.pad(b.tokens, ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=b.pad), + 'attention_mask': np.pad(torch.ones(b.tokens.shape[0]), ((0, max_length - b.tokens.shape[0])), 'constant', constant_values=0.), + } for b in batch] + return res + + def __init__(self, llm, mesh_encoder, **kwargs): + super().__init__() + self.tokenizer: TokenizerSpec = kwargs.get('tokenizer') + self.vocab_size = self.tokenizer.vocab_size + + _d = llm.copy() + _d['vocab_size'] = self.tokenizer.vocab_size + llm_config = AutoConfig.from_pretrained(**_d) + # Force float32 precision for the model + llm_config.torch_dtype = torch.float32 + # Force enable pre_norm + llm_config.pre_norm = True + self.transformer = AutoModelForCausalLM.from_config(config=llm_config) + + self.hidden_size = llm.hidden_size + + self.mesh_encoder = get_mesh_encoder(**mesh_encoder) + + if ( + isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo) or + isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo_encoder) + ): + self.output_proj = nn.Linear(self.mesh_encoder.width, self.hidden_size) + else: + raise NotImplementedError() + + def encode_mesh_cond(self, vertices: FloatTensor, normals: FloatTensor) -> FloatTensor: + assert not torch.isnan(vertices).any() + assert not torch.isnan(normals).any() + if ( + isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo) or + isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo_encoder) + ): + if (len(vertices.shape) == 3): + shape_embed, latents, token_num, pre_pc = self.mesh_encoder.encode_latents(pc=vertices, feats=normals) + else: + shape_embed, latents, token_num, pre_pc = self.mesh_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0)) + latents = self.output_proj(latents) + return latents + else: + raise NotImplementedError() + + def training_step(self, batch: Dict) -> Dict[str, FloatTensor]: + cond = self.encode_mesh_cond(vertices=batch['vertices'], normals=batch['normals']).to(dtype=self.transformer.dtype) + B = cond.shape[0] + input_ids: LongTensor = batch['input_ids'] + inputs_embeds = self.transformer.get_input_embeddings()(input_ids).to(dtype=self.transformer.dtype) + + inputs_embeds = torch.concat([cond, inputs_embeds], dim=1) + + attention_mask = batch['attention_mask'] + # add attention to condition + attention_mask = pad(attention_mask, (cond.shape[1], 0, 0, 0), value=1.) + output = self.transformer( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + ) + + # (B, L, vocab_size) + logit = output.logits[:, cond.shape[1]:].reshape(B, -1, self.vocab_size) + # compute loss with shift one-token right + device = logit.device # (B, n, num_discrete) + logit = logit[:, :-1] # (B, n) + num_discrete = self.tokenizer.num_discrete + s = torch.nn.functional.softmax(logit, dim=-1) + + label = input_ids[:, 1:].clone() # (B, n) + mask = label < num_discrete + dis = torch.arange(num_discrete, device=device).view(1, 1, -1) # (B, n, num_discrete) + dis = (dis - label.unsqueeze(2).repeat(1, 1, num_discrete)).type(torch.float32) / num_discrete + dis_loss = (s[:, :, :num_discrete] * torch.abs(dis))[mask].sum() / 50 # ignore padding loss + + label[attention_mask[:, cond.shape[1] + 1:]==0] = -100 + + assert not torch.isnan(logit).any(), logit + ce_loss = nn.functional.cross_entropy(logit.permute(0, 2, 1), label) + return { + 'ce_loss': ce_loss, + 'dis_loss': dis_loss, + } + + def forward(self, data: Dict): + return self.training_step(data=data) + + @torch.no_grad() + def generate( + self, + vertices: FloatTensor, + normals: FloatTensor, + cls: Union[str, None]=None, + **kwargs, + ) -> DetokenzeOutput: + ''' + Do not support batch! + ''' + cond = self.encode_mesh_cond(vertices=vertices, normals=normals).to(dtype=self.transformer.dtype) + + start_tokens = [self.tokenizer.bos] + + if cls is not None: + start_tokens.append(self.tokenizer.cls_name_to_token(cls=cls)) + start_embed = self.transformer.get_input_embeddings()( + torch.tensor(start_tokens, dtype=torch.long, device=cond.device).unsqueeze(0) + ).to(dtype=self.transformer.dtype) + cond = torch.cat([cond, start_embed], dim=1) + + results = self.transformer.generate( + inputs_embeds=cond, + bos_token_id=self.tokenizer.bos, + eos_token_id=self.tokenizer.eos, + pad_token_id=self.tokenizer.pad, + **kwargs, + ) + output_ids = results[0, :] + for token in reversed(start_tokens): + output_ids = pad(output_ids, (1, 0), value=token) + output_ids = output_ids.detach().cpu().numpy() + + res = self.tokenizer.detokenize(ids=output_ids) + return res + + def predict_step(self, batch: Dict, no_cls: bool=False): + vertices: FloatTensor = batch['vertices'] + normals : FloatTensor = batch['normals'] + paths : List[str] = batch['path'] + cls = batch['cls'] + generate_kwargs = deepcopy(batch['generate_kwargs']) + + no_cls = generate_kwargs.get('no_cls', False) + use_dir_cls = generate_kwargs.get('use_dir_cls', False) + assign_cls = generate_kwargs.get('assign_cls', None) + + generate_kwargs.pop('no_cls', None) + generate_kwargs.pop('use_dir_cls', None) + generate_kwargs.pop('assign_cls', None) + + if vertices.dim() == 2: + vertices = vertices.unsqueeze(0) + normals = normals.unsqueeze(0) + outputs = [] + for i in range(vertices.shape[0]): + if no_cls: + _cls = None + elif assign_cls is not None: + _cls = assign_cls + elif use_dir_cls: + _cls = paths[i].removeprefix('./').split('/')[0] + else: + _cls = cls[i] + res = self.generate(vertices=vertices[i], normals=normals[i], cls=_cls, **generate_kwargs) + outputs.append(res) + return outputs \ No newline at end of file diff --git a/UniRig/src/model/unirig_skin.py b/UniRig/src/model/unirig_skin.py new file mode 100644 index 0000000000000000000000000000000000000000..2e4212e1974c6036255bb638995bf2e5bebf676c --- /dev/null +++ b/UniRig/src/model/unirig_skin.py @@ -0,0 +1,440 @@ +import torch +from torch import nn, FloatTensor, LongTensor, Tensor +import torch.nn.functional as F +import numpy as np +from torch.nn.functional import pad +from typing import Dict, List +from transformers import AutoModelForCausalLM, AutoConfig +import math +import torch_scatter +from flash_attn.modules.mha import MHA + +from .spec import ModelSpec, ModelInput +from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder + +from ..data.utils import linear_blend_skinning + +class FrequencyPositionalEmbedding(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__( + self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True, + ) -> None: + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) + else: + frequencies = torch.linspace( + 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self._get_dims(input_dim) + + def _get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device)).view( + *x.shape[:-1], -1 + ) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + +class ResidualCrossAttn(nn.Module): + def __init__(self, feat_dim: int, num_heads: int): + super().__init__() + assert feat_dim % num_heads == 0, "feat_dim must be divisible by num_heads" + + self.norm1 = nn.LayerNorm(feat_dim) + self.norm2 = nn.LayerNorm(feat_dim) + # self.attention = nn.MultiheadAttention(embed_dim=feat_dim, num_heads=num_heads, batch_first=True) + self.attention = MHA(embed_dim=feat_dim, num_heads=num_heads, cross_attn=True) + self.ffn = nn.Sequential( + nn.Linear(feat_dim, feat_dim * 4), + nn.GELU(), + nn.Linear(feat_dim * 4, feat_dim), + ) + + def forward(self, q, kv): + residual = q + attn_output = self.attention(q, x_kv=kv) + x = self.norm1(residual + attn_output) + x = self.norm2(x + self.ffn(x)) + return x + +class BoneEncoder(nn.Module): + def __init__( + self, + feat_bone_dim: int, + feat_dim: int, + embed_dim: int, + num_heads: int, + num_attn: int, + ): + super().__init__() + self.feat_bone_dim = feat_bone_dim + self.feat_dim = feat_dim + self.num_heads = num_heads + self.num_attn = num_attn + + self.position_embed = FrequencyPositionalEmbedding(input_dim=self.feat_bone_dim) + + self.bone_encoder = nn.Sequential( + self.position_embed, + nn.Linear(self.position_embed.out_dim, embed_dim), + nn.LayerNorm(embed_dim), + nn.GELU(), + nn.Linear(embed_dim, embed_dim * 4), + nn.LayerNorm(embed_dim * 4), + nn.GELU(), + nn.Linear(embed_dim * 4, feat_dim), + nn.LayerNorm(feat_dim), + nn.GELU(), + ) + self.attn = nn.ModuleList() + for _ in range(self.num_attn): + self.attn.append(ResidualCrossAttn(feat_dim, self.num_heads)) + + def forward( + self, + base_bone: FloatTensor, + num_bones: LongTensor, + parents: LongTensor, + min_coord: FloatTensor, + global_latents: FloatTensor, + ): + # base_bone: (B, J, C) + B = base_bone.shape[0] + J = base_bone.shape[1] + x = self.bone_encoder((base_bone-min_coord[:, None, :]).reshape(-1, base_bone.shape[-1])).reshape(B, J, -1) + + latents = torch.cat([x, global_latents], dim=1) + + for (i, attn) in enumerate(self.attn): + x = attn(x, latents) + return x + +class SkinweightPred(nn.Module): + def __init__(self, in_dim, mlp_dim): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, mlp_dim), + nn.LayerNorm(mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, mlp_dim), + nn.LayerNorm(mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, mlp_dim), + nn.LayerNorm(mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, mlp_dim), + nn.LayerNorm(mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, 1), + ) + + def forward(self, x): + return self.net(x) + +class UniRigSkin(ModelSpec): + + def process_fn(self, batch: List[ModelInput]) -> List[Dict]: + max_bones = 0 + for b in batch: + max_bones = max(max_bones, b.asset.J) + res = [] + current_offset = 0 + for b in batch: + vertex_groups = b.asset.sampled_vertex_groups + current_offset += b.vertices.shape[0] + # (N, J) + voxel_skin = vertex_groups['voxel_skin'] + + voxel_skin = np.pad(voxel_skin, ((0, 0), (0, max_bones-b.asset.J)), 'constant', constant_values=0.0) + + # (J, 4, 4) + res.append({ + 'voxel_skin': voxel_skin, + 'offset': current_offset, + }) + return res + + def __init__(self, mesh_encoder, global_encoder, **kwargs): + super().__init__() + + self.num_train_vertex = kwargs['num_train_vertex'] + self.feat_dim = kwargs['feat_dim'] + self.num_heads = kwargs['num_heads'] + self.grid_size = kwargs['grid_size'] + self.mlp_dim = kwargs['mlp_dim'] + self.num_bone_attn = kwargs['num_bone_attn'] + self.num_mesh_bone_attn = kwargs['num_mesh_bone_attn'] + self.bone_embed_dim = kwargs['bone_embed_dim'] + self.voxel_mask = kwargs.get('voxel_mask', 2) + + self.mesh_encoder = get_mesh_encoder(**mesh_encoder) + self.global_encoder = get_mesh_encoder(**global_encoder) + if isinstance(self.mesh_encoder, MAP_MESH_ENCODER.ptv3obj): + self.feat_map = nn.Sequential( + nn.Linear(mesh_encoder['enc_channels'][-1], self.feat_dim), + nn.LayerNorm(self.feat_dim), + nn.GELU(), + ) + else: + raise NotImplementedError() + if isinstance(self.global_encoder, MAP_MESH_ENCODER.michelangelo_encoder): + self.out_proj = nn.Sequential( + nn.Linear(self.global_encoder.width, self.feat_dim), + nn.LayerNorm(self.feat_dim), + nn.GELU(), + ) + else: + raise NotImplementedError() + + self.bone_encoder = BoneEncoder( + feat_bone_dim=3, + feat_dim=self.feat_dim, + embed_dim=self.bone_embed_dim, + num_heads=self.num_heads, + num_attn=self.num_bone_attn, + ) + + self.downscale = nn.Sequential( + nn.Linear(2 * self.num_heads, self.num_heads), + nn.LayerNorm(self.num_heads), + nn.GELU(), + ) + self.skinweight_pred = SkinweightPred( + self.num_heads, + self.mlp_dim, + ) + + self.mesh_bone_attn = nn.ModuleList() + self.mesh_bone_attn.extend([ + ResidualCrossAttn(self.feat_dim, self.num_heads) for _ in range(self.num_mesh_bone_attn) + ]) + + self.qmesh = nn.Linear(self.feat_dim, self.feat_dim * self.num_heads) + self.kmesh = nn.Linear(self.feat_dim, self.feat_dim * self.num_heads) + + self.voxel_skin_embed = nn.Linear(1, self.num_heads) + self.voxel_skin_norm = nn.LayerNorm(self.num_heads) + self.attn_skin_norm = nn.LayerNorm(self.num_heads) + + def encode_mesh_cond(self, vertices: FloatTensor, normals: FloatTensor) -> FloatTensor: + assert not torch.isnan(vertices).any() + assert not torch.isnan(normals).any() + if isinstance(self.global_encoder, MAP_MESH_ENCODER.michelangelo_encoder): + if (len(vertices.shape) == 3): + shape_embed, latents, token_num, pre_pc = self.global_encoder.encode_latents(pc=vertices, feats=normals) + else: + shape_embed, latents, token_num, pre_pc = self.global_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0)) + latents = self.out_proj(latents) + return latents + else: + raise NotImplementedError() + + def _get_predict(self, batch: Dict) -> FloatTensor: + ''' + Return predicted skin. + ''' + + num_bones: Tensor = batch['num_bones'] + vertices: FloatTensor = batch['vertices'] # (B, N, 3) + normals: FloatTensor = batch['normals'] + joints: FloatTensor = batch['joints'] + tails: FloatTensor = batch['tails'] + voxel_skin: FloatTensor = batch['voxel_skin'] + parents: LongTensor = batch['parents'] + + # turn inputs' dtype into model's dtype + dtype = next(self.parameters()).dtype + vertices = vertices.type(dtype) + normals = normals.type(dtype) + joints = joints.type(dtype) + tails = tails.type(dtype) + voxel_skin = voxel_skin.type(dtype) + + B = vertices.shape[0] + N = vertices.shape[1] + J = joints.shape[1] + + assert vertices.dim() == 3 + assert normals.dim() == 3 + + part_offset = torch.tensor([(i+1)*N for i in range(B)], dtype=torch.int64, device=vertices.device) + idx_ptr = torch.nn.functional.pad(part_offset, (1, 0), value=0) + min_coord = torch_scatter.segment_csr(vertices.reshape(-1, 3), idx_ptr, reduce="min") + + pack = [] + if self.training: + train_indices = torch.randperm(N)[:self.num_train_vertex] + pack.append(train_indices) + else: + for i in range((N + self.num_train_vertex - 1) // self.num_train_vertex): + pack.append(torch.arange(i*self.num_train_vertex, min((i+1)*self.num_train_vertex, N))) + + # (B, seq_len, feat_dim) + global_latents = self.encode_mesh_cond(vertices, normals) + bone_feat = self.bone_encoder( + base_bone=joints, + num_bones=num_bones, + parents=parents, + min_coord=min_coord, + global_latents=global_latents, + ) + + if isinstance(self.mesh_encoder, MAP_MESH_ENCODER.ptv3obj): + feat = torch.cat([vertices, normals, torch.zeros_like(vertices)], dim=-1) + ptv3_input = { + 'coord': vertices.reshape(-1, 3), + 'feat': feat.reshape(-1, 9), + 'offset': torch.tensor(batch['offset']), + 'grid_size': self.grid_size, + } + if not self.training: + # must cast to float32 to avoid sparse-conv precision bugs + with torch.autocast(device_type='cuda', dtype=torch.float32): + mesh_feat = self.mesh_encoder(ptv3_input).feat + mesh_feat = self.feat_map(mesh_feat).view(B, N, self.feat_dim) + else: + mesh_feat = self.mesh_encoder(ptv3_input).feat + mesh_feat = self.feat_map(mesh_feat).view(B, N, self.feat_dim) + mesh_feat = mesh_feat.type(dtype) + else: + raise NotImplementedError() + + # (B, J + seq_len, feat_dim) + latents = torch.cat([bone_feat, global_latents], dim=1) + # (B, N, feat_dim) + for block in self.mesh_bone_attn: + mesh_feat = block( + q=mesh_feat, + kv=latents, + ) + + # trans to (B, num_heads, J, feat_dim) + bone_feat = self.kmesh(bone_feat).view(B, J, self.num_heads, self.feat_dim).transpose(1, 2) + + skin_pred_list = [] + if not self.training: + skin_mask = voxel_skin.clone() + for b in range(B): + num = num_bones[b] + for i in range(num): + p = parents[b, i] + if p < 0: + continue + skin_mask[b, :, p] += skin_mask[b, :, i] + for indices in pack: + cur_N = len(indices) + # trans to (B, num_heads, N, feat_dim) + cur_mesh_feat = self.qmesh(mesh_feat[:, indices]).view(B, cur_N, self.num_heads, self.feat_dim).transpose(1, 2) + + # attn_weight shape : (B, num_heads, N, J) + attn_weight = F.softmax(torch.bmm( + cur_mesh_feat.reshape(B * self.num_heads, cur_N, -1), + bone_feat.transpose(-2, -1).reshape(B * self.num_heads, -1, J) + ) / math.sqrt(self.feat_dim), dim=-1, dtype=dtype) + # (B, num_heads, N, J) -> (B, N, J, num_heads) + attn_weight = attn_weight.reshape(B, self.num_heads, cur_N, J).permute(0, 2, 3, 1) + attn_weight = self.attn_skin_norm(attn_weight) + + embed_voxel_skin = self.voxel_skin_embed(voxel_skin[:, indices].reshape(B, cur_N, J, 1)) + embed_voxel_skin = self.voxel_skin_norm(embed_voxel_skin) + + attn_weight = torch.cat([attn_weight, embed_voxel_skin], dim=-1) + attn_weight = self.downscale(attn_weight) + + # (B, N, J, num_heads * (1+c)) -> (B, N, J) + skin_pred = torch.zeros(B, cur_N, J).to(attn_weight.device, dtype) + for i in range(B): + # (N*J, C) + input_features = attn_weight[i, :, :num_bones[i], :].reshape(-1, attn_weight.shape[-1]) + + pred = self.skinweight_pred(input_features).reshape(cur_N, num_bones[i]) + skin_pred[i, :, :num_bones[i]] = F.softmax(pred) + skin_pred_list.append(skin_pred) + skin_pred_list = torch.cat(skin_pred_list, dim=1) + for i in range(B): + n = num_bones[i] + skin_pred_list[i, :, :n] = skin_pred_list[i, :, :n] * torch.pow(skin_mask[i, :, :n], self.voxel_mask) + skin_pred_list[i, :, :n] = skin_pred_list[i, :, :n] / skin_pred_list[i, :, :n].sum(dim=-1, keepdim=True) + return skin_pred_list, torch.cat(pack, dim=0) + + def predict_step(self, batch: Dict): + with torch.no_grad(): + num_bones: Tensor = batch['num_bones'] + + skin_pred, _ = self._get_predict(batch=batch) + outputs = [] + for i in range(skin_pred.shape[0]): + outputs.append(skin_pred[i, :, :num_bones[i]]) + return outputs \ No newline at end of file diff --git a/UniRig/src/system/__init__.py b/UniRig/src/system/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/UniRig/src/system/ar.py b/UniRig/src/system/ar.py new file mode 100644 index 0000000000000000000000000000000000000000..a07d506034b8e0a6251eb0becd865cfb3f2f163a --- /dev/null +++ b/UniRig/src/system/ar.py @@ -0,0 +1,165 @@ +from collections import defaultdict +import lightning as L +import os +import torch +import numpy as np +from torch import Tensor +from typing import Dict, Union, List +from lightning.pytorch.callbacks import BasePredictionWriter + +from numpy import ndarray + +from ..data.raw_data import RawData +from ..data.order import OrderConfig, get_order +from ..model.spec import ModelSpec +from ..tokenizer.spec import DetokenzeOutput + +class ARSystem(L.LightningModule): + + def __init__( + self, + steps_per_epoch: int, + model: ModelSpec, + generate_kwargs: Dict={}, + output_path: Union[str, None]=None, + record_res: Union[bool]=False, + validate_cast: str='bfloat16', + val_interval: Union[int, None]=None, + val_start_from: Union[int, None]=None, + ): + super().__init__() + self.save_hyperparameters(ignore="model") + self.steps_per_epoch = steps_per_epoch + self.model = model + self.generate_kwargs = generate_kwargs + self.output_path = output_path + self.record_res = record_res + self.validate_cast = validate_cast + self.val_interval = val_interval + self.val_start_from = val_start_from + + if self.record_res: + assert self.output_path is not None, "record_res is True, but output_path in ar is None" + + def _predict_step(self, batch, batch_idx, dataloader_idx=None): + batch['generate_kwargs'] = self.generate_kwargs + res = self.model.predict_step(batch) + assert isinstance(res, list), f"expect type of prediction from {self.model.__class__} to be a list, found: {type(res)}" + return res + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + try: + prediction: List[DetokenzeOutput] = self._predict_step(batch=batch, batch_idx=batch_idx, dataloader_idx=dataloader_idx) + return prediction + except Exception as e: + print(str(e)) + return [] + +class ARWriter(BasePredictionWriter): + def __init__( + self, + output_dir: Union[str, None], + order_config: Union[OrderConfig, None]=None, + **kwargs + ): + super().__init__('batch') + self.output_dir = output_dir + self.npz_dir = kwargs.get('npz_dir', None) + self.user_mode = kwargs.get('user_mode', False) + self.output_name = kwargs.get('output_name', None) # for a single name + self.repeat = kwargs.get('repeat', 1) + self.add_num = kwargs.get('add_num', False) + self.export_npz = kwargs.get('export_npz', None) + self.export_obj = kwargs.get('export_obj', None) + self.export_fbx = kwargs.get('export_fbx', None) + self.export_pc = kwargs.get('export_pc', None) + if order_config is not None: + self.order = get_order(config=order_config) + else: + self.order = None + + self._epoch = 0 + + def on_predict_end(self, trainer, pl_module): + if self._epoch < self.repeat - 1: + print(f"Finished prediction run {self._epoch + 1}/{self.repeat}, starting next run...") + self._epoch += 1 + trainer.predict_dataloader = trainer.datamodule.predict_dataloader() + trainer.predict_loop.run() + + def write_on_batch_end(self, trainer, pl_module: ARSystem, prediction: List[Dict], batch_indices, batch, batch_idx, dataloader_idx): + assert 'path' in batch + paths = batch['path'] + detokenize_output_list: List[DetokenzeOutput] = prediction + vertices = batch['vertices'] + + origin_vertices = batch['origin_vertices'] + origin_vertex_normals = batch['origin_vertex_normals'] + origin_faces = batch['origin_faces'] + origin_face_normals = batch['origin_face_normals'] + num_points = batch['num_points'] + num_faces = batch['num_faces'] + + if isinstance(origin_vertices, torch.Tensor): + origin_vertices = origin_vertices.detach().cpu().numpy() + if isinstance(origin_vertex_normals, torch.Tensor): + origin_vertex_normals = origin_vertex_normals.detach().cpu().numpy() + if isinstance(origin_faces, torch.Tensor): + origin_faces = origin_faces.detach().cpu().numpy() + if isinstance(origin_face_normals, torch.Tensor): + origin_face_normals = origin_face_normals.detach().cpu().numpy() + if isinstance(num_points, torch.Tensor): + num_points = num_points.detach().cpu().numpy() + if isinstance(num_faces, torch.Tensor): + num_faces = num_faces.detach().cpu().numpy() + + for (id, detokenize_output) in enumerate(detokenize_output_list): + assert isinstance(detokenize_output, DetokenzeOutput), f"expect item of the list to be DetokenzeOutput, found: {type(detokenize_output)}" + def make_path(save_name: str, suffix: str, trim: bool=False): + if trim: + path = os.path.relpath(paths[id], self.npz_dir) + else: + path = paths[id] + + if self.output_dir is not None: + path = os.path.join(self.output_dir, path) + + if self.add_num: + path = os.path.join(path, f"{save_name}_{self._epoch}.{suffix}") + else: + path = os.path.join(path, f"{save_name}.{suffix}") + return path + + num_p = num_points[id] + num_f = num_faces[id] + + raw_data = RawData( + vertices=origin_vertices[id, :num_p], + vertex_normals=origin_vertex_normals[id, :num_p], + faces=origin_faces[id, :num_f], + face_normals=origin_face_normals[id, :num_f], + joints=detokenize_output.joints, + tails=detokenize_output.tails, + parents=detokenize_output.parents, + skin=None, + no_skin=detokenize_output.no_skin, + names=detokenize_output.names, + matrix_local=None, + path=None, + cls=detokenize_output.cls, + ) + if not self.user_mode and self.export_npz is not None: + print(make_path(self.export_npz, 'npz')) + raw_data.save(path=make_path(self.export_npz, 'npz')) + if not self.user_mode and self.export_obj is not None: + raw_data.export_skeleton(path=make_path(self.export_obj, 'obj')) + if not self.user_mode and self.export_pc is not None: + raw_data.export_pc(path=make_path(self.export_pc, 'obj')) + if self.export_fbx is not None: + if not self.user_mode: + raw_data.export_fbx(path=make_path(self.export_fbx, 'fbx')) + else: + if self.output_name is not None: + raw_data.export_fbx(path=self.output_name) + else: + raw_data.export_fbx(path=make_path(self.export_fbx, 'fbx', trim=True)) \ No newline at end of file diff --git a/UniRig/src/system/parse.py b/UniRig/src/system/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..114cc2d16b60420ad63ce0a36fcb2ec3f574a51f --- /dev/null +++ b/UniRig/src/system/parse.py @@ -0,0 +1,27 @@ +import torch +from torch.optim import Optimizer +from lightning.pytorch import LightningModule +from lightning.pytorch.callbacks import BasePredictionWriter + +from .ar import ARSystem, ARWriter +from .skin import SkinSystem, SkinWriter + +def get_system(**kwargs) -> LightningModule: + MAP = { + 'ar': ARSystem, + 'skin': SkinSystem, + } + __target__ = kwargs['__target__'] + assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}" + del kwargs['__target__'] + return MAP[__target__](**kwargs) + +def get_writer(**kwargs) -> BasePredictionWriter: + MAP = { + 'ar': ARWriter, + 'skin': SkinWriter, + } + __target__ = kwargs['__target__'] + assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}" + del kwargs['__target__'] + return MAP[__target__](**kwargs) \ No newline at end of file diff --git a/UniRig/src/system/skin.py b/UniRig/src/system/skin.py new file mode 100644 index 0000000000000000000000000000000000000000..4548c3ea69a225ce6e1e5594d63997225aa6482a --- /dev/null +++ b/UniRig/src/system/skin.py @@ -0,0 +1,262 @@ +from collections import defaultdict + +import torch.distributed +import lightning as L +import os +import torch +import numpy as np +from torch import Tensor, FloatTensor, LongTensor +from typing import Dict, Union, List, Literal +from lightning.pytorch.callbacks import BasePredictionWriter + +from numpy import ndarray +from scipy.sparse import csr_matrix +from scipy.spatial import cKDTree + +from ..data.order import OrderConfig, get_order +from ..data.raw_data import RawSkin, RawData +from ..data.exporter import Exporter +from ..model.spec import ModelSpec + +class SkinSystem(L.LightningModule): + + def __init__( + self, + steps_per_epoch: int, + model: ModelSpec, + output_path: Union[str, None]=None, + record_res: Union[bool]=False, + val_interval: Union[int, None]=None, + val_start_from: Union[int, None]=None, + ): + super().__init__() + self.save_hyperparameters(ignore="model") + self.steps_per_epoch = steps_per_epoch + self.model = model + self.output_path = output_path + self.record_res = record_res + self.val_interval = val_interval + self.val_start_from = val_start_from + + if self.record_res: + assert self.output_path is not None, "record_res is True, but output_path in skin is None" + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + res = self.model.predict_step(batch) + + if isinstance(res, list): + return { + 'skin_pred': res, + } + elif isinstance(res, dict): + assert 'skin_pred' in res, f"expect key 'skin_pred' in prediction from {self.model.__class__}, found: {res.keys()}" + return res + else: + assert 0, f"expect type of prediction from {self.model.__class__} to be a list or dict, found: {type(res)}" + +class SkinWriter(BasePredictionWriter): + def __init__( + self, + output_dir: Union[str, None], + save_name: str, + order_config: Union[OrderConfig, None]=None, + **kwargs + ): + super().__init__('batch') + self.output_dir = output_dir + self.npz_dir = kwargs.get('npz_dir', None) + self.user_mode = kwargs.get('user_mode', False) + self.output_name = kwargs.get('output_name', None) # for a single name + self.save_name = save_name + self.add_num = kwargs.get('add_num', False) + self.export_npz = kwargs.get('export_npz', True) + self.export_fbx = kwargs.get('export_fbx', False) + if order_config is not None: + self.order = get_order(config=order_config) + else: + self.order = None + + self._epoch = 0 + + def write_on_batch_end(self, trainer, pl_module: SkinSystem, prediction: List[Dict], batch_indices, batch, batch_idx, dataloader_idx): + assert 'path' in batch + paths: List[str] = batch['path'] + data_names: List[str] = batch['data_name'] + joints: FloatTensor = batch['joints'] + num_bones: LongTensor = batch['num_bones'] + num_faces: LongTensor = batch['num_faces'] + num_points: LongTensor = batch['num_points'] + tails: FloatTensor = batch['tails'] + parents_list: LongTensor = batch['parents'] # -1 represents root + vertices: FloatTensor = batch['origin_vertices'] + sampled_vertices: FloatTensor = batch['vertices'] + faces: LongTensor = batch['origin_faces'] + + joints = joints.detach().cpu().numpy() + tails = tails.detach().cpu().numpy() + parents_list = parents_list.detach().cpu().numpy() + num_bones = num_bones.detach().cpu().numpy() + num_faces = num_faces.detach().cpu().numpy() + vertices = vertices.detach().cpu().numpy() + faces = faces.detach().cpu().numpy() + + skin_pred_list: List = prediction['skin_pred'] + ret_sampled_vertices = prediction.get('sampled_vertices', None) + if ret_sampled_vertices is not None: + assert isinstance(ret_sampled_vertices, Tensor) + sampled_vertices = ret_sampled_vertices + if isinstance(sampled_vertices, Tensor): + sampled_vertices = sampled_vertices.type(torch.float32).detach().cpu().numpy() + for (id, skin_pred) in enumerate(skin_pred_list): + if isinstance(skin_pred, Tensor): + skin_pred = skin_pred.type(torch.float32).detach().cpu().numpy() + + # TODO: add custom post-processing here + + # resample + N = num_points[id] + J = num_bones[id] + F = num_faces[id] + o_vertices = vertices[id, :N] + + _parents = parents_list[id] + parents = [] + for i in range(J): + if _parents[i] == -1: + parents.append(None) + else: + parents.append(_parents[i]) + + skin_resampled = reskin( + sampled_vertices=sampled_vertices[id], + vertices=o_vertices, + parents=parents, + faces=faces[id, :F], + sampled_skin=skin_pred, + sample_method='median', + alpha=2.0, + threshold=0.03, + ) + + def make_path(save_name: str, suffix: str, trim: bool=False): + if trim: + path = os.path.relpath(paths[id], self.npz_dir) + else: + path = paths[id] + + if self.output_dir is not None: + path = os.path.join(self.output_dir, path) + + if self.add_num: + path = os.path.join(path, f"{save_name}_{self._epoch}.{suffix}") + else: + path = os.path.join(path, f"{save_name}.{suffix}") + return path + + raw_data = RawSkin(skin=skin_pred, vertices=sampled_vertices[id], joints=joints[id, :J]) + if self.export_npz is not None: + raw_data.save(path=make_path(self.export_npz, 'npz')) + if self.export_fbx is not None: + try: + exporter = Exporter() + names = RawData.load(path=os.path.join(paths[id], data_names[id])).names + if names is None: + names = [f"bone_{i}" for i in range(J)] + if self.user_mode: + if self.output_name is not None: + path = self.output_name + else: + path = make_path(self.save_name, 'fbx', trim=True) + else: + path = make_path(self.export_fbx, 'fbx') + exporter._export_fbx( + path=path, + vertices=o_vertices, + joints=joints[id, :J], + skin=skin_resampled, + parents=parents, + names=names, + faces=faces[id, :F], + group_per_vertex=4, + tails=tails[id, :J], + use_extrude_bone=False, + use_connect_unique_child=False, + # do_not_normalize=True, + ) + except Exception as e: + print(str(e)) + + def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): + self._epoch += 1 + +def reskin( + sampled_vertices: ndarray, + vertices: ndarray, + parents: List[Union[None, int]], + faces: ndarray, + sampled_skin: ndarray, + sample_method: Literal['mean', 'median']='mean', + **kwargs, +) -> ndarray: + nearest_samples = kwargs.get('nearest_samples', 7) + iter_steps = kwargs.get('iter_steps', 1) + threshold = kwargs.get('threshold', 0.01) + alpha = kwargs.get('alpha', 2) + + assert sample_method in ['mean', 'median'] + + N = vertices.shape[0] + J = sampled_skin.shape[1] + if sample_method == 'mean': + tree = cKDTree(sampled_vertices) + dis, nearest = tree.query(vertices, k=nearest_samples, p=2) + # weighted sum + weights = np.exp(-alpha * dis) # (N, nearest_samples) + weight_sum = weights.sum(axis=1, keepdims=True) + sampled_skin_nearest = sampled_skin[nearest] + skin = (sampled_skin_nearest * weights[..., np.newaxis]).sum(axis=1) / weight_sum + elif sample_method == 'median': + tree = cKDTree(sampled_vertices) + dis, nearest = tree.query(vertices, k=nearest_samples, p=2) + skin = np.median(sampled_skin[nearest], axis=1) + else: + assert 0 + + # (from, to) + edges = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0) + edges = np.concatenate([edges, edges[:, [1, 0]]], axis=0) # (2*F*3, 2) + + # diffusion in neighbours + for _ in range(iter_steps): + sum_skin = skin.copy() + for i in reversed(range(J)): + p = parents[i] + if p is None: + continue + sum_skin[:, p] += sum_skin[:, i] + # (2*F*3, J) + # only transfer from hotter to cooler + mask = sum_skin[edges[:, 1]] < sum_skin[edges[:, 0]] + neighbor_skin = np.zeros_like(sum_skin) # (N, J) + neighbor_co = np.zeros((N, J), dtype=np.float32) + + dis = np.sqrt(((vertices[edges[:, 1]] - vertices[edges[:, 0]])**2).sum(axis=1, keepdims=True)) + co = np.exp(-dis * alpha) + + neighbor_skin[edges[:, 1]] += sum_skin[edges[:, 0]] * co * mask + neighbor_co[edges[:, 1]] += co * mask + + sum_skin = (sum_skin + neighbor_skin) / (1. + neighbor_co) + for i in range(J): + p = parents[i] + if p is None: + continue + sum_skin[:, p] -= sum_skin[:, i] + skin = sum_skin / sum_skin.sum(axis=-1, keepdims=True) + + # avoid 0-skin + mask = (skin>=threshold).any(axis=-1, keepdims=True) + skin[(skin TokenizerSpec: + MAP = { + 'tokenizer_part': TokenizerPart, + } + assert config.method in MAP, f"expect: [{','.join(MAP.keys())}], found: {config.method}" + return MAP[config.method](config=config) \ No newline at end of file diff --git a/UniRig/src/tokenizer/spec.py b/UniRig/src/tokenizer/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..8665cbf7a8908e7ab4f9c28b91bb528f474fd751 --- /dev/null +++ b/UniRig/src/tokenizer/spec.py @@ -0,0 +1,314 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Dict + +import numpy as np +from numpy import ndarray + +from typing import Union, List, Tuple +from dataclasses import dataclass + +from ..data.exporter import Exporter +from ..data.order import OrderConfig, Order, get_order + +@dataclass(frozen=True) +class TokenizerConfig(): + # which tokenizer to use + method: str + + # coord discrete + num_discrete: int + + # normalization range + continuous_range: Tuple[float, float] + + # cls token id + cls_token_id: Dict[str, int] + + # parts token id + parts_token_id: Dict[str, int] + + order_config: Union[OrderConfig, None] + + @staticmethod + def parse(config) -> 'TokenizerConfig': + order_config = config.get('order_config', None) + + return TokenizerConfig( + method=config.method, + num_discrete=config.num_discrete, + continuous_range=config.continuous_range,cls_token_id=config.cls_token_id, + parts_token_id=config.get('parts_token_id', {}), + order_config=OrderConfig.parse(order_config) if order_config is not None else None, + ) + +@dataclass(frozen=True) +class TokenizeInput(): + # (J, 6), (parent position, position) + bones: ndarray + + # (J, 3), tails of bones(this is an attribute to indicate direction, not bones[i, 3:6]). Should NOT be used for non-leaf joints. + tails: Union[ndarray, None] + + # (B, J), bool, whether there is a branch, always False for root + branch: ndarray + + # (J), bool, whether the bone is a leaf node (has no child) + is_leaf: ndarray + + # (B, J), bool, whether the bone has skin + no_skin: Union[ndarray, None] + + # string of class in tokenizer + cls: Union[str, None] + + # Part token added before the i-th bone. If parts_bias[i] is None, a spring token will be added. + parts_bias: Dict[int, Union[str, None]] + + @property + def num_bones(self): + return self.bones.shape[0] + +@dataclass(frozen=True) +class DetokenzeOutput(Exporter): + # original tokens + tokens: ndarray + + # (J, 6), (parent position, position) + bones: ndarray + + # (J), parent of each bone + parents: List[Union[int, None]] + + # (J, 3), tails of bones(this is an attribute to indicate direction, not bones[i, 3:6]) + tails: Union[ndarray, None] + + # (B, J), bool, whether the bone has skin + no_skin: Union[ndarray, None] + + # string of class in tokenizer + cls: Union[str, None] + + # part names in order + parts: List[str] + + # names of joints + names: Union[None, List[str]] + + # normalization cube + continuous_range: Tuple[float, float] + + @property + def joints(self): + return self.bones[:, 3:] + + @property + def p_joints(self): + return self.bones[:, :3] + + @property + def num_bones(self): + return self.bones.shape[0] + + def _get_parents(self) -> List[Union[int, None]]: + parents = [] + for (i, bone) in enumerate(self.bones): + p_joint = bone[:3] + dis = 999999 + pid = None + for j in reversed(range(i)): + n_dis = ((self.bones[j][3:] - p_joint)**2).sum() + if n_dis < dis: + pid = j + dis = n_dis + parents.append(pid) + return parents + + def export_skeleton(self, path: str): + parents = self._get_parents() + self._export_skeleton(joints=self.bones[:, 3:], parents=parents, path=path) + + def export_bones(self, path: str): + assert self.tails is not None, 'tails is None, cannot exporrt bones' + self._export_bones(bones=np.concatenate([self.bones[:, 3:], self.tails], axis=-1), path=path) + + def export_skeleton_sequence(self, path: str): + parents = self._get_parents() + self._export_skeleton_sequence(joints=self.bones[:, 3:], parents=parents, path=path) + +class TokenizerSpec(ABC): + """ + Abstract class for tokenizer + """ + + def __init__(self, **kwargs): + super().__init__() + pass + + @abstractmethod + def tokenize(self, input: TokenizeInput) -> ndarray: + pass + + def detokenize(self, ids: ndarray, **kwargs) -> DetokenzeOutput: + raise NotImplementedError("{} has no method 'detokenize'".format(type(self).__name__)) + + @abstractmethod + def get_require_parts(self) -> List[str]: + """All parts token names""" + pass + + @abstractmethod + def cls_name_to_token(self, cls: str) -> int: + """Cls name to token""" + pass + + @abstractmethod + def part_name_to_token(self, part: str) -> int: + """Part name to token""" + pass + + @property + @abstractmethod + def vocab_size(self): + """The vocabulary size""" + pass + + @property + def pad(self): + raise NotImplementedError("{} has no attribute 'pad'".format(type(self).__name__)) + + @property + def bos(self): + raise NotImplementedError("{} has no attribute 'bos'".format(type(self).__name__)) + + @property + def eos(self): + raise NotImplementedError("{} has no attribute 'eos'".format(type(self).__name__)) + + @property + def num_discrete(self): + raise NotImplementedError("{} has no attribute 'num_discrete'".format(type(self).__name__)) + + @property + @abstractmethod + def continuous_range(self) -> Tuple[float, float]: + pass + +def make_skeleton( + joints: ndarray, + p_joints: ndarray, + tails_dict: Dict[int, ndarray], + convert_leaf_bones_to_tails: bool, + extrude_tail_for_leaf: bool, + extrude_tail_for_branch: bool, + extrude_scale: float=0.5, + strict: bool=False, +) -> Tuple[ndarray, ndarray, List[int], List[Union[None, int]]]: + ''' + Args: + joints: heads of bones + + p_joints: parent position of joints + + tails_dict: tail position of the i-th joint + + convert_leaf_bones_to_tails: remove leaf bones and make them tails of their parents + + extrude_tail_for_leaf: add a tail for leaf bone + + extrude_tail_for_branch: add a tail for joint with multiple children + + extrude_scale: length scale of tail offset + + strict: if true, raise error when there are joints in the same location + + Returns: + bones, tails, available_bones_id, parents + ''' + assert (convert_leaf_bones_to_tails & extrude_tail_for_leaf)==False, 'cannot extrude tail for leaf when convert_leaf_bones_to_tails is True' + assert joints.shape[0] == p_joints.shape[0] + # build parents + bones = [] # (parent_position, position) + parents = [] + for (i, joint) in enumerate(joints): + if len(bones) == 0: + bones.append(np.concatenate([joint, joint])) # root + parents.append(None) + continue + p_joint = p_joints[i] + dis = 999999 + pid = None + for j in reversed(range(i)): + n_dis = ((bones[j][3:] - p_joint)**2).sum() + if n_dis < dis: + pid = j + dis = n_dis + bones.append(np.concatenate([joints[pid], joint])) + parents.append(pid) + bones = np.stack(bones) + + children = defaultdict(list) + for (i, pid) in enumerate(parents): + if pid is None: + continue + children[pid].append(i) + + available_bones_id = [] + if convert_leaf_bones_to_tails: + for (i, pid) in enumerate(parents): + if len(children[i]) != 0: + available_bones_id.append(i) + continue + tails_dict[pid] = bones[i, 3:] + else: + available_bones_id = [i for i in range(bones.shape[0])] + + # tail for leaf + for (i, pid) in enumerate(parents): + if len(children[i]) != 0: + continue + if extrude_tail_for_leaf: + d = bones[i, 3:] - bones[pid, 3:] + length = np.linalg.norm(d) + if strict: + assert length > 1e-9, 'two joints in the same point found' + elif length <= 1e-9: + d = np.array([0., 0., 1.]) + tails_dict[i] = bones[i, 3:] + d * extrude_scale + else: + tails_dict[i] = bones[i, 3:] + + # tail for branch + for (i, pid) in enumerate(parents): + if len(children[i]) <= 1: + continue + if extrude_tail_for_branch: + if pid is None: # root + av_len = 0 + for child in children[i]: + av_len += np.linalg.norm(bones[i, 3:] - bones[child, 3:]) + av_len /= len(children[i]) + d = bones[i, 3:] + np.array([0., 0., extrude_scale * av_len]) + else: + d = bones[i, 3:] - bones[pid, 3:] + length = np.linalg.norm(d) + if strict: + assert length > 1e-9, 'two joints in the same point found' + elif length <= 1e-9: + d = np.array([0., 0., 1.]) + tails_dict[i] = bones[i, 3:] + d * extrude_scale + else: + tails_dict[i] = bones[i, 3:] + + # assign new tail + for (i, pid) in enumerate(parents): + if len(children[i]) != 1: + continue + child = children[i][0] + tails_dict[i] = bones[child, 3:] + + tails = [] + for i in range(bones.shape[0]): + tails.append(tails_dict[i]) + tails = np.stack(tails) + return bones, tails, available_bones_id, parents \ No newline at end of file diff --git a/UniRig/src/tokenizer/tokenizer_part.py b/UniRig/src/tokenizer/tokenizer_part.py new file mode 100644 index 0000000000000000000000000000000000000000..b1bdd6cd8db6e690355584afd90ba97826042856 --- /dev/null +++ b/UniRig/src/tokenizer/tokenizer_part.py @@ -0,0 +1,241 @@ +import numpy as np +from numpy import ndarray + +from typing import Dict, Tuple, Union, List + +from .spec import TokenizerSpec, TokenizeInput, DetokenzeOutput, TokenizerConfig +from .spec import make_skeleton +from ..data.order import get_order + +class TokenizerPart(TokenizerSpec): + def __init__( + self, + config: TokenizerConfig, + ): + super().__init__() + + self._num_discrete = config.num_discrete + self._continuous_range = config.continuous_range + self.cls_token_id = config.cls_token_id.copy() + self.parts_token_id = config.parts_token_id.copy() + self.order = get_order(config.order_config) + _offset = config.num_discrete + + self.token_id_branch = _offset + 0 + self.token_id_bos = _offset + 1 + self.token_id_eos = _offset + 2 + self.token_id_pad = _offset + 3 + _offset += 4 + + self.token_id_spring = _offset + 0 + _offset += 1 + + assert None not in self.parts_token_id + for i in self.parts_token_id: + self.parts_token_id[i] += _offset + _offset += len(self.parts_token_id) + + self.token_id_cls_none = _offset + 0 + _offset += 1 + + for i in self.cls_token_id: + self.cls_token_id[i] += _offset + _offset += len(self.cls_token_id) + + self._vocab_size = _offset + + self.parts_token_id_name = [x for x in self.parts_token_id] + + self.part_token_to_name = {v: k for k, v in self.parts_token_id.items()} + assert len(self.part_token_to_name) == len(self.parts_token_id), 'names with same token found in parts_token_id' + self.part_token_to_name[self.token_id_spring] = None + + self.cls_token_to_name = {v: k for k, v in self.cls_token_id.items()} + assert len(self.cls_token_to_name) == len(self.cls_token_id), 'names with same token found in cls_token_id' + + def cls_name_to_token(self, cls: str) -> int: + if cls not in self.cls_token_id: + return self.token_id_cls_none + return self.cls_token_id[cls] + + def part_name_to_token(self, part: str) -> int: + assert part in self.parts_token_id, f"do not find part name `{part}` in tokenizer" + return self.parts_token_id[part] + + def tokenize(self, input: TokenizeInput) -> ndarray: + num_bones = input.num_bones + bones = discretize(t=input.bones, continuous_range=self.continuous_range, num_discrete=self.num_discrete) + tails = discretize(t=input.tails, continuous_range=self.continuous_range, num_discrete=self.num_discrete) + + branch = input.branch + is_leaf = input.is_leaf + + tokens = [self.token_id_bos] + if input.cls is None or input.cls not in self.cls_token_id: + tokens.append(self.token_id_cls_none) + else: + tokens.append(self.cls_token_id[input.cls]) + use_leaf = False + for i in range(num_bones): + # add parts token id + if i in input.parts_bias: + part = input.parts_bias[i] + if part is None: + tokens.append(self.token_id_spring) + else: + assert part in self.parts_token_id, f"do not find part name {part} in tokenizer {self.__class__}" + tokens.append(self.parts_token_id[part]) + if branch[i]: + tokens.append(self.token_id_branch) + tokens.append(bones[i, 0]) + tokens.append(bones[i, 1]) + tokens.append(bones[i, 2]) + tokens.append(bones[i, 3]) + tokens.append(bones[i, 4]) + tokens.append(bones[i, 5]) + else: + tokens.append(bones[i, 3]) + tokens.append(bones[i, 4]) + tokens.append(bones[i, 5]) + tokens.append(self.token_id_eos) + return np.array(tokens, dtype=np.int64) + + + def detokenize(self, ids: ndarray, **kwargs) -> DetokenzeOutput: + assert isinstance(ids, ndarray), 'expect ids to be ndarray' + if ids[0] != self.token_id_bos: + raise ValueError(f"first token is not bos") + trailing_pad = 0 + while trailing_pad < ids.shape[0] and ids[-trailing_pad-1] == self.token_id_pad: + trailing_pad += 1 + if ids[-1-trailing_pad] != self.token_id_eos: + raise ValueError(f"last token is not eos") + ids = ids[1:-1-trailing_pad] + joints = [] + p_joints = [] + tails_dict = {} + parts = [] + i = 0 + is_branch = False + last_joint = None + num_bones = 0 + while i < len(ids): + if ids[i] < self.num_discrete: + if is_branch: + p_joint = undiscretize(t=ids[i:i+3], continuous_range=self.continuous_range, num_discrete=self.num_discrete) + current_joint = undiscretize(t=ids[i+3:i+6], continuous_range=self.continuous_range, num_discrete=self.num_discrete) + joints.append(current_joint) + p_joints.append(p_joint) + i += 6 + else: + current_joint = undiscretize(t=ids[i:i+3], continuous_range=self.continuous_range, num_discrete=self.num_discrete) + joints.append(current_joint) + if len(p_joints) == 0: # root + p_joints.append(current_joint) + p_joint = current_joint + else: + assert last_joint is not None + p_joints.append(last_joint) + p_joint = last_joint + i += 3 + if last_joint is not None: + tails_dict[num_bones-1] = current_joint + last_joint = current_joint + num_bones += 1 + is_branch = False + elif ids[i]==self.token_id_branch: + is_branch = True + last_joint = None + i += 1 + elif ids[i]==self.token_id_spring or ids[i] in self.parts_token_id.values(): + parts.append(self.part_token_to_name[ids[i]]) + i += 1 + elif ids[i] in self.cls_token_id.values(): + cls = ids[i] + i += 1 + elif ids[i] == self.token_id_cls_none: + cls = None + i += 1 + else: + raise ValueError(f"unexpected token found: {ids[i]}") + joints = np.stack(joints) + p_joints = np.stack(p_joints) + # leaf is ignored in this tokenizer so need to extrude tails for leaf and branch + bones, tails, available_bones_id, parents = make_skeleton( + joints=joints, + p_joints=p_joints, + tails_dict=tails_dict, + convert_leaf_bones_to_tails=False, + extrude_tail_for_leaf=True, + extrude_tail_for_branch=True, + ) + bones = bones[available_bones_id] + tails = tails[available_bones_id] + if cls in self.cls_token_to_name: + cls = self.cls_token_to_name[cls] + else: + cls = None + if self.order is not None: + names = self.order.make_names(cls=cls, parts=parts, num_bones=num_bones) + else: + names = [f"bone_{i}" for i in range(num_bones)] + return DetokenzeOutput( + tokens=ids, + parents=parents, + bones=bones, + tails=tails, + no_skin=None, + cls=cls, + parts=parts, + names=names, + continuous_range=self.continuous_range, + ) + + def get_require_parts(self) -> List[str]: + return self.parts_token_id_name + + @property + def vocab_size(self): + return self._vocab_size + + @property + def pad(self): + return self.token_id_pad + + @property + def bos(self): + return self.token_id_bos + + @property + def eos(self): + return self.token_id_eos + + @property + def num_discrete(self): + return self._num_discrete + + @property + def continuous_range(self) -> Tuple[float, float]: + return self._continuous_range + +def discretize( + t: ndarray, + continuous_range: Tuple[float, float], + num_discrete: int, +) -> ndarray: + lo, hi = continuous_range + assert hi >= lo + t = (t - lo) / (hi - lo) + t *= num_discrete + return np.clip(t.round(), 0, num_discrete - 1).astype(np.int64) + +def undiscretize( + t: ndarray, + continuous_range: Tuple[float, float], + num_discrete: int, +) -> ndarray: + lo, hi = continuous_range + assert hi >= lo + t = t.astype(np.float32) + 0.5 + t /= num_discrete + return t * (hi - lo) + lo