jkorstad commited on
Commit
f499d3b
·
1 Parent(s): 8c30895

Correctly add UniRig source files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. UniRig/.gitattributes +5 -0
  2. UniRig/.gitignore +59 -0
  3. UniRig/LICENSE +21 -0
  4. UniRig/README.md +164 -0
  5. UniRig/assets/doc/devil.gif +3 -0
  6. UniRig/assets/doc/dragon.gif +3 -0
  7. UniRig/assets/doc/rabbit.gif +3 -0
  8. UniRig/assets/doc/unirig_teaser.png +3 -0
  9. UniRig/blender/add-on-vrm-v2.20.77_modified.zip +3 -0
  10. UniRig/configs/data/quick_inference.yaml +16 -0
  11. UniRig/configs/model/unirig_ar_350m_1024_81920_float32.yaml +32 -0
  12. UniRig/configs/model/unirig_skin.yaml +52 -0
  13. UniRig/configs/skeleton/mixamo.yaml +59 -0
  14. UniRig/configs/skeleton/vroid.yaml +59 -0
  15. UniRig/configs/system/ar_inference_articulationxl.yaml +14 -0
  16. UniRig/configs/system/skin.yaml +5 -0
  17. UniRig/configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml +30 -0
  18. UniRig/configs/task/quick_inference_unirig_skin.yaml +28 -0
  19. UniRig/configs/tokenizer/tokenizer_parts_articulationxl_256.yaml +14 -0
  20. UniRig/configs/transform/inference_ar_transform.yaml +30 -0
  21. UniRig/configs/transform/inference_skin_transform.yaml +32 -0
  22. UniRig/examples/bird.glb +3 -0
  23. UniRig/examples/giraffe.glb +3 -0
  24. UniRig/examples/skeleton/bird.fbx +3 -0
  25. UniRig/examples/skeleton/giraffe.fbx +3 -0
  26. UniRig/examples/skeleton/tira.fbx +3 -0
  27. UniRig/examples/skeleton/tripo_carrot.fbx +3 -0
  28. UniRig/examples/tira.glb +3 -0
  29. UniRig/examples/tripo_carrot.glb +3 -0
  30. UniRig/launch/inference/extract.sh +60 -0
  31. UniRig/launch/inference/generate_skeleton.sh +81 -0
  32. UniRig/launch/inference/generate_skin.sh +84 -0
  33. UniRig/launch/inference/merge.sh +33 -0
  34. UniRig/requirements.txt +15 -0
  35. UniRig/run.py +186 -0
  36. UniRig/src/data/__init__.py +0 -0
  37. UniRig/src/data/asset.py +433 -0
  38. UniRig/src/data/augment.py +152 -0
  39. UniRig/src/data/datapath.py +149 -0
  40. UniRig/src/data/dataset.py +231 -0
  41. UniRig/src/data/exporter.py +486 -0
  42. UniRig/src/data/extract.py +523 -0
  43. UniRig/src/data/log.py +50 -0
  44. UniRig/src/data/order.py +112 -0
  45. UniRig/src/data/raw_data.py +307 -0
  46. UniRig/src/data/sampler.py +210 -0
  47. UniRig/src/data/spec.py +15 -0
  48. UniRig/src/data/tail.py +50 -0
  49. UniRig/src/data/transform.py +107 -0
  50. UniRig/src/data/utils.py +258 -0
UniRig/.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ examples/*.glb filter=lfs diff=lfs merge=lfs -text
2
+ examples/*.fbx filter=lfs diff=lfs merge=lfs -text
3
+ examples/*.vrm filter=lfs diff=lfs merge=lfs -text
4
+ examples/*.FBX filter=lfs diff=lfs merge=lfs -text
5
+ examples/*.obj filter=lfs diff=lfs merge=lfs -text
UniRig/.gitignore ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # igonore all pychace
2
+ **/__pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # ignore tmp & output files
7
+ _data/
8
+ tmp/
9
+ *.npz
10
+ *.blend
11
+ *.blend1
12
+ *.blend2
13
+
14
+ # ignore logs
15
+ wandb/
16
+ lightning_logs/
17
+ *.log
18
+
19
+ # ignore experiments
20
+ experiments/
21
+ results/
22
+ dataset_clean/
23
+ logs/
24
+ datalist/
25
+ dataset_inference/
26
+ dataset_inference_clean/
27
+ feature_viz/
28
+
29
+ # Distribution / packaging
30
+ dist/
31
+ build/
32
+ *.egg-info/
33
+ *.egg
34
+ *.whl
35
+
36
+ # Virtual environments
37
+ venv/
38
+ env/
39
+ .env/
40
+ .venv/
41
+
42
+ # IDE specific files
43
+ .idea/
44
+ .vscode/
45
+ *.swp
46
+ *.swo
47
+ .DS_Store
48
+
49
+ # Jupyter Notebook
50
+ .ipynb_checkpoints
51
+ *.ipynb
52
+
53
+ # Unit test / coverage reports
54
+ htmlcov/
55
+ .tox/
56
+ .coverage
57
+ .coverage.*
58
+ coverage.xml
59
+ *.cover
UniRig/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2025 VAST-AI-Research and contributors.
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE
UniRig/README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UniRig: One Model to Rig Them All
2
+
3
+ <div align="center">
4
+
5
+ [![Project Page](https://img.shields.io/badge/🏠-Project%20Page-blue.svg)](https://zjp-shadow.github.io/works/UniRig/)
6
+ [![Paper](https://img.shields.io/badge/📑-Paper-green.svg)](https://arxiv.org/abs/2504.12451)
7
+ [![Model](https://img.shields.io/badge/🤗-Model-yellow.svg)](https://huggingface.co/VAST-AI/UniRig)
8
+
9
+ </div>
10
+
11
+ ![teaser](assets/doc/unirig_teaser.png)
12
+
13
+ This repository contains the official implementation for the **SIGGRAPH'25 (TOG) UniRig** framework, a unified solution for automatic 3D model rigging, developed by Tsinghua University and [Tripo](https://www.tripo3d.ai).
14
+
15
+ **Paper:** [One Model to Rig Them All: Diverse Skeleton Rigging with UniRig](https://arxiv.org/abs/2504.12451)
16
+
17
+ ## Overview
18
+
19
+ Rigging 3D models – creating a skeleton and assigning skinning weights – is a crucial but often complex and time-consuming step in 3D animation. UniRig tackles this challenge by introducing a novel, unified framework leveraging large autoregressive models to automate the process for a diverse range of 3D assets.
20
+
21
+ Combining UniRig with keyframe animation produces these following results:
22
+
23
+ | ![devil](assets/doc/devil.gif) | ![dragon](assets/doc/dragon.gif) | ![rabbit](assets/doc/rabbit.gif) |
24
+ |:-----------------------------:|:-------------------------------:|:-------------------------------:|
25
+
26
+ The full UniRig system consists of two main stages:
27
+ 1. **Skeleton Prediction:** An GPT-like transformer autoregressively predicts a topologically valid skeleton hierarchy using a novel **Skeleton Tree Tokenization** scheme.
28
+ 2. **Skinning Weight & Attribute Prediction:** A **Bone-Point Cross Attention** mechanism predicts per-vertex skinning weights and relevant bone attributes (e.g., for physics simulation) based on the predicted skeleton and input mesh geometry.
29
+
30
+ This repository provides the code implementation for the entire framework vision, with components being released progressively.
31
+
32
+ ## Key Features (Full UniRig Framework)
33
+
34
+ * **Unified Model:** Aims to handle diverse model categories (humans, animals, objects) with a single framework.
35
+ * **Automated Skeleton Generation:** Predicts topologically valid skeleton structures. **(✅ Available in current release)**
36
+ * **Automated Skinning Prediction:** Predicts per-vertex skinning weights. **(✅ Available in current release)**
37
+ * **Bone Attribute Prediction:** Predicts attributes like stiffness for physics-based secondary motion. **(⏳ Coming Soon)**
38
+ * **High Accuracy & Robustness:** Achieves state-of-the-art results on challenging datasets (as shown in the paper with Rig-XL/VRoid training).
39
+ * **Efficient Tokenization:** Uses Skeleton Tree Tokenization for compact representation and efficient processing.
40
+ * **Human-in-the-Loop Ready:** Designed to potentially support iterative refinement workflows.
41
+
42
+ ## 🚨 Current Release Status & Roadmap 🚨
43
+
44
+ We are open-sourcing UniRig progressively. Please note the current status:
45
+
46
+ **Available Now (Initial Release):**
47
+ * ✅ **Code:** Implementation for skeleton and skinning prediction.
48
+ * ✅ **Model:** Skeleton & Skinning Prediction checkpoint trained on [**Articulation-XL2.0**](https://huggingface.co/datasets/Seed3D/Articulation-XL2.0). Available on [Hugging Face](https://huggingface.co/VAST-AI/UniRig).
49
+
50
+ **Planned Future Releases:**
51
+ * ⏳ Release of the **Rig-XL** and **VRoid** datasets used in the paper.
52
+ * ⏳ Full UniRig model checkpoints (Skeleton + Skinning) trained on Rig-XL/VRoid, replicating the paper's main results.
53
+
54
+ We appreciate your patience as we prepare these components for release. Follow [VAST-AI-Research](https://github.com/orgs/VAST-AI-Research) announcements for updates!
55
+
56
+ ## Installation
57
+
58
+ 1. **Prerequisites:**
59
+ * Python 3.11
60
+ * PyTorch (tested with version >=2.3.1)
61
+
62
+ 2. **Clone the repository:**
63
+ ```bash
64
+ git clone https://github.com/VAST-AI-Research/UniRig
65
+ cd UniRig
66
+ ```
67
+
68
+ 3. **Set up a virtual environment (recommended):**
69
+ ```bash
70
+ conda create -n UniRig python=3.11
71
+ conda activate UniRig
72
+ ```
73
+
74
+ 4. **Install dependencies:**
75
+ ```bash
76
+ python -m pip install torch torchvision
77
+ python -m pip install -r requirements.txt
78
+ python -m pip install spconv-{you-cuda-version}
79
+ python -m pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{your-torch-version}+{your-cuda-version}.html --no-cache-dir
80
+ python -m pip install numpy==1.26.4
81
+ ```
82
+
83
+ 5. **Download Model Checkpoint:**
84
+ The currently available skeleton prediction model checkpoint is hosted on Hugging Face and will typically be downloaded automatically by the provided scripts/functions.
85
+
86
+ 6. **(Optional, for importing/exporting .vrm) Install the blender addon:**
87
+ The blender addon is modifed from [VRM-Addon-for-Blender](https://github.com/saturday06/VRM-Addon-for-Blender).
88
+
89
+ Make sure you are in the root directory of the project, then:
90
+ ```bash
91
+ python -c "import bpy, os; bpy.ops.preferences.addon_install(filepath=os.path.abspath('blender/add-on-vrm-v2.20.77_modified.zip'))"
92
+ ```
93
+
94
+ ## Usage
95
+
96
+ ### Skeleton Prediction (Available Now)
97
+
98
+ Generate a skeleton for your 3D model using our pre-trained model. The process automatically analyzes the geometry and predicts an appropriate skeletal structure.
99
+
100
+ ```bash
101
+ # Process a single file
102
+ bash launch/inference/generate_skeleton.sh --input examples/giraffe.glb --output results/giraffe_skeleton.fbx
103
+
104
+ # Process multiple files in a directory
105
+ bash launch/inference/generate_skeleton.sh --input_dir <your_input_directory> --output_dir <your_output_directory>
106
+
107
+ # Try different skeleton variations by changing the random seed
108
+ bash launch/inference/generate_skeleton.sh --input examples/giraffe.glb --output results/giraffe_skeleton.fbx --seed 42
109
+ ```
110
+
111
+ Supported input formats: `.obj`, `.fbx`, `.glb`, and `.vrm`
112
+
113
+ ### Skinning Weight Prediction (Available Now)
114
+ ```bash
115
+ # Skin a single file
116
+ bash launch/inference/generate_skin.sh --input examples/skeleton/giraffe.fbx --output results/giraffe_skin.fbx
117
+
118
+ # Process multiple files in a directory
119
+ bash launch/inference/generate_skin.sh --input_dir <your_input_directory> --output_dir <your_output_directory>
120
+ ```
121
+
122
+ Note that the command above uses an **edited-version** from skeleton phase. The results may degrade significantly if the skeleton is inaccurate — for example, if tail bones or wing bones are missing. Therefore, it is recommended to refine the skeleton before performing skinning in order to achieve better results.
123
+
124
+ ### Merge the Predicted Results
125
+
126
+ Combine the predicted skeleton with your original 3D model to create a fully rigged asset:
127
+
128
+ ```bash
129
+ # Merge skeleton from skeleton prediction
130
+ bash launch/inference/merge.sh --source results/giraffe_skeleton.fbx --target examples/giraffe.glb --output results/giraffe_rigged.glb
131
+
132
+ # Or merge skin from skin prediction
133
+ bash launch/inference/merge.sh --source results/giraffe_skin.fbx --target examples/giraffe.glb --output results/giraffe_rigged.glb
134
+ ```
135
+
136
+ ## Models
137
+
138
+ Available models are hosted on the: https://huggingface.co/VAST-AI/UniRig
139
+
140
+ ## System Requirements
141
+
142
+ - CUDA-enabled GPU with at least 8GB VRAM
143
+
144
+ ## Citation
145
+
146
+ ```
147
+ @article{zhang2025unirig,
148
+ title={One Model to Rig Them All: Diverse Skeleton Rigging with UniRig},
149
+ author={Zhang, Jia-Peng and Pu, Cheng-Feng and Guo, Meng-Hao and Cao, Yan-Pei and Hu, Shi-Min},
150
+ journal={arXiv preprint arXiv:2504.12451},
151
+ year={2025}
152
+ }
153
+ ```
154
+
155
+ ## Acknowledgements
156
+
157
+ We would like to thank the following open-source projects and research works:
158
+
159
+ - [OPT](https://huggingface.co/facebook/opt-350m) for model architecture
160
+ - [3DShape2VecSet](https://github.com/1zb/3DShape2VecSet) for 3D shape representation
161
+ - [SAMPart3D](https://github.com/Pointcept/SAMPart3D) and [Michelangelo](https://github.com/NeuralCarver/Michelangelo/) for shape encoder implementation
162
+ - [Articulation-XL2.0](https://huggingface.co/datasets/Seed3D/Articulation-XL2.0) for a curated dataset
163
+
164
+ We are grateful to the broader research community for their open exploration and contributions to the field of 3D generation.
UniRig/assets/doc/devil.gif ADDED

Git LFS Details

  • SHA256: 1a6d00b42bd25d6708df9ca5b69900865eb50cf2bfe3cc72d64d87873cd99749
  • Pointer size: 132 Bytes
  • Size of remote file: 1.83 MB
UniRig/assets/doc/dragon.gif ADDED

Git LFS Details

  • SHA256: 1459fb925b79b710d9496be5105eed48539942a47763029ae0cc855684c9a0e9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.92 MB
UniRig/assets/doc/rabbit.gif ADDED

Git LFS Details

  • SHA256: 124cad359767a48a4f5cf84103c5400ceb2815728c7f4c7bab7d57d0944d93a3
  • Pointer size: 131 Bytes
  • Size of remote file: 732 kB
UniRig/assets/doc/unirig_teaser.png ADDED

Git LFS Details

  • SHA256: 44b056c9355386b872de584b734c57c823c7e73be7d20257e01e15861204247f
  • Pointer size: 133 Bytes
  • Size of remote file: 12.4 MB
UniRig/blender/add-on-vrm-v2.20.77_modified.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fe1e8b7e31ec602d9b32db33c533b03d68df747837bfad3479e1057bc9937c5
3
+ size 1331571
UniRig/configs/data/quick_inference.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_dataset_dir: &input_dataset_dir ./dataset_inference
2
+ output_dataset_dir: &output_dataset_dir ./dataset_inference_clean
3
+
4
+ predict_dataset_config:
5
+ shuffle: False
6
+ batch_size: 1
7
+ num_workers: 1
8
+ pin_memory: False
9
+ persistent_workers: False
10
+ datapath_config:
11
+ input_dataset_dir: *output_dataset_dir
12
+ use_prob: False
13
+ data_path:
14
+ inference: [
15
+ [./dataset_inference_clean/inference_datalist.txt, 1.0],
16
+ ]
UniRig/configs/model/unirig_ar_350m_1024_81920_float32.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __target__: unirig_ar
2
+ llm:
3
+ pretrained_model_name_or_path: facebook/opt-350m
4
+ n_positions: 3076
5
+ max_position_embeddings: 3076
6
+ hidden_size: 1024
7
+ word_embed_proj_dim: 1024
8
+ do_layer_norm_before: True
9
+ _attn_implementation: flash_attention_2
10
+
11
+ mesh_encoder:
12
+ __target__: michelangelo_encoder
13
+ pretrained_path: ~
14
+ freeze_encoder: False
15
+ device: cpu
16
+ dtype: float32
17
+ num_latents: 512
18
+ embed_dim: 64
19
+ point_feats: 3
20
+ num_freqs: 8
21
+ include_pi: False
22
+ heads: 8
23
+ width: 512
24
+ num_encoder_layers: 16
25
+ use_ln_post: True
26
+ init_scale: 0.25
27
+ qkv_bias: False
28
+ use_checkpoint: False
29
+ flash: True
30
+ supervision_type: sdf
31
+ query_method: False
32
+ token_num: 1024
UniRig/configs/model/unirig_skin.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __target__: unirig_skin
2
+
3
+ num_train_vertex: 512 # increase this for faster speed at the cost of memory
4
+ num_heads: 16
5
+ feat_dim: 768
6
+ grid_size: 0.005
7
+ mlp_dim: 512
8
+ num_bone_attn: 8
9
+ num_mesh_bone_attn: 16
10
+ bone_embed_dim: 1024
11
+ voxel_mask: 3.0
12
+
13
+ mesh_encoder:
14
+ # vertex groups are handled in model
15
+ __target__: ptv3obj
16
+ pretrained_path: ~
17
+ freeze_encoder: False
18
+ in_channels: 9
19
+ cls_mode: False
20
+ shuffle_orders: True
21
+ drop_path: 0.0
22
+ upcast_attention: False
23
+ upcast_softmax: False
24
+ enc_depths: [3, 3, 3, 6, 16]
25
+ enc_channels: [32, 64, 128, 256, 384]
26
+ enc_num_head: [2, 4, 8, 16, 24]
27
+ enable_qknorm: True
28
+ layer_norm: False
29
+ res_linear: True
30
+
31
+ global_encoder:
32
+ __target__: michelangelo_encoder
33
+ pretrained_path: ~
34
+ freeze_encoder: False
35
+ device: cpu
36
+ dtype: float32
37
+ num_latents: 512
38
+ embed_dim: 64
39
+ point_feats: 3
40
+ num_freqs: 8
41
+ include_pi: False
42
+ heads: 8
43
+ width: 512
44
+ num_encoder_layers: 16
45
+ use_ln_post: True
46
+ init_scale: 0.25
47
+ qkv_bias: False
48
+ use_checkpoint: False
49
+ flash: True
50
+ supervision_type: sdf
51
+ query_method: False
52
+ token_num: 1024
UniRig/configs/skeleton/mixamo.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ parts_order: [body, hand]
2
+
3
+ parts:
4
+ body: [
5
+ mixamorig:Hips,
6
+ mixamorig:Spine,
7
+ mixamorig:Spine1,
8
+ mixamorig:Spine2,
9
+ mixamorig:Neck,
10
+ mixamorig:Head,
11
+ mixamorig:LeftShoulder,
12
+ mixamorig:LeftArm,
13
+ mixamorig:LeftForeArm,
14
+ mixamorig:LeftHand,
15
+ mixamorig:RightShoulder,
16
+ mixamorig:RightArm,
17
+ mixamorig:RightForeArm,
18
+ mixamorig:RightHand,
19
+ mixamorig:LeftUpLeg,
20
+ mixamorig:LeftLeg,
21
+ mixamorig:LeftFoot,
22
+ mixamorig:LeftToeBase,
23
+ mixamorig:RightUpLeg,
24
+ mixamorig:RightLeg,
25
+ mixamorig:RightFoot,
26
+ mixamorig:RightToeBase,
27
+ ]
28
+ hand: [
29
+ mixamorig:LeftHandThumb1,
30
+ mixamorig:LeftHandThumb2,
31
+ mixamorig:LeftHandThumb3,
32
+ mixamorig:LeftHandIndex1,
33
+ mixamorig:LeftHandIndex2,
34
+ mixamorig:LeftHandIndex3,
35
+ mixamorig:LeftHandMiddle1,
36
+ mixamorig:LeftHandMiddle2,
37
+ mixamorig:LeftHandMiddle3,
38
+ mixamorig:LeftHandRing1,
39
+ mixamorig:LeftHandRing2,
40
+ mixamorig:LeftHandRing3,
41
+ mixamorig:LeftHandPinky1,
42
+ mixamorig:LeftHandPinky2,
43
+ mixamorig:LeftHandPinky3,
44
+ mixamorig:RightHandIndex1,
45
+ mixamorig:RightHandIndex2,
46
+ mixamorig:RightHandIndex3,
47
+ mixamorig:RightHandThumb1,
48
+ mixamorig:RightHandThumb2,
49
+ mixamorig:RightHandThumb3,
50
+ mixamorig:RightHandMiddle1,
51
+ mixamorig:RightHandMiddle2,
52
+ mixamorig:RightHandMiddle3,
53
+ mixamorig:RightHandRing1,
54
+ mixamorig:RightHandRing2,
55
+ mixamorig:RightHandRing3,
56
+ mixamorig:RightHandPinky1,
57
+ mixamorig:RightHandPinky2,
58
+ mixamorig:RightHandPinky3,
59
+ ]
UniRig/configs/skeleton/vroid.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ parts_order: [body, hand]
2
+
3
+ parts:
4
+ body: [
5
+ J_Bip_C_Hips,
6
+ J_Bip_C_Spine,
7
+ J_Bip_C_Chest,
8
+ J_Bip_C_UpperChest,
9
+ J_Bip_C_Neck,
10
+ J_Bip_C_Head,
11
+ J_Bip_L_Shoulder,
12
+ J_Bip_L_UpperArm,
13
+ J_Bip_L_LowerArm,
14
+ J_Bip_L_Hand,
15
+ J_Bip_R_Shoulder,
16
+ J_Bip_R_UpperArm,
17
+ J_Bip_R_LowerArm,
18
+ J_Bip_R_Hand,
19
+ J_Bip_L_UpperLeg,
20
+ J_Bip_L_LowerLeg,
21
+ J_Bip_L_Foot,
22
+ J_Bip_L_ToeBase,
23
+ J_Bip_R_UpperLeg,
24
+ J_Bip_R_LowerLeg,
25
+ J_Bip_R_Foot,
26
+ J_Bip_R_ToeBase,
27
+ ]
28
+ hand: [
29
+ J_Bip_L_Thumb1,
30
+ J_Bip_L_Thumb2,
31
+ J_Bip_L_Thumb3,
32
+ J_Bip_L_Index1,
33
+ J_Bip_L_Index2,
34
+ J_Bip_L_Index3,
35
+ J_Bip_L_Middle1,
36
+ J_Bip_L_Middle2,
37
+ J_Bip_L_Middle3,
38
+ J_Bip_L_Ring1,
39
+ J_Bip_L_Ring2,
40
+ J_Bip_L_Ring3,
41
+ J_Bip_L_Little1,
42
+ J_Bip_L_Little2,
43
+ J_Bip_L_Little3,
44
+ J_Bip_R_Index1,
45
+ J_Bip_R_Index2,
46
+ J_Bip_R_Index3,
47
+ J_Bip_R_Thumb1,
48
+ J_Bip_R_Thumb2,
49
+ J_Bip_R_Thumb3,
50
+ J_Bip_R_Middle1,
51
+ J_Bip_R_Middle2,
52
+ J_Bip_R_Middle3,
53
+ J_Bip_R_Ring1,
54
+ J_Bip_R_Ring2,
55
+ J_Bip_R_Ring3,
56
+ J_Bip_R_Little1,
57
+ J_Bip_R_Little2,
58
+ J_Bip_R_Little3,
59
+ ]
UniRig/configs/system/ar_inference_articulationxl.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __target__: ar
2
+ val_interval: 1
3
+ generate_kwargs:
4
+ max_new_tokens: 2048
5
+ num_return_sequences: 1
6
+ num_beams: 15
7
+ do_sample: True
8
+ top_k: 5
9
+ top_p: 0.95
10
+ repetition_penalty: 3.0
11
+ temperature: 1.5 # must be a float
12
+ no_cls: False
13
+ assign_cls: articulationxl
14
+ use_dir_cls: False
UniRig/configs/system/skin.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __target__: skin
2
+ val_interval: 1
3
+ val_start_from: 1
4
+ output_path: tmp_skin
5
+ record_res: True
UniRig/configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mode: predict
2
+ debug: False
3
+ experiment_name: quick_inference_skeleton_articulationxl_ar_256
4
+ resume_from_checkpoint: experiments/skeleton/articulation-xl_quantization_256/model.ckpt
5
+
6
+ components:
7
+ data: quick_inference
8
+ tokenizer: tokenizer_parts_articulationxl_256
9
+ transform: inference_ar_transform
10
+ model: unirig_ar_350m_1024_81920_float32
11
+ system: ar_inference_articulationxl
12
+ data_name: raw_data.npz
13
+
14
+ writer:
15
+ __target__: ar
16
+ output_dir: ~ # export results into the same input folder
17
+ add_num: False
18
+ repeat: 1
19
+ export_npz: predict_skeleton
20
+ export_obj: skeleton
21
+ export_fbx: skeleton
22
+ # export_pc: pc
23
+
24
+ trainer:
25
+ max_epochs: 1
26
+ num_nodes: 1
27
+ devices: 1
28
+ precision: bf16-mixed
29
+ accelerator: gpu
30
+ strategy: auto
UniRig/configs/task/quick_inference_unirig_skin.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mode: predict
2
+ debug: False
3
+ experiment_name: quick_inference_skin
4
+ resume_from_checkpoint: experiments/skin/articulation-xl/model.ckpt
5
+
6
+ components:
7
+ data: quick_inference
8
+ transform: inference_skin_transform
9
+ model: unirig_skin
10
+ system: skin
11
+ data_name: predict_skeleton.npz # capture data from ar phase
12
+
13
+ writer:
14
+ __target__: skin
15
+ output_dir: results
16
+ add_num: False
17
+ repeat: 1
18
+ save_name: predict
19
+ export_npz: predict_skin # this must be specified if textured results are required
20
+ export_fbx: result_fbx
21
+
22
+ trainer:
23
+ num_nodes: 1
24
+ devices: 1
25
+ precision: bf16-mixed
26
+ accelerator: gpu
27
+ strategy: auto
28
+ inference_mode: True
UniRig/configs/tokenizer/tokenizer_parts_articulationxl_256.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: tokenizer_part
2
+ num_discrete: 256
3
+ continuous_range: [-1, 1]
4
+ cls_token_id:
5
+ vroid: 0
6
+ mixamo: 1 # this is currently untrained, do not use it
7
+ articulationxl: 2
8
+ parts_token_id:
9
+ body: 0
10
+ hand: 1
11
+ order_config:
12
+ skeleton_path:
13
+ vroid: ./configs/skeleton/vroid.yaml
14
+ mixamo: ./configs/skeleton/mixamo.yaml
UniRig/configs/transform/inference_ar_transform.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sampler_config: &sampler_config
2
+ method: mix
3
+ num_samples: 65536
4
+ vertex_samples: 8192
5
+
6
+ tail_config: &tail_config
7
+ copy_joint_to_tail: False # Be careful ! If tail is important, keep it False !!!
8
+ connect_tail_to_unique_son: True
9
+
10
+ order_config: &order_config
11
+ skeleton_path:
12
+ vroid: ./configs/skeleton/vroid.yaml
13
+ mixamo: ./configs/skeleton/mixamo.yaml
14
+
15
+ vertex_group_config: &vertex_group_config
16
+
17
+ validate_transform_config: &validate_transform_config
18
+ augment_config:
19
+ augment_affine_config:
20
+ normalize_into: [-1.0, 1.0]
21
+ random_scale_p: 0.0
22
+ random_scale: [1.0, 1.0]
23
+ random_shift_p: 0.0
24
+ random_shift: [0.0, 0.0]
25
+ tail_config: *tail_config
26
+ order_config: *order_config
27
+ vertex_group_config: *vertex_group_config
28
+ sampler_config: *sampler_config
29
+
30
+ predict_transform_config: *validate_transform_config
UniRig/configs/transform/inference_skin_transform.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sampler_config: &sampler_config
2
+ method: mix
3
+ num_samples: 32768
4
+ vertex_samples: 8192
5
+
6
+ tail_config: &tail_config
7
+ copy_joint_to_tail: False # Be careful ! If tail is important, keep it False !!!
8
+ connect_tail_to_unique_son: True
9
+
10
+ order_config: &order_config
11
+ skeleton_path:
12
+ vroid: ./configs/skeleton/vroid.yaml
13
+ mixamo: ./configs/skeleton/mixamo.yaml
14
+
15
+ predict_transform_config:
16
+ augment_config:
17
+ augment_affine_config:
18
+ normalize_into: [-1.0, 1.0]
19
+ tail_config: *tail_config
20
+ order_config: *order_config
21
+ vertex_group_config:
22
+ names: ['voxel_skin']
23
+ kwargs:
24
+ voxel_skin:
25
+ grid: 196 # increase this for better results
26
+ alpha: 0.5
27
+ link_dis: 0.00001
28
+ grid_query: 7
29
+ vertex_query: 1
30
+ grid_weight: 3.0
31
+ # mode: exp
32
+ sampler_config: *sampler_config
UniRig/examples/bird.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb59726f598ab4a4e4c431b9789317f5b2f4252a6fb57364d929f5e1ddd7b5bb
3
+ size 8032388
UniRig/examples/giraffe.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a947cae00b169345802c08885c1e54313c6ce93c885bacf8e37e8a1f18f9e3b
3
+ size 6310044
UniRig/examples/skeleton/bird.fbx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:885d432850506ab673d509ae2481067cc623452d5910090e1e15f323b9f83fa2
3
+ size 401084
UniRig/examples/skeleton/giraffe.fbx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73673a15c8103fbf9a6b39762768ae48e7c9404eab4850a60e4863ee400336fd
3
+ size 759180
UniRig/examples/skeleton/tira.fbx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57ed6969266d642da2d8a579059462c7cd4a36c9bd7a8415236dcbea36607fee
3
+ size 1694668
UniRig/examples/skeleton/tripo_carrot.fbx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50e08d70c7bdfa96eb842aed18564d8486a4622a474dbc0e0ef9304af1d4c6d3
3
+ size 1879420
UniRig/examples/tira.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e5a282d0a99c61a8d2439496057b15b0c6ea02643c16da2dcbec85571157799
3
+ size 32346060
UniRig/examples/tripo_carrot.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c00b7e6ff1e71a019a02128a5b37e8d2db259a402313ad9b4eaab4f183ae40b
3
+ size 9511824
UniRig/launch/inference/extract.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extract mesh
2
+ config="configs/data/quick_inference.yaml"
3
+ require_suffix="obj,fbx,FBX,dae,glb,gltf,vrm"
4
+ num_runs=1
5
+ force_override="false"
6
+ faces_target_count=50000
7
+
8
+ while [[ "$#" -gt 0 ]]; do
9
+ case $1 in
10
+ --config) config="$2"; shift ;;
11
+ --require_suffix) require_suffix="$2"; shift ;;
12
+ --num_runs) num_runs="$2"; shift ;;
13
+ --force_override) force_override="$2"; shift ;;
14
+ --faces_target_count) faces_target_count="$2"; shift ;;
15
+ --time) time="$2"; shift ;;
16
+ --input) input="$2"; shift ;;
17
+ --input_dir) input_dir="$2"; shift ;;
18
+ --output_dir) output_dir="$2"; shift ;;
19
+ *) echo "Unknown parameter: $1"; exit 1 ;;
20
+ esac
21
+ shift
22
+ done
23
+
24
+ # ensure psutil is installed for memory management
25
+ pip install psutil --quiet
26
+ if [ $? -ne 0 ]; then
27
+ echo "Warning: Failed to install psutil. Memory management may not work properly."
28
+ fi
29
+
30
+ # set the time for all processes to use
31
+ time=$(date "+%Y_%m_%d_%H_%M_%S")
32
+
33
+ for (( i=0; i<num_runs; i++ ))
34
+ do
35
+ cmd=" \
36
+ python -m src.data.extract \
37
+ --config=$config \
38
+ --require_suffix=$require_suffix \
39
+ --force_override=$force_override \
40
+ --num_runs=$num_runs \
41
+ --id=$i \
42
+ --time=$time \
43
+ --faces_target_count=$faces_target_count \
44
+ "
45
+ if [ -n "$input" ]; then
46
+ cmd="$cmd --input=$input"
47
+ fi
48
+ if [ -n "$input_dir" ]; then
49
+ cmd="$cmd --input_dir=$input_dir"
50
+ fi
51
+ if [ -n "$output_dir" ]; then
52
+ cmd="$cmd --output_dir=$output_dir"
53
+ fi
54
+ cmd="$cmd &"
55
+ eval $cmd
56
+ done
57
+
58
+ wait
59
+
60
+ echo "done"
UniRig/launch/inference/generate_skeleton.sh ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generate skeleton
2
+ config="configs/data/quick_inference.yaml"
3
+ require_suffix="obj,fbx,FBX,dae,glb,gltf,vrm"
4
+ num_runs=1
5
+ force_override="false"
6
+ faces_target_count=50000
7
+ skeleton_task="configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml"
8
+ add_root="false"
9
+ seed=12345
10
+ npz_dir="tmp"
11
+
12
+ while [[ "$#" -gt 0 ]]; do
13
+ case $1 in
14
+ --config) config="$2"; shift ;;
15
+ --require_suffix) require_suffix="$2"; shift ;;
16
+ --num_runs) num_runs="$2"; shift ;;
17
+ --force_override) force_override="$2"; shift ;;
18
+ --faces_target_count) faces_target_count="$2"; shift ;;
19
+ --skeleton_task) skeleton_task="$2"; shift ;;
20
+ --add_root) add_root="$2"; shift ;;
21
+ --seed) seed="$2"; shift ;;
22
+ --input) input="$2"; shift ;;
23
+ --input_dir) input_dir="$2"; shift ;;
24
+ --output_dir) output_dir="$2"; shift ;;
25
+ --output) output="$2"; shift ;;
26
+ *) echo "Unknown parameter: $1"; exit 1 ;;
27
+ esac
28
+ shift
29
+ done
30
+
31
+ # 1. extract mesh
32
+ cmd=" \
33
+ bash ./launch/inference/extract.sh \
34
+ --config $config \
35
+ --require_suffix $require_suffix \
36
+ --force_override $force_override \
37
+ --num_runs $num_runs \
38
+ --faces_target_count $faces_target_count \
39
+ "
40
+ if [ -n "$input" ]; then
41
+ cmd="$cmd --input $input"
42
+ fi
43
+ if [ -n "$input_dir" ]; then
44
+ cmd="$cmd --input_dir $input_dir"
45
+ fi
46
+ if [ -n "$npz_dir" ]; then
47
+ cmd="$cmd --output_dir $npz_dir"
48
+ fi
49
+
50
+ cmd="$cmd &"
51
+ eval $cmd
52
+
53
+ wait
54
+
55
+ # 2. inference skeleton
56
+ cmd="\
57
+ python run.py \
58
+ --task=$skeleton_task \
59
+ --seed=$seed \
60
+ "
61
+ if [ -n "$input" ]; then
62
+ cmd="$cmd --input=$input"
63
+ fi
64
+ if [ -n "$input_dir" ]; then
65
+ cmd="$cmd --input_dir=$input_dir"
66
+ fi
67
+ if [ -n "$output" ]; then
68
+ cmd="$cmd --output=$output"
69
+ fi
70
+ if [ -n "$output_dir" ]; then
71
+ cmd="$cmd --output_dir=$output_dir"
72
+ fi
73
+ if [ -n "$npz_dir" ]; then
74
+ cmd="$cmd --npz_dir=$npz_dir"
75
+ fi
76
+
77
+ eval $cmd
78
+
79
+ wait
80
+
81
+ echo "done"
UniRig/launch/inference/generate_skin.sh ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generate skin
2
+ config="configs/data/quick_inference.yaml"
3
+ require_suffix="obj,fbx,FBX,dae,glb,gltf,vrm"
4
+ num_runs=1
5
+ force_override="true"
6
+ faces_target_count=50000
7
+ skin_task="configs/task/quick_inference_unirig_skin.yaml"
8
+ seed=12345
9
+ npz_dir="tmp"
10
+ data_name="raw_data.npz"
11
+
12
+ while [[ "$#" -gt 0 ]]; do
13
+ case $1 in
14
+ --config) config="$2"; shift ;;
15
+ --require_suffix) require_suffix="$2"; shift ;;
16
+ --num_runs) num_runs="$2"; shift ;;
17
+ --force_override) force_override="$2"; shift ;;
18
+ --faces_target_count) faces_target_count="$2"; shift ;;
19
+ --skin_task) skin_task="$2"; shift ;;
20
+ --seed) seed="$2"; shift ;;
21
+ --input) input="$2"; shift ;;
22
+ --input_dir) input_dir="$2"; shift ;;
23
+ --output_dir) output_dir="$2"; shift ;;
24
+ --output) output="$2"; shift ;;
25
+ --data_name) data_name="$2"; shift ;;
26
+ *) echo "Unknown parameter: $1"; exit 1 ;;
27
+ esac
28
+ shift
29
+ done
30
+
31
+ # 1. extract mesh
32
+ cmd=" \
33
+ bash ./launch/inference/extract.sh \
34
+ --config $config \
35
+ --require_suffix $require_suffix \
36
+ --force_override $force_override \
37
+ --num_runs $num_runs \
38
+ --faces_target_count $faces_target_count \
39
+ "
40
+ if [ -n "$input" ]; then
41
+ cmd="$cmd --input $input"
42
+ fi
43
+ if [ -n "$input_dir" ]; then
44
+ cmd="$cmd --input_dir $input_dir"
45
+ fi
46
+ if [ -n "$npz_dir" ]; then
47
+ cmd="$cmd --output_dir $npz_dir"
48
+ fi
49
+
50
+ cmd="$cmd &"
51
+ eval $cmd
52
+
53
+ wait
54
+
55
+ # 2. inference skin
56
+ cmd="\
57
+ python run.py \
58
+ --task=$skin_task \
59
+ --seed=$seed \
60
+ "
61
+ if [ -n "$input" ]; then
62
+ cmd="$cmd --input=$input"
63
+ fi
64
+ if [ -n "$input_dir" ]; then
65
+ cmd="$cmd --input_dir=$input_dir"
66
+ fi
67
+ if [ -n "$output" ]; then
68
+ cmd="$cmd --output=$output"
69
+ fi
70
+ if [ -n "$output_dir" ]; then
71
+ cmd="$cmd --output_dir=$output_dir"
72
+ fi
73
+ if [ -n "$npz_dir" ]; then
74
+ cmd="$cmd --npz_dir=$npz_dir"
75
+ fi
76
+ if [ -n "$data_name" ]; then
77
+ cmd="$cmd --data_name=$data_name"
78
+ fi
79
+
80
+ eval $cmd
81
+
82
+ wait
83
+
84
+ echo "done"
UniRig/launch/inference/merge.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # merge texture
2
+ require_suffix="obj,fbx,FBX,dae,glb,gltf,vrm"
3
+ source=""
4
+ target=""
5
+ output=""
6
+
7
+ while [[ "$#" -gt 0 ]]; do
8
+ case $1 in
9
+ --require_suffix) require_suffix="$2"; shift ;;
10
+ --source) source="$2"; shift ;;
11
+ --target) target="$2"; shift ;;
12
+ --output) output="$2"; shift ;;
13
+ *) echo "Unknown parameter: $1"; exit 1 ;;
14
+ esac
15
+ shift
16
+ done
17
+
18
+ cmd=" \
19
+ python -m src.inference.merge \
20
+ --require_suffix=$require_suffix \
21
+ --num_runs=1 \
22
+ --id=0 \
23
+ --source=$source \
24
+ --target=$target \
25
+ --output=$output \
26
+ "
27
+
28
+ cmd="$cmd &"
29
+ eval $cmd
30
+
31
+ wait
32
+
33
+ echo "done"
UniRig/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ python-box
3
+ einops
4
+ omegaconf
5
+ pytorch_lightning
6
+ lightning
7
+ addict
8
+ timm
9
+ fast-simplification
10
+ bpy==4.2
11
+ flash_attn
12
+ trimesh
13
+ open3d
14
+ pyrender
15
+ huggingface_hub
UniRig/run.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ from box import Box
4
+ import os
5
+ import torch
6
+ import lightning as L
7
+ from lightning.pytorch.callbacks import ModelCheckpoint, Callback
8
+ from typing import List
9
+ from math import ceil
10
+ import numpy as np
11
+ from lightning.pytorch.strategies import FSDPStrategy, DDPStrategy
12
+ from src.inference.download import download
13
+
14
+ from src.data.asset import Asset
15
+ from src.data.extract import get_files
16
+ from src.data.dataset import UniRigDatasetModule, DatasetConfig, ModelInput
17
+ from src.data.datapath import Datapath
18
+ from src.data.transform import TransformConfig
19
+ from src.tokenizer.spec import TokenizerConfig
20
+ from src.tokenizer.parse import get_tokenizer
21
+ from src.model.parse import get_model
22
+ from src.system.parse import get_system, get_writer
23
+
24
+ from tqdm import tqdm
25
+ import time
26
+
27
+ def load(task: str, path: str) -> Box:
28
+ if path.endswith('.yaml'):
29
+ path = path.removesuffix('.yaml')
30
+ path += '.yaml'
31
+ print(f"\033[92mload {task} config: {path}\033[0m")
32
+ return Box(yaml.safe_load(open(path, 'r')))
33
+
34
+ def nullable_string(val):
35
+ if not val:
36
+ return None
37
+ return val
38
+
39
+ if __name__ == "__main__":
40
+ torch.set_float32_matmul_precision('high')
41
+
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument("--task", type=str, required=True)
44
+ parser.add_argument("--seed", type=int, required=False, default=123,
45
+ help="random seed")
46
+ parser.add_argument("--input", type=nullable_string, required=False, default=None,
47
+ help="a single input file or files splited by comma")
48
+ parser.add_argument("--input_dir", type=nullable_string, required=False, default=None,
49
+ help="input directory")
50
+ parser.add_argument("--output", type=nullable_string, required=False, default=None,
51
+ help="filename for a single output")
52
+ parser.add_argument("--output_dir", type=nullable_string, required=False, default=None,
53
+ help="output directory")
54
+ parser.add_argument("--npz_dir", type=nullable_string, required=False, default='tmp',
55
+ help="intermediate npz directory")
56
+ parser.add_argument("--cls", type=nullable_string, required=False, default=None,
57
+ help="class name")
58
+ parser.add_argument("--data_name", type=nullable_string, required=False, default=None,
59
+ help="npz filename from skeleton phase")
60
+ args = parser.parse_args()
61
+
62
+ L.seed_everything(args.seed, workers=True)
63
+
64
+ task = load('task', args.task)
65
+ mode = task.mode
66
+ assert mode in ['predict']
67
+
68
+ if args.input is not None or args.input_dir is not None:
69
+ assert args.output_dir is not None or args.output is not None, 'output or output_dir must be specified'
70
+ assert args.npz_dir is not None, 'npz_dir must be specified'
71
+ files = get_files(
72
+ data_name=task.components.data_name,
73
+ inputs=args.input,
74
+ input_dataset_dir=args.input_dir,
75
+ output_dataset_dir=args.npz_dir,
76
+ force_override=True,
77
+ warning=False,
78
+ )
79
+ files = [f[1] for f in files]
80
+ if len(files) > 1 and args.output is not None:
81
+ print("\033[92mwarning: output is specified, but multiple files are detected. Output will be written.\033[0m")
82
+ datapath = Datapath(files=files, cls=args.cls)
83
+ else:
84
+ datapath = None
85
+
86
+ data_config = load('data', os.path.join('configs/data', task.components.data))
87
+ transform_config = load('transform', os.path.join('configs/transform', task.components.transform))
88
+
89
+ # get tokenizer
90
+ tokenizer_config = task.components.get('tokenizer', None)
91
+ if tokenizer_config is not None:
92
+ tokenizer_config = load('tokenizer', os.path.join('configs/tokenizer', task.components.tokenizer))
93
+ tokenizer_config = TokenizerConfig.parse(config=tokenizer_config)
94
+
95
+ # get data name
96
+ data_name = task.components.get('data_name', 'raw_data.npz')
97
+ if args.data_name is not None:
98
+ data_name = args.data_name
99
+
100
+ # get predict dataset
101
+ predict_dataset_config = data_config.get('predict_dataset_config', None)
102
+ if predict_dataset_config is not None:
103
+ predict_dataset_config = DatasetConfig.parse(config=predict_dataset_config).split_by_cls()
104
+
105
+ # get predict transform
106
+ predict_transform_config = transform_config.get('predict_transform_config', None)
107
+ if predict_transform_config is not None:
108
+ predict_transform_config = TransformConfig.parse(config=predict_transform_config)
109
+
110
+ # get model
111
+ model_config = task.components.get('model', None)
112
+ if model_config is not None:
113
+ model_config = load('model', os.path.join('configs/model', model_config))
114
+ if tokenizer_config is not None:
115
+ tokenizer = get_tokenizer(config=tokenizer_config)
116
+ else:
117
+ tokenizer = None
118
+ model = get_model(tokenizer=tokenizer, **model_config)
119
+ else:
120
+ model = None
121
+
122
+ # set data
123
+ data = UniRigDatasetModule(
124
+ process_fn=None if model is None else model._process_fn,
125
+ predict_dataset_config=predict_dataset_config,
126
+ predict_transform_config=predict_transform_config,
127
+ tokenizer_config=tokenizer_config,
128
+ debug=False,
129
+ data_name=data_name,
130
+ datapath=datapath,
131
+ cls=args.cls,
132
+ )
133
+
134
+ # add call backs
135
+ callbacks = []
136
+
137
+ ## get checkpoint callback
138
+ checkpoint_config = task.get('checkpoint', None)
139
+ if checkpoint_config is not None:
140
+ checkpoint_config['dirpath'] = os.path.join('experiments', task.experiment_name)
141
+ callbacks.append(ModelCheckpoint(**checkpoint_config))
142
+
143
+ ## get writer callback
144
+ writer_config = task.get('writer', None)
145
+ if writer_config is not None:
146
+ assert predict_transform_config is not None, 'missing predict_transform_config in transform'
147
+ if args.output_dir is not None or args.output is not None:
148
+ if args.output is not None:
149
+ assert args.output.endswith('.fbx'), 'output must be .fbx'
150
+ writer_config['npz_dir'] = args.npz_dir
151
+ writer_config['output_dir'] = args.output_dir
152
+ writer_config['output_name'] = args.output
153
+ writer_config['user_mode'] = True
154
+ callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config))
155
+
156
+ # get trainer
157
+ trainer_config = task.get('trainer', {})
158
+
159
+ # get system
160
+ system_config = task.components.get('system', None)
161
+ if system_config is not None:
162
+ system_config = load('system', os.path.join('configs/system', system_config))
163
+ system = get_system(
164
+ **system_config,
165
+ model=model,
166
+ steps_per_epoch=1,
167
+ )
168
+ else:
169
+ system = None
170
+
171
+ logger = None
172
+
173
+ # set ckpt path
174
+ resume_from_checkpoint = task.get('resume_from_checkpoint', None)
175
+ resume_from_checkpoint = download(resume_from_checkpoint)
176
+ trainer = L.Trainer(
177
+ callbacks=callbacks,
178
+ logger=logger,
179
+ **trainer_config,
180
+ )
181
+
182
+ if mode == 'predict':
183
+ assert resume_from_checkpoint is not None, 'expect resume_from_checkpoint in task'
184
+ trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False)
185
+ else:
186
+ assert 0
UniRig/src/data/__init__.py ADDED
File without changes
UniRig/src/data/asset.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ import numpy as np
4
+ from numpy import ndarray
5
+
6
+ from typing import Dict, Union, List, Tuple
7
+
8
+ from .order import Order
9
+ from .raw_data import RawData
10
+ from .exporter import Exporter
11
+
12
+ from ..tokenizer.spec import TokenizeInput
13
+ from .utils import linear_blend_skinning
14
+
15
+ import trimesh
16
+
17
+
18
+ @dataclass
19
+ class Asset(Exporter):
20
+ '''
21
+ Dataclass to handle data parsed from raw data.
22
+ '''
23
+
24
+ # data class
25
+ cls: str
26
+
27
+ # where is this asset from
28
+ path: str
29
+
30
+ # data file name
31
+ data_name: str
32
+
33
+ # vertices of the mesh, shape (N, 3), float32
34
+ vertices: ndarray
35
+
36
+ # normals of vertices, shape (N, 3), float32
37
+ vertex_normals: ndarray
38
+
39
+ # faces of mesh, shape (F, 3), face id starts from 0 to F-1, int64
40
+ faces: ndarray
41
+
42
+ # face normal of mesh, shape (F, 3), float32
43
+ face_normals: ndarray
44
+
45
+ # joints of bones, shape (J, 3), float32
46
+ joints: Union[ndarray, None]=None
47
+
48
+ # tails of joints, shape (J, 3), float32
49
+ tails: Union[ndarray, None]=None
50
+
51
+ # skinning of joints, shape (N, J), float32
52
+ skin: Union[ndarray, None]=None
53
+
54
+ # whether the joint has skin, bool
55
+ no_skin: Union[ndarray, None]=None
56
+
57
+ # vertex groups
58
+ vertex_groups: Union[Dict[str, ndarray], None]=None
59
+
60
+ # parents of joints, None represents no parent(a root joint)
61
+ # make sure parent[k] < k
62
+ parents: Union[List[Union[int, None]], None]=None
63
+
64
+ # names of joints
65
+ names: Union[List[str], None]=None
66
+
67
+ # sampled vertices, shape (N, 3)
68
+ sampled_vertices: Union[ndarray, None]=None
69
+
70
+ # sampled normals, shape (N, 3)
71
+ sampled_normals: Union[ndarray, None]=None
72
+
73
+ # sampled vertex groups, every vertex group should be (N, J)
74
+ sampled_vertex_groups: Union[Dict[str, ndarray], None]=None
75
+
76
+ # {id: part}, part==None -> a spring token
77
+ parts_bias: Union[Dict[int, Union[str, None]], None]=None
78
+
79
+ # local coordinate, shape (J, 4, 4)
80
+ matrix_local: Union[ndarray, None]=None
81
+
82
+ # pose matrix for skinning loss calculation, shape (J, 4, 4)
83
+ pose_matrix: Union[ndarray, None]=None
84
+
85
+ meta: Union[Dict[str, ...], None]=None
86
+
87
+ @property
88
+ def N(self):
89
+ '''
90
+ number of vertices
91
+ '''
92
+ return self.vertices.shape[0]
93
+
94
+ @property
95
+ def F(self):
96
+ '''
97
+ number of faces
98
+ '''
99
+ return self.faces.shape[0]
100
+
101
+ @property
102
+ def J(self):
103
+ '''
104
+ number of joints
105
+ '''
106
+ return self.joints.shape[0]
107
+
108
+ def get_matrix(self, matrix_basis: ndarray, matrix_local: Union[ndarray, None]=None):
109
+ '''
110
+ get matrix
111
+
112
+ matrix_basis: (J, 4, 4)
113
+ '''
114
+ if matrix_local is None:
115
+ assert self.joints is not None
116
+ matrix_local = self.matrix_local
117
+ if matrix_local is None:
118
+ matrix_local = np.zeros((self.J, 4, 4))
119
+ matrix_local[:, 0, 0] = 1.
120
+ matrix_local[:, 1, 1] = 1.
121
+ matrix_local[:, 2, 2] = 1.
122
+ matrix_local[:, 3, 3] = 1.
123
+ for i in range(self.J):
124
+ matrix_local[i, :3, 3] = self.joints[i]
125
+
126
+ matrix = np.zeros((self.J, 4, 4))
127
+ for i in range(self.J):
128
+ if i==0:
129
+ matrix[i] = matrix_local[i] @ matrix_basis[i]
130
+ else:
131
+ pid = self.parents[i]
132
+ matrix_parent = matrix[pid]
133
+ matrix_local_parent = matrix_local[pid]
134
+
135
+ matrix[i] = (
136
+ matrix_parent @
137
+ (np.linalg.inv(matrix_local_parent) @ matrix_local[i]) @
138
+ matrix_basis[i]
139
+ )
140
+ return matrix
141
+
142
+ def apply_matrix_basis(self, matrix_basis: ndarray):
143
+ '''
144
+ apply a pose to armature
145
+
146
+ matrix_basis: (J, 4, 4)
147
+ '''
148
+ matrix_local = self.matrix_local
149
+ if matrix_local is None:
150
+ matrix_local = np.zeros((self.J, 4, 4))
151
+ matrix_local[:, 0, 0] = 1.
152
+ matrix_local[:, 1, 1] = 1.
153
+ matrix_local[:, 2, 2] = 1.
154
+ matrix_local[:, 3, 3] = 1.
155
+ for i in range(self.J):
156
+ matrix_local[i, :3, 3] = self.joints[i].copy()
157
+
158
+ matrix = self.get_matrix(matrix_basis=matrix_basis, matrix_local=matrix_local)
159
+ self.joints = matrix[:, :3, 3].copy()
160
+ vertices = linear_blend_skinning(self.vertices, matrix_local, matrix, self.skin, pad=1, value=1.)
161
+ # update matrix_local
162
+ self.matrix_local = matrix.copy()
163
+
164
+ # change tails
165
+ if self.tails is not None:
166
+ t_skin = np.eye(self.J)
167
+ self.tails = linear_blend_skinning(self.tails, matrix_local, matrix, t_skin, pad=1, value=1.)
168
+ # in accordance with trimesh's normals
169
+ mesh = trimesh.Trimesh(vertices=vertices, faces=self.faces, process=False)
170
+ self.vertices = vertices
171
+ self.vertex_normals = mesh.vertex_normals.copy()
172
+ self.face_normals = mesh.face_normals.copy()
173
+
174
+ def set_order_by_names(self, new_names: List[str]):
175
+ assert len(new_names) == len(self.names)
176
+ name_to_id = {name: id for (id, name) in enumerate(self.names)}
177
+ new_name_to_id = {name: id for (id, name) in enumerate(new_names)}
178
+ perm = []
179
+ new_parents = []
180
+ for (new_id, name) in enumerate(new_names):
181
+ perm.append(name_to_id[name])
182
+ pid = self.parents[name_to_id[name]]
183
+ if new_id == 0:
184
+ assert pid is None, 'first bone is not root bone'
185
+ else:
186
+ pname = self.names[pid]
187
+ pid = new_name_to_id[pname]
188
+ assert pid < new_id, 'new order does not form a tree'
189
+ new_parents.append(pid)
190
+
191
+ if self.joints is not None:
192
+ self.joints = self.joints[perm]
193
+ self.parents = new_parents
194
+ if self.tails is not None:
195
+ self.tails = self.tails[perm]
196
+ if self.skin is not None:
197
+ self.skin = self.skin[:, perm]
198
+ if self.no_skin is not None:
199
+ self.no_skin = self.no_skin[perm]
200
+ if self.matrix_local is not None:
201
+ self.matrix_local = self.matrix_local[perm]
202
+ self.names = new_names
203
+
204
+ def set_order(self, order: Order):
205
+ if self.names is None or self.parents is None:
206
+ return
207
+ new_names, self.parts_bias = order.arrange_names(cls=self.cls, names=self.names, parents=self.parents)
208
+ self.set_order_by_names(new_names=new_names)
209
+
210
+ def collapse(self, keep: List[str]):
211
+ dsu = [i for i in range(self.J)]
212
+
213
+ def find(x: int) -> int:
214
+ if dsu[x] == x:
215
+ return x
216
+ y = find(dsu[x])
217
+ dsu[x] = y
218
+ return y
219
+
220
+ def merge(x: int, y: int):
221
+ dsu[find(x)] = find(y)
222
+
223
+ if self.tails is not None:
224
+ new_tails = self.tails.copy()
225
+ else:
226
+ new_tails = None
227
+ if self.skin is not None:
228
+ new_skin = self.skin.copy()
229
+ else:
230
+ new_skin = None
231
+
232
+ if self.no_skin is not None:
233
+ new_no_skin = self.no_skin.copy()
234
+ else:
235
+ new_no_skin = None
236
+
237
+ if self.matrix_local is not None:
238
+ matrix_local = self.matrix_local.copy()
239
+ else:
240
+ matrix_local = None
241
+ new_names = []
242
+ new_parents = []
243
+ perm = []
244
+ new_name_to_id = {}
245
+ tot = 0
246
+ for (i, name) in enumerate(self.names):
247
+ if name in keep:
248
+ new_names.append(name)
249
+ new_name_to_id[name] = tot
250
+ tot += 1
251
+ perm.append(i)
252
+ pid = self.parents[i]
253
+ if pid is None:
254
+ new_parents.append(None)
255
+ else:
256
+ pid = find(pid)
257
+ new_parents.append(new_name_to_id[self.names[pid]])
258
+ continue
259
+ assert i != 0, 'cannot remove root'
260
+ id = find(i)
261
+ pid = find(self.parents[id])
262
+ # be careful !
263
+ # do not copy tail here because you dont know which child to inherit from
264
+ if new_skin is not None:
265
+ new_skin[:, pid] += new_skin[:, id]
266
+ if new_no_skin is not None:
267
+ new_no_skin[pid] &= new_no_skin[id]
268
+ merge(id, pid)
269
+
270
+ if new_tails is not None:
271
+ new_tails = new_tails[perm]
272
+ if new_skin is not None:
273
+ new_skin = new_skin[:, perm]
274
+ if new_no_skin is not None:
275
+ new_no_skin = new_no_skin[perm]
276
+ if matrix_local is not None:
277
+ matrix_local = matrix_local[perm]
278
+
279
+ if self.joints is not None:
280
+ self.joints = self.joints[perm]
281
+ self.parents = new_parents
282
+ self.tails = new_tails
283
+ self.skin = new_skin
284
+ self.no_skin = new_no_skin
285
+ self.names = new_names
286
+ self.matrix_local = matrix_local
287
+
288
+ @staticmethod
289
+ def from_raw_data(
290
+ raw_data: RawData,
291
+ cls: str,
292
+ path: str,
293
+ data_name: str,
294
+ ) -> 'Asset':
295
+ '''
296
+ Return an asset initialized from raw data and do transform.
297
+ '''
298
+ return Asset(
299
+ cls=cls,
300
+ path=path,
301
+ data_name=data_name,
302
+ vertices=raw_data.vertices,
303
+ vertex_normals=raw_data.vertex_normals,
304
+ faces=raw_data.faces,
305
+ face_normals=raw_data.face_normals,
306
+ joints=raw_data.joints,
307
+ tails=raw_data.tails,
308
+ skin=raw_data.skin,
309
+ no_skin=raw_data.no_skin,
310
+ parents=raw_data.parents,
311
+ names=raw_data.names,
312
+ matrix_local=raw_data.matrix_local,
313
+ meta={},
314
+ )
315
+
316
+ def get_tokenize_input(self) -> TokenizeInput:
317
+ children = defaultdict(list)
318
+
319
+ for (id, p) in enumerate(self.parents):
320
+ if p is not None:
321
+ children[p].append(id)
322
+ bones = []
323
+ branch = []
324
+ is_leaf = []
325
+ last = None
326
+ for i in range(self.J):
327
+ is_leaf.append(len(children[i])==0)
328
+ if i == 0:
329
+ bones.append(np.concatenate([self.joints[i], self.joints[i]]))
330
+ branch.append(False)
331
+ else:
332
+ pid = self.parents[i]
333
+ bones.append(np.concatenate([self.joints[pid], self.joints[i]]))
334
+ branch.append(pid!=last)
335
+ last = i
336
+ bones = np.stack(bones)
337
+ branch = np.array(branch, dtype=bool)
338
+ is_leaf = np.array(is_leaf, dtype=bool)
339
+ return TokenizeInput(
340
+ bones=bones,
341
+ tails=self.tails,
342
+ branch=branch,
343
+ is_leaf=is_leaf,
344
+ no_skin=self.no_skin,
345
+ cls=self.cls,
346
+ parts_bias=self.parts_bias,
347
+ )
348
+
349
+ def export_pc(self, path: str, with_normal: bool=True, normal_size=0.01):
350
+ '''
351
+ export point cloud
352
+ '''
353
+ vertices = self.vertices
354
+ normals = self.vertex_normals
355
+ if self.sampled_vertices is not None:
356
+ vertices = self.sampled_vertices
357
+ normals = self.sampled_normals
358
+ if with_normal == False:
359
+ normals = None
360
+ self._export_pc(vertices=vertices, path=path, vertex_normals=normals, normal_size=normal_size)
361
+
362
+ def export_mesh(self, path: str):
363
+ '''
364
+ export mesh
365
+ '''
366
+ self._export_mesh(vertices=self.vertices, faces=self.faces, path=path)
367
+
368
+ def export_skeleton(self, path: str):
369
+ '''
370
+ export spring
371
+ '''
372
+ self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
373
+
374
+ def export_skeleton_sequence(self, path: str):
375
+ '''
376
+ export spring
377
+ '''
378
+ self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
379
+
380
+ def export_fbx(
381
+ self,
382
+ path: str,
383
+ vertex_group_name: str,
384
+ extrude_size: float=0.03,
385
+ group_per_vertex: int=-1,
386
+ add_root: bool=False,
387
+ do_not_normalize: bool=False,
388
+ use_extrude_bone: bool=True,
389
+ use_connect_unique_child: bool=True,
390
+ extrude_from_parent: bool=True,
391
+ use_tail: bool=False,
392
+ use_origin: bool=False,
393
+ ):
394
+ '''
395
+ export the whole model with skining
396
+ '''
397
+ self._export_fbx(
398
+ path=path,
399
+ vertices=self.vertices if use_origin else self.sampled_vertices,
400
+ joints=self.joints,
401
+ skin=self.sampled_vertex_groups[vertex_group_name],
402
+ parents=self.parents,
403
+ names=self.names,
404
+ faces=self.faces if use_origin else None,
405
+ extrude_size=extrude_size,
406
+ group_per_vertex=group_per_vertex,
407
+ add_root=add_root,
408
+ do_not_normalize=do_not_normalize,
409
+ use_extrude_bone=use_extrude_bone,
410
+ use_connect_unique_child=use_connect_unique_child,
411
+ extrude_from_parent=extrude_from_parent,
412
+ tails=self.tails if use_tail else None,
413
+ )
414
+
415
+ def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256], use_tail: bool=False):
416
+ if use_tail:
417
+ assert self.tails is not None
418
+ self._export_render(
419
+ path=path,
420
+ vertices=self.vertices,
421
+ faces=self.faces,
422
+ bones=np.concatenate([self.joints, self.tails], axis=-1),
423
+ resolution=resolution,
424
+ )
425
+ else:
426
+ pjoints = self.joints[self.parents[1:]]
427
+ self._export_render(
428
+ path=path,
429
+ vertices=self.vertices,
430
+ faces=self.faces,
431
+ bones=np.concatenate([pjoints, self.joints[1:]], axis=-1),
432
+ resolution=resolution,
433
+ )
UniRig/src/data/augment.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Tuple, Union, List, Dict
3
+ from numpy import ndarray
4
+ import numpy as np
5
+ from abc import ABC, abstractmethod
6
+ from scipy.spatial.transform import Rotation as R
7
+
8
+ from .spec import ConfigSpec
9
+ from .asset import Asset
10
+ from .utils import axis_angle_to_matrix
11
+
12
+ @dataclass(frozen=True)
13
+ class AugmentAffineConfig(ConfigSpec):
14
+ # final normalization cube
15
+ normalize_into: Tuple[float, float]
16
+
17
+ # randomly scale coordinates with probability p
18
+ random_scale_p: float
19
+
20
+ # scale range (lower, upper)
21
+ random_scale: Tuple[float, float]
22
+
23
+ # randomly shift coordinates with probability p
24
+ random_shift_p: float
25
+
26
+ # shift range (lower, upper)
27
+ random_shift: Tuple[float, float]
28
+
29
+ @classmethod
30
+ def parse(cls, config) -> Union['AugmentAffineConfig', None]:
31
+ if config is None:
32
+ return None
33
+ cls.check_keys(config)
34
+ return AugmentAffineConfig(
35
+ normalize_into=config.normalize_into,
36
+ random_scale_p=config.get('random_scale_p', 0.),
37
+ random_scale=config.get('random_scale', [1., 1.]),
38
+ random_shift_p=config.get('random_shift_p', 0.),
39
+ random_shift=config.get('random_shift', [0., 0.]),
40
+ )
41
+
42
+ @dataclass(frozen=True)
43
+ class AugmentConfig(ConfigSpec):
44
+ '''
45
+ Config to handle final easy augmentation of vertices, normals and bones before sampling.
46
+ '''
47
+ augment_affine_config: Union[AugmentAffineConfig, None]
48
+
49
+ @classmethod
50
+ def parse(cls, config) -> 'AugmentConfig':
51
+ cls.check_keys(config)
52
+ return AugmentConfig(
53
+ augment_affine_config=AugmentAffineConfig.parse(config.get('augment_affine_config', None)),
54
+ )
55
+
56
+ class Augment(ABC):
57
+ '''
58
+ Abstract class for augmentation
59
+ '''
60
+ def __init__(self):
61
+ pass
62
+
63
+ @abstractmethod
64
+ def transform(self, asset: Asset, **kwargs):
65
+ pass
66
+
67
+ @abstractmethod
68
+ def inverse(self, asset: Asset):
69
+ pass
70
+
71
+ class AugmentAffine(Augment):
72
+
73
+ def __init__(self, config: AugmentAffineConfig):
74
+ super().__init__()
75
+ self.config = config
76
+
77
+ def _apply(self, v: ndarray, trans: ndarray) -> ndarray:
78
+ return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
79
+
80
+ def transform(self, asset: Asset, **kwargs):
81
+ bound_min = asset.vertices.min(axis=0)
82
+ bound_max = asset.vertices.max(axis=0)
83
+ if asset.joints is not None:
84
+ joints_bound_min = asset.joints.min(axis=0)
85
+ joints_bound_max = asset.joints.max(axis=0)
86
+ bound_min = np.minimum(bound_min, joints_bound_min)
87
+ bound_max = np.maximum(bound_max, joints_bound_max)
88
+
89
+ trans_vertex = np.eye(4, dtype=np.float32)
90
+
91
+ trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex
92
+
93
+ # scale into the cube
94
+ normalize_into = self.config.normalize_into
95
+ scale = np.max((bound_max - bound_min) / (normalize_into[1] - normalize_into[0]))
96
+ trans_vertex = _scale_to_m(1. / scale) @ trans_vertex
97
+
98
+ bias = (normalize_into[0] + normalize_into[1]) / 2
99
+ trans_vertex = _trans_to_m(np.array([bias, bias, bias], dtype=np.float32)) @ trans_vertex
100
+
101
+ if np.random.rand() < self.config.random_scale_p:
102
+ scale = _scale_to_m(np.random.uniform(self.config.random_scale[0], self.config.random_scale[1]))
103
+ trans_vertex = scale @ trans_vertex
104
+
105
+ if np.random.rand() < self.config.random_shift_p:
106
+ l, r = self.config.random_shift
107
+ shift = _trans_to_m(np.array([np.random.uniform(l, r), np.random.uniform(l, r), np.random.uniform(l, r)]), dtype=np.float32)
108
+ trans_vertex = shift @ trans_vertex
109
+
110
+ asset.vertices = self._apply(asset.vertices, trans_vertex)
111
+ # do not affect scale in matrix
112
+ if asset.matrix_local is not None:
113
+ asset.matrix_local[:, :, 3:4] = trans_vertex @ asset.matrix_local[:, :, 3:4]
114
+ if asset.pose_matrix is not None:
115
+ asset.pose_matrix[:, :, 3:4] = trans_vertex @ asset.pose_matrix[:, :, 3:4]
116
+ # do not affect normal here
117
+ if asset.joints is not None:
118
+ asset.joints = self._apply(asset.joints, trans_vertex)
119
+ if asset.tails is not None:
120
+ asset.tails = self._apply(asset.tails, trans_vertex)
121
+
122
+ self.trans_vertex = trans_vertex
123
+
124
+ def inverse(self, asset: Asset):
125
+ m = np.linalg.inv(self.trans_vertex)
126
+ asset.vertices = self._apply(asset.vertices, m)
127
+ if asset.joints is not None:
128
+ asset.joints = self._apply(asset.joints, m)
129
+ if asset.tails is not None:
130
+ asset.tails = self._apply(asset.tails, m)
131
+
132
+ def _trans_to_m(v: ndarray):
133
+ m = np.eye(4, dtype=np.float32)
134
+ m[0:3, 3] = v
135
+ return m
136
+
137
+ def _scale_to_m(r: ndarray):
138
+ m = np.zeros((4, 4), dtype=np.float32)
139
+ m[0, 0] = r
140
+ m[1, 1] = r
141
+ m[2, 2] = r
142
+ m[3, 3] = 1.
143
+ return m
144
+
145
+ def get_augments(config: AugmentConfig) -> Tuple[List[Augment], List[Augment]]:
146
+ first_augments = [] # augments before sample
147
+ second_augments = [] # augments after sample
148
+ augment_affine_config = config.augment_affine_config
149
+
150
+ if augment_affine_config is not None:
151
+ second_augments.append(AugmentAffine(config=augment_affine_config))
152
+ return first_augments, second_augments
UniRig/src/data/datapath.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Union, Tuple, List
5
+ import numpy as np
6
+ from numpy import ndarray
7
+ import os
8
+ from random import shuffle
9
+ from box import Box
10
+ from torch.onnx.symbolic_opset11 import index_copy
11
+
12
+ from .spec import ConfigSpec
13
+
14
+ @dataclass
15
+ class DatapathConfig(ConfigSpec):
16
+ '''
17
+ Config to handle input data paths.
18
+ '''
19
+ # root
20
+ input_dataset_dir: str
21
+
22
+ # use proportion data sampling
23
+ use_prob: bool
24
+
25
+ # cls: [(path_1, p_1), ...]
26
+ data_path: Dict[str, List[Tuple[str, float]]]
27
+
28
+ # how many files to return when using data sampling
29
+ num_files: Union[int, None]
30
+
31
+ @classmethod
32
+ def from_args(cls, **kwargs) -> 'DatapathConfig':
33
+ '''
34
+ Make a temporary datapath from user inputs.
35
+ '''
36
+ input = kwargs.get('input', None)
37
+ output = kwargs.get('output', None)
38
+ recursive = kwargs.get('recursive', False)
39
+
40
+
41
+ @classmethod
42
+ def parse(cls, config) -> 'DatapathConfig':
43
+ cls.check_keys(config)
44
+ return DatapathConfig(
45
+ input_dataset_dir=config.input_dataset_dir,
46
+ use_prob=config.get('use_prob', True),
47
+ data_path=config.data_path,
48
+ num_files=config.get('num_files', None),
49
+ )
50
+
51
+ def split_by_cls(self) -> Dict[str, 'DatapathConfig']:
52
+ res: Dict[str, DatapathConfig] = {}
53
+ for cls in self.data_path:
54
+ res[cls] = deepcopy(self)
55
+ res[cls].data_path = {cls: self.data_path[cls]}
56
+ return res
57
+
58
+ class Datapath():
59
+ def __init__(
60
+ self,
61
+ config: Union[DatapathConfig, None]=None,
62
+ files: Union[List[str], None]=None,
63
+ cls: Union[str, None]=None,
64
+ ):
65
+ if config is not None:
66
+ self.config = config
67
+ self.file_list = []
68
+ cls_probs_first = []
69
+ cls_first = []
70
+
71
+ self.files_by_class: Dict[str, List[Dict]] = defaultdict(list)
72
+ self.class_positions: Dict[str, List[int]] = defaultdict(list)
73
+ self.cls_probs_second: Dict[str, ndarray] = defaultdict(List)
74
+
75
+ for cls in self.config.data_path:
76
+ prob = 0.
77
+ probs_second = []
78
+ for (path, p) in self.config.data_path[cls]:
79
+ prob += p
80
+ probs_second.append(p)
81
+ with open(path, 'r') as f:
82
+ file_items = []
83
+ missing = 0
84
+ for l in f.readlines():
85
+ raw_data_path = os.path.join(self.config.input_dataset_dir, l.strip(), 'raw_data.npz')
86
+ if not os.path.exists(raw_data_path):
87
+ missing += 1
88
+ continue
89
+ file_items.append({
90
+ 'cls': cls,
91
+ 'path': os.path.join(self.config.input_dataset_dir, l.strip()),
92
+ 'prob': p
93
+ })
94
+ assert len(file_items) > 0, f"files in {path} are all missing! root: {self.config.input_dataset_dir}"
95
+ if missing > 0:
96
+ print(f"\033[31m{cls}: {missing} missing files\033[0m")
97
+ self.files_by_class[cls].append(file_items)
98
+ self.class_positions[cls].append(0)
99
+ self.file_list.extend(file_items)
100
+ probs_second = np.array(probs_second)
101
+ self.cls_probs_second[cls] = probs_second / probs_second.sum()
102
+ cls_first.append(cls)
103
+ cls_probs_first.append(prob)
104
+ cls_probs_first = np.array(cls_probs_first)
105
+ self.cls_first: List[str] = cls_first
106
+ self.cls_probs_first: Dict[str, List[float]] = cls_probs_first / cls_probs_first.sum()
107
+ elif files is not None:
108
+ if cls is None:
109
+ cls = 'inference'
110
+ self.file_list = [{'cls': cls, 'path': file} for file in files]
111
+ cls_probs_first = np.array([1.])
112
+ cls_first = []
113
+
114
+ self.files_by_class: Dict[str, List[Dict]] = {cls: self.file_list.copy()}
115
+ self.class_positions: Dict[str, List[int]] = {cls: [0]}
116
+ self.cls_probs_second: Dict[str, ndarray] = {cls: np.array([1.])}
117
+ self.config = Box({'use_prob': False})
118
+ else:
119
+ assert(0)
120
+
121
+ def __len__(self):
122
+ if self.config.use_prob:
123
+ assert self.config.num_files is not None, 'num_files is not specified'
124
+ return self.config.num_files
125
+ return len(self.file_list)
126
+
127
+ def __getitem__(self, index) -> Tuple[str, str]:
128
+ if self.config.use_prob:
129
+ # first sample a class
130
+ cls = np.random.choice(self.cls_first, p=self.cls_probs_first)
131
+
132
+ # second sample in this class
133
+ idx = np.random.choice(len(self.files_by_class[cls]), p=self.cls_probs_second[cls])
134
+
135
+ # get the current position
136
+ pos = self.class_positions[cls][idx]
137
+ files = self.files_by_class[cls][idx]
138
+
139
+ # get the item andd update position
140
+ item = files[pos]
141
+ self.class_positions[cls][idx] = (pos + 1) % len(files)
142
+ if (pos + 1) % len(files) == 0:
143
+ shuffle(self.files_by_class[cls][idx])
144
+ else:
145
+ item = self.file_list[index]
146
+ return (item['cls'], item['path'])
147
+
148
+ def get_data(self) -> List[Tuple[str, str]]:
149
+ return [self[i] for i in range(len(self))]
UniRig/src/data/dataset.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from dataclasses import dataclass
3
+ import lightning.pytorch as pl
4
+ # from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
5
+ import torch
6
+ from torch import LongTensor
7
+ from torch.utils import data
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from typing import Dict, List, Tuple, Union, Callable
10
+ import os
11
+ import numpy as np
12
+
13
+ from .raw_data import RawData
14
+ from .asset import Asset
15
+ from .transform import TransformConfig, transform_asset
16
+ from .datapath import DatapathConfig, Datapath
17
+ from .spec import ConfigSpec
18
+
19
+ from ..tokenizer.spec import TokenizerSpec, TokenizerConfig
20
+ from ..tokenizer.parse import get_tokenizer
21
+ from ..model.spec import ModelInput
22
+
23
+ @dataclass
24
+ class DatasetConfig(ConfigSpec):
25
+ '''
26
+ Config to handle dataset format.
27
+ '''
28
+ # shuffle dataset
29
+ shuffle: bool
30
+
31
+ # batch size
32
+ batch_size: int
33
+
34
+ # number of workers
35
+ num_workers: int
36
+
37
+ # datapath
38
+ datapath_config: DatapathConfig
39
+
40
+ # use pin memory
41
+ pin_memory: bool = True
42
+
43
+ # use persistent workers
44
+ persistent_workers: bool = True
45
+
46
+ @classmethod
47
+ def parse(cls, config) -> 'DatapathConfig':
48
+ cls.check_keys(config)
49
+ return DatasetConfig(
50
+ shuffle=config.shuffle,
51
+ batch_size=config.batch_size,
52
+ num_workers=config.num_workers,
53
+ pin_memory=config.pin_memory,
54
+ persistent_workers=config.persistent_workers,
55
+ datapath_config=DatapathConfig.parse(config.datapath_config),
56
+ )
57
+
58
+ def split_by_cls(self) -> Dict[str, 'DatasetConfig']:
59
+ res: Dict[str, DatasetConfig] = {}
60
+ datapath_config_dict = self.datapath_config.split_by_cls()
61
+ for cls in self.datapath_config.data_path:
62
+ res[cls] = deepcopy(self)
63
+ res[cls].datapath_config = datapath_config_dict[cls]
64
+ return res
65
+
66
+ class UniRigDatasetModule(pl.LightningDataModule):
67
+ def __init__(
68
+ self,
69
+ process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
70
+ predict_dataset_config: Union[Dict[str, DatasetConfig], None]=None,
71
+ predict_transform_config: Union[TransformConfig, None]=None,
72
+ tokenizer_config: Union[TokenizerConfig, None]=None,
73
+ debug: bool=False,
74
+ data_name: str='raw_data.npz',
75
+ datapath: Union[Datapath, None]=None,
76
+ cls: Union[str, None]=None,
77
+ ):
78
+ super().__init__()
79
+ self.process_fn = process_fn
80
+ self.predict_dataset_config = predict_dataset_config
81
+ self.predict_transform_config = predict_transform_config
82
+ self.tokenizer_config = tokenizer_config
83
+ self.debug = debug
84
+ self.data_name = data_name
85
+
86
+ if debug:
87
+ print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m")
88
+
89
+ if datapath is not None:
90
+ self.train_datapath = None
91
+ self.validate_datapath = None
92
+ self.predict_datapath = {
93
+ cls: deepcopy(datapath),
94
+ }
95
+ self.predict_dataset_config = {
96
+ cls: DatasetConfig(
97
+ shuffle=False,
98
+ batch_size=1,
99
+ num_workers=0,
100
+ datapath_config=deepcopy(datapath),
101
+ pin_memory=False,
102
+ persistent_workers=False,
103
+ )
104
+ }
105
+ else:
106
+ # build predict datapath
107
+ if self.predict_dataset_config is not None:
108
+ self.predict_datapath = {
109
+ cls: Datapath(self.predict_dataset_config[cls].datapath_config)
110
+ for cls in self.predict_dataset_config
111
+ }
112
+ else:
113
+ self.predict_datapath = None
114
+
115
+ # get tokenizer
116
+ if tokenizer_config is None:
117
+ self.tokenizer = None
118
+ else:
119
+ self.tokenizer = get_tokenizer(config=tokenizer_config)
120
+
121
+ def prepare_data(self):
122
+ pass
123
+
124
+ def setup(self, stage=None):
125
+ if self.predict_datapath is not None:
126
+ self._predict_ds = {}
127
+ for cls in self.predict_datapath:
128
+ self._predict_ds[cls] = UniRigDataset(
129
+ process_fn=self.process_fn,
130
+ data=self.predict_datapath[cls].get_data(),
131
+ name=f"predict-{cls}",
132
+ tokenizer=self.tokenizer,
133
+ transform_config=self.predict_transform_config,
134
+ debug=self.debug,
135
+ data_name=self.data_name,
136
+ )
137
+
138
+ def predict_dataloader(self):
139
+ if not hasattr(self, "_predict_ds"):
140
+ self.setup()
141
+ return self._create_dataloader(
142
+ dataset=self._predict_ds,
143
+ config=self.predict_dataset_config,
144
+ is_train=False,
145
+ drop_last=False,
146
+ )
147
+
148
+ def _create_dataloader(
149
+ self,
150
+ dataset: Union[Dataset, Dict[str, Dataset]],
151
+ config: DatasetConfig,
152
+ is_train: bool,
153
+ **kwargs,
154
+ ) -> Union[DataLoader, Dict[str, DataLoader]]:
155
+ def create_single_dataloader(dataset, config: Union[DatasetConfig, Dict[str, DatasetConfig]], **kwargs):
156
+ return DataLoader(
157
+ dataset,
158
+ batch_size=config.batch_size,
159
+ shuffle=config.shuffle,
160
+ num_workers=config.num_workers,
161
+ pin_memory=config.pin_memory,
162
+ persistent_workers=config.persistent_workers,
163
+ collate_fn=dataset.collate_fn,
164
+ **kwargs,
165
+ )
166
+ if isinstance(dataset, Dict):
167
+ return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()}
168
+ else:
169
+ return create_single_dataloader(dataset, config, **kwargs)
170
+
171
+ class UniRigDataset(Dataset):
172
+ def __init__(
173
+ self,
174
+ data: List[Tuple[str, str]], # (cls, part)
175
+ name: str,
176
+ process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
177
+ tokenizer: Union[TokenizerSpec, None]=None,
178
+ transform_config: Union[TransformConfig, None]=None,
179
+ debug: bool=False,
180
+ data_name: str='raw_data.npz',
181
+ ) -> None:
182
+ super().__init__()
183
+
184
+ self.data = data
185
+ self.name = name
186
+ self.process_fn = process_fn
187
+ self.tokenizer = tokenizer
188
+ self.transform_config = transform_config
189
+ self.debug = debug
190
+ self.data_name = data_name
191
+
192
+ if not debug:
193
+ assert self.process_fn is not None, 'missing data processing function'
194
+
195
+ def __len__(self) -> int:
196
+ return len(self.data)
197
+
198
+ def __getitem__(self, idx) -> ModelInput:
199
+ cls, dir_path = self.data[idx]
200
+ raw_data = RawData.load(path=os.path.join(dir_path, self.data_name))
201
+ asset = Asset.from_raw_data(raw_data=raw_data, cls=cls, path=dir_path, data_name=self.data_name)
202
+
203
+ first_augments, second_augments = transform_asset(
204
+ asset=asset,
205
+ transform_config=self.transform_config,
206
+ )
207
+ if self.tokenizer is not None and asset.parents is not None:
208
+ tokens = self.tokenizer.tokenize(input=asset.get_tokenize_input())
209
+ else:
210
+ tokens = None
211
+ return ModelInput(
212
+ tokens=tokens,
213
+ pad=None if self.tokenizer is None else self.tokenizer.pad,
214
+ vertices=asset.sampled_vertices.astype(np.float32),
215
+ normals=asset.sampled_normals.astype(np.float32),
216
+ joints=None if asset.joints is None else asset.joints.astype(np.float32),
217
+ tails=None if asset.tails is None else asset.tails.astype(np.float32),
218
+ asset=asset,
219
+ augments=None,
220
+ )
221
+
222
+ def _collate_fn_debug(self, batch):
223
+ return batch
224
+
225
+ def _collate_fn(self, batch):
226
+ return data.dataloader.default_collate(self.process_fn(batch))
227
+
228
+ def collate_fn(self, batch):
229
+ if self.debug:
230
+ return self._collate_fn_debug(batch)
231
+ return self._collate_fn(batch)
UniRig/src/data/exporter.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from numpy import ndarray
3
+ from typing import List, Union, Tuple
4
+ from collections import defaultdict
5
+ import os
6
+
7
+ try:
8
+ import open3d as o3d
9
+ OPEN3D_EQUIPPED = True
10
+ except:
11
+ print("do not have open3d")
12
+ OPEN3D_EQUIPPED = False
13
+
14
+ class Exporter():
15
+
16
+ def _safe_make_dir(self, path):
17
+ if os.path.dirname(path) == '':
18
+ return
19
+ os.makedirs(os.path.dirname(path), exist_ok=True)
20
+
21
+ def _export_skeleton(self, joints: ndarray, parents: List[Union[int, None]], path: str):
22
+ format = path.split('.')[-1]
23
+ assert format in ['obj']
24
+ name = path.removesuffix('.obj')
25
+ path = name + ".obj"
26
+ self._safe_make_dir(path)
27
+ J = joints.shape[0]
28
+ with open(path, 'w') as file:
29
+ file.write("o spring_joint\n")
30
+ _joints = []
31
+ for id in range(J):
32
+ pid = parents[id]
33
+ if pid is None or pid == -1:
34
+ continue
35
+ bx, by, bz = joints[id]
36
+ ex, ey, ez = joints[pid]
37
+ _joints.extend([
38
+ f"v {bx} {bz} {-by}\n",
39
+ f"v {ex} {ez} {-ey}\n",
40
+ f"v {ex} {ez} {-ey + 0.00001}\n"
41
+ ])
42
+ file.writelines(_joints)
43
+
44
+ _faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)]
45
+ file.writelines(_faces)
46
+
47
+ def _export_bones(self, bones: ndarray, path: str):
48
+ format = path.split('.')[-1]
49
+ assert format in ['obj']
50
+ name = path.removesuffix('.obj')
51
+ path = name + ".obj"
52
+ self._safe_make_dir(path)
53
+ J = bones.shape[0]
54
+ with open(path, 'w') as file:
55
+ file.write("o bones\n")
56
+ _joints = []
57
+ for bone in bones:
58
+ bx, by, bz = bone[:3]
59
+ ex, ey, ez = bone[3:]
60
+ _joints.extend([
61
+ f"v {bx} {bz} {-by}\n",
62
+ f"v {ex} {ez} {-ey}\n",
63
+ f"v {ex} {ez} {-ey + 0.00001}\n"
64
+ ])
65
+ file.writelines(_joints)
66
+
67
+ _faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)]
68
+ file.writelines(_faces)
69
+
70
+ def _export_skeleton_sequence(self, joints: ndarray, parents: List[Union[int, None]], path: str):
71
+ format = path.split('.')[-1]
72
+ assert format in ['obj']
73
+ name = path.removesuffix('.obj')
74
+ path = name + ".obj"
75
+ self._safe_make_dir(path)
76
+ J = joints.shape[0]
77
+ for i in range(J):
78
+ file = open(name + f"_{i}.obj", 'w')
79
+ file.write("o spring_joint\n")
80
+ _joints = []
81
+ for id in range(i + 1):
82
+ pid = parents[id]
83
+ if pid is None:
84
+ continue
85
+ bx, by, bz = joints[id]
86
+ ex, ey, ez = joints[pid]
87
+ _joints.extend([
88
+ f"v {bx} {bz} {-by}\n",
89
+ f"v {ex} {ez} {-ey}\n",
90
+ f"v {ex} {ez} {-ey + 0.00001}\n"
91
+ ])
92
+ file.writelines(_joints)
93
+
94
+ _faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)]
95
+ file.writelines(_faces)
96
+ file.close()
97
+
98
+ def _export_mesh(self, vertices: ndarray, faces: ndarray, path: str):
99
+ format = path.split('.')[-1]
100
+ assert format in ['obj', 'ply']
101
+ if path.endswith('ply'):
102
+ if not OPEN3D_EQUIPPED:
103
+ raise RuntimeError("open3d is not available")
104
+ mesh = o3d.geometry.TriangleMesh()
105
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
106
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
107
+ self._safe_make_dir(path)
108
+ o3d.io.write_triangle_mesh(path, mesh)
109
+ return
110
+ name = path.removesuffix('.obj')
111
+ path = name + ".obj"
112
+ self._safe_make_dir(path)
113
+ with open(path, 'w') as file:
114
+ file.write("o mesh\n")
115
+ _vertices = []
116
+ for co in vertices:
117
+ _vertices.append(f"v {co[0]} {co[2]} {-co[1]}\n")
118
+ file.writelines(_vertices)
119
+ _faces = []
120
+ for face in faces:
121
+ _faces.append(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
122
+ file.writelines(_faces)
123
+
124
+ def _export_pc(self, vertices: ndarray, path: str, vertex_normals: Union[ndarray, None]=None, normal_size: float=0.01):
125
+ if path.endswith('.ply'):
126
+ if vertex_normals is not None:
127
+ print("normal result will not be displayed in .ply format")
128
+ name = path.removesuffix('.ply')
129
+ path = name + ".ply"
130
+ pc = o3d.geometry.PointCloud()
131
+ pc.points = o3d.utility.Vector3dVector(vertices)
132
+ # segment fault when numpy >= 2.0 !! use torch environment
133
+ self._safe_make_dir(path)
134
+ o3d.io.write_point_cloud(path, pc)
135
+ return
136
+ name = path.removesuffix('.obj')
137
+ path = name + ".obj"
138
+ self._safe_make_dir(path)
139
+ with open(path, 'w') as file:
140
+ file.write("o pc\n")
141
+ _vertex = []
142
+ for co in vertices:
143
+ _vertex.append(f"v {co[0]} {co[2]} {-co[1]}\n")
144
+ file.writelines(_vertex)
145
+ if vertex_normals is not None:
146
+ new_path = path.replace('.obj', '_normal.obj')
147
+ nfile = open(new_path, 'w')
148
+ nfile.write("o normal\n")
149
+ _normal = []
150
+ for i in range(vertices.shape[0]):
151
+ co = vertices[i]
152
+ x = vertex_normals[i, 0]
153
+ y = vertex_normals[i, 1]
154
+ z = vertex_normals[i, 2]
155
+ _normal.extend([
156
+ f"v {co[0]} {co[2]} {-co[1]}\n",
157
+ f"v {co[0]+0.0001} {co[2]} {-co[1]}\n",
158
+ f"v {co[0]+x*normal_size} {co[2]+z*normal_size} {-(co[1]+y*normal_size)}\n",
159
+ f"f {i*3+1} {i*3+2} {i*3+3}\n",
160
+ ])
161
+ nfile.writelines(_normal)
162
+
163
+ def _make_armature(
164
+ self,
165
+ vertices: Union[ndarray, None],
166
+ joints: ndarray,
167
+ skin: Union[ndarray, None],
168
+ parents: List[Union[int, None]],
169
+ names: List[str],
170
+ faces: Union[ndarray, None]=None,
171
+ extrude_size: float=0.03,
172
+ group_per_vertex: int=-1,
173
+ add_root: bool=False,
174
+ do_not_normalize: bool=False,
175
+ use_extrude_bone: bool=True,
176
+ use_connect_unique_child: bool=True,
177
+ extrude_from_parent: bool=True,
178
+ tails: Union[ndarray, None]=None,
179
+ ):
180
+ import bpy # type: ignore
181
+ from mathutils import Vector # type: ignore
182
+
183
+ # make collection
184
+ collection = bpy.data.collections.new('new_collection')
185
+ bpy.context.scene.collection.children.link(collection)
186
+
187
+ # make mesh
188
+ if vertices is not None:
189
+ mesh = bpy.data.meshes.new('mesh')
190
+ if faces is None:
191
+ faces = []
192
+ mesh.from_pydata(vertices, [], faces)
193
+ mesh.update()
194
+
195
+ # make object from mesh
196
+ object = bpy.data.objects.new('character', mesh)
197
+
198
+ # add object to scene collection
199
+ collection.objects.link(object)
200
+
201
+ # deselect mesh
202
+ bpy.ops.object.armature_add(enter_editmode=True)
203
+ armature = bpy.data.armatures.get('Armature')
204
+ edit_bones = armature.edit_bones
205
+
206
+ J = joints.shape[0]
207
+ if tails is None:
208
+ tails = joints.copy()
209
+ tails[:, 2] += extrude_size
210
+ connects = [False for _ in range(J)]
211
+ children = defaultdict(list)
212
+ for i in range(1, J):
213
+ children[parents[i]].append(i)
214
+ if tails is not None:
215
+ if use_extrude_bone:
216
+ for i in range(J):
217
+ if len(children[i]) != 1 and extrude_from_parent and i != 0:
218
+ pjoint = joints[parents[i]]
219
+ joint = joints[i]
220
+ d = joint - pjoint
221
+ if np.linalg.norm(d) < 0.000001:
222
+ d = np.array([0., 0., 1.]) # in case son.head == parent.head
223
+ else:
224
+ d = d / np.linalg.norm(d)
225
+ tails[i] = joint + d * extrude_size
226
+ if use_connect_unique_child:
227
+ for i in range(J):
228
+ if len(children[i]) == 1:
229
+ child = children[i][0]
230
+ tails[i] = joints[child]
231
+ if parents[i] is not None and len(children[parents[i]]) == 1:
232
+ connects[i] = True
233
+
234
+ if add_root:
235
+ bone_root = edit_bones.get('Bone')
236
+ bone_root.name = 'Root'
237
+ bone_root.tail = Vector((joints[0, 0], joints[0, 1], joints[0, 2]))
238
+ else:
239
+ bone_root = edit_bones.get('Bone')
240
+ bone_root.name = names[0]
241
+ bone_root.head = Vector((joints[0, 0], joints[0, 1], joints[0, 2]))
242
+ bone_root.tail = Vector((joints[0, 0], joints[0, 1], joints[0, 2] + extrude_size))
243
+
244
+ def extrude_bone(
245
+ edit_bones,
246
+ name: str,
247
+ parent_name: str,
248
+ head: Tuple[float, float, float],
249
+ tail: Tuple[float, float, float],
250
+ connect: bool
251
+ ):
252
+ bone = edit_bones.new(name)
253
+ bone.head = Vector((head[0], head[1], head[2]))
254
+ bone.tail = Vector((tail[0], tail[1], tail[2]))
255
+ bone.name = name
256
+ parent_bone = edit_bones.get(parent_name)
257
+ bone.parent = parent_bone
258
+ bone.use_connect = connect
259
+ assert not np.isnan(head).any(), f"nan found in head of bone {name}"
260
+ assert not np.isnan(tail).any(), f"nan found in tail of bone {name}"
261
+
262
+ for i in range(J):
263
+ if add_root is False and i==0:
264
+ continue
265
+ edit_bones = armature.edit_bones
266
+ pname = 'Root' if parents[i] is None else names[parents[i]]
267
+ extrude_bone(edit_bones, names[i], pname, joints[i], tails[i], connects[i])
268
+ for i in range(J):
269
+ bone = edit_bones.get(names[i])
270
+ bone.head = Vector((joints[i, 0], joints[i, 1], joints[i, 2]))
271
+ bone.tail = Vector((tails[i, 0], tails[i, 1], tails[i, 2]))
272
+
273
+ if vertices is None or skin is None:
274
+ return
275
+ # must set to object mode to enable parent_set
276
+ bpy.ops.object.mode_set(mode='OBJECT')
277
+ objects = bpy.data.objects
278
+ for o in bpy.context.selected_objects:
279
+ o.select_set(False)
280
+ ob = objects['character']
281
+ arm = bpy.data.objects['Armature']
282
+ ob.select_set(True)
283
+ arm.select_set(True)
284
+ bpy.ops.object.parent_set(type='ARMATURE_NAME')
285
+ vis = []
286
+ for x in ob.vertex_groups:
287
+ vis.append(x.name)
288
+ #sparsify
289
+ argsorted = np.argsort(-skin, axis=1)
290
+ vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted]
291
+ if group_per_vertex == -1:
292
+ group_per_vertex = vertex_group_reweight.shape[-1]
293
+ if not do_not_normalize:
294
+ vertex_group_reweight = vertex_group_reweight / vertex_group_reweight[..., :group_per_vertex].sum(axis=1)[...,None]
295
+
296
+ for v, w in enumerate(skin):
297
+ for ii in range(group_per_vertex):
298
+ i = argsorted[v, ii]
299
+ if i >= J:
300
+ continue
301
+ n = names[i]
302
+ if n not in vis:
303
+ continue
304
+ ob.vertex_groups[n].add([v], vertex_group_reweight[v, ii], 'REPLACE')
305
+
306
+ def _clean_bpy(self):
307
+ import bpy # type: ignore
308
+ for c in bpy.data.actions:
309
+ bpy.data.actions.remove(c)
310
+ for c in bpy.data.armatures:
311
+ bpy.data.armatures.remove(c)
312
+ for c in bpy.data.cameras:
313
+ bpy.data.cameras.remove(c)
314
+ for c in bpy.data.collections:
315
+ bpy.data.collections.remove(c)
316
+ for c in bpy.data.images:
317
+ bpy.data.images.remove(c)
318
+ for c in bpy.data.materials:
319
+ bpy.data.materials.remove(c)
320
+ for c in bpy.data.meshes:
321
+ bpy.data.meshes.remove(c)
322
+ for c in bpy.data.objects:
323
+ bpy.data.objects.remove(c)
324
+ for c in bpy.data.textures:
325
+ bpy.data.textures.remove(c)
326
+
327
+ def _export_fbx(
328
+ self,
329
+ path: str,
330
+ vertices: Union[ndarray, None],
331
+ joints: ndarray,
332
+ skin: Union[ndarray, None],
333
+ parents: List[Union[int, None]],
334
+ names: List[str],
335
+ faces: Union[ndarray, None]=None,
336
+ extrude_size: float=0.03,
337
+ group_per_vertex: int=-1,
338
+ add_root: bool=False,
339
+ do_not_normalize: bool=False,
340
+ use_extrude_bone: bool=True,
341
+ use_connect_unique_child: bool=True,
342
+ extrude_from_parent: bool=True,
343
+ tails: Union[ndarray, None]=None,
344
+ ):
345
+ '''
346
+ Requires bpy installed
347
+ '''
348
+ import bpy # type: ignore
349
+ self._safe_make_dir(path)
350
+ self._clean_bpy()
351
+ self._make_armature(
352
+ vertices=vertices,
353
+ joints=joints,
354
+ skin=skin,
355
+ parents=parents,
356
+ names=names,
357
+ faces=faces,
358
+ extrude_size=extrude_size,
359
+ group_per_vertex=group_per_vertex,
360
+ add_root=add_root,
361
+ do_not_normalize=do_not_normalize,
362
+ use_extrude_bone=use_extrude_bone,
363
+ use_connect_unique_child=use_connect_unique_child,
364
+ extrude_from_parent=extrude_from_parent,
365
+ tails=tails,
366
+ )
367
+
368
+ # always enable add_leaf_bones to keep leaf bones
369
+ bpy.ops.export_scene.fbx(filepath=path, check_existing=False, add_leaf_bones=False)
370
+
371
+ def _export_render(
372
+ self,
373
+ path: str,
374
+ vertices: Union[ndarray, None],
375
+ faces: Union[ndarray, None],
376
+ bones: Union[ndarray, None],
377
+ resolution: Tuple[float, float]=[256, 256],
378
+ ):
379
+ import bpy # type: ignore
380
+ import bpy_extras # type: ignore
381
+ from mathutils import Vector # type: ignore
382
+
383
+ self._safe_make_dir(path)
384
+ # normalize into [-1, 1]^3
385
+ # copied from augment
386
+ assert (vertices is not None) or (bones is not None)
387
+ bounds = []
388
+ if vertices is not None:
389
+ bounds.append(vertices)
390
+ if bones is not None:
391
+ bounds.append(bones[:, :3])
392
+ bounds.append(bones[:, 3:])
393
+ bounds = np.concatenate(bounds, axis=0)
394
+ bound_min = bounds.min(axis=0)
395
+ bound_max = bounds.max(axis=0)
396
+
397
+ trans_vertex = np.eye(4)
398
+
399
+ trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex
400
+
401
+ # scale into the cube [-1, 1]
402
+ scale = np.max((bound_max - bound_min) / 2)
403
+ trans_vertex = _scale_to_m(1. / scale) @ trans_vertex
404
+
405
+ def _apply(v: ndarray, trans: ndarray) -> ndarray:
406
+ return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
407
+
408
+ if vertices is not None:
409
+ vertices = _apply(vertices, trans_vertex)
410
+ if bones is not None:
411
+ bones[:, :3] = _apply(bones[:, :3], trans_vertex)
412
+ bones[:, 3:] = _apply(bones[:, 3:], trans_vertex)
413
+
414
+ # bpy api calls
415
+ self._clean_bpy()
416
+ bpy.context.scene.render.engine = 'BLENDER_WORKBENCH'
417
+ bpy.context.scene.render.film_transparent = True
418
+ bpy.context.scene.display.shading.background_type = 'VIEWPORT'
419
+
420
+ collection = bpy.data.collections.new('new_collection')
421
+ bpy.context.scene.collection.children.link(collection)
422
+
423
+ if vertices is not None:
424
+ mesh_data = bpy.data.meshes.new(name="MeshData")
425
+ mesh_obj = bpy.data.objects.new(name="MeshObject", object_data=mesh_data)
426
+ collection.objects.link(mesh_obj)
427
+
428
+ mesh_data.from_pydata((vertices).tolist(), [], faces.tolist())
429
+ mesh_data.update()
430
+
431
+ def look_at(camera, point):
432
+ direction = point - camera.location
433
+ rot_quat = direction.to_track_quat('-Z', 'Y')
434
+ camera.rotation_euler = rot_quat.to_euler()
435
+
436
+ bpy.ops.object.camera_add(location=(4, -4, 2.5))
437
+ camera = bpy.context.object
438
+ camera.data.angle = np.radians(25.0)
439
+ look_at(camera, Vector((0, 0, -0.2)))
440
+ bpy.context.scene.camera = camera
441
+
442
+ bpy.context.scene.render.resolution_x = resolution[0]
443
+ bpy.context.scene.render.resolution_y = resolution[1]
444
+ bpy.context.scene.render.image_settings.file_format = 'PNG'
445
+ bpy.context.scene.render.filepath = path
446
+
447
+ bpy.ops.render.render(write_still=True)
448
+ # some AI generated code to draw bones over mesh
449
+ if bones is not None:
450
+ # TODO: do not save image after rendering
451
+ from PIL import Image, ImageDraw
452
+ img_pil = Image.open(path).convert("RGBA")
453
+ draw = ImageDraw.Draw(img_pil)
454
+
455
+ from bpy_extras.image_utils import load_image # type: ignore
456
+ bpy.context.scene.use_nodes = True
457
+ nodes = bpy.context.scene.node_tree.nodes
458
+ # nodes.clear()
459
+
460
+ img = load_image(path)
461
+ image_node = nodes.new(type='CompositorNodeImage')
462
+ image_node.image = img
463
+
464
+ for i, bone in enumerate(bones):
465
+ head, tail = bone[:3], bone[3:]
466
+ head_2d = bpy_extras.object_utils.world_to_camera_view(bpy.context.scene, camera, Vector(head))
467
+ tail_2d = bpy_extras.object_utils.world_to_camera_view(bpy.context.scene, camera, Vector(tail))
468
+
469
+ res_x, res_y = resolution
470
+ head_pix = (head_2d.x * res_x, (1 - head_2d.y) * res_y)
471
+ tail_pix = (tail_2d.x * res_x, (1 - tail_2d.y) * res_y)
472
+ draw.line([head_pix, tail_pix], fill=(255, 0, 0, 255), width=1)
473
+ img_pil.save(path)
474
+
475
+ def _trans_to_m(v: ndarray):
476
+ m = np.eye(4)
477
+ m[0:3, 3] = v
478
+ return m
479
+
480
+ def _scale_to_m(r: ndarray):
481
+ m = np.zeros((4, 4))
482
+ m[0, 0] = r
483
+ m[1, 1] = r
484
+ m[2, 2] = r
485
+ m[3, 3] = 1.
486
+ return m
UniRig/src/data/extract.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bpy, os
2
+ from collections import defaultdict
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ from numpy import ndarray
6
+ from typing import Dict, Tuple, List, Optional, Union
7
+ import trimesh
8
+ import fast_simplification
9
+ from scipy.spatial import KDTree
10
+
11
+ import argparse
12
+ import yaml
13
+ from box import Box
14
+ import os
15
+
16
+ from .log import new_entry, add_error, add_warning, new_log, end_log
17
+ from .raw_data import RawData
18
+
19
+ def load(filepath: str):
20
+ old_objs = set(bpy.context.scene.objects)
21
+
22
+ if not os.path.exists(filepath):
23
+ raise ValueError(f'File {filepath} does not exist !')
24
+
25
+ try:
26
+ if filepath.endswith(".vrm"):
27
+ # enable vrm addon and load vrm model
28
+ bpy.ops.preferences.addon_enable(module='vrm')
29
+
30
+ bpy.ops.import_scene.vrm(
31
+ filepath=filepath,
32
+ use_addon_preferences=True,
33
+ extract_textures_into_folder=False,
34
+ make_new_texture_folder=False,
35
+ set_shading_type_to_material_on_import=False,
36
+ set_view_transform_to_standard_on_import=True,
37
+ set_armature_display_to_wire=True,
38
+ set_armature_display_to_show_in_front=True,
39
+ set_armature_bone_shape_to_default=True,
40
+ disable_bake=True, # customized option for better performance
41
+ )
42
+ elif filepath.endswith(".obj"):
43
+ bpy.ops.wm.obj_import(filepath=filepath)
44
+ elif filepath.endswith(".fbx") or filepath.endswith(".FBX"):
45
+ # end bone is removed using remove_dummy_bone
46
+ bpy.ops.import_scene.fbx(filepath=filepath, ignore_leaf_bones=False, use_image_search=False)
47
+ elif filepath.endswith(".glb") or filepath.endswith(".gltf"):
48
+ bpy.ops.import_scene.gltf(filepath=filepath, import_pack_images=False)
49
+ elif filepath.endswith(".dae"):
50
+ bpy.ops.wm.collada_import(filepath=filepath)
51
+ elif filepath.endswith(".blend"):
52
+ with bpy.data.libraries.load(filepath) as (data_from, data_to):
53
+ data_to.objects = data_from.objects
54
+ for obj in data_to.objects:
55
+ if obj is not None:
56
+ bpy.context.collection.objects.link(obj)
57
+ else:
58
+ raise ValueError(f"not suported type {filepath}")
59
+ except:
60
+ raise ValueError(f"failed to load {filepath}")
61
+
62
+ armature = [x for x in set(bpy.context.scene.objects)-old_objs if x.type=="ARMATURE"]
63
+ if len(armature)==0:
64
+ return None
65
+ if len(armature)>1:
66
+ raise ValueError(f"multiple armatures found")
67
+ armature = armature[0]
68
+
69
+ armature.select_set(True)
70
+ bpy.context.view_layer.objects.active = armature
71
+ bpy.ops.object.mode_set(mode='EDIT')
72
+ for bone in bpy.data.armatures[0].edit_bones:
73
+ bone.roll = 0. # change all roll to 0. to prevent weird behaviour
74
+
75
+ bpy.ops.object.mode_set(mode='OBJECT')
76
+ armature.select_set(False)
77
+
78
+ bpy.ops.object.select_all(action='DESELECT')
79
+ return armature
80
+
81
+ # remove all data in bpy
82
+ def clean_bpy():
83
+ # First try to purge orphan data
84
+ try:
85
+ bpy.ops.outliner.orphans_purge(do_local_ids=True, do_linked_ids=True, do_recursive=True)
86
+ except Exception as e:
87
+ print(f"Warning: Could not purge orphans: {e}")
88
+
89
+ # Then remove all data by type
90
+ data_types = [
91
+ bpy.data.actions,
92
+ bpy.data.armatures,
93
+ bpy.data.cameras,
94
+ bpy.data.collections,
95
+ bpy.data.curves,
96
+ bpy.data.images,
97
+ bpy.data.lights,
98
+ bpy.data.materials,
99
+ bpy.data.meshes,
100
+ bpy.data.objects,
101
+ bpy.data.textures,
102
+ bpy.data.worlds,
103
+ bpy.data.node_groups
104
+ ]
105
+
106
+ for data_collection in data_types:
107
+ try:
108
+ for item in data_collection:
109
+ try:
110
+ data_collection.remove(item)
111
+ except Exception as e:
112
+ print(f"Warning: Could not remove {item.name} from {data_collection}: {e}")
113
+ except Exception as e:
114
+ print(f"Warning: Error processing {data_collection}: {e}")
115
+
116
+ # Force garbage collection to free memory
117
+ import gc
118
+ gc.collect()
119
+
120
+ def get_arranged_bones(armature):
121
+ matrix_world = armature.matrix_world
122
+ arranged_bones = []
123
+ root = armature.pose.bones[0]
124
+ while root.parent is not None:
125
+ root = root.parent
126
+ Q = [root]
127
+ rot = np.array(matrix_world)[:3, :3]
128
+
129
+ # dfs and sort
130
+ while len(Q) != 0:
131
+ b = Q.pop(0)
132
+ arranged_bones.append(b)
133
+ children = []
134
+ for cb in b.children:
135
+ head = rot @ np.array(b.head)
136
+ children.append((cb, head[0], head[1], head[2]))
137
+ children = sorted(children, key=lambda x: (x[3], x[1], x[2]))
138
+ _c = [x[0] for x in children]
139
+ Q = _c + Q
140
+ return arranged_bones
141
+
142
+ def process_mesh():
143
+ meshes = []
144
+ for v in bpy.data.objects:
145
+ if v.type == 'MESH':
146
+ meshes.append(v)
147
+
148
+ _dict_mesh = {}
149
+ for obj in meshes:
150
+ m = np.array(obj.matrix_world)
151
+ matrix_world_rot = m[:3, :3]
152
+ matrix_world_bias = m[:3, 3]
153
+ rot = matrix_world_rot
154
+ total_vertices = len(obj.data.vertices)
155
+ vertex = np.zeros((4, total_vertices))
156
+ vertex_normal = np.zeros((total_vertices, 3))
157
+ obj_verts = obj.data.vertices
158
+ faces = []
159
+ normals = []
160
+
161
+ for v in obj_verts:
162
+ vertex_normal[v.index] = rot @ np.array(v.normal) # be careful !
163
+ vv = rot @ v.co
164
+ vv = np.array(vv) + matrix_world_bias
165
+ vertex[0:3, v.index] = vv
166
+ vertex[3][v.index] = 1 # affine coordinate
167
+
168
+ for polygon in obj.data.polygons:
169
+ edges = polygon.edge_keys
170
+ nodes = []
171
+ adj = {}
172
+ for edge in edges:
173
+ if adj.get(edge[0]) is None:
174
+ adj[edge[0]] = []
175
+ adj[edge[0]].append(edge[1])
176
+ if adj.get(edge[1]) is None:
177
+ adj[edge[1]] = []
178
+ adj[edge[1]].append(edge[0])
179
+ nodes.append(edge[0])
180
+ nodes.append(edge[1])
181
+ normal = polygon.normal
182
+ nodes = list(set(sorted(nodes)))
183
+ first = nodes[0]
184
+ loop = []
185
+ now = first
186
+ vis = {}
187
+ while True:
188
+ loop.append(now)
189
+ vis[now] = True
190
+ if vis.get(adj[now][0]) is None:
191
+ now = adj[now][0]
192
+ elif vis.get(adj[now][1]) is None:
193
+ now = adj[now][1]
194
+ else:
195
+ break
196
+ for (second, third) in zip(loop[1:], loop[2:]):
197
+ faces.append((first + 1, second + 1, third + 1)) # the cursed +1
198
+ normals.append(rot @ normal) # and the cursed normal of BLENDER
199
+
200
+ correct_faces = []
201
+ for (i, face) in enumerate(faces):
202
+ normal = normals[i]
203
+ v0 = face[0] - 1
204
+ v1 = face[1] - 1
205
+ v2 = face[2] - 1
206
+ v = np.cross(
207
+ vertex[:3, v1] - vertex[:3, v0],
208
+ vertex[:3, v2] - vertex[:3, v0],
209
+ )
210
+ if (v*normal).sum() > 0:
211
+ correct_faces.append(face)
212
+ else:
213
+ correct_faces.append((face[0], face[2], face[1]))
214
+ if len(correct_faces) > 0:
215
+ _dict_mesh[obj.name] = {
216
+ 'vertex': vertex,
217
+ 'face': correct_faces,
218
+ }
219
+
220
+ vertex = np.concatenate([_dict_mesh[name]['vertex'] for name in _dict_mesh], axis=1)[:3, :].transpose()
221
+
222
+ total_faces = 0
223
+ now_bias = 0
224
+ for name in _dict_mesh:
225
+ total_faces += len(_dict_mesh[name]['face'])
226
+ faces = np.zeros((total_faces, 3), dtype=np.int64)
227
+ tot = 0
228
+ for name in _dict_mesh:
229
+ f = np.array(_dict_mesh[name]['face'], dtype=np.int64)
230
+ faces[tot:tot+f.shape[0]] = f + now_bias
231
+ now_bias += _dict_mesh[name]['vertex'].shape[1]
232
+ tot += f.shape[0]
233
+
234
+ return vertex, faces
235
+
236
+ def process_armature(
237
+ armature,
238
+ arranged_bones,
239
+ ) -> Tuple[np.ndarray, np.ndarray]:
240
+ matrix_world = armature.matrix_world
241
+ index = {}
242
+
243
+ for (id, pbone) in enumerate(arranged_bones):
244
+ index[pbone.name] = id
245
+
246
+ root = armature.pose.bones[0]
247
+ while root.parent is not None:
248
+ root = root.parent
249
+ m = np.array(matrix_world.to_4x4())
250
+ scale_inv = np.linalg.inv(np.diag(matrix_world.to_scale()))
251
+ rot = m[:3, :3]
252
+ bias = m[:3, 3]
253
+
254
+ s = []
255
+ bpy.ops.object.editmode_toggle()
256
+ edit_bones = armature.data.edit_bones
257
+
258
+ J = len(arranged_bones)
259
+ joints = np.zeros((J, 3), dtype=np.float32)
260
+ tails = np.zeros((J, 3), dtype=np.float32)
261
+ parents = []
262
+ name_to_id = {}
263
+ names = []
264
+ matrix_local_stack = np.zeros((J, 4, 4), dtype=np.float32)
265
+ for (id, pbone) in enumerate(arranged_bones):
266
+ name = pbone.name
267
+ names.append(name)
268
+ matrix_local = np.array(pbone.bone.matrix_local)
269
+ use_inherit_rotation = pbone.bone.use_inherit_rotation
270
+ if use_inherit_rotation == False:
271
+ add_warning(f"use_inherit_rotation of bone {name} is False !")
272
+ head = rot @ matrix_local[0:3, 3] + bias
273
+ s.append(head)
274
+ edit_bone = edit_bones.get(name)
275
+ tail = rot @ np.array(edit_bone.tail) + bias
276
+
277
+ name_to_id[name] = id
278
+ joints[id] = head
279
+ tails[id] = tail
280
+ parents.append(None if pbone.parent not in arranged_bones else name_to_id[pbone.parent.name])
281
+ # remove scale part
282
+ matrix_local[:, 3:4] = m @ matrix_local[:, 3:4]
283
+ matrix_local[:3, :3] = scale_inv @ matrix_local[:3, :3]
284
+ matrix_local_stack[id] = matrix_local
285
+ bpy.ops.object.editmode_toggle()
286
+
287
+ return joints, tails, parents, names, matrix_local_stack
288
+
289
+ def save_raw_data(
290
+ path: str,
291
+ vertices: ndarray,
292
+ faces: ndarray,
293
+ joints: Union[ndarray, None],
294
+ tails: Union[ndarray, None],
295
+ parents: Union[List[Union[int, None]], None],
296
+ names: Union[List[str], None],
297
+ matrix_local: Union[ndarray, None],
298
+ target_count: int,
299
+ ):
300
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
301
+ vertices = np.array(mesh.vertices, dtype=np.float32)
302
+ faces = np.array(mesh.faces, dtype=np.int64)
303
+ if faces.shape[0] > target_count:
304
+ vertices, faces = fast_simplification.simplify(vertices, faces, target_count=target_count)
305
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
306
+
307
+ new_vertices = np.array(mesh.vertices, dtype=np.float32)
308
+ new_vertex_normals = np.array(mesh.vertex_normals, dtype=np.float32)
309
+ new_faces = np.array(mesh.faces, dtype=np.int64)
310
+ new_face_normals = np.array(mesh.face_normals, dtype=np.float32)
311
+ if joints is not None:
312
+ new_joints = np.array(joints, dtype=np.float32)
313
+ else:
314
+ new_joints = None
315
+ raw_data = RawData(
316
+ vertices=new_vertices,
317
+ vertex_normals=new_vertex_normals,
318
+ faces=new_faces,
319
+ face_normals=new_face_normals,
320
+ joints=new_joints,
321
+ tails=tails,
322
+ skin=None,
323
+ no_skin=None,
324
+ parents=parents,
325
+ names=names,
326
+ matrix_local=matrix_local,
327
+ )
328
+ raw_data.check()
329
+ raw_data.save(path=path)
330
+
331
+ def extract_builtin(
332
+ output_folder: str,
333
+ target_count: int,
334
+ num_runs: int,
335
+ id: int,
336
+ time: str,
337
+ files: List[Union[str, str]],
338
+ ):
339
+ log_path = "./logs"
340
+ log_path = os.path.join(log_path, time)
341
+
342
+ num_files = len(files)
343
+ gap = num_files // num_runs
344
+ start = gap * id
345
+ end = gap * (id + 1)
346
+ if id+1==num_runs:
347
+ end = num_files
348
+
349
+ files = sorted(files)
350
+ if end!=-1:
351
+ files = files[:end]
352
+ new_log(log_path, f"extract_builtin_{start}_{end}")
353
+ tot = 0
354
+ for file in tqdm(files[start:]):
355
+ input_file = file[0]
356
+ output_dir = file[1]
357
+ clean_bpy()
358
+ new_entry(input_file)
359
+ try:
360
+ print(f"Now processing {input_file}...")
361
+
362
+ armature = load(input_file)
363
+
364
+ print('save to:', output_dir)
365
+ os.makedirs(output_dir, exist_ok=True)
366
+
367
+ vertices, faces = process_mesh()
368
+ if armature is not None:
369
+ arranged_bones = get_arranged_bones(armature)
370
+ joints, tails, parents, names, matrix_local = process_armature(armature, arranged_bones)
371
+
372
+ else:
373
+ joints = None
374
+ tails = None
375
+ parents = None
376
+ names = None
377
+ matrix_local = None
378
+
379
+ save_file = os.path.join(output_dir, 'raw_data.npz')
380
+ save_raw_data(
381
+ path=save_file,
382
+ vertices=vertices,
383
+ faces=faces-1,
384
+ joints=joints,
385
+ tails=tails,
386
+ parents=parents,
387
+ names=names,
388
+ matrix_local=matrix_local,
389
+ target_count=target_count,
390
+ )
391
+
392
+ tot += 1
393
+
394
+ except ValueError as e:
395
+ add_error(str(e))
396
+ print(f"ValueError: {str(e)}")
397
+ except RuntimeError as e:
398
+ add_error(str(e))
399
+ print(f"RuntimeError: {str(e)}")
400
+ except TimeoutError as e:
401
+ add_error("time out")
402
+ print("TimeoutError: Processing timed out")
403
+ except Exception as e:
404
+ add_error(f"Unexpected error: {str(e)}")
405
+ print(f"Unexpected error: {str(e)}")
406
+ end_log()
407
+ print(f"{tot} models processed")
408
+
409
+ def str2bool(v):
410
+ if isinstance(v, bool):
411
+ return v
412
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
413
+ return True
414
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
415
+ return False
416
+ else:
417
+ raise argparse.ArgumentTypeError('Boolean value expected.')
418
+
419
+ def nullable_string(val):
420
+ if not val:
421
+ return None
422
+ return val
423
+
424
+ def get_files(
425
+ data_name: str,
426
+ input_dataset_dir: str,
427
+ output_dataset_dir: str,
428
+ inputs: Union[str, None]=None,
429
+ require_suffix: List[str]=['obj','fbx','FBX','dae','glb','gltf','vrm'],
430
+ force_override: bool=False,
431
+ warning: bool=True,
432
+ ) -> List[Tuple[str, str]]:
433
+
434
+ files = [] # (input_file, output_dir)
435
+ if inputs is not None: # specified input file(s)
436
+ vis = {}
437
+ inputs = inputs.split(',')
438
+ for file in inputs:
439
+ file_name = file.removeprefix("./")
440
+ # remove suffix
441
+ file_name = '.'.join(file_name.split('.')[:-1])
442
+ output_dir = os.path.join(output_dataset_dir, file_name)
443
+ raw_data_npz = os.path.join(output_dir, data_name)
444
+ if not force_override and os.path.exists(raw_data_npz):
445
+ continue
446
+ if warning and output_dir in vis:
447
+ print(f"\033[33mWARNING: duplicate output directory: {output_dir}, you need to rename prefix of files to avoid ambiguity\033[0m")
448
+ vis[output_dir] = True
449
+ files.append((file, output_dir))
450
+ else:
451
+ vis = {}
452
+ for root, dirs, f in os.walk(input_dataset_dir):
453
+ for file in f:
454
+ if file.split('.')[-1] in require_suffix:
455
+ file_name = file.removeprefix("./")
456
+ # remove suffix
457
+ file_name = '.'.join(file_name.split('.')[:-1])
458
+
459
+ output_dir = os.path.join(output_dataset_dir, os.path.relpath(root, input_dataset_dir), file_name)
460
+ raw_data_npz = os.path.join(output_dir, data_name)
461
+
462
+ # Check if all required files exist
463
+ if not force_override and os.path.exists(raw_data_npz):
464
+ continue
465
+ if warning and output_dir in vis:
466
+ print(f"\033[33mWARNING: duplicate output directory: {output_dir}, you need to rename prefix of files to avoid ambiguity\033[0m")
467
+ vis[output_dir] = True
468
+ files.append((os.path.join(root, file), output_dir))
469
+
470
+ return files
471
+
472
+ def parse():
473
+ parser = argparse.ArgumentParser()
474
+ parser.add_argument('--config', type=str, required=True)
475
+ parser.add_argument('--require_suffix', type=str, required=True)
476
+ parser.add_argument('--faces_target_count', type=int, required=True)
477
+ parser.add_argument('--num_runs', type=int, required=True)
478
+ parser.add_argument('--force_override', type=str2bool, required=True)
479
+ parser.add_argument('--id', type=int, required=True)
480
+ parser.add_argument('--time', type=str, required=True)
481
+
482
+ parser.add_argument('--input', type=nullable_string, required=False, default=None)
483
+ parser.add_argument('--input_dir', type=nullable_string, required=False, default=None)
484
+ parser.add_argument('--output_dir', type=nullable_string, required=False, default=None)
485
+ return parser.parse_args()
486
+
487
+ if __name__ == "__main__":
488
+ args = parse()
489
+
490
+ config = Box(yaml.safe_load(open(args.config, "r")))
491
+
492
+ num_runs = args.num_runs
493
+ id = args.id
494
+ timestamp = args.time
495
+ require_suffix = args.require_suffix.split(',')
496
+ force_override = args.force_override
497
+ target_count = args.faces_target_count
498
+
499
+ if args.input_dir:
500
+ config.input_dataset_dir = args.input_dir
501
+ if args.output_dir:
502
+ config.output_dataset_dir = args.output_dir
503
+
504
+ assert config.input_dataset_dir is not None or args.input is None, 'you cannot specify both input and input_dir'
505
+
506
+ files = get_files(
507
+ data_name='raw_data.npz',
508
+ inputs=args.input,
509
+ input_dataset_dir=config.input_dataset_dir,
510
+ output_dataset_dir=config.output_dataset_dir,
511
+ require_suffix=require_suffix,
512
+ force_override=force_override,
513
+ warning=True,
514
+ )
515
+
516
+ extract_builtin(
517
+ output_folder=config.output_dataset_dir,
518
+ target_count=target_count,
519
+ num_runs=num_runs,
520
+ id=id,
521
+ time=timestamp,
522
+ files=files,
523
+ )
UniRig/src/data/log.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ login_time = ''
5
+ log_filepath = ''
6
+
7
+ class Entry:
8
+ def __init__(self, entry_name):
9
+ self.entry = entry_name
10
+ self.error = None
11
+ self.warning = []
12
+
13
+ def have_error(self):
14
+ return self.error != None
15
+
16
+ def have_warning(self):
17
+ return len(self.warning) != 0
18
+
19
+ logs: List[Entry] = []
20
+
21
+ def new_log(path, log_name):
22
+ global login_time, log_filepath
23
+ log_filepath = os.path.join(path, f"{log_name}.txt")
24
+ os.makedirs(path, exist_ok=True)
25
+ with open(log_filepath, 'a') as file:
26
+ file.write(f"Log: {log_name}\n")
27
+
28
+ def end_log():
29
+ global log_filepath
30
+ with open(log_filepath, 'a') as file:
31
+ file.write(f"End of file\n")
32
+
33
+ def new_entry(entry_name):
34
+ global log_filepath
35
+ print(f"\033[32mNow processing {entry_name}...\033[0m")
36
+ logs.append(Entry(entry_name))
37
+
38
+ def add_error(error):
39
+ global log_filepath
40
+ print(f"\033[31mError found when processing {logs[-1].entry}: {error}\033[0m")
41
+ logs[-1].error = error
42
+ with open(log_filepath, 'a') as file:
43
+ file.write(f"Entry: {logs[-1].entry}, Error: {error}\n")
44
+
45
+ def add_warning(warning):
46
+ global log_filepath
47
+ print(f"\033[33mWarning found when processing {logs[-1].entry}: {warning}\033[0m")
48
+ logs[-1].warning.append(warning)
49
+ with open(log_filepath, 'a') as file:
50
+ file.write(f"Entry: {logs[-1].entry}, Warning: {warning}\n")
UniRig/src/data/order.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Union
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ import yaml
5
+ from box import Box
6
+
7
+ from .spec import ConfigSpec
8
+
9
+ @dataclass
10
+ class OrderConfig(ConfigSpec):
11
+ '''
12
+ Config to handle bones re-ordering.
13
+ '''
14
+
15
+ # {skeleton_name: path}
16
+ skeleton_path: Dict[str, str]
17
+
18
+ # {cls: {part_name: [bone_name_1, bone_name_2, ...]}}
19
+ parts: Dict[str, Dict[str, List[str]]]
20
+
21
+ # {cls: parts of bones to be arranged in [part_name_1, part_name_2, ...]}
22
+ parts_order: Dict[str, List[str]]
23
+
24
+ @classmethod
25
+ def parse(cls, config):
26
+ cls.check_keys(config)
27
+ skeleton_path = config.skeleton_path
28
+ parts = {}
29
+ parts_order = {}
30
+ for (cls, path) in skeleton_path.items():
31
+ assert cls not in parts, 'cls conflicts'
32
+ d = Box(yaml.safe_load(open(path, 'r')))
33
+ parts[cls] = d.parts
34
+ parts_order[cls] = d.parts_order
35
+ return OrderConfig(
36
+ skeleton_path=skeleton_path,
37
+ parts=parts,
38
+ parts_order=parts_order,
39
+ )
40
+
41
+ class Order():
42
+
43
+ # {part_name: [bone_name_1, bone_name_2, ...]}
44
+ parts: Dict[str, Dict[str, List[str]]]
45
+
46
+ # parts of bones to be arranged in [part_name_1, part_name_2, ...]
47
+ parts_order: Dict[str, List[str]]
48
+
49
+ def __init__(self, config: OrderConfig):
50
+ self.parts = config.parts
51
+ self.parts_order = config.parts_order
52
+
53
+ def part_exists(self, cls: str, part: str, names: List[str]) -> bool:
54
+ '''
55
+ Check if part exists.
56
+ '''
57
+ if part not in self.parts[cls]:
58
+ return False
59
+ for name in self.parts[cls][part]:
60
+ if name not in names:
61
+ return False
62
+ return True
63
+
64
+ def make_names(self, cls: Union[str, None], parts: List[Union[str, None]], num_bones: int) -> List[str]:
65
+ '''
66
+ Get names for specified cls.
67
+ '''
68
+ names = []
69
+ for part in parts:
70
+ if part is None: # spring
71
+ continue
72
+ if cls in self.parts and part in self.parts[cls]:
73
+ names.extend(self.parts[cls][part])
74
+ assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones"
75
+ for i in range(len(names), num_bones):
76
+ names.append(f"bone_{i}")
77
+ return names
78
+
79
+ def arrange_names(self, cls: str, names: List[str], parents: List[Union[int, None]]) -> Tuple[List[str], Dict[int, Union[str]]]:
80
+ '''
81
+ Arrange names according to required parts order.
82
+ '''
83
+ if cls not in self.parts_order:
84
+ return names, {0: None} # add a spring token
85
+ vis = defaultdict(bool)
86
+ name_to_id = {name: i for (i, name) in enumerate(names)}
87
+ new_names = []
88
+ parts_bias = {}
89
+ for part in self.parts_order[cls]:
90
+ if self.part_exists(cls=cls, part=part, names=names):
91
+ for name in self.parts[cls][part]:
92
+ vis[name] = True
93
+ flag = False
94
+ for name in self.parts[cls][part]:
95
+ pid = parents[name_to_id[name]]
96
+ if pid is None:
97
+ continue
98
+ if not vis[names[pid]]:
99
+ flag = True
100
+ break
101
+ if flag: # incorrect parts order and should immediately add a spring token
102
+ break
103
+ parts_bias[len(new_names)] = part
104
+ new_names.extend(self.parts[cls][part])
105
+ parts_bias[len(new_names)] = None # add a spring token
106
+ for name in names:
107
+ if name not in new_names:
108
+ new_names.append(name)
109
+ return new_names, parts_bias
110
+
111
+ def get_order(config: OrderConfig) -> Order:
112
+ return Order(config=config)
UniRig/src/data/raw_data.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import numpy as np
3
+ from numpy import ndarray
4
+
5
+ import os
6
+ from typing import Union, List, Tuple
7
+
8
+ from .exporter import Exporter
9
+
10
+ from ..tokenizer.spec import DetokenzeOutput
11
+ from .order import Order
12
+
13
+ @dataclass(frozen=True)
14
+ class RawData(Exporter):
15
+ '''
16
+ Dataclass to handle data from processed model files.
17
+ '''
18
+
19
+ # vertices of the mesh, shape (N, 3), float32
20
+ vertices: Union[ndarray, None]
21
+
22
+ # normals of vertices, shape (N, 3), float32
23
+ vertex_normals: Union[ndarray, None]
24
+
25
+ # faces of mesh, shape (F, 3), face id starts from 0 to F-1, int64
26
+ faces: Union[ndarray, None]
27
+
28
+ # face normal of mesh, shape (F, 3), float32
29
+ face_normals: Union[ndarray, None]
30
+
31
+ # joints of bones, shape (J, 3), float32
32
+ joints: Union[ndarray, None]
33
+
34
+ # tails of joints, shape (J, 3), float32
35
+ tails: Union[ndarray, None]
36
+
37
+ # skinning of joints, shape (N, J), float32
38
+ skin: Union[ndarray, None]
39
+
40
+ # whether the joint has skin, bool
41
+ no_skin: Union[ndarray, None]
42
+
43
+ # parents of joints, None represents no parent(a root joint)
44
+ # make sure parent[k] < k
45
+ parents: Union[List[Union[int, None]], None]
46
+
47
+ # names of joints
48
+ names: Union[List[str], None]
49
+
50
+ # local coordinate
51
+ matrix_local: Union[ndarray, None]
52
+
53
+ # path to data
54
+ path: Union[str, None]=None
55
+
56
+ # data cls
57
+ cls: Union[str, None]=None
58
+
59
+ @staticmethod
60
+ def load(path: str) -> 'RawData':
61
+ data = np.load(path, allow_pickle=True)
62
+ d = {name: data[name][()] for name in data}
63
+ d['path'] = path
64
+ return RawData(**d)
65
+
66
+ def save(self, path: str):
67
+ os.makedirs(os.path.dirname(path), exist_ok=True)
68
+ np.savez(file=path, **self.__dict__)
69
+
70
+ @property
71
+ def N(self):
72
+ '''
73
+ number of vertices
74
+ '''
75
+ return self.vertices.shape[0]
76
+
77
+ @property
78
+ def F(self):
79
+ '''
80
+ number of faces
81
+ '''
82
+ return self.faces.shape[0]
83
+
84
+ @property
85
+ def J(self):
86
+ '''
87
+ number of joints
88
+ '''
89
+ return self.joints.shape[0]
90
+
91
+ def check(self):
92
+ if self.names is not None and self.joints is not None:
93
+ assert len(self.names) == self.J
94
+ if self.names is not None and self.parents is not None:
95
+ assert len(self.names) == len(self.parents)
96
+ if self.parents is not None:
97
+ for (i, pid) in enumerate(self.parents):
98
+ if i==0:
99
+ assert pid is None
100
+ else:
101
+ assert pid is not None
102
+ assert pid < i
103
+
104
+ def export_pc(self, path: str, with_normal: bool=True, normal_size=0.01):
105
+ '''
106
+ export point cloud
107
+ '''
108
+ if with_normal:
109
+ self._export_pc(vertices=self.vertices, path=path, vertex_normals=self.vertex_normals, normal_size=normal_size)
110
+ else:
111
+ self._export_pc(vertices=self.vertices, path=path, vertex_normals=None, normal_size=normal_size)
112
+
113
+ def export_mesh(self, path: str):
114
+ '''
115
+ export mesh
116
+ '''
117
+ self._export_mesh(vertices=self.vertices, faces=self.faces, path=path)
118
+
119
+ def export_skeleton(self, path: str):
120
+ '''
121
+ export spring
122
+ '''
123
+ self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
124
+
125
+ def export_skeleton_sequence(self, path: str):
126
+ '''
127
+ export spring
128
+ '''
129
+ self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
130
+
131
+ def export_fbx(
132
+ self,
133
+ path: str,
134
+ extrude_size: float=0.03,
135
+ group_per_vertex: int=-1,
136
+ add_root: bool=False,
137
+ do_not_normalize: bool=False,
138
+ use_extrude_bone: bool=True,
139
+ use_connect_unique_child: bool=True,
140
+ extrude_from_parent: bool=True,
141
+ use_tail: bool=False,
142
+ custom_vertex_group: Union[ndarray, None]=None,
143
+ ):
144
+ '''
145
+ export the whole model with skining
146
+ '''
147
+ self._export_fbx(
148
+ path=path,
149
+ vertices=self.vertices,
150
+ joints=self.joints,
151
+ skin=self.skin if custom_vertex_group is None else custom_vertex_group,
152
+ parents=self.parents,
153
+ names=self.names,
154
+ faces=self.faces,
155
+ extrude_size=extrude_size,
156
+ group_per_vertex=group_per_vertex,
157
+ add_root=add_root,
158
+ do_not_normalize=do_not_normalize,
159
+ use_extrude_bone=use_extrude_bone,
160
+ use_connect_unique_child=use_connect_unique_child,
161
+ extrude_from_parent=extrude_from_parent,
162
+ tails=self.tails if use_tail else None,
163
+ )
164
+
165
+ def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256]):
166
+ self._export_render(
167
+ path=path,
168
+ vertices=self.vertices,
169
+ faces=self.faces,
170
+ bones=np.concatenate([self.joints, self.tails], axis=-1),
171
+ resolution=resolution,
172
+ )
173
+
174
+ @dataclass(frozen=True)
175
+ class RawSkeleton(Exporter):
176
+ '''
177
+ Dataclass to handle skeleton from AR.
178
+ '''
179
+ # joints of bones, shape (J, 3), float32
180
+ joints: Union[ndarray, None]
181
+
182
+ # tails of joints, shape (J, 3), float32
183
+ tails: Union[ndarray, None]
184
+
185
+ # whether the joint has skin, bool
186
+ no_skin: Union[ndarray, None]
187
+
188
+ # parents of joints, None represents no parent(a root joint)
189
+ # make sure parent[k] < k
190
+ parents: Union[List[Union[int, None]], None]
191
+
192
+ # names of joints
193
+ names: Union[List[str], None]
194
+
195
+ @staticmethod
196
+ def load(path: str) -> 'RawSkeleton':
197
+ data = np.load(path, allow_pickle=True)
198
+ return RawSkeleton(**{name: data[name][()] for name in data})
199
+
200
+ def save(self, path: str):
201
+ os.makedirs(os.path.dirname(path), exist_ok=True)
202
+ np.savez(file=path, **self.__dict__)
203
+
204
+ @staticmethod
205
+ def from_detokenize_output(res: DetokenzeOutput, order: Union[Order, None]) -> 'RawSkeleton':
206
+ J = len(res.bones)
207
+ names = order.make_names(cls=res.cls, parts=res.parts, num_bones=J)
208
+ joints = res.joints
209
+ p_joints = res.p_joints
210
+ parents = []
211
+ for (i, joint) in enumerate(joints):
212
+ if i == 0:
213
+ parents.append(None)
214
+ continue
215
+ p_joint = p_joints[i]
216
+ dis = 999999
217
+ pid = None
218
+ for j in reversed(range(i)):
219
+ n_dis = ((joints[j] - p_joint)**2).sum()
220
+ if n_dis < dis:
221
+ pid = j
222
+ dis = n_dis
223
+ parents.append(pid)
224
+ return RawSkeleton(
225
+ joints=joints,
226
+ tails=res.tails,
227
+ no_skin=res.no_skin,
228
+ parents=parents,
229
+ names=names,
230
+ )
231
+
232
+ def export_skeleton(self, path: str):
233
+ '''
234
+ export spring
235
+ '''
236
+ self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
237
+
238
+ def export_skeleton_sequence(self, path: str):
239
+ '''
240
+ export spring
241
+ '''
242
+ self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
243
+
244
+ def export_fbx(
245
+ self,
246
+ path: str,
247
+ extrude_size: float=0.03,
248
+ group_per_vertex: int=-1,
249
+ add_root: bool=False,
250
+ do_not_normalize: bool=False,
251
+ use_extrude_bone: bool=True,
252
+ use_connect_unique_child: bool=True,
253
+ extrude_from_parent: bool=True,
254
+ use_tail: bool=False,
255
+ ):
256
+ '''
257
+ export the whole model with skining
258
+ '''
259
+ self._export_fbx(
260
+ path=path,
261
+ vertices=None,
262
+ joints=self.joints,
263
+ skin=None,
264
+ parents=self.parents,
265
+ names=self.names,
266
+ faces=None,
267
+ extrude_size=extrude_size,
268
+ group_per_vertex=group_per_vertex,
269
+ add_root=add_root,
270
+ do_not_normalize=do_not_normalize,
271
+ use_extrude_bone=use_extrude_bone,
272
+ use_connect_unique_child=use_connect_unique_child,
273
+ extrude_from_parent=extrude_from_parent,
274
+ tails=self.tails if use_tail else None,
275
+ )
276
+
277
+ def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256]):
278
+ self._export_render(
279
+ path=path,
280
+ vertices=None,
281
+ faces=None,
282
+ bones=np.concatenate([self.joints, self.tails], axis=-1),
283
+ resolution=resolution,
284
+ )
285
+
286
+ @dataclass
287
+ class RawSkin(Exporter):
288
+ '''
289
+ Dataclass to handle skeleton from AR.
290
+ '''
291
+ # skin, shape (J, N)
292
+ skin: ndarray
293
+
294
+ # always sampled, shape (N, 3)
295
+ vertices: Union[ndarray, None]=None
296
+
297
+ # for future use, shape (J, 3)
298
+ joints: Union[ndarray, None]=None
299
+
300
+ @staticmethod
301
+ def load(path: str) -> 'RawSkin':
302
+ data = np.load(path, allow_pickle=True)
303
+ return RawSkin(**{name: data[name][()] for name in data})
304
+
305
+ def save(self, path: str):
306
+ os.makedirs(os.path.dirname(path), exist_ok=True)
307
+ np.savez(file=path, **self.__dict__)
UniRig/src/data/sampler.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from heapq import heappush, heappop, heapify
3
+ from dataclasses import dataclass
4
+ from abc import ABC, abstractmethod
5
+ import numpy as np
6
+ from numpy import ndarray
7
+
8
+ from typing import Dict, Tuple
9
+
10
+ from .asset import Asset
11
+ from .spec import ConfigSpec
12
+
13
+ @dataclass
14
+ class SamplerConfig(ConfigSpec):
15
+ '''
16
+ Config to handle bones re-ordering.
17
+ '''
18
+ # which sampler to use
19
+ method: str
20
+
21
+ # how many samples in total
22
+ num_samples: int
23
+
24
+ # how many vertex samples
25
+ vertex_samples: int
26
+
27
+ # kwargs
28
+ kwargs: Dict[str, Dict]
29
+
30
+ @classmethod
31
+ def parse(cls, config) -> 'SamplerConfig':
32
+ cls.check_keys(config)
33
+ return SamplerConfig(
34
+ method=config.method,
35
+ num_samples=config.get('num_samples', 0),
36
+ vertex_samples=config.get('vertex_samples', 0),
37
+ kwargs=config.get('kwargs', {}),
38
+ )
39
+
40
+ @dataclass
41
+ class SamplerResult():
42
+ # sampled vertices
43
+ vertices: ndarray
44
+
45
+ # sampled normals
46
+ normals: ndarray
47
+
48
+ # sampled vertex groups
49
+ vertex_groups: Dict[str, ndarray]
50
+
51
+ class Sampler(ABC):
52
+ '''
53
+ Abstract class for samplers.
54
+ '''
55
+
56
+ def _sample_barycentric(
57
+ self,
58
+ vertex_group: ndarray,
59
+ faces: ndarray,
60
+ face_index: ndarray,
61
+ random_lengths: ndarray,
62
+ ):
63
+ v_origins = vertex_group[faces[face_index, 0]]
64
+ v_vectors = vertex_group[faces[face_index, 1:]]
65
+ v_vectors -= v_origins[:, np.newaxis, :]
66
+
67
+ sample_vector = (v_vectors * random_lengths).sum(axis=1)
68
+ v_samples = sample_vector + v_origins
69
+ return v_samples
70
+
71
+ @abstractmethod
72
+ def __init__(self, config: SamplerConfig):
73
+ pass
74
+
75
+ @abstractmethod
76
+ def sample(
77
+ self,
78
+ asset: Asset,
79
+ ) -> SamplerResult:
80
+ '''
81
+ Return sampled vertices, sampled normals and vertex groups.
82
+ '''
83
+ pass
84
+
85
+ class SamplerOrigin(Sampler):
86
+ def __init__(self, config: SamplerConfig):
87
+ super().__init__(config)
88
+ self.num_samples = config.num_samples
89
+ self.vertex_samples = config.vertex_samples
90
+
91
+ def sample(
92
+ self,
93
+ asset: Asset,
94
+ ) -> SamplerResult:
95
+ perm = np.random.permutation(asset.vertices.shape[0])
96
+ if asset.vertices.shape[0] < self.num_samples:
97
+ m = self.num_samples - asset.vertices.shape[0]
98
+ perm = np.concatenate([perm, np.random.randint(0, asset.vertices.shape[0], (m,))])
99
+ perm = perm[:self.num_samples]
100
+ n_v = asset.vertices[perm]
101
+ n_n = asset.vertex_normals[perm]
102
+ n_vg = {name: v[perm] for name, v in asset.vertex_groups.items()}
103
+ return SamplerResult(
104
+ vertices=n_v,
105
+ normals=n_n,
106
+ vertex_groups=n_vg,
107
+ )
108
+
109
+ class SamplerMix(Sampler):
110
+ def __init__(self, config: SamplerConfig):
111
+ super().__init__(config)
112
+ self.num_samples = config.num_samples
113
+ self.vertex_samples = config.vertex_samples
114
+ assert self.num_samples >= self.vertex_samples, 'num_samples should >= vertex_samples'
115
+
116
+ @property
117
+ def mesh_preserve(self):
118
+ return self.num_samples==-1
119
+
120
+ def sample(
121
+ self,
122
+ asset: Asset,
123
+ ) -> SamplerResult:
124
+ # 1. sample vertices
125
+ num_samples = self.num_samples
126
+ perm = np.random.permutation(asset.vertices.shape[0])
127
+ vertex_samples = min(self.vertex_samples, asset.vertices.shape[0])
128
+ num_samples -= vertex_samples
129
+ perm = perm[:vertex_samples]
130
+ n_vertex = asset.vertices[perm]
131
+ n_normal = asset.vertex_normals[perm]
132
+ n_v = {name: v[perm] for name, v in asset.vertex_groups.items()}
133
+
134
+ # 2. sample surface
135
+ perm = np.random.permutation(num_samples)
136
+ vertex_samples, face_index, random_lengths = sample_surface(
137
+ num_samples=num_samples,
138
+ vertices=asset.vertices,
139
+ faces=asset.faces,
140
+ return_weight=True,
141
+ )
142
+ vertex_samples = np.concatenate([n_vertex, vertex_samples], axis=0)
143
+ normal_samples = np.concatenate([n_normal, asset.face_normals[face_index]], axis=0)
144
+ vertex_group_samples = {}
145
+ for n, v in asset.vertex_groups.items():
146
+ g = self._sample_barycentric(
147
+ vertex_group=v,
148
+ faces=asset.faces,
149
+ face_index=face_index,
150
+ random_lengths=random_lengths,
151
+ )
152
+ vertex_group_samples[n] = np.concatenate([n_v[n], g], axis=0)
153
+ return SamplerResult(
154
+ vertices=vertex_samples,
155
+ normals=normal_samples,
156
+ vertex_groups=vertex_group_samples,
157
+ )
158
+
159
+ def sample_surface(
160
+ num_samples: int,
161
+ vertices: ndarray,
162
+ faces: ndarray,
163
+ return_weight: bool=False,
164
+ ):
165
+ '''
166
+ Randomly pick samples according to face area.
167
+
168
+ See sample_surface: https://github.com/mikedh/trimesh/blob/main/trimesh/sample.py
169
+ '''
170
+ # get face area
171
+ offset_0 = vertices[faces[:, 1]] - vertices[faces[:, 0]]
172
+ offset_1 = vertices[faces[:, 2]] - vertices[faces[:, 0]]
173
+ face_weight = np.cross(offset_0, offset_1, axis=-1)
174
+ face_weight = (face_weight * face_weight).sum(axis=1)
175
+
176
+ weight_cum = np.cumsum(face_weight, axis=0)
177
+ face_pick = np.random.rand(num_samples) * weight_cum[-1]
178
+ face_index = np.searchsorted(weight_cum, face_pick)
179
+
180
+ # pull triangles into the form of an origin + 2 vectors
181
+ tri_origins = vertices[faces[:, 0]]
182
+ tri_vectors = vertices[faces[:, 1:]]
183
+ tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))
184
+
185
+ # pull the vectors for the faces we are going to sample from
186
+ tri_origins = tri_origins[face_index]
187
+ tri_vectors = tri_vectors[face_index]
188
+
189
+ # randomly generate two 0-1 scalar components to multiply edge vectors b
190
+ random_lengths = np.random.rand(len(tri_vectors), 2, 1)
191
+
192
+ random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0
193
+ random_lengths[random_test] -= 1.0
194
+ random_lengths = np.abs(random_lengths)
195
+
196
+ sample_vector = (tri_vectors * random_lengths).sum(axis=1)
197
+ vertex_samples = sample_vector + tri_origins
198
+ if not return_weight:
199
+ return vertex_samples
200
+ return vertex_samples, face_index, random_lengths
201
+
202
+ def get_sampler(config: SamplerConfig) -> Sampler:
203
+ method = config.method
204
+ if method=='origin':
205
+ sampler = SamplerOrigin(config)
206
+ elif method=='mix':
207
+ sampler = SamplerMix(config)
208
+ else:
209
+ raise ValueError(f"sampler method {method} not supported")
210
+ return sampler
UniRig/src/data/spec.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import fields
3
+
4
+ class ConfigSpec(ABC):
5
+ @classmethod
6
+ def check_keys(cls, config):
7
+ expect = [field.name for field in fields(cls)]
8
+ for key in config.keys():
9
+ if key not in expect:
10
+ raise ValueError(f"expect names {expect} in {cls.__name__}, found {key}")
11
+
12
+ @classmethod
13
+ @abstractmethod
14
+ def parse(cls, config):
15
+ pass
UniRig/src/data/tail.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ import numpy as np
4
+ from numpy import ndarray
5
+
6
+ from typing import Tuple
7
+
8
+ from .asset import Asset
9
+ from .spec import ConfigSpec
10
+
11
+ @dataclass
12
+ class TailConfig(ConfigSpec):
13
+ '''
14
+ Config to handle tails.
15
+ '''
16
+
17
+ # copy joints to tails
18
+ copy_joint_to_tail: bool
19
+
20
+ # if the joint has only one son, then connect tail to son's joint
21
+ connect_tail_to_unique_son: bool
22
+
23
+ @classmethod
24
+ def parse(cls, config) -> 'TailConfig':
25
+ cls.check_keys(config)
26
+ return TailConfig(
27
+ copy_joint_to_tail=config.copy_joint_to_tail,
28
+ connect_tail_to_unique_son=config.connect_tail_to_unique_son,
29
+ )
30
+
31
+ class Tail():
32
+
33
+ def __init__(self, config: TailConfig):
34
+ self.config = config
35
+
36
+ def process_tail(self, asset: Asset):
37
+ if self.config.copy_joint_to_tail:
38
+ assert asset.tails is None, 'copying joints to existing tails is not permitted, please change copy_joint_to_tail to False in transform config'
39
+ asset.tails = asset.joints.copy()
40
+ if self.config.connect_tail_to_unique_son and asset.tails is not None:
41
+ children = defaultdict(list)
42
+ for (id, p) in enumerate(asset.parents):
43
+ if p is not None:
44
+ children[p].append(id)
45
+ for i in range(asset.J):
46
+ if len(children[i]) == 1:
47
+ asset.tails[i] = asset.joints[children[i][0]]
48
+
49
+ def get_tail(config: TailConfig) -> Tail:
50
+ return Tail(config=config)
UniRig/src/data/transform.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import Union, List, Tuple
4
+ from copy import deepcopy
5
+
6
+ from .asset import Asset
7
+ from .augment import AugmentConfig, Augment, get_augments
8
+ from .order import OrderConfig, Order, get_order
9
+ from .sampler import SamplerConfig, get_sampler
10
+ from .vertex_group import VertexGroupConfig, get_vertex_groups
11
+ from .tail import TailConfig, get_tail
12
+ from .spec import ConfigSpec
13
+
14
+ @dataclass
15
+ class TransformConfig(ConfigSpec):
16
+
17
+ tail_config: Union[TailConfig, None]=None,
18
+
19
+ order_config: Union[OrderConfig, None]=None,
20
+
21
+ vertex_group_config: Union[VertexGroupConfig, None]=None,
22
+
23
+ augment_config: Union[AugmentConfig, None]=None,
24
+
25
+ sampler_config: Union[SamplerConfig, None]=None,
26
+
27
+ @classmethod
28
+ def parse(cls, config) -> 'TransformConfig':
29
+ cls.check_keys(config)
30
+ tail_config = config.get('tail_config', None)
31
+ order_config = config.get('order_config', None)
32
+ vertex_group_config = config.get('vertex_group_config', None)
33
+ augment_config = config.get('augment_config', None)
34
+ sampler_config = config.get('sampler_config', None)
35
+
36
+ if tail_config is not None:
37
+ tail_config = TailConfig.parse(config=tail_config)
38
+ if order_config is not None:
39
+ order_config = OrderConfig.parse(config=order_config)
40
+ if vertex_group_config is not None:
41
+ vertex_group_config = VertexGroupConfig.parse(config=vertex_group_config)
42
+ if augment_config is not None:
43
+ augment_config = AugmentConfig.parse(config=augment_config)
44
+ if sampler_config is not None:
45
+ sampler_config = SamplerConfig.parse(config=sampler_config)
46
+
47
+ return TransformConfig(
48
+ tail_config=tail_config,
49
+ order_config=order_config,
50
+ vertex_group_config=vertex_group_config,
51
+ augment_config=augment_config,
52
+ sampler_config=sampler_config,
53
+ )
54
+
55
+ def transform_asset(
56
+ asset: Asset,
57
+ transform_config: TransformConfig,
58
+ ) -> Tuple[List[Augment], List[Augment]]:
59
+ assert isinstance(transform_config, TransformConfig), f"found {type(transform_config)}"
60
+ # 1. try processing tails
61
+ # TODO: use a better method
62
+ if transform_config.tail_config is not None:
63
+ tail = get_tail(config=transform_config.tail_config)
64
+ tail.process_tail(asset=asset)
65
+
66
+ # 2. arrange bones
67
+ if transform_config.order_config is not None:
68
+ order = get_order(config=transform_config.order_config)
69
+ asset.set_order(order=order)
70
+
71
+ # 3. collapse must perform first
72
+ if transform_config.augment_config:
73
+ first_augments, second_augments = get_augments(config=transform_config.augment_config)
74
+ else:
75
+ first_augments = []
76
+ second_augments = []
77
+
78
+ kwargs = {}
79
+ for augment in first_augments:
80
+ augment.transform(asset=asset, **kwargs)
81
+
82
+ # 4. get vertex groups
83
+ if transform_config.vertex_group_config is not None:
84
+ vertex_groups = get_vertex_groups(config=transform_config.vertex_group_config)
85
+ d = {}
86
+ for v in vertex_groups:
87
+ d.update(v.get_vertex_group(asset=asset))
88
+ asset.vertex_groups = d
89
+ else:
90
+ asset.vertex_groups = {}
91
+
92
+ # 5. regular augments
93
+ for augment in second_augments:
94
+ augment.transform(asset=asset, **kwargs)
95
+
96
+ # 6. sample
97
+ if transform_config.sampler_config is not None:
98
+ sampler = get_sampler(config=transform_config.sampler_config)
99
+ res = sampler.sample(asset=asset)
100
+ asset.sampled_vertices = res.vertices
101
+ asset.sampled_normals = res.normals
102
+ asset.sampled_vertex_groups = res.vertex_groups
103
+ else:
104
+ asset.sampled_vertices = asset.vertices.copy()
105
+ asset.sampled_normals = asset.vertex_normals.copy()
106
+ asset.sampled_vertex_groups = deepcopy(asset.vertex_groups)
107
+ return first_augments, second_augments
UniRig/src/data/utils.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from numpy import ndarray
4
+ from torch import Tensor, FloatTensor
5
+ from typing import Tuple, Union
6
+
7
+ from scipy.spatial.transform import Rotation as R
8
+ from scipy.sparse import csc_matrix
9
+ import numpy as np
10
+
11
+ def quaternion_to_matrix(x, use_4x4=True) -> FloatTensor:
12
+ """
13
+ Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#quaternion_to_matrix
14
+ """
15
+ if not isinstance(x, Tensor):
16
+ quaternions = torch.tensor(x, dtype=torch.float32)
17
+ else:
18
+ quaternions = x
19
+ r, i, j, k = torch.unbind(quaternions, -1)
20
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
21
+ device = quaternions.device
22
+
23
+ if use_4x4:
24
+ o = torch.stack(
25
+ (
26
+ 1 - two_s * (j * j + k * k),
27
+ two_s * (i * j - k * r),
28
+ two_s * (i * k + j * r),
29
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
30
+ two_s * (i * j + k * r),
31
+ 1 - two_s * (i * i + k * k),
32
+ two_s * (j * k - i * r),
33
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
34
+ two_s * (i * k - j * r),
35
+ two_s * (j * k + i * r),
36
+ 1 - two_s * (i * i + j * j),
37
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
38
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
39
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
40
+ torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
41
+ torch.ones(quaternions.shape[:-1], device=device, dtype=torch.float32),
42
+ ),
43
+ -1,
44
+ )
45
+ return o.reshape(quaternions.shape[:-1] + (4, 4))
46
+ else:
47
+ o = torch.stack(
48
+ (
49
+ 1 - two_s * (j * j + k * k),
50
+ two_s * (i * j - k * r),
51
+ two_s * (i * k + j * r),
52
+ two_s * (i * j + k * r),
53
+ 1 - two_s * (i * i + k * k),
54
+ two_s * (j * k - i * r),
55
+ two_s * (i * k - j * r),
56
+ two_s * (j * k + i * r),
57
+ 1 - two_s * (i * i + j * j),
58
+ ),
59
+ -1,
60
+ )
61
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
62
+
63
+ def axis_angle_to_quaternion(axis_angle: FloatTensor) -> FloatTensor:
64
+ """
65
+ Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#axis_angle_to_quaternion
66
+ """
67
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
68
+ half_angles = angles * 0.5
69
+ eps = 1e-6
70
+ small_angles = angles.abs() < eps
71
+ sin_half_angles_over_angles = torch.empty_like(angles)
72
+ sin_half_angles_over_angles[~small_angles] = (
73
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
74
+ )
75
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
76
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
77
+ sin_half_angles_over_angles[small_angles] = (
78
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
79
+ )
80
+ quaternions = torch.cat(
81
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
82
+ )
83
+ return quaternions
84
+
85
+ def axis_angle_to_matrix(axis_angle: Union[FloatTensor, ndarray]) -> Union[FloatTensor, ndarray]:
86
+ """
87
+ Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#axis_angle_to_matrix
88
+ """
89
+ if isinstance(axis_angle, FloatTensor):
90
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
91
+ else:
92
+ res = np.pad(R.from_rotvec(axis_angle).as_matrix(), ((0, 0), (0, 1), (0, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0)))
93
+ assert res.ndim == 3
94
+ res[:, -1, -1] = 1
95
+ return res
96
+
97
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
98
+ """
99
+ Returns torch.sqrt(torch.max(0, x))
100
+ but with a zero subgradient where x is 0.
101
+ """
102
+ ret = torch.zeros_like(x)
103
+ positive_mask = x > 0
104
+ if torch.is_grad_enabled():
105
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
106
+ else:
107
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
108
+ return ret
109
+
110
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
111
+ """
112
+ Convert a unit quaternion to a standard form: one in which the real
113
+ part is non negative.
114
+
115
+ Args:
116
+ quaternions: Quaternions with real part first,
117
+ as tensor of shape (..., 4).
118
+
119
+ Returns:
120
+ Standardized quaternions as tensor of shape (..., 4).
121
+ """
122
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
123
+
124
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Convert rotations given as rotation matrices to quaternions.
127
+
128
+ Args:
129
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
130
+
131
+ Returns:
132
+ quaternions with real part first, as tensor of shape (..., 4).
133
+ """
134
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
135
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
136
+
137
+ batch_dim = matrix.shape[:-2]
138
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
139
+ matrix.reshape(batch_dim + (9,)), dim=-1
140
+ )
141
+
142
+ q_abs = _sqrt_positive_part(
143
+ torch.stack(
144
+ [
145
+ 1.0 + m00 + m11 + m22,
146
+ 1.0 + m00 - m11 - m22,
147
+ 1.0 - m00 + m11 - m22,
148
+ 1.0 - m00 - m11 + m22,
149
+ ],
150
+ dim=-1,
151
+ )
152
+ )
153
+
154
+ # we produce the desired quaternion multiplied by each of r, i, j, k
155
+ quat_by_rijk = torch.stack(
156
+ [
157
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
158
+ # `int`.
159
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
160
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
161
+ # `int`.
162
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
163
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
164
+ # `int`.
165
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
166
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
167
+ # `int`.
168
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
169
+ ],
170
+ dim=-2,
171
+ )
172
+
173
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
174
+ # the candidate won't be picked.
175
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
176
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
177
+
178
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
179
+ # forall i; we pick the best-conditioned one (with the largest denominator)
180
+ out = quat_candidates[
181
+ torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
182
+ ].reshape(batch_dim + (4,))
183
+ return standardize_quaternion(out)
184
+
185
+ def linear_blend_skinning(
186
+ vertex: Union[FloatTensor, ndarray],
187
+ matrix_local: Union[FloatTensor, ndarray],
188
+ matrix: Union[FloatTensor, ndarray],
189
+ skin: Union[FloatTensor, ndarray],
190
+ pad: int=0,
191
+ value: float=0.,
192
+ ) -> Union[FloatTensor, ndarray]:
193
+ '''
194
+ Args:
195
+ vertex: (B, N, 4-pad) or (N, 4-pad)
196
+ matrix_local: (B, J, 4, 4) or (J, 4, 4)
197
+ matrix: (B, J, 4, 4) or (J, 4, 4)
198
+ skin: (B, N, J) or (N, J), value of pseudo bones should be 0
199
+ Returns:
200
+ (B, N, 3) or (N, 3)
201
+ '''
202
+ assert vertex.shape[-1] + pad == 4
203
+ if isinstance(vertex, Tensor):
204
+ dims = vertex.dim()
205
+ elif isinstance(vertex, ndarray):
206
+ dims = vertex.ndim
207
+ else:
208
+ raise NotImplementedError()
209
+ if dims == 3: # Case: (B, N, 3+pad)
210
+ assert isinstance(vertex, Tensor)
211
+ J = matrix_local.shape[1]
212
+ # (B, J, 3+pad, N)
213
+ offset = (
214
+ matrix_local.inverse() @
215
+ torch.nn.functional.pad(vertex, (0, pad, 0, 0, 0, 0), value=value).unsqueeze(1).transpose(2, 3).repeat(1, J, 1, 1)
216
+ )
217
+ # (B, J, 4, N)
218
+ per_bone_matrix = matrix @ offset
219
+ # (B, J, 4, N)
220
+ weighted_per_bone_matrix = skin.transpose(1, 2).unsqueeze(2) * per_bone_matrix
221
+ # (B, 3, N)
222
+ g = weighted_per_bone_matrix.sum(dim=1)
223
+ # (B, 3, N)
224
+ final = g[:, 0:3, :] / (skin.transpose(1, 2).sum(dim=1) + 1e-8).unsqueeze(1)
225
+ return final.permute(0, 2, 1)
226
+
227
+ elif dims == 2: # Case: (N, 3+pad)
228
+ if isinstance(vertex, Tensor):
229
+ J = matrix_local.shape[0]
230
+ offset = (
231
+ matrix_local.inverse() @
232
+ torch.nn.functional.pad(vertex, (0, pad, 0, 0), value=value).unsqueeze(0).transpose(1, 2).repeat(J, 1, 1)
233
+ )
234
+ per_bone_matrix = matrix @ offset
235
+ weighted_per_bone_matrix = skin.transpose(0, 1).unsqueeze(1) * per_bone_matrix
236
+ g = weighted_per_bone_matrix.sum(dim=0)
237
+ final = g[0:3, :] / (skin.transpose(0, 1).sum(dim=0) + 1e-8).unsqueeze(0)
238
+ return final.permute(1, 0) # Output shape (N, 3)
239
+ else:
240
+ J = matrix_local.shape[0]
241
+ N = vertex.shape[0]
242
+ # (4, N)
243
+ padded = np.pad(vertex, ((0, 0), (0, pad)), 'constant', constant_values=(0, value)).T
244
+ # (J, 4, 4)
245
+ trans = matrix @ np.linalg.inv(matrix_local)
246
+ weighted_per_bone_matrix = []
247
+ # (J, N)
248
+ mask = (skin > 0).T
249
+ for i in range(J):
250
+ offset = np.zeros((4, N), dtype=np.float32)
251
+ offset[:, mask[i]] = (trans[i] @ padded[:, mask[i]]) * skin.T[i, mask[i]]
252
+ weighted_per_bone_matrix.append(offset)
253
+ weighted_per_bone_matrix = np.stack(weighted_per_bone_matrix)
254
+ g = np.sum(weighted_per_bone_matrix, axis=0)
255
+ final = g[:3, :] / (np.sum(skin, axis=1) + 1e-8)
256
+ return final.T
257
+ else:
258
+ assert 0, f'unsupported shape: {vertex.shape}'