Spaces:
Running
on
Zero
Running
on
Zero
Correctly add UniRig source files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- UniRig/.gitattributes +5 -0
- UniRig/.gitignore +59 -0
- UniRig/LICENSE +21 -0
- UniRig/README.md +164 -0
- UniRig/assets/doc/devil.gif +3 -0
- UniRig/assets/doc/dragon.gif +3 -0
- UniRig/assets/doc/rabbit.gif +3 -0
- UniRig/assets/doc/unirig_teaser.png +3 -0
- UniRig/blender/add-on-vrm-v2.20.77_modified.zip +3 -0
- UniRig/configs/data/quick_inference.yaml +16 -0
- UniRig/configs/model/unirig_ar_350m_1024_81920_float32.yaml +32 -0
- UniRig/configs/model/unirig_skin.yaml +52 -0
- UniRig/configs/skeleton/mixamo.yaml +59 -0
- UniRig/configs/skeleton/vroid.yaml +59 -0
- UniRig/configs/system/ar_inference_articulationxl.yaml +14 -0
- UniRig/configs/system/skin.yaml +5 -0
- UniRig/configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml +30 -0
- UniRig/configs/task/quick_inference_unirig_skin.yaml +28 -0
- UniRig/configs/tokenizer/tokenizer_parts_articulationxl_256.yaml +14 -0
- UniRig/configs/transform/inference_ar_transform.yaml +30 -0
- UniRig/configs/transform/inference_skin_transform.yaml +32 -0
- UniRig/examples/bird.glb +3 -0
- UniRig/examples/giraffe.glb +3 -0
- UniRig/examples/skeleton/bird.fbx +3 -0
- UniRig/examples/skeleton/giraffe.fbx +3 -0
- UniRig/examples/skeleton/tira.fbx +3 -0
- UniRig/examples/skeleton/tripo_carrot.fbx +3 -0
- UniRig/examples/tira.glb +3 -0
- UniRig/examples/tripo_carrot.glb +3 -0
- UniRig/launch/inference/extract.sh +60 -0
- UniRig/launch/inference/generate_skeleton.sh +81 -0
- UniRig/launch/inference/generate_skin.sh +84 -0
- UniRig/launch/inference/merge.sh +33 -0
- UniRig/requirements.txt +15 -0
- UniRig/run.py +186 -0
- UniRig/src/data/__init__.py +0 -0
- UniRig/src/data/asset.py +433 -0
- UniRig/src/data/augment.py +152 -0
- UniRig/src/data/datapath.py +149 -0
- UniRig/src/data/dataset.py +231 -0
- UniRig/src/data/exporter.py +486 -0
- UniRig/src/data/extract.py +523 -0
- UniRig/src/data/log.py +50 -0
- UniRig/src/data/order.py +112 -0
- UniRig/src/data/raw_data.py +307 -0
- UniRig/src/data/sampler.py +210 -0
- UniRig/src/data/spec.py +15 -0
- UniRig/src/data/tail.py +50 -0
- UniRig/src/data/transform.py +107 -0
- UniRig/src/data/utils.py +258 -0
UniRig/.gitattributes
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
examples/*.glb filter=lfs diff=lfs merge=lfs -text
|
2 |
+
examples/*.fbx filter=lfs diff=lfs merge=lfs -text
|
3 |
+
examples/*.vrm filter=lfs diff=lfs merge=lfs -text
|
4 |
+
examples/*.FBX filter=lfs diff=lfs merge=lfs -text
|
5 |
+
examples/*.obj filter=lfs diff=lfs merge=lfs -text
|
UniRig/.gitignore
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# igonore all pychace
|
2 |
+
**/__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# ignore tmp & output files
|
7 |
+
_data/
|
8 |
+
tmp/
|
9 |
+
*.npz
|
10 |
+
*.blend
|
11 |
+
*.blend1
|
12 |
+
*.blend2
|
13 |
+
|
14 |
+
# ignore logs
|
15 |
+
wandb/
|
16 |
+
lightning_logs/
|
17 |
+
*.log
|
18 |
+
|
19 |
+
# ignore experiments
|
20 |
+
experiments/
|
21 |
+
results/
|
22 |
+
dataset_clean/
|
23 |
+
logs/
|
24 |
+
datalist/
|
25 |
+
dataset_inference/
|
26 |
+
dataset_inference_clean/
|
27 |
+
feature_viz/
|
28 |
+
|
29 |
+
# Distribution / packaging
|
30 |
+
dist/
|
31 |
+
build/
|
32 |
+
*.egg-info/
|
33 |
+
*.egg
|
34 |
+
*.whl
|
35 |
+
|
36 |
+
# Virtual environments
|
37 |
+
venv/
|
38 |
+
env/
|
39 |
+
.env/
|
40 |
+
.venv/
|
41 |
+
|
42 |
+
# IDE specific files
|
43 |
+
.idea/
|
44 |
+
.vscode/
|
45 |
+
*.swp
|
46 |
+
*.swo
|
47 |
+
.DS_Store
|
48 |
+
|
49 |
+
# Jupyter Notebook
|
50 |
+
.ipynb_checkpoints
|
51 |
+
*.ipynb
|
52 |
+
|
53 |
+
# Unit test / coverage reports
|
54 |
+
htmlcov/
|
55 |
+
.tox/
|
56 |
+
.coverage
|
57 |
+
.coverage.*
|
58 |
+
coverage.xml
|
59 |
+
*.cover
|
UniRig/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2025 VAST-AI-Research and contributors.
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE
|
UniRig/README.md
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# UniRig: One Model to Rig Them All
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
|
5 |
+
[](https://zjp-shadow.github.io/works/UniRig/)
|
6 |
+
[](https://arxiv.org/abs/2504.12451)
|
7 |
+
[](https://huggingface.co/VAST-AI/UniRig)
|
8 |
+
|
9 |
+
</div>
|
10 |
+
|
11 |
+

|
12 |
+
|
13 |
+
This repository contains the official implementation for the **SIGGRAPH'25 (TOG) UniRig** framework, a unified solution for automatic 3D model rigging, developed by Tsinghua University and [Tripo](https://www.tripo3d.ai).
|
14 |
+
|
15 |
+
**Paper:** [One Model to Rig Them All: Diverse Skeleton Rigging with UniRig](https://arxiv.org/abs/2504.12451)
|
16 |
+
|
17 |
+
## Overview
|
18 |
+
|
19 |
+
Rigging 3D models – creating a skeleton and assigning skinning weights – is a crucial but often complex and time-consuming step in 3D animation. UniRig tackles this challenge by introducing a novel, unified framework leveraging large autoregressive models to automate the process for a diverse range of 3D assets.
|
20 |
+
|
21 |
+
Combining UniRig with keyframe animation produces these following results:
|
22 |
+
|
23 |
+
|  |  |  |
|
24 |
+
|:-----------------------------:|:-------------------------------:|:-------------------------------:|
|
25 |
+
|
26 |
+
The full UniRig system consists of two main stages:
|
27 |
+
1. **Skeleton Prediction:** An GPT-like transformer autoregressively predicts a topologically valid skeleton hierarchy using a novel **Skeleton Tree Tokenization** scheme.
|
28 |
+
2. **Skinning Weight & Attribute Prediction:** A **Bone-Point Cross Attention** mechanism predicts per-vertex skinning weights and relevant bone attributes (e.g., for physics simulation) based on the predicted skeleton and input mesh geometry.
|
29 |
+
|
30 |
+
This repository provides the code implementation for the entire framework vision, with components being released progressively.
|
31 |
+
|
32 |
+
## Key Features (Full UniRig Framework)
|
33 |
+
|
34 |
+
* **Unified Model:** Aims to handle diverse model categories (humans, animals, objects) with a single framework.
|
35 |
+
* **Automated Skeleton Generation:** Predicts topologically valid skeleton structures. **(✅ Available in current release)**
|
36 |
+
* **Automated Skinning Prediction:** Predicts per-vertex skinning weights. **(✅ Available in current release)**
|
37 |
+
* **Bone Attribute Prediction:** Predicts attributes like stiffness for physics-based secondary motion. **(⏳ Coming Soon)**
|
38 |
+
* **High Accuracy & Robustness:** Achieves state-of-the-art results on challenging datasets (as shown in the paper with Rig-XL/VRoid training).
|
39 |
+
* **Efficient Tokenization:** Uses Skeleton Tree Tokenization for compact representation and efficient processing.
|
40 |
+
* **Human-in-the-Loop Ready:** Designed to potentially support iterative refinement workflows.
|
41 |
+
|
42 |
+
## 🚨 Current Release Status & Roadmap 🚨
|
43 |
+
|
44 |
+
We are open-sourcing UniRig progressively. Please note the current status:
|
45 |
+
|
46 |
+
**Available Now (Initial Release):**
|
47 |
+
* ✅ **Code:** Implementation for skeleton and skinning prediction.
|
48 |
+
* ✅ **Model:** Skeleton & Skinning Prediction checkpoint trained on [**Articulation-XL2.0**](https://huggingface.co/datasets/Seed3D/Articulation-XL2.0). Available on [Hugging Face](https://huggingface.co/VAST-AI/UniRig).
|
49 |
+
|
50 |
+
**Planned Future Releases:**
|
51 |
+
* ⏳ Release of the **Rig-XL** and **VRoid** datasets used in the paper.
|
52 |
+
* ⏳ Full UniRig model checkpoints (Skeleton + Skinning) trained on Rig-XL/VRoid, replicating the paper's main results.
|
53 |
+
|
54 |
+
We appreciate your patience as we prepare these components for release. Follow [VAST-AI-Research](https://github.com/orgs/VAST-AI-Research) announcements for updates!
|
55 |
+
|
56 |
+
## Installation
|
57 |
+
|
58 |
+
1. **Prerequisites:**
|
59 |
+
* Python 3.11
|
60 |
+
* PyTorch (tested with version >=2.3.1)
|
61 |
+
|
62 |
+
2. **Clone the repository:**
|
63 |
+
```bash
|
64 |
+
git clone https://github.com/VAST-AI-Research/UniRig
|
65 |
+
cd UniRig
|
66 |
+
```
|
67 |
+
|
68 |
+
3. **Set up a virtual environment (recommended):**
|
69 |
+
```bash
|
70 |
+
conda create -n UniRig python=3.11
|
71 |
+
conda activate UniRig
|
72 |
+
```
|
73 |
+
|
74 |
+
4. **Install dependencies:**
|
75 |
+
```bash
|
76 |
+
python -m pip install torch torchvision
|
77 |
+
python -m pip install -r requirements.txt
|
78 |
+
python -m pip install spconv-{you-cuda-version}
|
79 |
+
python -m pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{your-torch-version}+{your-cuda-version}.html --no-cache-dir
|
80 |
+
python -m pip install numpy==1.26.4
|
81 |
+
```
|
82 |
+
|
83 |
+
5. **Download Model Checkpoint:**
|
84 |
+
The currently available skeleton prediction model checkpoint is hosted on Hugging Face and will typically be downloaded automatically by the provided scripts/functions.
|
85 |
+
|
86 |
+
6. **(Optional, for importing/exporting .vrm) Install the blender addon:**
|
87 |
+
The blender addon is modifed from [VRM-Addon-for-Blender](https://github.com/saturday06/VRM-Addon-for-Blender).
|
88 |
+
|
89 |
+
Make sure you are in the root directory of the project, then:
|
90 |
+
```bash
|
91 |
+
python -c "import bpy, os; bpy.ops.preferences.addon_install(filepath=os.path.abspath('blender/add-on-vrm-v2.20.77_modified.zip'))"
|
92 |
+
```
|
93 |
+
|
94 |
+
## Usage
|
95 |
+
|
96 |
+
### Skeleton Prediction (Available Now)
|
97 |
+
|
98 |
+
Generate a skeleton for your 3D model using our pre-trained model. The process automatically analyzes the geometry and predicts an appropriate skeletal structure.
|
99 |
+
|
100 |
+
```bash
|
101 |
+
# Process a single file
|
102 |
+
bash launch/inference/generate_skeleton.sh --input examples/giraffe.glb --output results/giraffe_skeleton.fbx
|
103 |
+
|
104 |
+
# Process multiple files in a directory
|
105 |
+
bash launch/inference/generate_skeleton.sh --input_dir <your_input_directory> --output_dir <your_output_directory>
|
106 |
+
|
107 |
+
# Try different skeleton variations by changing the random seed
|
108 |
+
bash launch/inference/generate_skeleton.sh --input examples/giraffe.glb --output results/giraffe_skeleton.fbx --seed 42
|
109 |
+
```
|
110 |
+
|
111 |
+
Supported input formats: `.obj`, `.fbx`, `.glb`, and `.vrm`
|
112 |
+
|
113 |
+
### Skinning Weight Prediction (Available Now)
|
114 |
+
```bash
|
115 |
+
# Skin a single file
|
116 |
+
bash launch/inference/generate_skin.sh --input examples/skeleton/giraffe.fbx --output results/giraffe_skin.fbx
|
117 |
+
|
118 |
+
# Process multiple files in a directory
|
119 |
+
bash launch/inference/generate_skin.sh --input_dir <your_input_directory> --output_dir <your_output_directory>
|
120 |
+
```
|
121 |
+
|
122 |
+
Note that the command above uses an **edited-version** from skeleton phase. The results may degrade significantly if the skeleton is inaccurate — for example, if tail bones or wing bones are missing. Therefore, it is recommended to refine the skeleton before performing skinning in order to achieve better results.
|
123 |
+
|
124 |
+
### Merge the Predicted Results
|
125 |
+
|
126 |
+
Combine the predicted skeleton with your original 3D model to create a fully rigged asset:
|
127 |
+
|
128 |
+
```bash
|
129 |
+
# Merge skeleton from skeleton prediction
|
130 |
+
bash launch/inference/merge.sh --source results/giraffe_skeleton.fbx --target examples/giraffe.glb --output results/giraffe_rigged.glb
|
131 |
+
|
132 |
+
# Or merge skin from skin prediction
|
133 |
+
bash launch/inference/merge.sh --source results/giraffe_skin.fbx --target examples/giraffe.glb --output results/giraffe_rigged.glb
|
134 |
+
```
|
135 |
+
|
136 |
+
## Models
|
137 |
+
|
138 |
+
Available models are hosted on the: https://huggingface.co/VAST-AI/UniRig
|
139 |
+
|
140 |
+
## System Requirements
|
141 |
+
|
142 |
+
- CUDA-enabled GPU with at least 8GB VRAM
|
143 |
+
|
144 |
+
## Citation
|
145 |
+
|
146 |
+
```
|
147 |
+
@article{zhang2025unirig,
|
148 |
+
title={One Model to Rig Them All: Diverse Skeleton Rigging with UniRig},
|
149 |
+
author={Zhang, Jia-Peng and Pu, Cheng-Feng and Guo, Meng-Hao and Cao, Yan-Pei and Hu, Shi-Min},
|
150 |
+
journal={arXiv preprint arXiv:2504.12451},
|
151 |
+
year={2025}
|
152 |
+
}
|
153 |
+
```
|
154 |
+
|
155 |
+
## Acknowledgements
|
156 |
+
|
157 |
+
We would like to thank the following open-source projects and research works:
|
158 |
+
|
159 |
+
- [OPT](https://huggingface.co/facebook/opt-350m) for model architecture
|
160 |
+
- [3DShape2VecSet](https://github.com/1zb/3DShape2VecSet) for 3D shape representation
|
161 |
+
- [SAMPart3D](https://github.com/Pointcept/SAMPart3D) and [Michelangelo](https://github.com/NeuralCarver/Michelangelo/) for shape encoder implementation
|
162 |
+
- [Articulation-XL2.0](https://huggingface.co/datasets/Seed3D/Articulation-XL2.0) for a curated dataset
|
163 |
+
|
164 |
+
We are grateful to the broader research community for their open exploration and contributions to the field of 3D generation.
|
UniRig/assets/doc/devil.gif
ADDED
![]() |
Git LFS Details
|
UniRig/assets/doc/dragon.gif
ADDED
![]() |
Git LFS Details
|
UniRig/assets/doc/rabbit.gif
ADDED
![]() |
Git LFS Details
|
UniRig/assets/doc/unirig_teaser.png
ADDED
![]() |
Git LFS Details
|
UniRig/blender/add-on-vrm-v2.20.77_modified.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8fe1e8b7e31ec602d9b32db33c533b03d68df747837bfad3479e1057bc9937c5
|
3 |
+
size 1331571
|
UniRig/configs/data/quick_inference.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_dataset_dir: &input_dataset_dir ./dataset_inference
|
2 |
+
output_dataset_dir: &output_dataset_dir ./dataset_inference_clean
|
3 |
+
|
4 |
+
predict_dataset_config:
|
5 |
+
shuffle: False
|
6 |
+
batch_size: 1
|
7 |
+
num_workers: 1
|
8 |
+
pin_memory: False
|
9 |
+
persistent_workers: False
|
10 |
+
datapath_config:
|
11 |
+
input_dataset_dir: *output_dataset_dir
|
12 |
+
use_prob: False
|
13 |
+
data_path:
|
14 |
+
inference: [
|
15 |
+
[./dataset_inference_clean/inference_datalist.txt, 1.0],
|
16 |
+
]
|
UniRig/configs/model/unirig_ar_350m_1024_81920_float32.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__target__: unirig_ar
|
2 |
+
llm:
|
3 |
+
pretrained_model_name_or_path: facebook/opt-350m
|
4 |
+
n_positions: 3076
|
5 |
+
max_position_embeddings: 3076
|
6 |
+
hidden_size: 1024
|
7 |
+
word_embed_proj_dim: 1024
|
8 |
+
do_layer_norm_before: True
|
9 |
+
_attn_implementation: flash_attention_2
|
10 |
+
|
11 |
+
mesh_encoder:
|
12 |
+
__target__: michelangelo_encoder
|
13 |
+
pretrained_path: ~
|
14 |
+
freeze_encoder: False
|
15 |
+
device: cpu
|
16 |
+
dtype: float32
|
17 |
+
num_latents: 512
|
18 |
+
embed_dim: 64
|
19 |
+
point_feats: 3
|
20 |
+
num_freqs: 8
|
21 |
+
include_pi: False
|
22 |
+
heads: 8
|
23 |
+
width: 512
|
24 |
+
num_encoder_layers: 16
|
25 |
+
use_ln_post: True
|
26 |
+
init_scale: 0.25
|
27 |
+
qkv_bias: False
|
28 |
+
use_checkpoint: False
|
29 |
+
flash: True
|
30 |
+
supervision_type: sdf
|
31 |
+
query_method: False
|
32 |
+
token_num: 1024
|
UniRig/configs/model/unirig_skin.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__target__: unirig_skin
|
2 |
+
|
3 |
+
num_train_vertex: 512 # increase this for faster speed at the cost of memory
|
4 |
+
num_heads: 16
|
5 |
+
feat_dim: 768
|
6 |
+
grid_size: 0.005
|
7 |
+
mlp_dim: 512
|
8 |
+
num_bone_attn: 8
|
9 |
+
num_mesh_bone_attn: 16
|
10 |
+
bone_embed_dim: 1024
|
11 |
+
voxel_mask: 3.0
|
12 |
+
|
13 |
+
mesh_encoder:
|
14 |
+
# vertex groups are handled in model
|
15 |
+
__target__: ptv3obj
|
16 |
+
pretrained_path: ~
|
17 |
+
freeze_encoder: False
|
18 |
+
in_channels: 9
|
19 |
+
cls_mode: False
|
20 |
+
shuffle_orders: True
|
21 |
+
drop_path: 0.0
|
22 |
+
upcast_attention: False
|
23 |
+
upcast_softmax: False
|
24 |
+
enc_depths: [3, 3, 3, 6, 16]
|
25 |
+
enc_channels: [32, 64, 128, 256, 384]
|
26 |
+
enc_num_head: [2, 4, 8, 16, 24]
|
27 |
+
enable_qknorm: True
|
28 |
+
layer_norm: False
|
29 |
+
res_linear: True
|
30 |
+
|
31 |
+
global_encoder:
|
32 |
+
__target__: michelangelo_encoder
|
33 |
+
pretrained_path: ~
|
34 |
+
freeze_encoder: False
|
35 |
+
device: cpu
|
36 |
+
dtype: float32
|
37 |
+
num_latents: 512
|
38 |
+
embed_dim: 64
|
39 |
+
point_feats: 3
|
40 |
+
num_freqs: 8
|
41 |
+
include_pi: False
|
42 |
+
heads: 8
|
43 |
+
width: 512
|
44 |
+
num_encoder_layers: 16
|
45 |
+
use_ln_post: True
|
46 |
+
init_scale: 0.25
|
47 |
+
qkv_bias: False
|
48 |
+
use_checkpoint: False
|
49 |
+
flash: True
|
50 |
+
supervision_type: sdf
|
51 |
+
query_method: False
|
52 |
+
token_num: 1024
|
UniRig/configs/skeleton/mixamo.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
parts_order: [body, hand]
|
2 |
+
|
3 |
+
parts:
|
4 |
+
body: [
|
5 |
+
mixamorig:Hips,
|
6 |
+
mixamorig:Spine,
|
7 |
+
mixamorig:Spine1,
|
8 |
+
mixamorig:Spine2,
|
9 |
+
mixamorig:Neck,
|
10 |
+
mixamorig:Head,
|
11 |
+
mixamorig:LeftShoulder,
|
12 |
+
mixamorig:LeftArm,
|
13 |
+
mixamorig:LeftForeArm,
|
14 |
+
mixamorig:LeftHand,
|
15 |
+
mixamorig:RightShoulder,
|
16 |
+
mixamorig:RightArm,
|
17 |
+
mixamorig:RightForeArm,
|
18 |
+
mixamorig:RightHand,
|
19 |
+
mixamorig:LeftUpLeg,
|
20 |
+
mixamorig:LeftLeg,
|
21 |
+
mixamorig:LeftFoot,
|
22 |
+
mixamorig:LeftToeBase,
|
23 |
+
mixamorig:RightUpLeg,
|
24 |
+
mixamorig:RightLeg,
|
25 |
+
mixamorig:RightFoot,
|
26 |
+
mixamorig:RightToeBase,
|
27 |
+
]
|
28 |
+
hand: [
|
29 |
+
mixamorig:LeftHandThumb1,
|
30 |
+
mixamorig:LeftHandThumb2,
|
31 |
+
mixamorig:LeftHandThumb3,
|
32 |
+
mixamorig:LeftHandIndex1,
|
33 |
+
mixamorig:LeftHandIndex2,
|
34 |
+
mixamorig:LeftHandIndex3,
|
35 |
+
mixamorig:LeftHandMiddle1,
|
36 |
+
mixamorig:LeftHandMiddle2,
|
37 |
+
mixamorig:LeftHandMiddle3,
|
38 |
+
mixamorig:LeftHandRing1,
|
39 |
+
mixamorig:LeftHandRing2,
|
40 |
+
mixamorig:LeftHandRing3,
|
41 |
+
mixamorig:LeftHandPinky1,
|
42 |
+
mixamorig:LeftHandPinky2,
|
43 |
+
mixamorig:LeftHandPinky3,
|
44 |
+
mixamorig:RightHandIndex1,
|
45 |
+
mixamorig:RightHandIndex2,
|
46 |
+
mixamorig:RightHandIndex3,
|
47 |
+
mixamorig:RightHandThumb1,
|
48 |
+
mixamorig:RightHandThumb2,
|
49 |
+
mixamorig:RightHandThumb3,
|
50 |
+
mixamorig:RightHandMiddle1,
|
51 |
+
mixamorig:RightHandMiddle2,
|
52 |
+
mixamorig:RightHandMiddle3,
|
53 |
+
mixamorig:RightHandRing1,
|
54 |
+
mixamorig:RightHandRing2,
|
55 |
+
mixamorig:RightHandRing3,
|
56 |
+
mixamorig:RightHandPinky1,
|
57 |
+
mixamorig:RightHandPinky2,
|
58 |
+
mixamorig:RightHandPinky3,
|
59 |
+
]
|
UniRig/configs/skeleton/vroid.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
parts_order: [body, hand]
|
2 |
+
|
3 |
+
parts:
|
4 |
+
body: [
|
5 |
+
J_Bip_C_Hips,
|
6 |
+
J_Bip_C_Spine,
|
7 |
+
J_Bip_C_Chest,
|
8 |
+
J_Bip_C_UpperChest,
|
9 |
+
J_Bip_C_Neck,
|
10 |
+
J_Bip_C_Head,
|
11 |
+
J_Bip_L_Shoulder,
|
12 |
+
J_Bip_L_UpperArm,
|
13 |
+
J_Bip_L_LowerArm,
|
14 |
+
J_Bip_L_Hand,
|
15 |
+
J_Bip_R_Shoulder,
|
16 |
+
J_Bip_R_UpperArm,
|
17 |
+
J_Bip_R_LowerArm,
|
18 |
+
J_Bip_R_Hand,
|
19 |
+
J_Bip_L_UpperLeg,
|
20 |
+
J_Bip_L_LowerLeg,
|
21 |
+
J_Bip_L_Foot,
|
22 |
+
J_Bip_L_ToeBase,
|
23 |
+
J_Bip_R_UpperLeg,
|
24 |
+
J_Bip_R_LowerLeg,
|
25 |
+
J_Bip_R_Foot,
|
26 |
+
J_Bip_R_ToeBase,
|
27 |
+
]
|
28 |
+
hand: [
|
29 |
+
J_Bip_L_Thumb1,
|
30 |
+
J_Bip_L_Thumb2,
|
31 |
+
J_Bip_L_Thumb3,
|
32 |
+
J_Bip_L_Index1,
|
33 |
+
J_Bip_L_Index2,
|
34 |
+
J_Bip_L_Index3,
|
35 |
+
J_Bip_L_Middle1,
|
36 |
+
J_Bip_L_Middle2,
|
37 |
+
J_Bip_L_Middle3,
|
38 |
+
J_Bip_L_Ring1,
|
39 |
+
J_Bip_L_Ring2,
|
40 |
+
J_Bip_L_Ring3,
|
41 |
+
J_Bip_L_Little1,
|
42 |
+
J_Bip_L_Little2,
|
43 |
+
J_Bip_L_Little3,
|
44 |
+
J_Bip_R_Index1,
|
45 |
+
J_Bip_R_Index2,
|
46 |
+
J_Bip_R_Index3,
|
47 |
+
J_Bip_R_Thumb1,
|
48 |
+
J_Bip_R_Thumb2,
|
49 |
+
J_Bip_R_Thumb3,
|
50 |
+
J_Bip_R_Middle1,
|
51 |
+
J_Bip_R_Middle2,
|
52 |
+
J_Bip_R_Middle3,
|
53 |
+
J_Bip_R_Ring1,
|
54 |
+
J_Bip_R_Ring2,
|
55 |
+
J_Bip_R_Ring3,
|
56 |
+
J_Bip_R_Little1,
|
57 |
+
J_Bip_R_Little2,
|
58 |
+
J_Bip_R_Little3,
|
59 |
+
]
|
UniRig/configs/system/ar_inference_articulationxl.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__target__: ar
|
2 |
+
val_interval: 1
|
3 |
+
generate_kwargs:
|
4 |
+
max_new_tokens: 2048
|
5 |
+
num_return_sequences: 1
|
6 |
+
num_beams: 15
|
7 |
+
do_sample: True
|
8 |
+
top_k: 5
|
9 |
+
top_p: 0.95
|
10 |
+
repetition_penalty: 3.0
|
11 |
+
temperature: 1.5 # must be a float
|
12 |
+
no_cls: False
|
13 |
+
assign_cls: articulationxl
|
14 |
+
use_dir_cls: False
|
UniRig/configs/system/skin.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__target__: skin
|
2 |
+
val_interval: 1
|
3 |
+
val_start_from: 1
|
4 |
+
output_path: tmp_skin
|
5 |
+
record_res: True
|
UniRig/configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mode: predict
|
2 |
+
debug: False
|
3 |
+
experiment_name: quick_inference_skeleton_articulationxl_ar_256
|
4 |
+
resume_from_checkpoint: experiments/skeleton/articulation-xl_quantization_256/model.ckpt
|
5 |
+
|
6 |
+
components:
|
7 |
+
data: quick_inference
|
8 |
+
tokenizer: tokenizer_parts_articulationxl_256
|
9 |
+
transform: inference_ar_transform
|
10 |
+
model: unirig_ar_350m_1024_81920_float32
|
11 |
+
system: ar_inference_articulationxl
|
12 |
+
data_name: raw_data.npz
|
13 |
+
|
14 |
+
writer:
|
15 |
+
__target__: ar
|
16 |
+
output_dir: ~ # export results into the same input folder
|
17 |
+
add_num: False
|
18 |
+
repeat: 1
|
19 |
+
export_npz: predict_skeleton
|
20 |
+
export_obj: skeleton
|
21 |
+
export_fbx: skeleton
|
22 |
+
# export_pc: pc
|
23 |
+
|
24 |
+
trainer:
|
25 |
+
max_epochs: 1
|
26 |
+
num_nodes: 1
|
27 |
+
devices: 1
|
28 |
+
precision: bf16-mixed
|
29 |
+
accelerator: gpu
|
30 |
+
strategy: auto
|
UniRig/configs/task/quick_inference_unirig_skin.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mode: predict
|
2 |
+
debug: False
|
3 |
+
experiment_name: quick_inference_skin
|
4 |
+
resume_from_checkpoint: experiments/skin/articulation-xl/model.ckpt
|
5 |
+
|
6 |
+
components:
|
7 |
+
data: quick_inference
|
8 |
+
transform: inference_skin_transform
|
9 |
+
model: unirig_skin
|
10 |
+
system: skin
|
11 |
+
data_name: predict_skeleton.npz # capture data from ar phase
|
12 |
+
|
13 |
+
writer:
|
14 |
+
__target__: skin
|
15 |
+
output_dir: results
|
16 |
+
add_num: False
|
17 |
+
repeat: 1
|
18 |
+
save_name: predict
|
19 |
+
export_npz: predict_skin # this must be specified if textured results are required
|
20 |
+
export_fbx: result_fbx
|
21 |
+
|
22 |
+
trainer:
|
23 |
+
num_nodes: 1
|
24 |
+
devices: 1
|
25 |
+
precision: bf16-mixed
|
26 |
+
accelerator: gpu
|
27 |
+
strategy: auto
|
28 |
+
inference_mode: True
|
UniRig/configs/tokenizer/tokenizer_parts_articulationxl_256.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: tokenizer_part
|
2 |
+
num_discrete: 256
|
3 |
+
continuous_range: [-1, 1]
|
4 |
+
cls_token_id:
|
5 |
+
vroid: 0
|
6 |
+
mixamo: 1 # this is currently untrained, do not use it
|
7 |
+
articulationxl: 2
|
8 |
+
parts_token_id:
|
9 |
+
body: 0
|
10 |
+
hand: 1
|
11 |
+
order_config:
|
12 |
+
skeleton_path:
|
13 |
+
vroid: ./configs/skeleton/vroid.yaml
|
14 |
+
mixamo: ./configs/skeleton/mixamo.yaml
|
UniRig/configs/transform/inference_ar_transform.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sampler_config: &sampler_config
|
2 |
+
method: mix
|
3 |
+
num_samples: 65536
|
4 |
+
vertex_samples: 8192
|
5 |
+
|
6 |
+
tail_config: &tail_config
|
7 |
+
copy_joint_to_tail: False # Be careful ! If tail is important, keep it False !!!
|
8 |
+
connect_tail_to_unique_son: True
|
9 |
+
|
10 |
+
order_config: &order_config
|
11 |
+
skeleton_path:
|
12 |
+
vroid: ./configs/skeleton/vroid.yaml
|
13 |
+
mixamo: ./configs/skeleton/mixamo.yaml
|
14 |
+
|
15 |
+
vertex_group_config: &vertex_group_config
|
16 |
+
|
17 |
+
validate_transform_config: &validate_transform_config
|
18 |
+
augment_config:
|
19 |
+
augment_affine_config:
|
20 |
+
normalize_into: [-1.0, 1.0]
|
21 |
+
random_scale_p: 0.0
|
22 |
+
random_scale: [1.0, 1.0]
|
23 |
+
random_shift_p: 0.0
|
24 |
+
random_shift: [0.0, 0.0]
|
25 |
+
tail_config: *tail_config
|
26 |
+
order_config: *order_config
|
27 |
+
vertex_group_config: *vertex_group_config
|
28 |
+
sampler_config: *sampler_config
|
29 |
+
|
30 |
+
predict_transform_config: *validate_transform_config
|
UniRig/configs/transform/inference_skin_transform.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sampler_config: &sampler_config
|
2 |
+
method: mix
|
3 |
+
num_samples: 32768
|
4 |
+
vertex_samples: 8192
|
5 |
+
|
6 |
+
tail_config: &tail_config
|
7 |
+
copy_joint_to_tail: False # Be careful ! If tail is important, keep it False !!!
|
8 |
+
connect_tail_to_unique_son: True
|
9 |
+
|
10 |
+
order_config: &order_config
|
11 |
+
skeleton_path:
|
12 |
+
vroid: ./configs/skeleton/vroid.yaml
|
13 |
+
mixamo: ./configs/skeleton/mixamo.yaml
|
14 |
+
|
15 |
+
predict_transform_config:
|
16 |
+
augment_config:
|
17 |
+
augment_affine_config:
|
18 |
+
normalize_into: [-1.0, 1.0]
|
19 |
+
tail_config: *tail_config
|
20 |
+
order_config: *order_config
|
21 |
+
vertex_group_config:
|
22 |
+
names: ['voxel_skin']
|
23 |
+
kwargs:
|
24 |
+
voxel_skin:
|
25 |
+
grid: 196 # increase this for better results
|
26 |
+
alpha: 0.5
|
27 |
+
link_dis: 0.00001
|
28 |
+
grid_query: 7
|
29 |
+
vertex_query: 1
|
30 |
+
grid_weight: 3.0
|
31 |
+
# mode: exp
|
32 |
+
sampler_config: *sampler_config
|
UniRig/examples/bird.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb59726f598ab4a4e4c431b9789317f5b2f4252a6fb57364d929f5e1ddd7b5bb
|
3 |
+
size 8032388
|
UniRig/examples/giraffe.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a947cae00b169345802c08885c1e54313c6ce93c885bacf8e37e8a1f18f9e3b
|
3 |
+
size 6310044
|
UniRig/examples/skeleton/bird.fbx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:885d432850506ab673d509ae2481067cc623452d5910090e1e15f323b9f83fa2
|
3 |
+
size 401084
|
UniRig/examples/skeleton/giraffe.fbx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73673a15c8103fbf9a6b39762768ae48e7c9404eab4850a60e4863ee400336fd
|
3 |
+
size 759180
|
UniRig/examples/skeleton/tira.fbx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57ed6969266d642da2d8a579059462c7cd4a36c9bd7a8415236dcbea36607fee
|
3 |
+
size 1694668
|
UniRig/examples/skeleton/tripo_carrot.fbx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50e08d70c7bdfa96eb842aed18564d8486a4622a474dbc0e0ef9304af1d4c6d3
|
3 |
+
size 1879420
|
UniRig/examples/tira.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8e5a282d0a99c61a8d2439496057b15b0c6ea02643c16da2dcbec85571157799
|
3 |
+
size 32346060
|
UniRig/examples/tripo_carrot.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c00b7e6ff1e71a019a02128a5b37e8d2db259a402313ad9b4eaab4f183ae40b
|
3 |
+
size 9511824
|
UniRig/launch/inference/extract.sh
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# extract mesh
|
2 |
+
config="configs/data/quick_inference.yaml"
|
3 |
+
require_suffix="obj,fbx,FBX,dae,glb,gltf,vrm"
|
4 |
+
num_runs=1
|
5 |
+
force_override="false"
|
6 |
+
faces_target_count=50000
|
7 |
+
|
8 |
+
while [[ "$#" -gt 0 ]]; do
|
9 |
+
case $1 in
|
10 |
+
--config) config="$2"; shift ;;
|
11 |
+
--require_suffix) require_suffix="$2"; shift ;;
|
12 |
+
--num_runs) num_runs="$2"; shift ;;
|
13 |
+
--force_override) force_override="$2"; shift ;;
|
14 |
+
--faces_target_count) faces_target_count="$2"; shift ;;
|
15 |
+
--time) time="$2"; shift ;;
|
16 |
+
--input) input="$2"; shift ;;
|
17 |
+
--input_dir) input_dir="$2"; shift ;;
|
18 |
+
--output_dir) output_dir="$2"; shift ;;
|
19 |
+
*) echo "Unknown parameter: $1"; exit 1 ;;
|
20 |
+
esac
|
21 |
+
shift
|
22 |
+
done
|
23 |
+
|
24 |
+
# ensure psutil is installed for memory management
|
25 |
+
pip install psutil --quiet
|
26 |
+
if [ $? -ne 0 ]; then
|
27 |
+
echo "Warning: Failed to install psutil. Memory management may not work properly."
|
28 |
+
fi
|
29 |
+
|
30 |
+
# set the time for all processes to use
|
31 |
+
time=$(date "+%Y_%m_%d_%H_%M_%S")
|
32 |
+
|
33 |
+
for (( i=0; i<num_runs; i++ ))
|
34 |
+
do
|
35 |
+
cmd=" \
|
36 |
+
python -m src.data.extract \
|
37 |
+
--config=$config \
|
38 |
+
--require_suffix=$require_suffix \
|
39 |
+
--force_override=$force_override \
|
40 |
+
--num_runs=$num_runs \
|
41 |
+
--id=$i \
|
42 |
+
--time=$time \
|
43 |
+
--faces_target_count=$faces_target_count \
|
44 |
+
"
|
45 |
+
if [ -n "$input" ]; then
|
46 |
+
cmd="$cmd --input=$input"
|
47 |
+
fi
|
48 |
+
if [ -n "$input_dir" ]; then
|
49 |
+
cmd="$cmd --input_dir=$input_dir"
|
50 |
+
fi
|
51 |
+
if [ -n "$output_dir" ]; then
|
52 |
+
cmd="$cmd --output_dir=$output_dir"
|
53 |
+
fi
|
54 |
+
cmd="$cmd &"
|
55 |
+
eval $cmd
|
56 |
+
done
|
57 |
+
|
58 |
+
wait
|
59 |
+
|
60 |
+
echo "done"
|
UniRig/launch/inference/generate_skeleton.sh
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# generate skeleton
|
2 |
+
config="configs/data/quick_inference.yaml"
|
3 |
+
require_suffix="obj,fbx,FBX,dae,glb,gltf,vrm"
|
4 |
+
num_runs=1
|
5 |
+
force_override="false"
|
6 |
+
faces_target_count=50000
|
7 |
+
skeleton_task="configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml"
|
8 |
+
add_root="false"
|
9 |
+
seed=12345
|
10 |
+
npz_dir="tmp"
|
11 |
+
|
12 |
+
while [[ "$#" -gt 0 ]]; do
|
13 |
+
case $1 in
|
14 |
+
--config) config="$2"; shift ;;
|
15 |
+
--require_suffix) require_suffix="$2"; shift ;;
|
16 |
+
--num_runs) num_runs="$2"; shift ;;
|
17 |
+
--force_override) force_override="$2"; shift ;;
|
18 |
+
--faces_target_count) faces_target_count="$2"; shift ;;
|
19 |
+
--skeleton_task) skeleton_task="$2"; shift ;;
|
20 |
+
--add_root) add_root="$2"; shift ;;
|
21 |
+
--seed) seed="$2"; shift ;;
|
22 |
+
--input) input="$2"; shift ;;
|
23 |
+
--input_dir) input_dir="$2"; shift ;;
|
24 |
+
--output_dir) output_dir="$2"; shift ;;
|
25 |
+
--output) output="$2"; shift ;;
|
26 |
+
*) echo "Unknown parameter: $1"; exit 1 ;;
|
27 |
+
esac
|
28 |
+
shift
|
29 |
+
done
|
30 |
+
|
31 |
+
# 1. extract mesh
|
32 |
+
cmd=" \
|
33 |
+
bash ./launch/inference/extract.sh \
|
34 |
+
--config $config \
|
35 |
+
--require_suffix $require_suffix \
|
36 |
+
--force_override $force_override \
|
37 |
+
--num_runs $num_runs \
|
38 |
+
--faces_target_count $faces_target_count \
|
39 |
+
"
|
40 |
+
if [ -n "$input" ]; then
|
41 |
+
cmd="$cmd --input $input"
|
42 |
+
fi
|
43 |
+
if [ -n "$input_dir" ]; then
|
44 |
+
cmd="$cmd --input_dir $input_dir"
|
45 |
+
fi
|
46 |
+
if [ -n "$npz_dir" ]; then
|
47 |
+
cmd="$cmd --output_dir $npz_dir"
|
48 |
+
fi
|
49 |
+
|
50 |
+
cmd="$cmd &"
|
51 |
+
eval $cmd
|
52 |
+
|
53 |
+
wait
|
54 |
+
|
55 |
+
# 2. inference skeleton
|
56 |
+
cmd="\
|
57 |
+
python run.py \
|
58 |
+
--task=$skeleton_task \
|
59 |
+
--seed=$seed \
|
60 |
+
"
|
61 |
+
if [ -n "$input" ]; then
|
62 |
+
cmd="$cmd --input=$input"
|
63 |
+
fi
|
64 |
+
if [ -n "$input_dir" ]; then
|
65 |
+
cmd="$cmd --input_dir=$input_dir"
|
66 |
+
fi
|
67 |
+
if [ -n "$output" ]; then
|
68 |
+
cmd="$cmd --output=$output"
|
69 |
+
fi
|
70 |
+
if [ -n "$output_dir" ]; then
|
71 |
+
cmd="$cmd --output_dir=$output_dir"
|
72 |
+
fi
|
73 |
+
if [ -n "$npz_dir" ]; then
|
74 |
+
cmd="$cmd --npz_dir=$npz_dir"
|
75 |
+
fi
|
76 |
+
|
77 |
+
eval $cmd
|
78 |
+
|
79 |
+
wait
|
80 |
+
|
81 |
+
echo "done"
|
UniRig/launch/inference/generate_skin.sh
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# generate skin
|
2 |
+
config="configs/data/quick_inference.yaml"
|
3 |
+
require_suffix="obj,fbx,FBX,dae,glb,gltf,vrm"
|
4 |
+
num_runs=1
|
5 |
+
force_override="true"
|
6 |
+
faces_target_count=50000
|
7 |
+
skin_task="configs/task/quick_inference_unirig_skin.yaml"
|
8 |
+
seed=12345
|
9 |
+
npz_dir="tmp"
|
10 |
+
data_name="raw_data.npz"
|
11 |
+
|
12 |
+
while [[ "$#" -gt 0 ]]; do
|
13 |
+
case $1 in
|
14 |
+
--config) config="$2"; shift ;;
|
15 |
+
--require_suffix) require_suffix="$2"; shift ;;
|
16 |
+
--num_runs) num_runs="$2"; shift ;;
|
17 |
+
--force_override) force_override="$2"; shift ;;
|
18 |
+
--faces_target_count) faces_target_count="$2"; shift ;;
|
19 |
+
--skin_task) skin_task="$2"; shift ;;
|
20 |
+
--seed) seed="$2"; shift ;;
|
21 |
+
--input) input="$2"; shift ;;
|
22 |
+
--input_dir) input_dir="$2"; shift ;;
|
23 |
+
--output_dir) output_dir="$2"; shift ;;
|
24 |
+
--output) output="$2"; shift ;;
|
25 |
+
--data_name) data_name="$2"; shift ;;
|
26 |
+
*) echo "Unknown parameter: $1"; exit 1 ;;
|
27 |
+
esac
|
28 |
+
shift
|
29 |
+
done
|
30 |
+
|
31 |
+
# 1. extract mesh
|
32 |
+
cmd=" \
|
33 |
+
bash ./launch/inference/extract.sh \
|
34 |
+
--config $config \
|
35 |
+
--require_suffix $require_suffix \
|
36 |
+
--force_override $force_override \
|
37 |
+
--num_runs $num_runs \
|
38 |
+
--faces_target_count $faces_target_count \
|
39 |
+
"
|
40 |
+
if [ -n "$input" ]; then
|
41 |
+
cmd="$cmd --input $input"
|
42 |
+
fi
|
43 |
+
if [ -n "$input_dir" ]; then
|
44 |
+
cmd="$cmd --input_dir $input_dir"
|
45 |
+
fi
|
46 |
+
if [ -n "$npz_dir" ]; then
|
47 |
+
cmd="$cmd --output_dir $npz_dir"
|
48 |
+
fi
|
49 |
+
|
50 |
+
cmd="$cmd &"
|
51 |
+
eval $cmd
|
52 |
+
|
53 |
+
wait
|
54 |
+
|
55 |
+
# 2. inference skin
|
56 |
+
cmd="\
|
57 |
+
python run.py \
|
58 |
+
--task=$skin_task \
|
59 |
+
--seed=$seed \
|
60 |
+
"
|
61 |
+
if [ -n "$input" ]; then
|
62 |
+
cmd="$cmd --input=$input"
|
63 |
+
fi
|
64 |
+
if [ -n "$input_dir" ]; then
|
65 |
+
cmd="$cmd --input_dir=$input_dir"
|
66 |
+
fi
|
67 |
+
if [ -n "$output" ]; then
|
68 |
+
cmd="$cmd --output=$output"
|
69 |
+
fi
|
70 |
+
if [ -n "$output_dir" ]; then
|
71 |
+
cmd="$cmd --output_dir=$output_dir"
|
72 |
+
fi
|
73 |
+
if [ -n "$npz_dir" ]; then
|
74 |
+
cmd="$cmd --npz_dir=$npz_dir"
|
75 |
+
fi
|
76 |
+
if [ -n "$data_name" ]; then
|
77 |
+
cmd="$cmd --data_name=$data_name"
|
78 |
+
fi
|
79 |
+
|
80 |
+
eval $cmd
|
81 |
+
|
82 |
+
wait
|
83 |
+
|
84 |
+
echo "done"
|
UniRig/launch/inference/merge.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# merge texture
|
2 |
+
require_suffix="obj,fbx,FBX,dae,glb,gltf,vrm"
|
3 |
+
source=""
|
4 |
+
target=""
|
5 |
+
output=""
|
6 |
+
|
7 |
+
while [[ "$#" -gt 0 ]]; do
|
8 |
+
case $1 in
|
9 |
+
--require_suffix) require_suffix="$2"; shift ;;
|
10 |
+
--source) source="$2"; shift ;;
|
11 |
+
--target) target="$2"; shift ;;
|
12 |
+
--output) output="$2"; shift ;;
|
13 |
+
*) echo "Unknown parameter: $1"; exit 1 ;;
|
14 |
+
esac
|
15 |
+
shift
|
16 |
+
done
|
17 |
+
|
18 |
+
cmd=" \
|
19 |
+
python -m src.inference.merge \
|
20 |
+
--require_suffix=$require_suffix \
|
21 |
+
--num_runs=1 \
|
22 |
+
--id=0 \
|
23 |
+
--source=$source \
|
24 |
+
--target=$target \
|
25 |
+
--output=$output \
|
26 |
+
"
|
27 |
+
|
28 |
+
cmd="$cmd &"
|
29 |
+
eval $cmd
|
30 |
+
|
31 |
+
wait
|
32 |
+
|
33 |
+
echo "done"
|
UniRig/requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
python-box
|
3 |
+
einops
|
4 |
+
omegaconf
|
5 |
+
pytorch_lightning
|
6 |
+
lightning
|
7 |
+
addict
|
8 |
+
timm
|
9 |
+
fast-simplification
|
10 |
+
bpy==4.2
|
11 |
+
flash_attn
|
12 |
+
trimesh
|
13 |
+
open3d
|
14 |
+
pyrender
|
15 |
+
huggingface_hub
|
UniRig/run.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import yaml
|
3 |
+
from box import Box
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import lightning as L
|
7 |
+
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
|
8 |
+
from typing import List
|
9 |
+
from math import ceil
|
10 |
+
import numpy as np
|
11 |
+
from lightning.pytorch.strategies import FSDPStrategy, DDPStrategy
|
12 |
+
from src.inference.download import download
|
13 |
+
|
14 |
+
from src.data.asset import Asset
|
15 |
+
from src.data.extract import get_files
|
16 |
+
from src.data.dataset import UniRigDatasetModule, DatasetConfig, ModelInput
|
17 |
+
from src.data.datapath import Datapath
|
18 |
+
from src.data.transform import TransformConfig
|
19 |
+
from src.tokenizer.spec import TokenizerConfig
|
20 |
+
from src.tokenizer.parse import get_tokenizer
|
21 |
+
from src.model.parse import get_model
|
22 |
+
from src.system.parse import get_system, get_writer
|
23 |
+
|
24 |
+
from tqdm import tqdm
|
25 |
+
import time
|
26 |
+
|
27 |
+
def load(task: str, path: str) -> Box:
|
28 |
+
if path.endswith('.yaml'):
|
29 |
+
path = path.removesuffix('.yaml')
|
30 |
+
path += '.yaml'
|
31 |
+
print(f"\033[92mload {task} config: {path}\033[0m")
|
32 |
+
return Box(yaml.safe_load(open(path, 'r')))
|
33 |
+
|
34 |
+
def nullable_string(val):
|
35 |
+
if not val:
|
36 |
+
return None
|
37 |
+
return val
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
torch.set_float32_matmul_precision('high')
|
41 |
+
|
42 |
+
parser = argparse.ArgumentParser()
|
43 |
+
parser.add_argument("--task", type=str, required=True)
|
44 |
+
parser.add_argument("--seed", type=int, required=False, default=123,
|
45 |
+
help="random seed")
|
46 |
+
parser.add_argument("--input", type=nullable_string, required=False, default=None,
|
47 |
+
help="a single input file or files splited by comma")
|
48 |
+
parser.add_argument("--input_dir", type=nullable_string, required=False, default=None,
|
49 |
+
help="input directory")
|
50 |
+
parser.add_argument("--output", type=nullable_string, required=False, default=None,
|
51 |
+
help="filename for a single output")
|
52 |
+
parser.add_argument("--output_dir", type=nullable_string, required=False, default=None,
|
53 |
+
help="output directory")
|
54 |
+
parser.add_argument("--npz_dir", type=nullable_string, required=False, default='tmp',
|
55 |
+
help="intermediate npz directory")
|
56 |
+
parser.add_argument("--cls", type=nullable_string, required=False, default=None,
|
57 |
+
help="class name")
|
58 |
+
parser.add_argument("--data_name", type=nullable_string, required=False, default=None,
|
59 |
+
help="npz filename from skeleton phase")
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
L.seed_everything(args.seed, workers=True)
|
63 |
+
|
64 |
+
task = load('task', args.task)
|
65 |
+
mode = task.mode
|
66 |
+
assert mode in ['predict']
|
67 |
+
|
68 |
+
if args.input is not None or args.input_dir is not None:
|
69 |
+
assert args.output_dir is not None or args.output is not None, 'output or output_dir must be specified'
|
70 |
+
assert args.npz_dir is not None, 'npz_dir must be specified'
|
71 |
+
files = get_files(
|
72 |
+
data_name=task.components.data_name,
|
73 |
+
inputs=args.input,
|
74 |
+
input_dataset_dir=args.input_dir,
|
75 |
+
output_dataset_dir=args.npz_dir,
|
76 |
+
force_override=True,
|
77 |
+
warning=False,
|
78 |
+
)
|
79 |
+
files = [f[1] for f in files]
|
80 |
+
if len(files) > 1 and args.output is not None:
|
81 |
+
print("\033[92mwarning: output is specified, but multiple files are detected. Output will be written.\033[0m")
|
82 |
+
datapath = Datapath(files=files, cls=args.cls)
|
83 |
+
else:
|
84 |
+
datapath = None
|
85 |
+
|
86 |
+
data_config = load('data', os.path.join('configs/data', task.components.data))
|
87 |
+
transform_config = load('transform', os.path.join('configs/transform', task.components.transform))
|
88 |
+
|
89 |
+
# get tokenizer
|
90 |
+
tokenizer_config = task.components.get('tokenizer', None)
|
91 |
+
if tokenizer_config is not None:
|
92 |
+
tokenizer_config = load('tokenizer', os.path.join('configs/tokenizer', task.components.tokenizer))
|
93 |
+
tokenizer_config = TokenizerConfig.parse(config=tokenizer_config)
|
94 |
+
|
95 |
+
# get data name
|
96 |
+
data_name = task.components.get('data_name', 'raw_data.npz')
|
97 |
+
if args.data_name is not None:
|
98 |
+
data_name = args.data_name
|
99 |
+
|
100 |
+
# get predict dataset
|
101 |
+
predict_dataset_config = data_config.get('predict_dataset_config', None)
|
102 |
+
if predict_dataset_config is not None:
|
103 |
+
predict_dataset_config = DatasetConfig.parse(config=predict_dataset_config).split_by_cls()
|
104 |
+
|
105 |
+
# get predict transform
|
106 |
+
predict_transform_config = transform_config.get('predict_transform_config', None)
|
107 |
+
if predict_transform_config is not None:
|
108 |
+
predict_transform_config = TransformConfig.parse(config=predict_transform_config)
|
109 |
+
|
110 |
+
# get model
|
111 |
+
model_config = task.components.get('model', None)
|
112 |
+
if model_config is not None:
|
113 |
+
model_config = load('model', os.path.join('configs/model', model_config))
|
114 |
+
if tokenizer_config is not None:
|
115 |
+
tokenizer = get_tokenizer(config=tokenizer_config)
|
116 |
+
else:
|
117 |
+
tokenizer = None
|
118 |
+
model = get_model(tokenizer=tokenizer, **model_config)
|
119 |
+
else:
|
120 |
+
model = None
|
121 |
+
|
122 |
+
# set data
|
123 |
+
data = UniRigDatasetModule(
|
124 |
+
process_fn=None if model is None else model._process_fn,
|
125 |
+
predict_dataset_config=predict_dataset_config,
|
126 |
+
predict_transform_config=predict_transform_config,
|
127 |
+
tokenizer_config=tokenizer_config,
|
128 |
+
debug=False,
|
129 |
+
data_name=data_name,
|
130 |
+
datapath=datapath,
|
131 |
+
cls=args.cls,
|
132 |
+
)
|
133 |
+
|
134 |
+
# add call backs
|
135 |
+
callbacks = []
|
136 |
+
|
137 |
+
## get checkpoint callback
|
138 |
+
checkpoint_config = task.get('checkpoint', None)
|
139 |
+
if checkpoint_config is not None:
|
140 |
+
checkpoint_config['dirpath'] = os.path.join('experiments', task.experiment_name)
|
141 |
+
callbacks.append(ModelCheckpoint(**checkpoint_config))
|
142 |
+
|
143 |
+
## get writer callback
|
144 |
+
writer_config = task.get('writer', None)
|
145 |
+
if writer_config is not None:
|
146 |
+
assert predict_transform_config is not None, 'missing predict_transform_config in transform'
|
147 |
+
if args.output_dir is not None or args.output is not None:
|
148 |
+
if args.output is not None:
|
149 |
+
assert args.output.endswith('.fbx'), 'output must be .fbx'
|
150 |
+
writer_config['npz_dir'] = args.npz_dir
|
151 |
+
writer_config['output_dir'] = args.output_dir
|
152 |
+
writer_config['output_name'] = args.output
|
153 |
+
writer_config['user_mode'] = True
|
154 |
+
callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config))
|
155 |
+
|
156 |
+
# get trainer
|
157 |
+
trainer_config = task.get('trainer', {})
|
158 |
+
|
159 |
+
# get system
|
160 |
+
system_config = task.components.get('system', None)
|
161 |
+
if system_config is not None:
|
162 |
+
system_config = load('system', os.path.join('configs/system', system_config))
|
163 |
+
system = get_system(
|
164 |
+
**system_config,
|
165 |
+
model=model,
|
166 |
+
steps_per_epoch=1,
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
system = None
|
170 |
+
|
171 |
+
logger = None
|
172 |
+
|
173 |
+
# set ckpt path
|
174 |
+
resume_from_checkpoint = task.get('resume_from_checkpoint', None)
|
175 |
+
resume_from_checkpoint = download(resume_from_checkpoint)
|
176 |
+
trainer = L.Trainer(
|
177 |
+
callbacks=callbacks,
|
178 |
+
logger=logger,
|
179 |
+
**trainer_config,
|
180 |
+
)
|
181 |
+
|
182 |
+
if mode == 'predict':
|
183 |
+
assert resume_from_checkpoint is not None, 'expect resume_from_checkpoint in task'
|
184 |
+
trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False)
|
185 |
+
else:
|
186 |
+
assert 0
|
UniRig/src/data/__init__.py
ADDED
File without changes
|
UniRig/src/data/asset.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import numpy as np
|
4 |
+
from numpy import ndarray
|
5 |
+
|
6 |
+
from typing import Dict, Union, List, Tuple
|
7 |
+
|
8 |
+
from .order import Order
|
9 |
+
from .raw_data import RawData
|
10 |
+
from .exporter import Exporter
|
11 |
+
|
12 |
+
from ..tokenizer.spec import TokenizeInput
|
13 |
+
from .utils import linear_blend_skinning
|
14 |
+
|
15 |
+
import trimesh
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class Asset(Exporter):
|
20 |
+
'''
|
21 |
+
Dataclass to handle data parsed from raw data.
|
22 |
+
'''
|
23 |
+
|
24 |
+
# data class
|
25 |
+
cls: str
|
26 |
+
|
27 |
+
# where is this asset from
|
28 |
+
path: str
|
29 |
+
|
30 |
+
# data file name
|
31 |
+
data_name: str
|
32 |
+
|
33 |
+
# vertices of the mesh, shape (N, 3), float32
|
34 |
+
vertices: ndarray
|
35 |
+
|
36 |
+
# normals of vertices, shape (N, 3), float32
|
37 |
+
vertex_normals: ndarray
|
38 |
+
|
39 |
+
# faces of mesh, shape (F, 3), face id starts from 0 to F-1, int64
|
40 |
+
faces: ndarray
|
41 |
+
|
42 |
+
# face normal of mesh, shape (F, 3), float32
|
43 |
+
face_normals: ndarray
|
44 |
+
|
45 |
+
# joints of bones, shape (J, 3), float32
|
46 |
+
joints: Union[ndarray, None]=None
|
47 |
+
|
48 |
+
# tails of joints, shape (J, 3), float32
|
49 |
+
tails: Union[ndarray, None]=None
|
50 |
+
|
51 |
+
# skinning of joints, shape (N, J), float32
|
52 |
+
skin: Union[ndarray, None]=None
|
53 |
+
|
54 |
+
# whether the joint has skin, bool
|
55 |
+
no_skin: Union[ndarray, None]=None
|
56 |
+
|
57 |
+
# vertex groups
|
58 |
+
vertex_groups: Union[Dict[str, ndarray], None]=None
|
59 |
+
|
60 |
+
# parents of joints, None represents no parent(a root joint)
|
61 |
+
# make sure parent[k] < k
|
62 |
+
parents: Union[List[Union[int, None]], None]=None
|
63 |
+
|
64 |
+
# names of joints
|
65 |
+
names: Union[List[str], None]=None
|
66 |
+
|
67 |
+
# sampled vertices, shape (N, 3)
|
68 |
+
sampled_vertices: Union[ndarray, None]=None
|
69 |
+
|
70 |
+
# sampled normals, shape (N, 3)
|
71 |
+
sampled_normals: Union[ndarray, None]=None
|
72 |
+
|
73 |
+
# sampled vertex groups, every vertex group should be (N, J)
|
74 |
+
sampled_vertex_groups: Union[Dict[str, ndarray], None]=None
|
75 |
+
|
76 |
+
# {id: part}, part==None -> a spring token
|
77 |
+
parts_bias: Union[Dict[int, Union[str, None]], None]=None
|
78 |
+
|
79 |
+
# local coordinate, shape (J, 4, 4)
|
80 |
+
matrix_local: Union[ndarray, None]=None
|
81 |
+
|
82 |
+
# pose matrix for skinning loss calculation, shape (J, 4, 4)
|
83 |
+
pose_matrix: Union[ndarray, None]=None
|
84 |
+
|
85 |
+
meta: Union[Dict[str, ...], None]=None
|
86 |
+
|
87 |
+
@property
|
88 |
+
def N(self):
|
89 |
+
'''
|
90 |
+
number of vertices
|
91 |
+
'''
|
92 |
+
return self.vertices.shape[0]
|
93 |
+
|
94 |
+
@property
|
95 |
+
def F(self):
|
96 |
+
'''
|
97 |
+
number of faces
|
98 |
+
'''
|
99 |
+
return self.faces.shape[0]
|
100 |
+
|
101 |
+
@property
|
102 |
+
def J(self):
|
103 |
+
'''
|
104 |
+
number of joints
|
105 |
+
'''
|
106 |
+
return self.joints.shape[0]
|
107 |
+
|
108 |
+
def get_matrix(self, matrix_basis: ndarray, matrix_local: Union[ndarray, None]=None):
|
109 |
+
'''
|
110 |
+
get matrix
|
111 |
+
|
112 |
+
matrix_basis: (J, 4, 4)
|
113 |
+
'''
|
114 |
+
if matrix_local is None:
|
115 |
+
assert self.joints is not None
|
116 |
+
matrix_local = self.matrix_local
|
117 |
+
if matrix_local is None:
|
118 |
+
matrix_local = np.zeros((self.J, 4, 4))
|
119 |
+
matrix_local[:, 0, 0] = 1.
|
120 |
+
matrix_local[:, 1, 1] = 1.
|
121 |
+
matrix_local[:, 2, 2] = 1.
|
122 |
+
matrix_local[:, 3, 3] = 1.
|
123 |
+
for i in range(self.J):
|
124 |
+
matrix_local[i, :3, 3] = self.joints[i]
|
125 |
+
|
126 |
+
matrix = np.zeros((self.J, 4, 4))
|
127 |
+
for i in range(self.J):
|
128 |
+
if i==0:
|
129 |
+
matrix[i] = matrix_local[i] @ matrix_basis[i]
|
130 |
+
else:
|
131 |
+
pid = self.parents[i]
|
132 |
+
matrix_parent = matrix[pid]
|
133 |
+
matrix_local_parent = matrix_local[pid]
|
134 |
+
|
135 |
+
matrix[i] = (
|
136 |
+
matrix_parent @
|
137 |
+
(np.linalg.inv(matrix_local_parent) @ matrix_local[i]) @
|
138 |
+
matrix_basis[i]
|
139 |
+
)
|
140 |
+
return matrix
|
141 |
+
|
142 |
+
def apply_matrix_basis(self, matrix_basis: ndarray):
|
143 |
+
'''
|
144 |
+
apply a pose to armature
|
145 |
+
|
146 |
+
matrix_basis: (J, 4, 4)
|
147 |
+
'''
|
148 |
+
matrix_local = self.matrix_local
|
149 |
+
if matrix_local is None:
|
150 |
+
matrix_local = np.zeros((self.J, 4, 4))
|
151 |
+
matrix_local[:, 0, 0] = 1.
|
152 |
+
matrix_local[:, 1, 1] = 1.
|
153 |
+
matrix_local[:, 2, 2] = 1.
|
154 |
+
matrix_local[:, 3, 3] = 1.
|
155 |
+
for i in range(self.J):
|
156 |
+
matrix_local[i, :3, 3] = self.joints[i].copy()
|
157 |
+
|
158 |
+
matrix = self.get_matrix(matrix_basis=matrix_basis, matrix_local=matrix_local)
|
159 |
+
self.joints = matrix[:, :3, 3].copy()
|
160 |
+
vertices = linear_blend_skinning(self.vertices, matrix_local, matrix, self.skin, pad=1, value=1.)
|
161 |
+
# update matrix_local
|
162 |
+
self.matrix_local = matrix.copy()
|
163 |
+
|
164 |
+
# change tails
|
165 |
+
if self.tails is not None:
|
166 |
+
t_skin = np.eye(self.J)
|
167 |
+
self.tails = linear_blend_skinning(self.tails, matrix_local, matrix, t_skin, pad=1, value=1.)
|
168 |
+
# in accordance with trimesh's normals
|
169 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=self.faces, process=False)
|
170 |
+
self.vertices = vertices
|
171 |
+
self.vertex_normals = mesh.vertex_normals.copy()
|
172 |
+
self.face_normals = mesh.face_normals.copy()
|
173 |
+
|
174 |
+
def set_order_by_names(self, new_names: List[str]):
|
175 |
+
assert len(new_names) == len(self.names)
|
176 |
+
name_to_id = {name: id for (id, name) in enumerate(self.names)}
|
177 |
+
new_name_to_id = {name: id for (id, name) in enumerate(new_names)}
|
178 |
+
perm = []
|
179 |
+
new_parents = []
|
180 |
+
for (new_id, name) in enumerate(new_names):
|
181 |
+
perm.append(name_to_id[name])
|
182 |
+
pid = self.parents[name_to_id[name]]
|
183 |
+
if new_id == 0:
|
184 |
+
assert pid is None, 'first bone is not root bone'
|
185 |
+
else:
|
186 |
+
pname = self.names[pid]
|
187 |
+
pid = new_name_to_id[pname]
|
188 |
+
assert pid < new_id, 'new order does not form a tree'
|
189 |
+
new_parents.append(pid)
|
190 |
+
|
191 |
+
if self.joints is not None:
|
192 |
+
self.joints = self.joints[perm]
|
193 |
+
self.parents = new_parents
|
194 |
+
if self.tails is not None:
|
195 |
+
self.tails = self.tails[perm]
|
196 |
+
if self.skin is not None:
|
197 |
+
self.skin = self.skin[:, perm]
|
198 |
+
if self.no_skin is not None:
|
199 |
+
self.no_skin = self.no_skin[perm]
|
200 |
+
if self.matrix_local is not None:
|
201 |
+
self.matrix_local = self.matrix_local[perm]
|
202 |
+
self.names = new_names
|
203 |
+
|
204 |
+
def set_order(self, order: Order):
|
205 |
+
if self.names is None or self.parents is None:
|
206 |
+
return
|
207 |
+
new_names, self.parts_bias = order.arrange_names(cls=self.cls, names=self.names, parents=self.parents)
|
208 |
+
self.set_order_by_names(new_names=new_names)
|
209 |
+
|
210 |
+
def collapse(self, keep: List[str]):
|
211 |
+
dsu = [i for i in range(self.J)]
|
212 |
+
|
213 |
+
def find(x: int) -> int:
|
214 |
+
if dsu[x] == x:
|
215 |
+
return x
|
216 |
+
y = find(dsu[x])
|
217 |
+
dsu[x] = y
|
218 |
+
return y
|
219 |
+
|
220 |
+
def merge(x: int, y: int):
|
221 |
+
dsu[find(x)] = find(y)
|
222 |
+
|
223 |
+
if self.tails is not None:
|
224 |
+
new_tails = self.tails.copy()
|
225 |
+
else:
|
226 |
+
new_tails = None
|
227 |
+
if self.skin is not None:
|
228 |
+
new_skin = self.skin.copy()
|
229 |
+
else:
|
230 |
+
new_skin = None
|
231 |
+
|
232 |
+
if self.no_skin is not None:
|
233 |
+
new_no_skin = self.no_skin.copy()
|
234 |
+
else:
|
235 |
+
new_no_skin = None
|
236 |
+
|
237 |
+
if self.matrix_local is not None:
|
238 |
+
matrix_local = self.matrix_local.copy()
|
239 |
+
else:
|
240 |
+
matrix_local = None
|
241 |
+
new_names = []
|
242 |
+
new_parents = []
|
243 |
+
perm = []
|
244 |
+
new_name_to_id = {}
|
245 |
+
tot = 0
|
246 |
+
for (i, name) in enumerate(self.names):
|
247 |
+
if name in keep:
|
248 |
+
new_names.append(name)
|
249 |
+
new_name_to_id[name] = tot
|
250 |
+
tot += 1
|
251 |
+
perm.append(i)
|
252 |
+
pid = self.parents[i]
|
253 |
+
if pid is None:
|
254 |
+
new_parents.append(None)
|
255 |
+
else:
|
256 |
+
pid = find(pid)
|
257 |
+
new_parents.append(new_name_to_id[self.names[pid]])
|
258 |
+
continue
|
259 |
+
assert i != 0, 'cannot remove root'
|
260 |
+
id = find(i)
|
261 |
+
pid = find(self.parents[id])
|
262 |
+
# be careful !
|
263 |
+
# do not copy tail here because you dont know which child to inherit from
|
264 |
+
if new_skin is not None:
|
265 |
+
new_skin[:, pid] += new_skin[:, id]
|
266 |
+
if new_no_skin is not None:
|
267 |
+
new_no_skin[pid] &= new_no_skin[id]
|
268 |
+
merge(id, pid)
|
269 |
+
|
270 |
+
if new_tails is not None:
|
271 |
+
new_tails = new_tails[perm]
|
272 |
+
if new_skin is not None:
|
273 |
+
new_skin = new_skin[:, perm]
|
274 |
+
if new_no_skin is not None:
|
275 |
+
new_no_skin = new_no_skin[perm]
|
276 |
+
if matrix_local is not None:
|
277 |
+
matrix_local = matrix_local[perm]
|
278 |
+
|
279 |
+
if self.joints is not None:
|
280 |
+
self.joints = self.joints[perm]
|
281 |
+
self.parents = new_parents
|
282 |
+
self.tails = new_tails
|
283 |
+
self.skin = new_skin
|
284 |
+
self.no_skin = new_no_skin
|
285 |
+
self.names = new_names
|
286 |
+
self.matrix_local = matrix_local
|
287 |
+
|
288 |
+
@staticmethod
|
289 |
+
def from_raw_data(
|
290 |
+
raw_data: RawData,
|
291 |
+
cls: str,
|
292 |
+
path: str,
|
293 |
+
data_name: str,
|
294 |
+
) -> 'Asset':
|
295 |
+
'''
|
296 |
+
Return an asset initialized from raw data and do transform.
|
297 |
+
'''
|
298 |
+
return Asset(
|
299 |
+
cls=cls,
|
300 |
+
path=path,
|
301 |
+
data_name=data_name,
|
302 |
+
vertices=raw_data.vertices,
|
303 |
+
vertex_normals=raw_data.vertex_normals,
|
304 |
+
faces=raw_data.faces,
|
305 |
+
face_normals=raw_data.face_normals,
|
306 |
+
joints=raw_data.joints,
|
307 |
+
tails=raw_data.tails,
|
308 |
+
skin=raw_data.skin,
|
309 |
+
no_skin=raw_data.no_skin,
|
310 |
+
parents=raw_data.parents,
|
311 |
+
names=raw_data.names,
|
312 |
+
matrix_local=raw_data.matrix_local,
|
313 |
+
meta={},
|
314 |
+
)
|
315 |
+
|
316 |
+
def get_tokenize_input(self) -> TokenizeInput:
|
317 |
+
children = defaultdict(list)
|
318 |
+
|
319 |
+
for (id, p) in enumerate(self.parents):
|
320 |
+
if p is not None:
|
321 |
+
children[p].append(id)
|
322 |
+
bones = []
|
323 |
+
branch = []
|
324 |
+
is_leaf = []
|
325 |
+
last = None
|
326 |
+
for i in range(self.J):
|
327 |
+
is_leaf.append(len(children[i])==0)
|
328 |
+
if i == 0:
|
329 |
+
bones.append(np.concatenate([self.joints[i], self.joints[i]]))
|
330 |
+
branch.append(False)
|
331 |
+
else:
|
332 |
+
pid = self.parents[i]
|
333 |
+
bones.append(np.concatenate([self.joints[pid], self.joints[i]]))
|
334 |
+
branch.append(pid!=last)
|
335 |
+
last = i
|
336 |
+
bones = np.stack(bones)
|
337 |
+
branch = np.array(branch, dtype=bool)
|
338 |
+
is_leaf = np.array(is_leaf, dtype=bool)
|
339 |
+
return TokenizeInput(
|
340 |
+
bones=bones,
|
341 |
+
tails=self.tails,
|
342 |
+
branch=branch,
|
343 |
+
is_leaf=is_leaf,
|
344 |
+
no_skin=self.no_skin,
|
345 |
+
cls=self.cls,
|
346 |
+
parts_bias=self.parts_bias,
|
347 |
+
)
|
348 |
+
|
349 |
+
def export_pc(self, path: str, with_normal: bool=True, normal_size=0.01):
|
350 |
+
'''
|
351 |
+
export point cloud
|
352 |
+
'''
|
353 |
+
vertices = self.vertices
|
354 |
+
normals = self.vertex_normals
|
355 |
+
if self.sampled_vertices is not None:
|
356 |
+
vertices = self.sampled_vertices
|
357 |
+
normals = self.sampled_normals
|
358 |
+
if with_normal == False:
|
359 |
+
normals = None
|
360 |
+
self._export_pc(vertices=vertices, path=path, vertex_normals=normals, normal_size=normal_size)
|
361 |
+
|
362 |
+
def export_mesh(self, path: str):
|
363 |
+
'''
|
364 |
+
export mesh
|
365 |
+
'''
|
366 |
+
self._export_mesh(vertices=self.vertices, faces=self.faces, path=path)
|
367 |
+
|
368 |
+
def export_skeleton(self, path: str):
|
369 |
+
'''
|
370 |
+
export spring
|
371 |
+
'''
|
372 |
+
self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
|
373 |
+
|
374 |
+
def export_skeleton_sequence(self, path: str):
|
375 |
+
'''
|
376 |
+
export spring
|
377 |
+
'''
|
378 |
+
self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
|
379 |
+
|
380 |
+
def export_fbx(
|
381 |
+
self,
|
382 |
+
path: str,
|
383 |
+
vertex_group_name: str,
|
384 |
+
extrude_size: float=0.03,
|
385 |
+
group_per_vertex: int=-1,
|
386 |
+
add_root: bool=False,
|
387 |
+
do_not_normalize: bool=False,
|
388 |
+
use_extrude_bone: bool=True,
|
389 |
+
use_connect_unique_child: bool=True,
|
390 |
+
extrude_from_parent: bool=True,
|
391 |
+
use_tail: bool=False,
|
392 |
+
use_origin: bool=False,
|
393 |
+
):
|
394 |
+
'''
|
395 |
+
export the whole model with skining
|
396 |
+
'''
|
397 |
+
self._export_fbx(
|
398 |
+
path=path,
|
399 |
+
vertices=self.vertices if use_origin else self.sampled_vertices,
|
400 |
+
joints=self.joints,
|
401 |
+
skin=self.sampled_vertex_groups[vertex_group_name],
|
402 |
+
parents=self.parents,
|
403 |
+
names=self.names,
|
404 |
+
faces=self.faces if use_origin else None,
|
405 |
+
extrude_size=extrude_size,
|
406 |
+
group_per_vertex=group_per_vertex,
|
407 |
+
add_root=add_root,
|
408 |
+
do_not_normalize=do_not_normalize,
|
409 |
+
use_extrude_bone=use_extrude_bone,
|
410 |
+
use_connect_unique_child=use_connect_unique_child,
|
411 |
+
extrude_from_parent=extrude_from_parent,
|
412 |
+
tails=self.tails if use_tail else None,
|
413 |
+
)
|
414 |
+
|
415 |
+
def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256], use_tail: bool=False):
|
416 |
+
if use_tail:
|
417 |
+
assert self.tails is not None
|
418 |
+
self._export_render(
|
419 |
+
path=path,
|
420 |
+
vertices=self.vertices,
|
421 |
+
faces=self.faces,
|
422 |
+
bones=np.concatenate([self.joints, self.tails], axis=-1),
|
423 |
+
resolution=resolution,
|
424 |
+
)
|
425 |
+
else:
|
426 |
+
pjoints = self.joints[self.parents[1:]]
|
427 |
+
self._export_render(
|
428 |
+
path=path,
|
429 |
+
vertices=self.vertices,
|
430 |
+
faces=self.faces,
|
431 |
+
bones=np.concatenate([pjoints, self.joints[1:]], axis=-1),
|
432 |
+
resolution=resolution,
|
433 |
+
)
|
UniRig/src/data/augment.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Tuple, Union, List, Dict
|
3 |
+
from numpy import ndarray
|
4 |
+
import numpy as np
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
from scipy.spatial.transform import Rotation as R
|
7 |
+
|
8 |
+
from .spec import ConfigSpec
|
9 |
+
from .asset import Asset
|
10 |
+
from .utils import axis_angle_to_matrix
|
11 |
+
|
12 |
+
@dataclass(frozen=True)
|
13 |
+
class AugmentAffineConfig(ConfigSpec):
|
14 |
+
# final normalization cube
|
15 |
+
normalize_into: Tuple[float, float]
|
16 |
+
|
17 |
+
# randomly scale coordinates with probability p
|
18 |
+
random_scale_p: float
|
19 |
+
|
20 |
+
# scale range (lower, upper)
|
21 |
+
random_scale: Tuple[float, float]
|
22 |
+
|
23 |
+
# randomly shift coordinates with probability p
|
24 |
+
random_shift_p: float
|
25 |
+
|
26 |
+
# shift range (lower, upper)
|
27 |
+
random_shift: Tuple[float, float]
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def parse(cls, config) -> Union['AugmentAffineConfig', None]:
|
31 |
+
if config is None:
|
32 |
+
return None
|
33 |
+
cls.check_keys(config)
|
34 |
+
return AugmentAffineConfig(
|
35 |
+
normalize_into=config.normalize_into,
|
36 |
+
random_scale_p=config.get('random_scale_p', 0.),
|
37 |
+
random_scale=config.get('random_scale', [1., 1.]),
|
38 |
+
random_shift_p=config.get('random_shift_p', 0.),
|
39 |
+
random_shift=config.get('random_shift', [0., 0.]),
|
40 |
+
)
|
41 |
+
|
42 |
+
@dataclass(frozen=True)
|
43 |
+
class AugmentConfig(ConfigSpec):
|
44 |
+
'''
|
45 |
+
Config to handle final easy augmentation of vertices, normals and bones before sampling.
|
46 |
+
'''
|
47 |
+
augment_affine_config: Union[AugmentAffineConfig, None]
|
48 |
+
|
49 |
+
@classmethod
|
50 |
+
def parse(cls, config) -> 'AugmentConfig':
|
51 |
+
cls.check_keys(config)
|
52 |
+
return AugmentConfig(
|
53 |
+
augment_affine_config=AugmentAffineConfig.parse(config.get('augment_affine_config', None)),
|
54 |
+
)
|
55 |
+
|
56 |
+
class Augment(ABC):
|
57 |
+
'''
|
58 |
+
Abstract class for augmentation
|
59 |
+
'''
|
60 |
+
def __init__(self):
|
61 |
+
pass
|
62 |
+
|
63 |
+
@abstractmethod
|
64 |
+
def transform(self, asset: Asset, **kwargs):
|
65 |
+
pass
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def inverse(self, asset: Asset):
|
69 |
+
pass
|
70 |
+
|
71 |
+
class AugmentAffine(Augment):
|
72 |
+
|
73 |
+
def __init__(self, config: AugmentAffineConfig):
|
74 |
+
super().__init__()
|
75 |
+
self.config = config
|
76 |
+
|
77 |
+
def _apply(self, v: ndarray, trans: ndarray) -> ndarray:
|
78 |
+
return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
|
79 |
+
|
80 |
+
def transform(self, asset: Asset, **kwargs):
|
81 |
+
bound_min = asset.vertices.min(axis=0)
|
82 |
+
bound_max = asset.vertices.max(axis=0)
|
83 |
+
if asset.joints is not None:
|
84 |
+
joints_bound_min = asset.joints.min(axis=0)
|
85 |
+
joints_bound_max = asset.joints.max(axis=0)
|
86 |
+
bound_min = np.minimum(bound_min, joints_bound_min)
|
87 |
+
bound_max = np.maximum(bound_max, joints_bound_max)
|
88 |
+
|
89 |
+
trans_vertex = np.eye(4, dtype=np.float32)
|
90 |
+
|
91 |
+
trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex
|
92 |
+
|
93 |
+
# scale into the cube
|
94 |
+
normalize_into = self.config.normalize_into
|
95 |
+
scale = np.max((bound_max - bound_min) / (normalize_into[1] - normalize_into[0]))
|
96 |
+
trans_vertex = _scale_to_m(1. / scale) @ trans_vertex
|
97 |
+
|
98 |
+
bias = (normalize_into[0] + normalize_into[1]) / 2
|
99 |
+
trans_vertex = _trans_to_m(np.array([bias, bias, bias], dtype=np.float32)) @ trans_vertex
|
100 |
+
|
101 |
+
if np.random.rand() < self.config.random_scale_p:
|
102 |
+
scale = _scale_to_m(np.random.uniform(self.config.random_scale[0], self.config.random_scale[1]))
|
103 |
+
trans_vertex = scale @ trans_vertex
|
104 |
+
|
105 |
+
if np.random.rand() < self.config.random_shift_p:
|
106 |
+
l, r = self.config.random_shift
|
107 |
+
shift = _trans_to_m(np.array([np.random.uniform(l, r), np.random.uniform(l, r), np.random.uniform(l, r)]), dtype=np.float32)
|
108 |
+
trans_vertex = shift @ trans_vertex
|
109 |
+
|
110 |
+
asset.vertices = self._apply(asset.vertices, trans_vertex)
|
111 |
+
# do not affect scale in matrix
|
112 |
+
if asset.matrix_local is not None:
|
113 |
+
asset.matrix_local[:, :, 3:4] = trans_vertex @ asset.matrix_local[:, :, 3:4]
|
114 |
+
if asset.pose_matrix is not None:
|
115 |
+
asset.pose_matrix[:, :, 3:4] = trans_vertex @ asset.pose_matrix[:, :, 3:4]
|
116 |
+
# do not affect normal here
|
117 |
+
if asset.joints is not None:
|
118 |
+
asset.joints = self._apply(asset.joints, trans_vertex)
|
119 |
+
if asset.tails is not None:
|
120 |
+
asset.tails = self._apply(asset.tails, trans_vertex)
|
121 |
+
|
122 |
+
self.trans_vertex = trans_vertex
|
123 |
+
|
124 |
+
def inverse(self, asset: Asset):
|
125 |
+
m = np.linalg.inv(self.trans_vertex)
|
126 |
+
asset.vertices = self._apply(asset.vertices, m)
|
127 |
+
if asset.joints is not None:
|
128 |
+
asset.joints = self._apply(asset.joints, m)
|
129 |
+
if asset.tails is not None:
|
130 |
+
asset.tails = self._apply(asset.tails, m)
|
131 |
+
|
132 |
+
def _trans_to_m(v: ndarray):
|
133 |
+
m = np.eye(4, dtype=np.float32)
|
134 |
+
m[0:3, 3] = v
|
135 |
+
return m
|
136 |
+
|
137 |
+
def _scale_to_m(r: ndarray):
|
138 |
+
m = np.zeros((4, 4), dtype=np.float32)
|
139 |
+
m[0, 0] = r
|
140 |
+
m[1, 1] = r
|
141 |
+
m[2, 2] = r
|
142 |
+
m[3, 3] = 1.
|
143 |
+
return m
|
144 |
+
|
145 |
+
def get_augments(config: AugmentConfig) -> Tuple[List[Augment], List[Augment]]:
|
146 |
+
first_augments = [] # augments before sample
|
147 |
+
second_augments = [] # augments after sample
|
148 |
+
augment_affine_config = config.augment_affine_config
|
149 |
+
|
150 |
+
if augment_affine_config is not None:
|
151 |
+
second_augments.append(AugmentAffine(config=augment_affine_config))
|
152 |
+
return first_augments, second_augments
|
UniRig/src/data/datapath.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from collections import defaultdict
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Dict, Union, Tuple, List
|
5 |
+
import numpy as np
|
6 |
+
from numpy import ndarray
|
7 |
+
import os
|
8 |
+
from random import shuffle
|
9 |
+
from box import Box
|
10 |
+
from torch.onnx.symbolic_opset11 import index_copy
|
11 |
+
|
12 |
+
from .spec import ConfigSpec
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class DatapathConfig(ConfigSpec):
|
16 |
+
'''
|
17 |
+
Config to handle input data paths.
|
18 |
+
'''
|
19 |
+
# root
|
20 |
+
input_dataset_dir: str
|
21 |
+
|
22 |
+
# use proportion data sampling
|
23 |
+
use_prob: bool
|
24 |
+
|
25 |
+
# cls: [(path_1, p_1), ...]
|
26 |
+
data_path: Dict[str, List[Tuple[str, float]]]
|
27 |
+
|
28 |
+
# how many files to return when using data sampling
|
29 |
+
num_files: Union[int, None]
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def from_args(cls, **kwargs) -> 'DatapathConfig':
|
33 |
+
'''
|
34 |
+
Make a temporary datapath from user inputs.
|
35 |
+
'''
|
36 |
+
input = kwargs.get('input', None)
|
37 |
+
output = kwargs.get('output', None)
|
38 |
+
recursive = kwargs.get('recursive', False)
|
39 |
+
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def parse(cls, config) -> 'DatapathConfig':
|
43 |
+
cls.check_keys(config)
|
44 |
+
return DatapathConfig(
|
45 |
+
input_dataset_dir=config.input_dataset_dir,
|
46 |
+
use_prob=config.get('use_prob', True),
|
47 |
+
data_path=config.data_path,
|
48 |
+
num_files=config.get('num_files', None),
|
49 |
+
)
|
50 |
+
|
51 |
+
def split_by_cls(self) -> Dict[str, 'DatapathConfig']:
|
52 |
+
res: Dict[str, DatapathConfig] = {}
|
53 |
+
for cls in self.data_path:
|
54 |
+
res[cls] = deepcopy(self)
|
55 |
+
res[cls].data_path = {cls: self.data_path[cls]}
|
56 |
+
return res
|
57 |
+
|
58 |
+
class Datapath():
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
config: Union[DatapathConfig, None]=None,
|
62 |
+
files: Union[List[str], None]=None,
|
63 |
+
cls: Union[str, None]=None,
|
64 |
+
):
|
65 |
+
if config is not None:
|
66 |
+
self.config = config
|
67 |
+
self.file_list = []
|
68 |
+
cls_probs_first = []
|
69 |
+
cls_first = []
|
70 |
+
|
71 |
+
self.files_by_class: Dict[str, List[Dict]] = defaultdict(list)
|
72 |
+
self.class_positions: Dict[str, List[int]] = defaultdict(list)
|
73 |
+
self.cls_probs_second: Dict[str, ndarray] = defaultdict(List)
|
74 |
+
|
75 |
+
for cls in self.config.data_path:
|
76 |
+
prob = 0.
|
77 |
+
probs_second = []
|
78 |
+
for (path, p) in self.config.data_path[cls]:
|
79 |
+
prob += p
|
80 |
+
probs_second.append(p)
|
81 |
+
with open(path, 'r') as f:
|
82 |
+
file_items = []
|
83 |
+
missing = 0
|
84 |
+
for l in f.readlines():
|
85 |
+
raw_data_path = os.path.join(self.config.input_dataset_dir, l.strip(), 'raw_data.npz')
|
86 |
+
if not os.path.exists(raw_data_path):
|
87 |
+
missing += 1
|
88 |
+
continue
|
89 |
+
file_items.append({
|
90 |
+
'cls': cls,
|
91 |
+
'path': os.path.join(self.config.input_dataset_dir, l.strip()),
|
92 |
+
'prob': p
|
93 |
+
})
|
94 |
+
assert len(file_items) > 0, f"files in {path} are all missing! root: {self.config.input_dataset_dir}"
|
95 |
+
if missing > 0:
|
96 |
+
print(f"\033[31m{cls}: {missing} missing files\033[0m")
|
97 |
+
self.files_by_class[cls].append(file_items)
|
98 |
+
self.class_positions[cls].append(0)
|
99 |
+
self.file_list.extend(file_items)
|
100 |
+
probs_second = np.array(probs_second)
|
101 |
+
self.cls_probs_second[cls] = probs_second / probs_second.sum()
|
102 |
+
cls_first.append(cls)
|
103 |
+
cls_probs_first.append(prob)
|
104 |
+
cls_probs_first = np.array(cls_probs_first)
|
105 |
+
self.cls_first: List[str] = cls_first
|
106 |
+
self.cls_probs_first: Dict[str, List[float]] = cls_probs_first / cls_probs_first.sum()
|
107 |
+
elif files is not None:
|
108 |
+
if cls is None:
|
109 |
+
cls = 'inference'
|
110 |
+
self.file_list = [{'cls': cls, 'path': file} for file in files]
|
111 |
+
cls_probs_first = np.array([1.])
|
112 |
+
cls_first = []
|
113 |
+
|
114 |
+
self.files_by_class: Dict[str, List[Dict]] = {cls: self.file_list.copy()}
|
115 |
+
self.class_positions: Dict[str, List[int]] = {cls: [0]}
|
116 |
+
self.cls_probs_second: Dict[str, ndarray] = {cls: np.array([1.])}
|
117 |
+
self.config = Box({'use_prob': False})
|
118 |
+
else:
|
119 |
+
assert(0)
|
120 |
+
|
121 |
+
def __len__(self):
|
122 |
+
if self.config.use_prob:
|
123 |
+
assert self.config.num_files is not None, 'num_files is not specified'
|
124 |
+
return self.config.num_files
|
125 |
+
return len(self.file_list)
|
126 |
+
|
127 |
+
def __getitem__(self, index) -> Tuple[str, str]:
|
128 |
+
if self.config.use_prob:
|
129 |
+
# first sample a class
|
130 |
+
cls = np.random.choice(self.cls_first, p=self.cls_probs_first)
|
131 |
+
|
132 |
+
# second sample in this class
|
133 |
+
idx = np.random.choice(len(self.files_by_class[cls]), p=self.cls_probs_second[cls])
|
134 |
+
|
135 |
+
# get the current position
|
136 |
+
pos = self.class_positions[cls][idx]
|
137 |
+
files = self.files_by_class[cls][idx]
|
138 |
+
|
139 |
+
# get the item andd update position
|
140 |
+
item = files[pos]
|
141 |
+
self.class_positions[cls][idx] = (pos + 1) % len(files)
|
142 |
+
if (pos + 1) % len(files) == 0:
|
143 |
+
shuffle(self.files_by_class[cls][idx])
|
144 |
+
else:
|
145 |
+
item = self.file_list[index]
|
146 |
+
return (item['cls'], item['path'])
|
147 |
+
|
148 |
+
def get_data(self) -> List[Tuple[str, str]]:
|
149 |
+
return [self[i] for i in range(len(self))]
|
UniRig/src/data/dataset.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import lightning.pytorch as pl
|
4 |
+
# from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
|
5 |
+
import torch
|
6 |
+
from torch import LongTensor
|
7 |
+
from torch.utils import data
|
8 |
+
from torch.utils.data import DataLoader, Dataset
|
9 |
+
from typing import Dict, List, Tuple, Union, Callable
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from .raw_data import RawData
|
14 |
+
from .asset import Asset
|
15 |
+
from .transform import TransformConfig, transform_asset
|
16 |
+
from .datapath import DatapathConfig, Datapath
|
17 |
+
from .spec import ConfigSpec
|
18 |
+
|
19 |
+
from ..tokenizer.spec import TokenizerSpec, TokenizerConfig
|
20 |
+
from ..tokenizer.parse import get_tokenizer
|
21 |
+
from ..model.spec import ModelInput
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class DatasetConfig(ConfigSpec):
|
25 |
+
'''
|
26 |
+
Config to handle dataset format.
|
27 |
+
'''
|
28 |
+
# shuffle dataset
|
29 |
+
shuffle: bool
|
30 |
+
|
31 |
+
# batch size
|
32 |
+
batch_size: int
|
33 |
+
|
34 |
+
# number of workers
|
35 |
+
num_workers: int
|
36 |
+
|
37 |
+
# datapath
|
38 |
+
datapath_config: DatapathConfig
|
39 |
+
|
40 |
+
# use pin memory
|
41 |
+
pin_memory: bool = True
|
42 |
+
|
43 |
+
# use persistent workers
|
44 |
+
persistent_workers: bool = True
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def parse(cls, config) -> 'DatapathConfig':
|
48 |
+
cls.check_keys(config)
|
49 |
+
return DatasetConfig(
|
50 |
+
shuffle=config.shuffle,
|
51 |
+
batch_size=config.batch_size,
|
52 |
+
num_workers=config.num_workers,
|
53 |
+
pin_memory=config.pin_memory,
|
54 |
+
persistent_workers=config.persistent_workers,
|
55 |
+
datapath_config=DatapathConfig.parse(config.datapath_config),
|
56 |
+
)
|
57 |
+
|
58 |
+
def split_by_cls(self) -> Dict[str, 'DatasetConfig']:
|
59 |
+
res: Dict[str, DatasetConfig] = {}
|
60 |
+
datapath_config_dict = self.datapath_config.split_by_cls()
|
61 |
+
for cls in self.datapath_config.data_path:
|
62 |
+
res[cls] = deepcopy(self)
|
63 |
+
res[cls].datapath_config = datapath_config_dict[cls]
|
64 |
+
return res
|
65 |
+
|
66 |
+
class UniRigDatasetModule(pl.LightningDataModule):
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
|
70 |
+
predict_dataset_config: Union[Dict[str, DatasetConfig], None]=None,
|
71 |
+
predict_transform_config: Union[TransformConfig, None]=None,
|
72 |
+
tokenizer_config: Union[TokenizerConfig, None]=None,
|
73 |
+
debug: bool=False,
|
74 |
+
data_name: str='raw_data.npz',
|
75 |
+
datapath: Union[Datapath, None]=None,
|
76 |
+
cls: Union[str, None]=None,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
self.process_fn = process_fn
|
80 |
+
self.predict_dataset_config = predict_dataset_config
|
81 |
+
self.predict_transform_config = predict_transform_config
|
82 |
+
self.tokenizer_config = tokenizer_config
|
83 |
+
self.debug = debug
|
84 |
+
self.data_name = data_name
|
85 |
+
|
86 |
+
if debug:
|
87 |
+
print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m")
|
88 |
+
|
89 |
+
if datapath is not None:
|
90 |
+
self.train_datapath = None
|
91 |
+
self.validate_datapath = None
|
92 |
+
self.predict_datapath = {
|
93 |
+
cls: deepcopy(datapath),
|
94 |
+
}
|
95 |
+
self.predict_dataset_config = {
|
96 |
+
cls: DatasetConfig(
|
97 |
+
shuffle=False,
|
98 |
+
batch_size=1,
|
99 |
+
num_workers=0,
|
100 |
+
datapath_config=deepcopy(datapath),
|
101 |
+
pin_memory=False,
|
102 |
+
persistent_workers=False,
|
103 |
+
)
|
104 |
+
}
|
105 |
+
else:
|
106 |
+
# build predict datapath
|
107 |
+
if self.predict_dataset_config is not None:
|
108 |
+
self.predict_datapath = {
|
109 |
+
cls: Datapath(self.predict_dataset_config[cls].datapath_config)
|
110 |
+
for cls in self.predict_dataset_config
|
111 |
+
}
|
112 |
+
else:
|
113 |
+
self.predict_datapath = None
|
114 |
+
|
115 |
+
# get tokenizer
|
116 |
+
if tokenizer_config is None:
|
117 |
+
self.tokenizer = None
|
118 |
+
else:
|
119 |
+
self.tokenizer = get_tokenizer(config=tokenizer_config)
|
120 |
+
|
121 |
+
def prepare_data(self):
|
122 |
+
pass
|
123 |
+
|
124 |
+
def setup(self, stage=None):
|
125 |
+
if self.predict_datapath is not None:
|
126 |
+
self._predict_ds = {}
|
127 |
+
for cls in self.predict_datapath:
|
128 |
+
self._predict_ds[cls] = UniRigDataset(
|
129 |
+
process_fn=self.process_fn,
|
130 |
+
data=self.predict_datapath[cls].get_data(),
|
131 |
+
name=f"predict-{cls}",
|
132 |
+
tokenizer=self.tokenizer,
|
133 |
+
transform_config=self.predict_transform_config,
|
134 |
+
debug=self.debug,
|
135 |
+
data_name=self.data_name,
|
136 |
+
)
|
137 |
+
|
138 |
+
def predict_dataloader(self):
|
139 |
+
if not hasattr(self, "_predict_ds"):
|
140 |
+
self.setup()
|
141 |
+
return self._create_dataloader(
|
142 |
+
dataset=self._predict_ds,
|
143 |
+
config=self.predict_dataset_config,
|
144 |
+
is_train=False,
|
145 |
+
drop_last=False,
|
146 |
+
)
|
147 |
+
|
148 |
+
def _create_dataloader(
|
149 |
+
self,
|
150 |
+
dataset: Union[Dataset, Dict[str, Dataset]],
|
151 |
+
config: DatasetConfig,
|
152 |
+
is_train: bool,
|
153 |
+
**kwargs,
|
154 |
+
) -> Union[DataLoader, Dict[str, DataLoader]]:
|
155 |
+
def create_single_dataloader(dataset, config: Union[DatasetConfig, Dict[str, DatasetConfig]], **kwargs):
|
156 |
+
return DataLoader(
|
157 |
+
dataset,
|
158 |
+
batch_size=config.batch_size,
|
159 |
+
shuffle=config.shuffle,
|
160 |
+
num_workers=config.num_workers,
|
161 |
+
pin_memory=config.pin_memory,
|
162 |
+
persistent_workers=config.persistent_workers,
|
163 |
+
collate_fn=dataset.collate_fn,
|
164 |
+
**kwargs,
|
165 |
+
)
|
166 |
+
if isinstance(dataset, Dict):
|
167 |
+
return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()}
|
168 |
+
else:
|
169 |
+
return create_single_dataloader(dataset, config, **kwargs)
|
170 |
+
|
171 |
+
class UniRigDataset(Dataset):
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
data: List[Tuple[str, str]], # (cls, part)
|
175 |
+
name: str,
|
176 |
+
process_fn: Union[Callable[[List[ModelInput]], Dict]]=None,
|
177 |
+
tokenizer: Union[TokenizerSpec, None]=None,
|
178 |
+
transform_config: Union[TransformConfig, None]=None,
|
179 |
+
debug: bool=False,
|
180 |
+
data_name: str='raw_data.npz',
|
181 |
+
) -> None:
|
182 |
+
super().__init__()
|
183 |
+
|
184 |
+
self.data = data
|
185 |
+
self.name = name
|
186 |
+
self.process_fn = process_fn
|
187 |
+
self.tokenizer = tokenizer
|
188 |
+
self.transform_config = transform_config
|
189 |
+
self.debug = debug
|
190 |
+
self.data_name = data_name
|
191 |
+
|
192 |
+
if not debug:
|
193 |
+
assert self.process_fn is not None, 'missing data processing function'
|
194 |
+
|
195 |
+
def __len__(self) -> int:
|
196 |
+
return len(self.data)
|
197 |
+
|
198 |
+
def __getitem__(self, idx) -> ModelInput:
|
199 |
+
cls, dir_path = self.data[idx]
|
200 |
+
raw_data = RawData.load(path=os.path.join(dir_path, self.data_name))
|
201 |
+
asset = Asset.from_raw_data(raw_data=raw_data, cls=cls, path=dir_path, data_name=self.data_name)
|
202 |
+
|
203 |
+
first_augments, second_augments = transform_asset(
|
204 |
+
asset=asset,
|
205 |
+
transform_config=self.transform_config,
|
206 |
+
)
|
207 |
+
if self.tokenizer is not None and asset.parents is not None:
|
208 |
+
tokens = self.tokenizer.tokenize(input=asset.get_tokenize_input())
|
209 |
+
else:
|
210 |
+
tokens = None
|
211 |
+
return ModelInput(
|
212 |
+
tokens=tokens,
|
213 |
+
pad=None if self.tokenizer is None else self.tokenizer.pad,
|
214 |
+
vertices=asset.sampled_vertices.astype(np.float32),
|
215 |
+
normals=asset.sampled_normals.astype(np.float32),
|
216 |
+
joints=None if asset.joints is None else asset.joints.astype(np.float32),
|
217 |
+
tails=None if asset.tails is None else asset.tails.astype(np.float32),
|
218 |
+
asset=asset,
|
219 |
+
augments=None,
|
220 |
+
)
|
221 |
+
|
222 |
+
def _collate_fn_debug(self, batch):
|
223 |
+
return batch
|
224 |
+
|
225 |
+
def _collate_fn(self, batch):
|
226 |
+
return data.dataloader.default_collate(self.process_fn(batch))
|
227 |
+
|
228 |
+
def collate_fn(self, batch):
|
229 |
+
if self.debug:
|
230 |
+
return self._collate_fn_debug(batch)
|
231 |
+
return self._collate_fn(batch)
|
UniRig/src/data/exporter.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from numpy import ndarray
|
3 |
+
from typing import List, Union, Tuple
|
4 |
+
from collections import defaultdict
|
5 |
+
import os
|
6 |
+
|
7 |
+
try:
|
8 |
+
import open3d as o3d
|
9 |
+
OPEN3D_EQUIPPED = True
|
10 |
+
except:
|
11 |
+
print("do not have open3d")
|
12 |
+
OPEN3D_EQUIPPED = False
|
13 |
+
|
14 |
+
class Exporter():
|
15 |
+
|
16 |
+
def _safe_make_dir(self, path):
|
17 |
+
if os.path.dirname(path) == '':
|
18 |
+
return
|
19 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
20 |
+
|
21 |
+
def _export_skeleton(self, joints: ndarray, parents: List[Union[int, None]], path: str):
|
22 |
+
format = path.split('.')[-1]
|
23 |
+
assert format in ['obj']
|
24 |
+
name = path.removesuffix('.obj')
|
25 |
+
path = name + ".obj"
|
26 |
+
self._safe_make_dir(path)
|
27 |
+
J = joints.shape[0]
|
28 |
+
with open(path, 'w') as file:
|
29 |
+
file.write("o spring_joint\n")
|
30 |
+
_joints = []
|
31 |
+
for id in range(J):
|
32 |
+
pid = parents[id]
|
33 |
+
if pid is None or pid == -1:
|
34 |
+
continue
|
35 |
+
bx, by, bz = joints[id]
|
36 |
+
ex, ey, ez = joints[pid]
|
37 |
+
_joints.extend([
|
38 |
+
f"v {bx} {bz} {-by}\n",
|
39 |
+
f"v {ex} {ez} {-ey}\n",
|
40 |
+
f"v {ex} {ez} {-ey + 0.00001}\n"
|
41 |
+
])
|
42 |
+
file.writelines(_joints)
|
43 |
+
|
44 |
+
_faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)]
|
45 |
+
file.writelines(_faces)
|
46 |
+
|
47 |
+
def _export_bones(self, bones: ndarray, path: str):
|
48 |
+
format = path.split('.')[-1]
|
49 |
+
assert format in ['obj']
|
50 |
+
name = path.removesuffix('.obj')
|
51 |
+
path = name + ".obj"
|
52 |
+
self._safe_make_dir(path)
|
53 |
+
J = bones.shape[0]
|
54 |
+
with open(path, 'w') as file:
|
55 |
+
file.write("o bones\n")
|
56 |
+
_joints = []
|
57 |
+
for bone in bones:
|
58 |
+
bx, by, bz = bone[:3]
|
59 |
+
ex, ey, ez = bone[3:]
|
60 |
+
_joints.extend([
|
61 |
+
f"v {bx} {bz} {-by}\n",
|
62 |
+
f"v {ex} {ez} {-ey}\n",
|
63 |
+
f"v {ex} {ez} {-ey + 0.00001}\n"
|
64 |
+
])
|
65 |
+
file.writelines(_joints)
|
66 |
+
|
67 |
+
_faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)]
|
68 |
+
file.writelines(_faces)
|
69 |
+
|
70 |
+
def _export_skeleton_sequence(self, joints: ndarray, parents: List[Union[int, None]], path: str):
|
71 |
+
format = path.split('.')[-1]
|
72 |
+
assert format in ['obj']
|
73 |
+
name = path.removesuffix('.obj')
|
74 |
+
path = name + ".obj"
|
75 |
+
self._safe_make_dir(path)
|
76 |
+
J = joints.shape[0]
|
77 |
+
for i in range(J):
|
78 |
+
file = open(name + f"_{i}.obj", 'w')
|
79 |
+
file.write("o spring_joint\n")
|
80 |
+
_joints = []
|
81 |
+
for id in range(i + 1):
|
82 |
+
pid = parents[id]
|
83 |
+
if pid is None:
|
84 |
+
continue
|
85 |
+
bx, by, bz = joints[id]
|
86 |
+
ex, ey, ez = joints[pid]
|
87 |
+
_joints.extend([
|
88 |
+
f"v {bx} {bz} {-by}\n",
|
89 |
+
f"v {ex} {ez} {-ey}\n",
|
90 |
+
f"v {ex} {ez} {-ey + 0.00001}\n"
|
91 |
+
])
|
92 |
+
file.writelines(_joints)
|
93 |
+
|
94 |
+
_faces = [f"f {id*3+1} {id*3+2} {id*3+3}\n" for id in range(J)]
|
95 |
+
file.writelines(_faces)
|
96 |
+
file.close()
|
97 |
+
|
98 |
+
def _export_mesh(self, vertices: ndarray, faces: ndarray, path: str):
|
99 |
+
format = path.split('.')[-1]
|
100 |
+
assert format in ['obj', 'ply']
|
101 |
+
if path.endswith('ply'):
|
102 |
+
if not OPEN3D_EQUIPPED:
|
103 |
+
raise RuntimeError("open3d is not available")
|
104 |
+
mesh = o3d.geometry.TriangleMesh()
|
105 |
+
mesh.vertices = o3d.utility.Vector3dVector(vertices)
|
106 |
+
mesh.triangles = o3d.utility.Vector3iVector(faces)
|
107 |
+
self._safe_make_dir(path)
|
108 |
+
o3d.io.write_triangle_mesh(path, mesh)
|
109 |
+
return
|
110 |
+
name = path.removesuffix('.obj')
|
111 |
+
path = name + ".obj"
|
112 |
+
self._safe_make_dir(path)
|
113 |
+
with open(path, 'w') as file:
|
114 |
+
file.write("o mesh\n")
|
115 |
+
_vertices = []
|
116 |
+
for co in vertices:
|
117 |
+
_vertices.append(f"v {co[0]} {co[2]} {-co[1]}\n")
|
118 |
+
file.writelines(_vertices)
|
119 |
+
_faces = []
|
120 |
+
for face in faces:
|
121 |
+
_faces.append(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
|
122 |
+
file.writelines(_faces)
|
123 |
+
|
124 |
+
def _export_pc(self, vertices: ndarray, path: str, vertex_normals: Union[ndarray, None]=None, normal_size: float=0.01):
|
125 |
+
if path.endswith('.ply'):
|
126 |
+
if vertex_normals is not None:
|
127 |
+
print("normal result will not be displayed in .ply format")
|
128 |
+
name = path.removesuffix('.ply')
|
129 |
+
path = name + ".ply"
|
130 |
+
pc = o3d.geometry.PointCloud()
|
131 |
+
pc.points = o3d.utility.Vector3dVector(vertices)
|
132 |
+
# segment fault when numpy >= 2.0 !! use torch environment
|
133 |
+
self._safe_make_dir(path)
|
134 |
+
o3d.io.write_point_cloud(path, pc)
|
135 |
+
return
|
136 |
+
name = path.removesuffix('.obj')
|
137 |
+
path = name + ".obj"
|
138 |
+
self._safe_make_dir(path)
|
139 |
+
with open(path, 'w') as file:
|
140 |
+
file.write("o pc\n")
|
141 |
+
_vertex = []
|
142 |
+
for co in vertices:
|
143 |
+
_vertex.append(f"v {co[0]} {co[2]} {-co[1]}\n")
|
144 |
+
file.writelines(_vertex)
|
145 |
+
if vertex_normals is not None:
|
146 |
+
new_path = path.replace('.obj', '_normal.obj')
|
147 |
+
nfile = open(new_path, 'w')
|
148 |
+
nfile.write("o normal\n")
|
149 |
+
_normal = []
|
150 |
+
for i in range(vertices.shape[0]):
|
151 |
+
co = vertices[i]
|
152 |
+
x = vertex_normals[i, 0]
|
153 |
+
y = vertex_normals[i, 1]
|
154 |
+
z = vertex_normals[i, 2]
|
155 |
+
_normal.extend([
|
156 |
+
f"v {co[0]} {co[2]} {-co[1]}\n",
|
157 |
+
f"v {co[0]+0.0001} {co[2]} {-co[1]}\n",
|
158 |
+
f"v {co[0]+x*normal_size} {co[2]+z*normal_size} {-(co[1]+y*normal_size)}\n",
|
159 |
+
f"f {i*3+1} {i*3+2} {i*3+3}\n",
|
160 |
+
])
|
161 |
+
nfile.writelines(_normal)
|
162 |
+
|
163 |
+
def _make_armature(
|
164 |
+
self,
|
165 |
+
vertices: Union[ndarray, None],
|
166 |
+
joints: ndarray,
|
167 |
+
skin: Union[ndarray, None],
|
168 |
+
parents: List[Union[int, None]],
|
169 |
+
names: List[str],
|
170 |
+
faces: Union[ndarray, None]=None,
|
171 |
+
extrude_size: float=0.03,
|
172 |
+
group_per_vertex: int=-1,
|
173 |
+
add_root: bool=False,
|
174 |
+
do_not_normalize: bool=False,
|
175 |
+
use_extrude_bone: bool=True,
|
176 |
+
use_connect_unique_child: bool=True,
|
177 |
+
extrude_from_parent: bool=True,
|
178 |
+
tails: Union[ndarray, None]=None,
|
179 |
+
):
|
180 |
+
import bpy # type: ignore
|
181 |
+
from mathutils import Vector # type: ignore
|
182 |
+
|
183 |
+
# make collection
|
184 |
+
collection = bpy.data.collections.new('new_collection')
|
185 |
+
bpy.context.scene.collection.children.link(collection)
|
186 |
+
|
187 |
+
# make mesh
|
188 |
+
if vertices is not None:
|
189 |
+
mesh = bpy.data.meshes.new('mesh')
|
190 |
+
if faces is None:
|
191 |
+
faces = []
|
192 |
+
mesh.from_pydata(vertices, [], faces)
|
193 |
+
mesh.update()
|
194 |
+
|
195 |
+
# make object from mesh
|
196 |
+
object = bpy.data.objects.new('character', mesh)
|
197 |
+
|
198 |
+
# add object to scene collection
|
199 |
+
collection.objects.link(object)
|
200 |
+
|
201 |
+
# deselect mesh
|
202 |
+
bpy.ops.object.armature_add(enter_editmode=True)
|
203 |
+
armature = bpy.data.armatures.get('Armature')
|
204 |
+
edit_bones = armature.edit_bones
|
205 |
+
|
206 |
+
J = joints.shape[0]
|
207 |
+
if tails is None:
|
208 |
+
tails = joints.copy()
|
209 |
+
tails[:, 2] += extrude_size
|
210 |
+
connects = [False for _ in range(J)]
|
211 |
+
children = defaultdict(list)
|
212 |
+
for i in range(1, J):
|
213 |
+
children[parents[i]].append(i)
|
214 |
+
if tails is not None:
|
215 |
+
if use_extrude_bone:
|
216 |
+
for i in range(J):
|
217 |
+
if len(children[i]) != 1 and extrude_from_parent and i != 0:
|
218 |
+
pjoint = joints[parents[i]]
|
219 |
+
joint = joints[i]
|
220 |
+
d = joint - pjoint
|
221 |
+
if np.linalg.norm(d) < 0.000001:
|
222 |
+
d = np.array([0., 0., 1.]) # in case son.head == parent.head
|
223 |
+
else:
|
224 |
+
d = d / np.linalg.norm(d)
|
225 |
+
tails[i] = joint + d * extrude_size
|
226 |
+
if use_connect_unique_child:
|
227 |
+
for i in range(J):
|
228 |
+
if len(children[i]) == 1:
|
229 |
+
child = children[i][0]
|
230 |
+
tails[i] = joints[child]
|
231 |
+
if parents[i] is not None and len(children[parents[i]]) == 1:
|
232 |
+
connects[i] = True
|
233 |
+
|
234 |
+
if add_root:
|
235 |
+
bone_root = edit_bones.get('Bone')
|
236 |
+
bone_root.name = 'Root'
|
237 |
+
bone_root.tail = Vector((joints[0, 0], joints[0, 1], joints[0, 2]))
|
238 |
+
else:
|
239 |
+
bone_root = edit_bones.get('Bone')
|
240 |
+
bone_root.name = names[0]
|
241 |
+
bone_root.head = Vector((joints[0, 0], joints[0, 1], joints[0, 2]))
|
242 |
+
bone_root.tail = Vector((joints[0, 0], joints[0, 1], joints[0, 2] + extrude_size))
|
243 |
+
|
244 |
+
def extrude_bone(
|
245 |
+
edit_bones,
|
246 |
+
name: str,
|
247 |
+
parent_name: str,
|
248 |
+
head: Tuple[float, float, float],
|
249 |
+
tail: Tuple[float, float, float],
|
250 |
+
connect: bool
|
251 |
+
):
|
252 |
+
bone = edit_bones.new(name)
|
253 |
+
bone.head = Vector((head[0], head[1], head[2]))
|
254 |
+
bone.tail = Vector((tail[0], tail[1], tail[2]))
|
255 |
+
bone.name = name
|
256 |
+
parent_bone = edit_bones.get(parent_name)
|
257 |
+
bone.parent = parent_bone
|
258 |
+
bone.use_connect = connect
|
259 |
+
assert not np.isnan(head).any(), f"nan found in head of bone {name}"
|
260 |
+
assert not np.isnan(tail).any(), f"nan found in tail of bone {name}"
|
261 |
+
|
262 |
+
for i in range(J):
|
263 |
+
if add_root is False and i==0:
|
264 |
+
continue
|
265 |
+
edit_bones = armature.edit_bones
|
266 |
+
pname = 'Root' if parents[i] is None else names[parents[i]]
|
267 |
+
extrude_bone(edit_bones, names[i], pname, joints[i], tails[i], connects[i])
|
268 |
+
for i in range(J):
|
269 |
+
bone = edit_bones.get(names[i])
|
270 |
+
bone.head = Vector((joints[i, 0], joints[i, 1], joints[i, 2]))
|
271 |
+
bone.tail = Vector((tails[i, 0], tails[i, 1], tails[i, 2]))
|
272 |
+
|
273 |
+
if vertices is None or skin is None:
|
274 |
+
return
|
275 |
+
# must set to object mode to enable parent_set
|
276 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
277 |
+
objects = bpy.data.objects
|
278 |
+
for o in bpy.context.selected_objects:
|
279 |
+
o.select_set(False)
|
280 |
+
ob = objects['character']
|
281 |
+
arm = bpy.data.objects['Armature']
|
282 |
+
ob.select_set(True)
|
283 |
+
arm.select_set(True)
|
284 |
+
bpy.ops.object.parent_set(type='ARMATURE_NAME')
|
285 |
+
vis = []
|
286 |
+
for x in ob.vertex_groups:
|
287 |
+
vis.append(x.name)
|
288 |
+
#sparsify
|
289 |
+
argsorted = np.argsort(-skin, axis=1)
|
290 |
+
vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted]
|
291 |
+
if group_per_vertex == -1:
|
292 |
+
group_per_vertex = vertex_group_reweight.shape[-1]
|
293 |
+
if not do_not_normalize:
|
294 |
+
vertex_group_reweight = vertex_group_reweight / vertex_group_reweight[..., :group_per_vertex].sum(axis=1)[...,None]
|
295 |
+
|
296 |
+
for v, w in enumerate(skin):
|
297 |
+
for ii in range(group_per_vertex):
|
298 |
+
i = argsorted[v, ii]
|
299 |
+
if i >= J:
|
300 |
+
continue
|
301 |
+
n = names[i]
|
302 |
+
if n not in vis:
|
303 |
+
continue
|
304 |
+
ob.vertex_groups[n].add([v], vertex_group_reweight[v, ii], 'REPLACE')
|
305 |
+
|
306 |
+
def _clean_bpy(self):
|
307 |
+
import bpy # type: ignore
|
308 |
+
for c in bpy.data.actions:
|
309 |
+
bpy.data.actions.remove(c)
|
310 |
+
for c in bpy.data.armatures:
|
311 |
+
bpy.data.armatures.remove(c)
|
312 |
+
for c in bpy.data.cameras:
|
313 |
+
bpy.data.cameras.remove(c)
|
314 |
+
for c in bpy.data.collections:
|
315 |
+
bpy.data.collections.remove(c)
|
316 |
+
for c in bpy.data.images:
|
317 |
+
bpy.data.images.remove(c)
|
318 |
+
for c in bpy.data.materials:
|
319 |
+
bpy.data.materials.remove(c)
|
320 |
+
for c in bpy.data.meshes:
|
321 |
+
bpy.data.meshes.remove(c)
|
322 |
+
for c in bpy.data.objects:
|
323 |
+
bpy.data.objects.remove(c)
|
324 |
+
for c in bpy.data.textures:
|
325 |
+
bpy.data.textures.remove(c)
|
326 |
+
|
327 |
+
def _export_fbx(
|
328 |
+
self,
|
329 |
+
path: str,
|
330 |
+
vertices: Union[ndarray, None],
|
331 |
+
joints: ndarray,
|
332 |
+
skin: Union[ndarray, None],
|
333 |
+
parents: List[Union[int, None]],
|
334 |
+
names: List[str],
|
335 |
+
faces: Union[ndarray, None]=None,
|
336 |
+
extrude_size: float=0.03,
|
337 |
+
group_per_vertex: int=-1,
|
338 |
+
add_root: bool=False,
|
339 |
+
do_not_normalize: bool=False,
|
340 |
+
use_extrude_bone: bool=True,
|
341 |
+
use_connect_unique_child: bool=True,
|
342 |
+
extrude_from_parent: bool=True,
|
343 |
+
tails: Union[ndarray, None]=None,
|
344 |
+
):
|
345 |
+
'''
|
346 |
+
Requires bpy installed
|
347 |
+
'''
|
348 |
+
import bpy # type: ignore
|
349 |
+
self._safe_make_dir(path)
|
350 |
+
self._clean_bpy()
|
351 |
+
self._make_armature(
|
352 |
+
vertices=vertices,
|
353 |
+
joints=joints,
|
354 |
+
skin=skin,
|
355 |
+
parents=parents,
|
356 |
+
names=names,
|
357 |
+
faces=faces,
|
358 |
+
extrude_size=extrude_size,
|
359 |
+
group_per_vertex=group_per_vertex,
|
360 |
+
add_root=add_root,
|
361 |
+
do_not_normalize=do_not_normalize,
|
362 |
+
use_extrude_bone=use_extrude_bone,
|
363 |
+
use_connect_unique_child=use_connect_unique_child,
|
364 |
+
extrude_from_parent=extrude_from_parent,
|
365 |
+
tails=tails,
|
366 |
+
)
|
367 |
+
|
368 |
+
# always enable add_leaf_bones to keep leaf bones
|
369 |
+
bpy.ops.export_scene.fbx(filepath=path, check_existing=False, add_leaf_bones=False)
|
370 |
+
|
371 |
+
def _export_render(
|
372 |
+
self,
|
373 |
+
path: str,
|
374 |
+
vertices: Union[ndarray, None],
|
375 |
+
faces: Union[ndarray, None],
|
376 |
+
bones: Union[ndarray, None],
|
377 |
+
resolution: Tuple[float, float]=[256, 256],
|
378 |
+
):
|
379 |
+
import bpy # type: ignore
|
380 |
+
import bpy_extras # type: ignore
|
381 |
+
from mathutils import Vector # type: ignore
|
382 |
+
|
383 |
+
self._safe_make_dir(path)
|
384 |
+
# normalize into [-1, 1]^3
|
385 |
+
# copied from augment
|
386 |
+
assert (vertices is not None) or (bones is not None)
|
387 |
+
bounds = []
|
388 |
+
if vertices is not None:
|
389 |
+
bounds.append(vertices)
|
390 |
+
if bones is not None:
|
391 |
+
bounds.append(bones[:, :3])
|
392 |
+
bounds.append(bones[:, 3:])
|
393 |
+
bounds = np.concatenate(bounds, axis=0)
|
394 |
+
bound_min = bounds.min(axis=0)
|
395 |
+
bound_max = bounds.max(axis=0)
|
396 |
+
|
397 |
+
trans_vertex = np.eye(4)
|
398 |
+
|
399 |
+
trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex
|
400 |
+
|
401 |
+
# scale into the cube [-1, 1]
|
402 |
+
scale = np.max((bound_max - bound_min) / 2)
|
403 |
+
trans_vertex = _scale_to_m(1. / scale) @ trans_vertex
|
404 |
+
|
405 |
+
def _apply(v: ndarray, trans: ndarray) -> ndarray:
|
406 |
+
return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
|
407 |
+
|
408 |
+
if vertices is not None:
|
409 |
+
vertices = _apply(vertices, trans_vertex)
|
410 |
+
if bones is not None:
|
411 |
+
bones[:, :3] = _apply(bones[:, :3], trans_vertex)
|
412 |
+
bones[:, 3:] = _apply(bones[:, 3:], trans_vertex)
|
413 |
+
|
414 |
+
# bpy api calls
|
415 |
+
self._clean_bpy()
|
416 |
+
bpy.context.scene.render.engine = 'BLENDER_WORKBENCH'
|
417 |
+
bpy.context.scene.render.film_transparent = True
|
418 |
+
bpy.context.scene.display.shading.background_type = 'VIEWPORT'
|
419 |
+
|
420 |
+
collection = bpy.data.collections.new('new_collection')
|
421 |
+
bpy.context.scene.collection.children.link(collection)
|
422 |
+
|
423 |
+
if vertices is not None:
|
424 |
+
mesh_data = bpy.data.meshes.new(name="MeshData")
|
425 |
+
mesh_obj = bpy.data.objects.new(name="MeshObject", object_data=mesh_data)
|
426 |
+
collection.objects.link(mesh_obj)
|
427 |
+
|
428 |
+
mesh_data.from_pydata((vertices).tolist(), [], faces.tolist())
|
429 |
+
mesh_data.update()
|
430 |
+
|
431 |
+
def look_at(camera, point):
|
432 |
+
direction = point - camera.location
|
433 |
+
rot_quat = direction.to_track_quat('-Z', 'Y')
|
434 |
+
camera.rotation_euler = rot_quat.to_euler()
|
435 |
+
|
436 |
+
bpy.ops.object.camera_add(location=(4, -4, 2.5))
|
437 |
+
camera = bpy.context.object
|
438 |
+
camera.data.angle = np.radians(25.0)
|
439 |
+
look_at(camera, Vector((0, 0, -0.2)))
|
440 |
+
bpy.context.scene.camera = camera
|
441 |
+
|
442 |
+
bpy.context.scene.render.resolution_x = resolution[0]
|
443 |
+
bpy.context.scene.render.resolution_y = resolution[1]
|
444 |
+
bpy.context.scene.render.image_settings.file_format = 'PNG'
|
445 |
+
bpy.context.scene.render.filepath = path
|
446 |
+
|
447 |
+
bpy.ops.render.render(write_still=True)
|
448 |
+
# some AI generated code to draw bones over mesh
|
449 |
+
if bones is not None:
|
450 |
+
# TODO: do not save image after rendering
|
451 |
+
from PIL import Image, ImageDraw
|
452 |
+
img_pil = Image.open(path).convert("RGBA")
|
453 |
+
draw = ImageDraw.Draw(img_pil)
|
454 |
+
|
455 |
+
from bpy_extras.image_utils import load_image # type: ignore
|
456 |
+
bpy.context.scene.use_nodes = True
|
457 |
+
nodes = bpy.context.scene.node_tree.nodes
|
458 |
+
# nodes.clear()
|
459 |
+
|
460 |
+
img = load_image(path)
|
461 |
+
image_node = nodes.new(type='CompositorNodeImage')
|
462 |
+
image_node.image = img
|
463 |
+
|
464 |
+
for i, bone in enumerate(bones):
|
465 |
+
head, tail = bone[:3], bone[3:]
|
466 |
+
head_2d = bpy_extras.object_utils.world_to_camera_view(bpy.context.scene, camera, Vector(head))
|
467 |
+
tail_2d = bpy_extras.object_utils.world_to_camera_view(bpy.context.scene, camera, Vector(tail))
|
468 |
+
|
469 |
+
res_x, res_y = resolution
|
470 |
+
head_pix = (head_2d.x * res_x, (1 - head_2d.y) * res_y)
|
471 |
+
tail_pix = (tail_2d.x * res_x, (1 - tail_2d.y) * res_y)
|
472 |
+
draw.line([head_pix, tail_pix], fill=(255, 0, 0, 255), width=1)
|
473 |
+
img_pil.save(path)
|
474 |
+
|
475 |
+
def _trans_to_m(v: ndarray):
|
476 |
+
m = np.eye(4)
|
477 |
+
m[0:3, 3] = v
|
478 |
+
return m
|
479 |
+
|
480 |
+
def _scale_to_m(r: ndarray):
|
481 |
+
m = np.zeros((4, 4))
|
482 |
+
m[0, 0] = r
|
483 |
+
m[1, 1] = r
|
484 |
+
m[2, 2] = r
|
485 |
+
m[3, 3] = 1.
|
486 |
+
return m
|
UniRig/src/data/extract.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bpy, os
|
2 |
+
from collections import defaultdict
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
from numpy import ndarray
|
6 |
+
from typing import Dict, Tuple, List, Optional, Union
|
7 |
+
import trimesh
|
8 |
+
import fast_simplification
|
9 |
+
from scipy.spatial import KDTree
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import yaml
|
13 |
+
from box import Box
|
14 |
+
import os
|
15 |
+
|
16 |
+
from .log import new_entry, add_error, add_warning, new_log, end_log
|
17 |
+
from .raw_data import RawData
|
18 |
+
|
19 |
+
def load(filepath: str):
|
20 |
+
old_objs = set(bpy.context.scene.objects)
|
21 |
+
|
22 |
+
if not os.path.exists(filepath):
|
23 |
+
raise ValueError(f'File {filepath} does not exist !')
|
24 |
+
|
25 |
+
try:
|
26 |
+
if filepath.endswith(".vrm"):
|
27 |
+
# enable vrm addon and load vrm model
|
28 |
+
bpy.ops.preferences.addon_enable(module='vrm')
|
29 |
+
|
30 |
+
bpy.ops.import_scene.vrm(
|
31 |
+
filepath=filepath,
|
32 |
+
use_addon_preferences=True,
|
33 |
+
extract_textures_into_folder=False,
|
34 |
+
make_new_texture_folder=False,
|
35 |
+
set_shading_type_to_material_on_import=False,
|
36 |
+
set_view_transform_to_standard_on_import=True,
|
37 |
+
set_armature_display_to_wire=True,
|
38 |
+
set_armature_display_to_show_in_front=True,
|
39 |
+
set_armature_bone_shape_to_default=True,
|
40 |
+
disable_bake=True, # customized option for better performance
|
41 |
+
)
|
42 |
+
elif filepath.endswith(".obj"):
|
43 |
+
bpy.ops.wm.obj_import(filepath=filepath)
|
44 |
+
elif filepath.endswith(".fbx") or filepath.endswith(".FBX"):
|
45 |
+
# end bone is removed using remove_dummy_bone
|
46 |
+
bpy.ops.import_scene.fbx(filepath=filepath, ignore_leaf_bones=False, use_image_search=False)
|
47 |
+
elif filepath.endswith(".glb") or filepath.endswith(".gltf"):
|
48 |
+
bpy.ops.import_scene.gltf(filepath=filepath, import_pack_images=False)
|
49 |
+
elif filepath.endswith(".dae"):
|
50 |
+
bpy.ops.wm.collada_import(filepath=filepath)
|
51 |
+
elif filepath.endswith(".blend"):
|
52 |
+
with bpy.data.libraries.load(filepath) as (data_from, data_to):
|
53 |
+
data_to.objects = data_from.objects
|
54 |
+
for obj in data_to.objects:
|
55 |
+
if obj is not None:
|
56 |
+
bpy.context.collection.objects.link(obj)
|
57 |
+
else:
|
58 |
+
raise ValueError(f"not suported type {filepath}")
|
59 |
+
except:
|
60 |
+
raise ValueError(f"failed to load {filepath}")
|
61 |
+
|
62 |
+
armature = [x for x in set(bpy.context.scene.objects)-old_objs if x.type=="ARMATURE"]
|
63 |
+
if len(armature)==0:
|
64 |
+
return None
|
65 |
+
if len(armature)>1:
|
66 |
+
raise ValueError(f"multiple armatures found")
|
67 |
+
armature = armature[0]
|
68 |
+
|
69 |
+
armature.select_set(True)
|
70 |
+
bpy.context.view_layer.objects.active = armature
|
71 |
+
bpy.ops.object.mode_set(mode='EDIT')
|
72 |
+
for bone in bpy.data.armatures[0].edit_bones:
|
73 |
+
bone.roll = 0. # change all roll to 0. to prevent weird behaviour
|
74 |
+
|
75 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
76 |
+
armature.select_set(False)
|
77 |
+
|
78 |
+
bpy.ops.object.select_all(action='DESELECT')
|
79 |
+
return armature
|
80 |
+
|
81 |
+
# remove all data in bpy
|
82 |
+
def clean_bpy():
|
83 |
+
# First try to purge orphan data
|
84 |
+
try:
|
85 |
+
bpy.ops.outliner.orphans_purge(do_local_ids=True, do_linked_ids=True, do_recursive=True)
|
86 |
+
except Exception as e:
|
87 |
+
print(f"Warning: Could not purge orphans: {e}")
|
88 |
+
|
89 |
+
# Then remove all data by type
|
90 |
+
data_types = [
|
91 |
+
bpy.data.actions,
|
92 |
+
bpy.data.armatures,
|
93 |
+
bpy.data.cameras,
|
94 |
+
bpy.data.collections,
|
95 |
+
bpy.data.curves,
|
96 |
+
bpy.data.images,
|
97 |
+
bpy.data.lights,
|
98 |
+
bpy.data.materials,
|
99 |
+
bpy.data.meshes,
|
100 |
+
bpy.data.objects,
|
101 |
+
bpy.data.textures,
|
102 |
+
bpy.data.worlds,
|
103 |
+
bpy.data.node_groups
|
104 |
+
]
|
105 |
+
|
106 |
+
for data_collection in data_types:
|
107 |
+
try:
|
108 |
+
for item in data_collection:
|
109 |
+
try:
|
110 |
+
data_collection.remove(item)
|
111 |
+
except Exception as e:
|
112 |
+
print(f"Warning: Could not remove {item.name} from {data_collection}: {e}")
|
113 |
+
except Exception as e:
|
114 |
+
print(f"Warning: Error processing {data_collection}: {e}")
|
115 |
+
|
116 |
+
# Force garbage collection to free memory
|
117 |
+
import gc
|
118 |
+
gc.collect()
|
119 |
+
|
120 |
+
def get_arranged_bones(armature):
|
121 |
+
matrix_world = armature.matrix_world
|
122 |
+
arranged_bones = []
|
123 |
+
root = armature.pose.bones[0]
|
124 |
+
while root.parent is not None:
|
125 |
+
root = root.parent
|
126 |
+
Q = [root]
|
127 |
+
rot = np.array(matrix_world)[:3, :3]
|
128 |
+
|
129 |
+
# dfs and sort
|
130 |
+
while len(Q) != 0:
|
131 |
+
b = Q.pop(0)
|
132 |
+
arranged_bones.append(b)
|
133 |
+
children = []
|
134 |
+
for cb in b.children:
|
135 |
+
head = rot @ np.array(b.head)
|
136 |
+
children.append((cb, head[0], head[1], head[2]))
|
137 |
+
children = sorted(children, key=lambda x: (x[3], x[1], x[2]))
|
138 |
+
_c = [x[0] for x in children]
|
139 |
+
Q = _c + Q
|
140 |
+
return arranged_bones
|
141 |
+
|
142 |
+
def process_mesh():
|
143 |
+
meshes = []
|
144 |
+
for v in bpy.data.objects:
|
145 |
+
if v.type == 'MESH':
|
146 |
+
meshes.append(v)
|
147 |
+
|
148 |
+
_dict_mesh = {}
|
149 |
+
for obj in meshes:
|
150 |
+
m = np.array(obj.matrix_world)
|
151 |
+
matrix_world_rot = m[:3, :3]
|
152 |
+
matrix_world_bias = m[:3, 3]
|
153 |
+
rot = matrix_world_rot
|
154 |
+
total_vertices = len(obj.data.vertices)
|
155 |
+
vertex = np.zeros((4, total_vertices))
|
156 |
+
vertex_normal = np.zeros((total_vertices, 3))
|
157 |
+
obj_verts = obj.data.vertices
|
158 |
+
faces = []
|
159 |
+
normals = []
|
160 |
+
|
161 |
+
for v in obj_verts:
|
162 |
+
vertex_normal[v.index] = rot @ np.array(v.normal) # be careful !
|
163 |
+
vv = rot @ v.co
|
164 |
+
vv = np.array(vv) + matrix_world_bias
|
165 |
+
vertex[0:3, v.index] = vv
|
166 |
+
vertex[3][v.index] = 1 # affine coordinate
|
167 |
+
|
168 |
+
for polygon in obj.data.polygons:
|
169 |
+
edges = polygon.edge_keys
|
170 |
+
nodes = []
|
171 |
+
adj = {}
|
172 |
+
for edge in edges:
|
173 |
+
if adj.get(edge[0]) is None:
|
174 |
+
adj[edge[0]] = []
|
175 |
+
adj[edge[0]].append(edge[1])
|
176 |
+
if adj.get(edge[1]) is None:
|
177 |
+
adj[edge[1]] = []
|
178 |
+
adj[edge[1]].append(edge[0])
|
179 |
+
nodes.append(edge[0])
|
180 |
+
nodes.append(edge[1])
|
181 |
+
normal = polygon.normal
|
182 |
+
nodes = list(set(sorted(nodes)))
|
183 |
+
first = nodes[0]
|
184 |
+
loop = []
|
185 |
+
now = first
|
186 |
+
vis = {}
|
187 |
+
while True:
|
188 |
+
loop.append(now)
|
189 |
+
vis[now] = True
|
190 |
+
if vis.get(adj[now][0]) is None:
|
191 |
+
now = adj[now][0]
|
192 |
+
elif vis.get(adj[now][1]) is None:
|
193 |
+
now = adj[now][1]
|
194 |
+
else:
|
195 |
+
break
|
196 |
+
for (second, third) in zip(loop[1:], loop[2:]):
|
197 |
+
faces.append((first + 1, second + 1, third + 1)) # the cursed +1
|
198 |
+
normals.append(rot @ normal) # and the cursed normal of BLENDER
|
199 |
+
|
200 |
+
correct_faces = []
|
201 |
+
for (i, face) in enumerate(faces):
|
202 |
+
normal = normals[i]
|
203 |
+
v0 = face[0] - 1
|
204 |
+
v1 = face[1] - 1
|
205 |
+
v2 = face[2] - 1
|
206 |
+
v = np.cross(
|
207 |
+
vertex[:3, v1] - vertex[:3, v0],
|
208 |
+
vertex[:3, v2] - vertex[:3, v0],
|
209 |
+
)
|
210 |
+
if (v*normal).sum() > 0:
|
211 |
+
correct_faces.append(face)
|
212 |
+
else:
|
213 |
+
correct_faces.append((face[0], face[2], face[1]))
|
214 |
+
if len(correct_faces) > 0:
|
215 |
+
_dict_mesh[obj.name] = {
|
216 |
+
'vertex': vertex,
|
217 |
+
'face': correct_faces,
|
218 |
+
}
|
219 |
+
|
220 |
+
vertex = np.concatenate([_dict_mesh[name]['vertex'] for name in _dict_mesh], axis=1)[:3, :].transpose()
|
221 |
+
|
222 |
+
total_faces = 0
|
223 |
+
now_bias = 0
|
224 |
+
for name in _dict_mesh:
|
225 |
+
total_faces += len(_dict_mesh[name]['face'])
|
226 |
+
faces = np.zeros((total_faces, 3), dtype=np.int64)
|
227 |
+
tot = 0
|
228 |
+
for name in _dict_mesh:
|
229 |
+
f = np.array(_dict_mesh[name]['face'], dtype=np.int64)
|
230 |
+
faces[tot:tot+f.shape[0]] = f + now_bias
|
231 |
+
now_bias += _dict_mesh[name]['vertex'].shape[1]
|
232 |
+
tot += f.shape[0]
|
233 |
+
|
234 |
+
return vertex, faces
|
235 |
+
|
236 |
+
def process_armature(
|
237 |
+
armature,
|
238 |
+
arranged_bones,
|
239 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
240 |
+
matrix_world = armature.matrix_world
|
241 |
+
index = {}
|
242 |
+
|
243 |
+
for (id, pbone) in enumerate(arranged_bones):
|
244 |
+
index[pbone.name] = id
|
245 |
+
|
246 |
+
root = armature.pose.bones[0]
|
247 |
+
while root.parent is not None:
|
248 |
+
root = root.parent
|
249 |
+
m = np.array(matrix_world.to_4x4())
|
250 |
+
scale_inv = np.linalg.inv(np.diag(matrix_world.to_scale()))
|
251 |
+
rot = m[:3, :3]
|
252 |
+
bias = m[:3, 3]
|
253 |
+
|
254 |
+
s = []
|
255 |
+
bpy.ops.object.editmode_toggle()
|
256 |
+
edit_bones = armature.data.edit_bones
|
257 |
+
|
258 |
+
J = len(arranged_bones)
|
259 |
+
joints = np.zeros((J, 3), dtype=np.float32)
|
260 |
+
tails = np.zeros((J, 3), dtype=np.float32)
|
261 |
+
parents = []
|
262 |
+
name_to_id = {}
|
263 |
+
names = []
|
264 |
+
matrix_local_stack = np.zeros((J, 4, 4), dtype=np.float32)
|
265 |
+
for (id, pbone) in enumerate(arranged_bones):
|
266 |
+
name = pbone.name
|
267 |
+
names.append(name)
|
268 |
+
matrix_local = np.array(pbone.bone.matrix_local)
|
269 |
+
use_inherit_rotation = pbone.bone.use_inherit_rotation
|
270 |
+
if use_inherit_rotation == False:
|
271 |
+
add_warning(f"use_inherit_rotation of bone {name} is False !")
|
272 |
+
head = rot @ matrix_local[0:3, 3] + bias
|
273 |
+
s.append(head)
|
274 |
+
edit_bone = edit_bones.get(name)
|
275 |
+
tail = rot @ np.array(edit_bone.tail) + bias
|
276 |
+
|
277 |
+
name_to_id[name] = id
|
278 |
+
joints[id] = head
|
279 |
+
tails[id] = tail
|
280 |
+
parents.append(None if pbone.parent not in arranged_bones else name_to_id[pbone.parent.name])
|
281 |
+
# remove scale part
|
282 |
+
matrix_local[:, 3:4] = m @ matrix_local[:, 3:4]
|
283 |
+
matrix_local[:3, :3] = scale_inv @ matrix_local[:3, :3]
|
284 |
+
matrix_local_stack[id] = matrix_local
|
285 |
+
bpy.ops.object.editmode_toggle()
|
286 |
+
|
287 |
+
return joints, tails, parents, names, matrix_local_stack
|
288 |
+
|
289 |
+
def save_raw_data(
|
290 |
+
path: str,
|
291 |
+
vertices: ndarray,
|
292 |
+
faces: ndarray,
|
293 |
+
joints: Union[ndarray, None],
|
294 |
+
tails: Union[ndarray, None],
|
295 |
+
parents: Union[List[Union[int, None]], None],
|
296 |
+
names: Union[List[str], None],
|
297 |
+
matrix_local: Union[ndarray, None],
|
298 |
+
target_count: int,
|
299 |
+
):
|
300 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
|
301 |
+
vertices = np.array(mesh.vertices, dtype=np.float32)
|
302 |
+
faces = np.array(mesh.faces, dtype=np.int64)
|
303 |
+
if faces.shape[0] > target_count:
|
304 |
+
vertices, faces = fast_simplification.simplify(vertices, faces, target_count=target_count)
|
305 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
|
306 |
+
|
307 |
+
new_vertices = np.array(mesh.vertices, dtype=np.float32)
|
308 |
+
new_vertex_normals = np.array(mesh.vertex_normals, dtype=np.float32)
|
309 |
+
new_faces = np.array(mesh.faces, dtype=np.int64)
|
310 |
+
new_face_normals = np.array(mesh.face_normals, dtype=np.float32)
|
311 |
+
if joints is not None:
|
312 |
+
new_joints = np.array(joints, dtype=np.float32)
|
313 |
+
else:
|
314 |
+
new_joints = None
|
315 |
+
raw_data = RawData(
|
316 |
+
vertices=new_vertices,
|
317 |
+
vertex_normals=new_vertex_normals,
|
318 |
+
faces=new_faces,
|
319 |
+
face_normals=new_face_normals,
|
320 |
+
joints=new_joints,
|
321 |
+
tails=tails,
|
322 |
+
skin=None,
|
323 |
+
no_skin=None,
|
324 |
+
parents=parents,
|
325 |
+
names=names,
|
326 |
+
matrix_local=matrix_local,
|
327 |
+
)
|
328 |
+
raw_data.check()
|
329 |
+
raw_data.save(path=path)
|
330 |
+
|
331 |
+
def extract_builtin(
|
332 |
+
output_folder: str,
|
333 |
+
target_count: int,
|
334 |
+
num_runs: int,
|
335 |
+
id: int,
|
336 |
+
time: str,
|
337 |
+
files: List[Union[str, str]],
|
338 |
+
):
|
339 |
+
log_path = "./logs"
|
340 |
+
log_path = os.path.join(log_path, time)
|
341 |
+
|
342 |
+
num_files = len(files)
|
343 |
+
gap = num_files // num_runs
|
344 |
+
start = gap * id
|
345 |
+
end = gap * (id + 1)
|
346 |
+
if id+1==num_runs:
|
347 |
+
end = num_files
|
348 |
+
|
349 |
+
files = sorted(files)
|
350 |
+
if end!=-1:
|
351 |
+
files = files[:end]
|
352 |
+
new_log(log_path, f"extract_builtin_{start}_{end}")
|
353 |
+
tot = 0
|
354 |
+
for file in tqdm(files[start:]):
|
355 |
+
input_file = file[0]
|
356 |
+
output_dir = file[1]
|
357 |
+
clean_bpy()
|
358 |
+
new_entry(input_file)
|
359 |
+
try:
|
360 |
+
print(f"Now processing {input_file}...")
|
361 |
+
|
362 |
+
armature = load(input_file)
|
363 |
+
|
364 |
+
print('save to:', output_dir)
|
365 |
+
os.makedirs(output_dir, exist_ok=True)
|
366 |
+
|
367 |
+
vertices, faces = process_mesh()
|
368 |
+
if armature is not None:
|
369 |
+
arranged_bones = get_arranged_bones(armature)
|
370 |
+
joints, tails, parents, names, matrix_local = process_armature(armature, arranged_bones)
|
371 |
+
|
372 |
+
else:
|
373 |
+
joints = None
|
374 |
+
tails = None
|
375 |
+
parents = None
|
376 |
+
names = None
|
377 |
+
matrix_local = None
|
378 |
+
|
379 |
+
save_file = os.path.join(output_dir, 'raw_data.npz')
|
380 |
+
save_raw_data(
|
381 |
+
path=save_file,
|
382 |
+
vertices=vertices,
|
383 |
+
faces=faces-1,
|
384 |
+
joints=joints,
|
385 |
+
tails=tails,
|
386 |
+
parents=parents,
|
387 |
+
names=names,
|
388 |
+
matrix_local=matrix_local,
|
389 |
+
target_count=target_count,
|
390 |
+
)
|
391 |
+
|
392 |
+
tot += 1
|
393 |
+
|
394 |
+
except ValueError as e:
|
395 |
+
add_error(str(e))
|
396 |
+
print(f"ValueError: {str(e)}")
|
397 |
+
except RuntimeError as e:
|
398 |
+
add_error(str(e))
|
399 |
+
print(f"RuntimeError: {str(e)}")
|
400 |
+
except TimeoutError as e:
|
401 |
+
add_error("time out")
|
402 |
+
print("TimeoutError: Processing timed out")
|
403 |
+
except Exception as e:
|
404 |
+
add_error(f"Unexpected error: {str(e)}")
|
405 |
+
print(f"Unexpected error: {str(e)}")
|
406 |
+
end_log()
|
407 |
+
print(f"{tot} models processed")
|
408 |
+
|
409 |
+
def str2bool(v):
|
410 |
+
if isinstance(v, bool):
|
411 |
+
return v
|
412 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
413 |
+
return True
|
414 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
415 |
+
return False
|
416 |
+
else:
|
417 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
418 |
+
|
419 |
+
def nullable_string(val):
|
420 |
+
if not val:
|
421 |
+
return None
|
422 |
+
return val
|
423 |
+
|
424 |
+
def get_files(
|
425 |
+
data_name: str,
|
426 |
+
input_dataset_dir: str,
|
427 |
+
output_dataset_dir: str,
|
428 |
+
inputs: Union[str, None]=None,
|
429 |
+
require_suffix: List[str]=['obj','fbx','FBX','dae','glb','gltf','vrm'],
|
430 |
+
force_override: bool=False,
|
431 |
+
warning: bool=True,
|
432 |
+
) -> List[Tuple[str, str]]:
|
433 |
+
|
434 |
+
files = [] # (input_file, output_dir)
|
435 |
+
if inputs is not None: # specified input file(s)
|
436 |
+
vis = {}
|
437 |
+
inputs = inputs.split(',')
|
438 |
+
for file in inputs:
|
439 |
+
file_name = file.removeprefix("./")
|
440 |
+
# remove suffix
|
441 |
+
file_name = '.'.join(file_name.split('.')[:-1])
|
442 |
+
output_dir = os.path.join(output_dataset_dir, file_name)
|
443 |
+
raw_data_npz = os.path.join(output_dir, data_name)
|
444 |
+
if not force_override and os.path.exists(raw_data_npz):
|
445 |
+
continue
|
446 |
+
if warning and output_dir in vis:
|
447 |
+
print(f"\033[33mWARNING: duplicate output directory: {output_dir}, you need to rename prefix of files to avoid ambiguity\033[0m")
|
448 |
+
vis[output_dir] = True
|
449 |
+
files.append((file, output_dir))
|
450 |
+
else:
|
451 |
+
vis = {}
|
452 |
+
for root, dirs, f in os.walk(input_dataset_dir):
|
453 |
+
for file in f:
|
454 |
+
if file.split('.')[-1] in require_suffix:
|
455 |
+
file_name = file.removeprefix("./")
|
456 |
+
# remove suffix
|
457 |
+
file_name = '.'.join(file_name.split('.')[:-1])
|
458 |
+
|
459 |
+
output_dir = os.path.join(output_dataset_dir, os.path.relpath(root, input_dataset_dir), file_name)
|
460 |
+
raw_data_npz = os.path.join(output_dir, data_name)
|
461 |
+
|
462 |
+
# Check if all required files exist
|
463 |
+
if not force_override and os.path.exists(raw_data_npz):
|
464 |
+
continue
|
465 |
+
if warning and output_dir in vis:
|
466 |
+
print(f"\033[33mWARNING: duplicate output directory: {output_dir}, you need to rename prefix of files to avoid ambiguity\033[0m")
|
467 |
+
vis[output_dir] = True
|
468 |
+
files.append((os.path.join(root, file), output_dir))
|
469 |
+
|
470 |
+
return files
|
471 |
+
|
472 |
+
def parse():
|
473 |
+
parser = argparse.ArgumentParser()
|
474 |
+
parser.add_argument('--config', type=str, required=True)
|
475 |
+
parser.add_argument('--require_suffix', type=str, required=True)
|
476 |
+
parser.add_argument('--faces_target_count', type=int, required=True)
|
477 |
+
parser.add_argument('--num_runs', type=int, required=True)
|
478 |
+
parser.add_argument('--force_override', type=str2bool, required=True)
|
479 |
+
parser.add_argument('--id', type=int, required=True)
|
480 |
+
parser.add_argument('--time', type=str, required=True)
|
481 |
+
|
482 |
+
parser.add_argument('--input', type=nullable_string, required=False, default=None)
|
483 |
+
parser.add_argument('--input_dir', type=nullable_string, required=False, default=None)
|
484 |
+
parser.add_argument('--output_dir', type=nullable_string, required=False, default=None)
|
485 |
+
return parser.parse_args()
|
486 |
+
|
487 |
+
if __name__ == "__main__":
|
488 |
+
args = parse()
|
489 |
+
|
490 |
+
config = Box(yaml.safe_load(open(args.config, "r")))
|
491 |
+
|
492 |
+
num_runs = args.num_runs
|
493 |
+
id = args.id
|
494 |
+
timestamp = args.time
|
495 |
+
require_suffix = args.require_suffix.split(',')
|
496 |
+
force_override = args.force_override
|
497 |
+
target_count = args.faces_target_count
|
498 |
+
|
499 |
+
if args.input_dir:
|
500 |
+
config.input_dataset_dir = args.input_dir
|
501 |
+
if args.output_dir:
|
502 |
+
config.output_dataset_dir = args.output_dir
|
503 |
+
|
504 |
+
assert config.input_dataset_dir is not None or args.input is None, 'you cannot specify both input and input_dir'
|
505 |
+
|
506 |
+
files = get_files(
|
507 |
+
data_name='raw_data.npz',
|
508 |
+
inputs=args.input,
|
509 |
+
input_dataset_dir=config.input_dataset_dir,
|
510 |
+
output_dataset_dir=config.output_dataset_dir,
|
511 |
+
require_suffix=require_suffix,
|
512 |
+
force_override=force_override,
|
513 |
+
warning=True,
|
514 |
+
)
|
515 |
+
|
516 |
+
extract_builtin(
|
517 |
+
output_folder=config.output_dataset_dir,
|
518 |
+
target_count=target_count,
|
519 |
+
num_runs=num_runs,
|
520 |
+
id=id,
|
521 |
+
time=timestamp,
|
522 |
+
files=files,
|
523 |
+
)
|
UniRig/src/data/log.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
login_time = ''
|
5 |
+
log_filepath = ''
|
6 |
+
|
7 |
+
class Entry:
|
8 |
+
def __init__(self, entry_name):
|
9 |
+
self.entry = entry_name
|
10 |
+
self.error = None
|
11 |
+
self.warning = []
|
12 |
+
|
13 |
+
def have_error(self):
|
14 |
+
return self.error != None
|
15 |
+
|
16 |
+
def have_warning(self):
|
17 |
+
return len(self.warning) != 0
|
18 |
+
|
19 |
+
logs: List[Entry] = []
|
20 |
+
|
21 |
+
def new_log(path, log_name):
|
22 |
+
global login_time, log_filepath
|
23 |
+
log_filepath = os.path.join(path, f"{log_name}.txt")
|
24 |
+
os.makedirs(path, exist_ok=True)
|
25 |
+
with open(log_filepath, 'a') as file:
|
26 |
+
file.write(f"Log: {log_name}\n")
|
27 |
+
|
28 |
+
def end_log():
|
29 |
+
global log_filepath
|
30 |
+
with open(log_filepath, 'a') as file:
|
31 |
+
file.write(f"End of file\n")
|
32 |
+
|
33 |
+
def new_entry(entry_name):
|
34 |
+
global log_filepath
|
35 |
+
print(f"\033[32mNow processing {entry_name}...\033[0m")
|
36 |
+
logs.append(Entry(entry_name))
|
37 |
+
|
38 |
+
def add_error(error):
|
39 |
+
global log_filepath
|
40 |
+
print(f"\033[31mError found when processing {logs[-1].entry}: {error}\033[0m")
|
41 |
+
logs[-1].error = error
|
42 |
+
with open(log_filepath, 'a') as file:
|
43 |
+
file.write(f"Entry: {logs[-1].entry}, Error: {error}\n")
|
44 |
+
|
45 |
+
def add_warning(warning):
|
46 |
+
global log_filepath
|
47 |
+
print(f"\033[33mWarning found when processing {logs[-1].entry}: {warning}\033[0m")
|
48 |
+
logs[-1].warning.append(warning)
|
49 |
+
with open(log_filepath, 'a') as file:
|
50 |
+
file.write(f"Entry: {logs[-1].entry}, Warning: {warning}\n")
|
UniRig/src/data/order.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple, Union
|
2 |
+
from collections import defaultdict
|
3 |
+
from dataclasses import dataclass
|
4 |
+
import yaml
|
5 |
+
from box import Box
|
6 |
+
|
7 |
+
from .spec import ConfigSpec
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class OrderConfig(ConfigSpec):
|
11 |
+
'''
|
12 |
+
Config to handle bones re-ordering.
|
13 |
+
'''
|
14 |
+
|
15 |
+
# {skeleton_name: path}
|
16 |
+
skeleton_path: Dict[str, str]
|
17 |
+
|
18 |
+
# {cls: {part_name: [bone_name_1, bone_name_2, ...]}}
|
19 |
+
parts: Dict[str, Dict[str, List[str]]]
|
20 |
+
|
21 |
+
# {cls: parts of bones to be arranged in [part_name_1, part_name_2, ...]}
|
22 |
+
parts_order: Dict[str, List[str]]
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def parse(cls, config):
|
26 |
+
cls.check_keys(config)
|
27 |
+
skeleton_path = config.skeleton_path
|
28 |
+
parts = {}
|
29 |
+
parts_order = {}
|
30 |
+
for (cls, path) in skeleton_path.items():
|
31 |
+
assert cls not in parts, 'cls conflicts'
|
32 |
+
d = Box(yaml.safe_load(open(path, 'r')))
|
33 |
+
parts[cls] = d.parts
|
34 |
+
parts_order[cls] = d.parts_order
|
35 |
+
return OrderConfig(
|
36 |
+
skeleton_path=skeleton_path,
|
37 |
+
parts=parts,
|
38 |
+
parts_order=parts_order,
|
39 |
+
)
|
40 |
+
|
41 |
+
class Order():
|
42 |
+
|
43 |
+
# {part_name: [bone_name_1, bone_name_2, ...]}
|
44 |
+
parts: Dict[str, Dict[str, List[str]]]
|
45 |
+
|
46 |
+
# parts of bones to be arranged in [part_name_1, part_name_2, ...]
|
47 |
+
parts_order: Dict[str, List[str]]
|
48 |
+
|
49 |
+
def __init__(self, config: OrderConfig):
|
50 |
+
self.parts = config.parts
|
51 |
+
self.parts_order = config.parts_order
|
52 |
+
|
53 |
+
def part_exists(self, cls: str, part: str, names: List[str]) -> bool:
|
54 |
+
'''
|
55 |
+
Check if part exists.
|
56 |
+
'''
|
57 |
+
if part not in self.parts[cls]:
|
58 |
+
return False
|
59 |
+
for name in self.parts[cls][part]:
|
60 |
+
if name not in names:
|
61 |
+
return False
|
62 |
+
return True
|
63 |
+
|
64 |
+
def make_names(self, cls: Union[str, None], parts: List[Union[str, None]], num_bones: int) -> List[str]:
|
65 |
+
'''
|
66 |
+
Get names for specified cls.
|
67 |
+
'''
|
68 |
+
names = []
|
69 |
+
for part in parts:
|
70 |
+
if part is None: # spring
|
71 |
+
continue
|
72 |
+
if cls in self.parts and part in self.parts[cls]:
|
73 |
+
names.extend(self.parts[cls][part])
|
74 |
+
assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones"
|
75 |
+
for i in range(len(names), num_bones):
|
76 |
+
names.append(f"bone_{i}")
|
77 |
+
return names
|
78 |
+
|
79 |
+
def arrange_names(self, cls: str, names: List[str], parents: List[Union[int, None]]) -> Tuple[List[str], Dict[int, Union[str]]]:
|
80 |
+
'''
|
81 |
+
Arrange names according to required parts order.
|
82 |
+
'''
|
83 |
+
if cls not in self.parts_order:
|
84 |
+
return names, {0: None} # add a spring token
|
85 |
+
vis = defaultdict(bool)
|
86 |
+
name_to_id = {name: i for (i, name) in enumerate(names)}
|
87 |
+
new_names = []
|
88 |
+
parts_bias = {}
|
89 |
+
for part in self.parts_order[cls]:
|
90 |
+
if self.part_exists(cls=cls, part=part, names=names):
|
91 |
+
for name in self.parts[cls][part]:
|
92 |
+
vis[name] = True
|
93 |
+
flag = False
|
94 |
+
for name in self.parts[cls][part]:
|
95 |
+
pid = parents[name_to_id[name]]
|
96 |
+
if pid is None:
|
97 |
+
continue
|
98 |
+
if not vis[names[pid]]:
|
99 |
+
flag = True
|
100 |
+
break
|
101 |
+
if flag: # incorrect parts order and should immediately add a spring token
|
102 |
+
break
|
103 |
+
parts_bias[len(new_names)] = part
|
104 |
+
new_names.extend(self.parts[cls][part])
|
105 |
+
parts_bias[len(new_names)] = None # add a spring token
|
106 |
+
for name in names:
|
107 |
+
if name not in new_names:
|
108 |
+
new_names.append(name)
|
109 |
+
return new_names, parts_bias
|
110 |
+
|
111 |
+
def get_order(config: OrderConfig) -> Order:
|
112 |
+
return Order(config=config)
|
UniRig/src/data/raw_data.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import numpy as np
|
3 |
+
from numpy import ndarray
|
4 |
+
|
5 |
+
import os
|
6 |
+
from typing import Union, List, Tuple
|
7 |
+
|
8 |
+
from .exporter import Exporter
|
9 |
+
|
10 |
+
from ..tokenizer.spec import DetokenzeOutput
|
11 |
+
from .order import Order
|
12 |
+
|
13 |
+
@dataclass(frozen=True)
|
14 |
+
class RawData(Exporter):
|
15 |
+
'''
|
16 |
+
Dataclass to handle data from processed model files.
|
17 |
+
'''
|
18 |
+
|
19 |
+
# vertices of the mesh, shape (N, 3), float32
|
20 |
+
vertices: Union[ndarray, None]
|
21 |
+
|
22 |
+
# normals of vertices, shape (N, 3), float32
|
23 |
+
vertex_normals: Union[ndarray, None]
|
24 |
+
|
25 |
+
# faces of mesh, shape (F, 3), face id starts from 0 to F-1, int64
|
26 |
+
faces: Union[ndarray, None]
|
27 |
+
|
28 |
+
# face normal of mesh, shape (F, 3), float32
|
29 |
+
face_normals: Union[ndarray, None]
|
30 |
+
|
31 |
+
# joints of bones, shape (J, 3), float32
|
32 |
+
joints: Union[ndarray, None]
|
33 |
+
|
34 |
+
# tails of joints, shape (J, 3), float32
|
35 |
+
tails: Union[ndarray, None]
|
36 |
+
|
37 |
+
# skinning of joints, shape (N, J), float32
|
38 |
+
skin: Union[ndarray, None]
|
39 |
+
|
40 |
+
# whether the joint has skin, bool
|
41 |
+
no_skin: Union[ndarray, None]
|
42 |
+
|
43 |
+
# parents of joints, None represents no parent(a root joint)
|
44 |
+
# make sure parent[k] < k
|
45 |
+
parents: Union[List[Union[int, None]], None]
|
46 |
+
|
47 |
+
# names of joints
|
48 |
+
names: Union[List[str], None]
|
49 |
+
|
50 |
+
# local coordinate
|
51 |
+
matrix_local: Union[ndarray, None]
|
52 |
+
|
53 |
+
# path to data
|
54 |
+
path: Union[str, None]=None
|
55 |
+
|
56 |
+
# data cls
|
57 |
+
cls: Union[str, None]=None
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def load(path: str) -> 'RawData':
|
61 |
+
data = np.load(path, allow_pickle=True)
|
62 |
+
d = {name: data[name][()] for name in data}
|
63 |
+
d['path'] = path
|
64 |
+
return RawData(**d)
|
65 |
+
|
66 |
+
def save(self, path: str):
|
67 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
68 |
+
np.savez(file=path, **self.__dict__)
|
69 |
+
|
70 |
+
@property
|
71 |
+
def N(self):
|
72 |
+
'''
|
73 |
+
number of vertices
|
74 |
+
'''
|
75 |
+
return self.vertices.shape[0]
|
76 |
+
|
77 |
+
@property
|
78 |
+
def F(self):
|
79 |
+
'''
|
80 |
+
number of faces
|
81 |
+
'''
|
82 |
+
return self.faces.shape[0]
|
83 |
+
|
84 |
+
@property
|
85 |
+
def J(self):
|
86 |
+
'''
|
87 |
+
number of joints
|
88 |
+
'''
|
89 |
+
return self.joints.shape[0]
|
90 |
+
|
91 |
+
def check(self):
|
92 |
+
if self.names is not None and self.joints is not None:
|
93 |
+
assert len(self.names) == self.J
|
94 |
+
if self.names is not None and self.parents is not None:
|
95 |
+
assert len(self.names) == len(self.parents)
|
96 |
+
if self.parents is not None:
|
97 |
+
for (i, pid) in enumerate(self.parents):
|
98 |
+
if i==0:
|
99 |
+
assert pid is None
|
100 |
+
else:
|
101 |
+
assert pid is not None
|
102 |
+
assert pid < i
|
103 |
+
|
104 |
+
def export_pc(self, path: str, with_normal: bool=True, normal_size=0.01):
|
105 |
+
'''
|
106 |
+
export point cloud
|
107 |
+
'''
|
108 |
+
if with_normal:
|
109 |
+
self._export_pc(vertices=self.vertices, path=path, vertex_normals=self.vertex_normals, normal_size=normal_size)
|
110 |
+
else:
|
111 |
+
self._export_pc(vertices=self.vertices, path=path, vertex_normals=None, normal_size=normal_size)
|
112 |
+
|
113 |
+
def export_mesh(self, path: str):
|
114 |
+
'''
|
115 |
+
export mesh
|
116 |
+
'''
|
117 |
+
self._export_mesh(vertices=self.vertices, faces=self.faces, path=path)
|
118 |
+
|
119 |
+
def export_skeleton(self, path: str):
|
120 |
+
'''
|
121 |
+
export spring
|
122 |
+
'''
|
123 |
+
self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
|
124 |
+
|
125 |
+
def export_skeleton_sequence(self, path: str):
|
126 |
+
'''
|
127 |
+
export spring
|
128 |
+
'''
|
129 |
+
self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
|
130 |
+
|
131 |
+
def export_fbx(
|
132 |
+
self,
|
133 |
+
path: str,
|
134 |
+
extrude_size: float=0.03,
|
135 |
+
group_per_vertex: int=-1,
|
136 |
+
add_root: bool=False,
|
137 |
+
do_not_normalize: bool=False,
|
138 |
+
use_extrude_bone: bool=True,
|
139 |
+
use_connect_unique_child: bool=True,
|
140 |
+
extrude_from_parent: bool=True,
|
141 |
+
use_tail: bool=False,
|
142 |
+
custom_vertex_group: Union[ndarray, None]=None,
|
143 |
+
):
|
144 |
+
'''
|
145 |
+
export the whole model with skining
|
146 |
+
'''
|
147 |
+
self._export_fbx(
|
148 |
+
path=path,
|
149 |
+
vertices=self.vertices,
|
150 |
+
joints=self.joints,
|
151 |
+
skin=self.skin if custom_vertex_group is None else custom_vertex_group,
|
152 |
+
parents=self.parents,
|
153 |
+
names=self.names,
|
154 |
+
faces=self.faces,
|
155 |
+
extrude_size=extrude_size,
|
156 |
+
group_per_vertex=group_per_vertex,
|
157 |
+
add_root=add_root,
|
158 |
+
do_not_normalize=do_not_normalize,
|
159 |
+
use_extrude_bone=use_extrude_bone,
|
160 |
+
use_connect_unique_child=use_connect_unique_child,
|
161 |
+
extrude_from_parent=extrude_from_parent,
|
162 |
+
tails=self.tails if use_tail else None,
|
163 |
+
)
|
164 |
+
|
165 |
+
def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256]):
|
166 |
+
self._export_render(
|
167 |
+
path=path,
|
168 |
+
vertices=self.vertices,
|
169 |
+
faces=self.faces,
|
170 |
+
bones=np.concatenate([self.joints, self.tails], axis=-1),
|
171 |
+
resolution=resolution,
|
172 |
+
)
|
173 |
+
|
174 |
+
@dataclass(frozen=True)
|
175 |
+
class RawSkeleton(Exporter):
|
176 |
+
'''
|
177 |
+
Dataclass to handle skeleton from AR.
|
178 |
+
'''
|
179 |
+
# joints of bones, shape (J, 3), float32
|
180 |
+
joints: Union[ndarray, None]
|
181 |
+
|
182 |
+
# tails of joints, shape (J, 3), float32
|
183 |
+
tails: Union[ndarray, None]
|
184 |
+
|
185 |
+
# whether the joint has skin, bool
|
186 |
+
no_skin: Union[ndarray, None]
|
187 |
+
|
188 |
+
# parents of joints, None represents no parent(a root joint)
|
189 |
+
# make sure parent[k] < k
|
190 |
+
parents: Union[List[Union[int, None]], None]
|
191 |
+
|
192 |
+
# names of joints
|
193 |
+
names: Union[List[str], None]
|
194 |
+
|
195 |
+
@staticmethod
|
196 |
+
def load(path: str) -> 'RawSkeleton':
|
197 |
+
data = np.load(path, allow_pickle=True)
|
198 |
+
return RawSkeleton(**{name: data[name][()] for name in data})
|
199 |
+
|
200 |
+
def save(self, path: str):
|
201 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
202 |
+
np.savez(file=path, **self.__dict__)
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def from_detokenize_output(res: DetokenzeOutput, order: Union[Order, None]) -> 'RawSkeleton':
|
206 |
+
J = len(res.bones)
|
207 |
+
names = order.make_names(cls=res.cls, parts=res.parts, num_bones=J)
|
208 |
+
joints = res.joints
|
209 |
+
p_joints = res.p_joints
|
210 |
+
parents = []
|
211 |
+
for (i, joint) in enumerate(joints):
|
212 |
+
if i == 0:
|
213 |
+
parents.append(None)
|
214 |
+
continue
|
215 |
+
p_joint = p_joints[i]
|
216 |
+
dis = 999999
|
217 |
+
pid = None
|
218 |
+
for j in reversed(range(i)):
|
219 |
+
n_dis = ((joints[j] - p_joint)**2).sum()
|
220 |
+
if n_dis < dis:
|
221 |
+
pid = j
|
222 |
+
dis = n_dis
|
223 |
+
parents.append(pid)
|
224 |
+
return RawSkeleton(
|
225 |
+
joints=joints,
|
226 |
+
tails=res.tails,
|
227 |
+
no_skin=res.no_skin,
|
228 |
+
parents=parents,
|
229 |
+
names=names,
|
230 |
+
)
|
231 |
+
|
232 |
+
def export_skeleton(self, path: str):
|
233 |
+
'''
|
234 |
+
export spring
|
235 |
+
'''
|
236 |
+
self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
|
237 |
+
|
238 |
+
def export_skeleton_sequence(self, path: str):
|
239 |
+
'''
|
240 |
+
export spring
|
241 |
+
'''
|
242 |
+
self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
|
243 |
+
|
244 |
+
def export_fbx(
|
245 |
+
self,
|
246 |
+
path: str,
|
247 |
+
extrude_size: float=0.03,
|
248 |
+
group_per_vertex: int=-1,
|
249 |
+
add_root: bool=False,
|
250 |
+
do_not_normalize: bool=False,
|
251 |
+
use_extrude_bone: bool=True,
|
252 |
+
use_connect_unique_child: bool=True,
|
253 |
+
extrude_from_parent: bool=True,
|
254 |
+
use_tail: bool=False,
|
255 |
+
):
|
256 |
+
'''
|
257 |
+
export the whole model with skining
|
258 |
+
'''
|
259 |
+
self._export_fbx(
|
260 |
+
path=path,
|
261 |
+
vertices=None,
|
262 |
+
joints=self.joints,
|
263 |
+
skin=None,
|
264 |
+
parents=self.parents,
|
265 |
+
names=self.names,
|
266 |
+
faces=None,
|
267 |
+
extrude_size=extrude_size,
|
268 |
+
group_per_vertex=group_per_vertex,
|
269 |
+
add_root=add_root,
|
270 |
+
do_not_normalize=do_not_normalize,
|
271 |
+
use_extrude_bone=use_extrude_bone,
|
272 |
+
use_connect_unique_child=use_connect_unique_child,
|
273 |
+
extrude_from_parent=extrude_from_parent,
|
274 |
+
tails=self.tails if use_tail else None,
|
275 |
+
)
|
276 |
+
|
277 |
+
def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256]):
|
278 |
+
self._export_render(
|
279 |
+
path=path,
|
280 |
+
vertices=None,
|
281 |
+
faces=None,
|
282 |
+
bones=np.concatenate([self.joints, self.tails], axis=-1),
|
283 |
+
resolution=resolution,
|
284 |
+
)
|
285 |
+
|
286 |
+
@dataclass
|
287 |
+
class RawSkin(Exporter):
|
288 |
+
'''
|
289 |
+
Dataclass to handle skeleton from AR.
|
290 |
+
'''
|
291 |
+
# skin, shape (J, N)
|
292 |
+
skin: ndarray
|
293 |
+
|
294 |
+
# always sampled, shape (N, 3)
|
295 |
+
vertices: Union[ndarray, None]=None
|
296 |
+
|
297 |
+
# for future use, shape (J, 3)
|
298 |
+
joints: Union[ndarray, None]=None
|
299 |
+
|
300 |
+
@staticmethod
|
301 |
+
def load(path: str) -> 'RawSkin':
|
302 |
+
data = np.load(path, allow_pickle=True)
|
303 |
+
return RawSkin(**{name: data[name][()] for name in data})
|
304 |
+
|
305 |
+
def save(self, path: str):
|
306 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
307 |
+
np.savez(file=path, **self.__dict__)
|
UniRig/src/data/sampler.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from heapq import heappush, heappop, heapify
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
import numpy as np
|
6 |
+
from numpy import ndarray
|
7 |
+
|
8 |
+
from typing import Dict, Tuple
|
9 |
+
|
10 |
+
from .asset import Asset
|
11 |
+
from .spec import ConfigSpec
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class SamplerConfig(ConfigSpec):
|
15 |
+
'''
|
16 |
+
Config to handle bones re-ordering.
|
17 |
+
'''
|
18 |
+
# which sampler to use
|
19 |
+
method: str
|
20 |
+
|
21 |
+
# how many samples in total
|
22 |
+
num_samples: int
|
23 |
+
|
24 |
+
# how many vertex samples
|
25 |
+
vertex_samples: int
|
26 |
+
|
27 |
+
# kwargs
|
28 |
+
kwargs: Dict[str, Dict]
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def parse(cls, config) -> 'SamplerConfig':
|
32 |
+
cls.check_keys(config)
|
33 |
+
return SamplerConfig(
|
34 |
+
method=config.method,
|
35 |
+
num_samples=config.get('num_samples', 0),
|
36 |
+
vertex_samples=config.get('vertex_samples', 0),
|
37 |
+
kwargs=config.get('kwargs', {}),
|
38 |
+
)
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class SamplerResult():
|
42 |
+
# sampled vertices
|
43 |
+
vertices: ndarray
|
44 |
+
|
45 |
+
# sampled normals
|
46 |
+
normals: ndarray
|
47 |
+
|
48 |
+
# sampled vertex groups
|
49 |
+
vertex_groups: Dict[str, ndarray]
|
50 |
+
|
51 |
+
class Sampler(ABC):
|
52 |
+
'''
|
53 |
+
Abstract class for samplers.
|
54 |
+
'''
|
55 |
+
|
56 |
+
def _sample_barycentric(
|
57 |
+
self,
|
58 |
+
vertex_group: ndarray,
|
59 |
+
faces: ndarray,
|
60 |
+
face_index: ndarray,
|
61 |
+
random_lengths: ndarray,
|
62 |
+
):
|
63 |
+
v_origins = vertex_group[faces[face_index, 0]]
|
64 |
+
v_vectors = vertex_group[faces[face_index, 1:]]
|
65 |
+
v_vectors -= v_origins[:, np.newaxis, :]
|
66 |
+
|
67 |
+
sample_vector = (v_vectors * random_lengths).sum(axis=1)
|
68 |
+
v_samples = sample_vector + v_origins
|
69 |
+
return v_samples
|
70 |
+
|
71 |
+
@abstractmethod
|
72 |
+
def __init__(self, config: SamplerConfig):
|
73 |
+
pass
|
74 |
+
|
75 |
+
@abstractmethod
|
76 |
+
def sample(
|
77 |
+
self,
|
78 |
+
asset: Asset,
|
79 |
+
) -> SamplerResult:
|
80 |
+
'''
|
81 |
+
Return sampled vertices, sampled normals and vertex groups.
|
82 |
+
'''
|
83 |
+
pass
|
84 |
+
|
85 |
+
class SamplerOrigin(Sampler):
|
86 |
+
def __init__(self, config: SamplerConfig):
|
87 |
+
super().__init__(config)
|
88 |
+
self.num_samples = config.num_samples
|
89 |
+
self.vertex_samples = config.vertex_samples
|
90 |
+
|
91 |
+
def sample(
|
92 |
+
self,
|
93 |
+
asset: Asset,
|
94 |
+
) -> SamplerResult:
|
95 |
+
perm = np.random.permutation(asset.vertices.shape[0])
|
96 |
+
if asset.vertices.shape[0] < self.num_samples:
|
97 |
+
m = self.num_samples - asset.vertices.shape[0]
|
98 |
+
perm = np.concatenate([perm, np.random.randint(0, asset.vertices.shape[0], (m,))])
|
99 |
+
perm = perm[:self.num_samples]
|
100 |
+
n_v = asset.vertices[perm]
|
101 |
+
n_n = asset.vertex_normals[perm]
|
102 |
+
n_vg = {name: v[perm] for name, v in asset.vertex_groups.items()}
|
103 |
+
return SamplerResult(
|
104 |
+
vertices=n_v,
|
105 |
+
normals=n_n,
|
106 |
+
vertex_groups=n_vg,
|
107 |
+
)
|
108 |
+
|
109 |
+
class SamplerMix(Sampler):
|
110 |
+
def __init__(self, config: SamplerConfig):
|
111 |
+
super().__init__(config)
|
112 |
+
self.num_samples = config.num_samples
|
113 |
+
self.vertex_samples = config.vertex_samples
|
114 |
+
assert self.num_samples >= self.vertex_samples, 'num_samples should >= vertex_samples'
|
115 |
+
|
116 |
+
@property
|
117 |
+
def mesh_preserve(self):
|
118 |
+
return self.num_samples==-1
|
119 |
+
|
120 |
+
def sample(
|
121 |
+
self,
|
122 |
+
asset: Asset,
|
123 |
+
) -> SamplerResult:
|
124 |
+
# 1. sample vertices
|
125 |
+
num_samples = self.num_samples
|
126 |
+
perm = np.random.permutation(asset.vertices.shape[0])
|
127 |
+
vertex_samples = min(self.vertex_samples, asset.vertices.shape[0])
|
128 |
+
num_samples -= vertex_samples
|
129 |
+
perm = perm[:vertex_samples]
|
130 |
+
n_vertex = asset.vertices[perm]
|
131 |
+
n_normal = asset.vertex_normals[perm]
|
132 |
+
n_v = {name: v[perm] for name, v in asset.vertex_groups.items()}
|
133 |
+
|
134 |
+
# 2. sample surface
|
135 |
+
perm = np.random.permutation(num_samples)
|
136 |
+
vertex_samples, face_index, random_lengths = sample_surface(
|
137 |
+
num_samples=num_samples,
|
138 |
+
vertices=asset.vertices,
|
139 |
+
faces=asset.faces,
|
140 |
+
return_weight=True,
|
141 |
+
)
|
142 |
+
vertex_samples = np.concatenate([n_vertex, vertex_samples], axis=0)
|
143 |
+
normal_samples = np.concatenate([n_normal, asset.face_normals[face_index]], axis=0)
|
144 |
+
vertex_group_samples = {}
|
145 |
+
for n, v in asset.vertex_groups.items():
|
146 |
+
g = self._sample_barycentric(
|
147 |
+
vertex_group=v,
|
148 |
+
faces=asset.faces,
|
149 |
+
face_index=face_index,
|
150 |
+
random_lengths=random_lengths,
|
151 |
+
)
|
152 |
+
vertex_group_samples[n] = np.concatenate([n_v[n], g], axis=0)
|
153 |
+
return SamplerResult(
|
154 |
+
vertices=vertex_samples,
|
155 |
+
normals=normal_samples,
|
156 |
+
vertex_groups=vertex_group_samples,
|
157 |
+
)
|
158 |
+
|
159 |
+
def sample_surface(
|
160 |
+
num_samples: int,
|
161 |
+
vertices: ndarray,
|
162 |
+
faces: ndarray,
|
163 |
+
return_weight: bool=False,
|
164 |
+
):
|
165 |
+
'''
|
166 |
+
Randomly pick samples according to face area.
|
167 |
+
|
168 |
+
See sample_surface: https://github.com/mikedh/trimesh/blob/main/trimesh/sample.py
|
169 |
+
'''
|
170 |
+
# get face area
|
171 |
+
offset_0 = vertices[faces[:, 1]] - vertices[faces[:, 0]]
|
172 |
+
offset_1 = vertices[faces[:, 2]] - vertices[faces[:, 0]]
|
173 |
+
face_weight = np.cross(offset_0, offset_1, axis=-1)
|
174 |
+
face_weight = (face_weight * face_weight).sum(axis=1)
|
175 |
+
|
176 |
+
weight_cum = np.cumsum(face_weight, axis=0)
|
177 |
+
face_pick = np.random.rand(num_samples) * weight_cum[-1]
|
178 |
+
face_index = np.searchsorted(weight_cum, face_pick)
|
179 |
+
|
180 |
+
# pull triangles into the form of an origin + 2 vectors
|
181 |
+
tri_origins = vertices[faces[:, 0]]
|
182 |
+
tri_vectors = vertices[faces[:, 1:]]
|
183 |
+
tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))
|
184 |
+
|
185 |
+
# pull the vectors for the faces we are going to sample from
|
186 |
+
tri_origins = tri_origins[face_index]
|
187 |
+
tri_vectors = tri_vectors[face_index]
|
188 |
+
|
189 |
+
# randomly generate two 0-1 scalar components to multiply edge vectors b
|
190 |
+
random_lengths = np.random.rand(len(tri_vectors), 2, 1)
|
191 |
+
|
192 |
+
random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0
|
193 |
+
random_lengths[random_test] -= 1.0
|
194 |
+
random_lengths = np.abs(random_lengths)
|
195 |
+
|
196 |
+
sample_vector = (tri_vectors * random_lengths).sum(axis=1)
|
197 |
+
vertex_samples = sample_vector + tri_origins
|
198 |
+
if not return_weight:
|
199 |
+
return vertex_samples
|
200 |
+
return vertex_samples, face_index, random_lengths
|
201 |
+
|
202 |
+
def get_sampler(config: SamplerConfig) -> Sampler:
|
203 |
+
method = config.method
|
204 |
+
if method=='origin':
|
205 |
+
sampler = SamplerOrigin(config)
|
206 |
+
elif method=='mix':
|
207 |
+
sampler = SamplerMix(config)
|
208 |
+
else:
|
209 |
+
raise ValueError(f"sampler method {method} not supported")
|
210 |
+
return sampler
|
UniRig/src/data/spec.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from dataclasses import fields
|
3 |
+
|
4 |
+
class ConfigSpec(ABC):
|
5 |
+
@classmethod
|
6 |
+
def check_keys(cls, config):
|
7 |
+
expect = [field.name for field in fields(cls)]
|
8 |
+
for key in config.keys():
|
9 |
+
if key not in expect:
|
10 |
+
raise ValueError(f"expect names {expect} in {cls.__name__}, found {key}")
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
@abstractmethod
|
14 |
+
def parse(cls, config):
|
15 |
+
pass
|
UniRig/src/data/tail.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import numpy as np
|
4 |
+
from numpy import ndarray
|
5 |
+
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
from .asset import Asset
|
9 |
+
from .spec import ConfigSpec
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class TailConfig(ConfigSpec):
|
13 |
+
'''
|
14 |
+
Config to handle tails.
|
15 |
+
'''
|
16 |
+
|
17 |
+
# copy joints to tails
|
18 |
+
copy_joint_to_tail: bool
|
19 |
+
|
20 |
+
# if the joint has only one son, then connect tail to son's joint
|
21 |
+
connect_tail_to_unique_son: bool
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def parse(cls, config) -> 'TailConfig':
|
25 |
+
cls.check_keys(config)
|
26 |
+
return TailConfig(
|
27 |
+
copy_joint_to_tail=config.copy_joint_to_tail,
|
28 |
+
connect_tail_to_unique_son=config.connect_tail_to_unique_son,
|
29 |
+
)
|
30 |
+
|
31 |
+
class Tail():
|
32 |
+
|
33 |
+
def __init__(self, config: TailConfig):
|
34 |
+
self.config = config
|
35 |
+
|
36 |
+
def process_tail(self, asset: Asset):
|
37 |
+
if self.config.copy_joint_to_tail:
|
38 |
+
assert asset.tails is None, 'copying joints to existing tails is not permitted, please change copy_joint_to_tail to False in transform config'
|
39 |
+
asset.tails = asset.joints.copy()
|
40 |
+
if self.config.connect_tail_to_unique_son and asset.tails is not None:
|
41 |
+
children = defaultdict(list)
|
42 |
+
for (id, p) in enumerate(asset.parents):
|
43 |
+
if p is not None:
|
44 |
+
children[p].append(id)
|
45 |
+
for i in range(asset.J):
|
46 |
+
if len(children[i]) == 1:
|
47 |
+
asset.tails[i] = asset.joints[children[i][0]]
|
48 |
+
|
49 |
+
def get_tail(config: TailConfig) -> Tail:
|
50 |
+
return Tail(config=config)
|
UniRig/src/data/transform.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Union, List, Tuple
|
4 |
+
from copy import deepcopy
|
5 |
+
|
6 |
+
from .asset import Asset
|
7 |
+
from .augment import AugmentConfig, Augment, get_augments
|
8 |
+
from .order import OrderConfig, Order, get_order
|
9 |
+
from .sampler import SamplerConfig, get_sampler
|
10 |
+
from .vertex_group import VertexGroupConfig, get_vertex_groups
|
11 |
+
from .tail import TailConfig, get_tail
|
12 |
+
from .spec import ConfigSpec
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class TransformConfig(ConfigSpec):
|
16 |
+
|
17 |
+
tail_config: Union[TailConfig, None]=None,
|
18 |
+
|
19 |
+
order_config: Union[OrderConfig, None]=None,
|
20 |
+
|
21 |
+
vertex_group_config: Union[VertexGroupConfig, None]=None,
|
22 |
+
|
23 |
+
augment_config: Union[AugmentConfig, None]=None,
|
24 |
+
|
25 |
+
sampler_config: Union[SamplerConfig, None]=None,
|
26 |
+
|
27 |
+
@classmethod
|
28 |
+
def parse(cls, config) -> 'TransformConfig':
|
29 |
+
cls.check_keys(config)
|
30 |
+
tail_config = config.get('tail_config', None)
|
31 |
+
order_config = config.get('order_config', None)
|
32 |
+
vertex_group_config = config.get('vertex_group_config', None)
|
33 |
+
augment_config = config.get('augment_config', None)
|
34 |
+
sampler_config = config.get('sampler_config', None)
|
35 |
+
|
36 |
+
if tail_config is not None:
|
37 |
+
tail_config = TailConfig.parse(config=tail_config)
|
38 |
+
if order_config is not None:
|
39 |
+
order_config = OrderConfig.parse(config=order_config)
|
40 |
+
if vertex_group_config is not None:
|
41 |
+
vertex_group_config = VertexGroupConfig.parse(config=vertex_group_config)
|
42 |
+
if augment_config is not None:
|
43 |
+
augment_config = AugmentConfig.parse(config=augment_config)
|
44 |
+
if sampler_config is not None:
|
45 |
+
sampler_config = SamplerConfig.parse(config=sampler_config)
|
46 |
+
|
47 |
+
return TransformConfig(
|
48 |
+
tail_config=tail_config,
|
49 |
+
order_config=order_config,
|
50 |
+
vertex_group_config=vertex_group_config,
|
51 |
+
augment_config=augment_config,
|
52 |
+
sampler_config=sampler_config,
|
53 |
+
)
|
54 |
+
|
55 |
+
def transform_asset(
|
56 |
+
asset: Asset,
|
57 |
+
transform_config: TransformConfig,
|
58 |
+
) -> Tuple[List[Augment], List[Augment]]:
|
59 |
+
assert isinstance(transform_config, TransformConfig), f"found {type(transform_config)}"
|
60 |
+
# 1. try processing tails
|
61 |
+
# TODO: use a better method
|
62 |
+
if transform_config.tail_config is not None:
|
63 |
+
tail = get_tail(config=transform_config.tail_config)
|
64 |
+
tail.process_tail(asset=asset)
|
65 |
+
|
66 |
+
# 2. arrange bones
|
67 |
+
if transform_config.order_config is not None:
|
68 |
+
order = get_order(config=transform_config.order_config)
|
69 |
+
asset.set_order(order=order)
|
70 |
+
|
71 |
+
# 3. collapse must perform first
|
72 |
+
if transform_config.augment_config:
|
73 |
+
first_augments, second_augments = get_augments(config=transform_config.augment_config)
|
74 |
+
else:
|
75 |
+
first_augments = []
|
76 |
+
second_augments = []
|
77 |
+
|
78 |
+
kwargs = {}
|
79 |
+
for augment in first_augments:
|
80 |
+
augment.transform(asset=asset, **kwargs)
|
81 |
+
|
82 |
+
# 4. get vertex groups
|
83 |
+
if transform_config.vertex_group_config is not None:
|
84 |
+
vertex_groups = get_vertex_groups(config=transform_config.vertex_group_config)
|
85 |
+
d = {}
|
86 |
+
for v in vertex_groups:
|
87 |
+
d.update(v.get_vertex_group(asset=asset))
|
88 |
+
asset.vertex_groups = d
|
89 |
+
else:
|
90 |
+
asset.vertex_groups = {}
|
91 |
+
|
92 |
+
# 5. regular augments
|
93 |
+
for augment in second_augments:
|
94 |
+
augment.transform(asset=asset, **kwargs)
|
95 |
+
|
96 |
+
# 6. sample
|
97 |
+
if transform_config.sampler_config is not None:
|
98 |
+
sampler = get_sampler(config=transform_config.sampler_config)
|
99 |
+
res = sampler.sample(asset=asset)
|
100 |
+
asset.sampled_vertices = res.vertices
|
101 |
+
asset.sampled_normals = res.normals
|
102 |
+
asset.sampled_vertex_groups = res.vertex_groups
|
103 |
+
else:
|
104 |
+
asset.sampled_vertices = asset.vertices.copy()
|
105 |
+
asset.sampled_normals = asset.vertex_normals.copy()
|
106 |
+
asset.sampled_vertex_groups = deepcopy(asset.vertex_groups)
|
107 |
+
return first_augments, second_augments
|
UniRig/src/data/utils.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from numpy import ndarray
|
4 |
+
from torch import Tensor, FloatTensor
|
5 |
+
from typing import Tuple, Union
|
6 |
+
|
7 |
+
from scipy.spatial.transform import Rotation as R
|
8 |
+
from scipy.sparse import csc_matrix
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
def quaternion_to_matrix(x, use_4x4=True) -> FloatTensor:
|
12 |
+
"""
|
13 |
+
Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#quaternion_to_matrix
|
14 |
+
"""
|
15 |
+
if not isinstance(x, Tensor):
|
16 |
+
quaternions = torch.tensor(x, dtype=torch.float32)
|
17 |
+
else:
|
18 |
+
quaternions = x
|
19 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
20 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
21 |
+
device = quaternions.device
|
22 |
+
|
23 |
+
if use_4x4:
|
24 |
+
o = torch.stack(
|
25 |
+
(
|
26 |
+
1 - two_s * (j * j + k * k),
|
27 |
+
two_s * (i * j - k * r),
|
28 |
+
two_s * (i * k + j * r),
|
29 |
+
torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
|
30 |
+
two_s * (i * j + k * r),
|
31 |
+
1 - two_s * (i * i + k * k),
|
32 |
+
two_s * (j * k - i * r),
|
33 |
+
torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
|
34 |
+
two_s * (i * k - j * r),
|
35 |
+
two_s * (j * k + i * r),
|
36 |
+
1 - two_s * (i * i + j * j),
|
37 |
+
torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
|
38 |
+
torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
|
39 |
+
torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
|
40 |
+
torch.zeros(quaternions.shape[:-1], device=device, dtype=torch.float32),
|
41 |
+
torch.ones(quaternions.shape[:-1], device=device, dtype=torch.float32),
|
42 |
+
),
|
43 |
+
-1,
|
44 |
+
)
|
45 |
+
return o.reshape(quaternions.shape[:-1] + (4, 4))
|
46 |
+
else:
|
47 |
+
o = torch.stack(
|
48 |
+
(
|
49 |
+
1 - two_s * (j * j + k * k),
|
50 |
+
two_s * (i * j - k * r),
|
51 |
+
two_s * (i * k + j * r),
|
52 |
+
two_s * (i * j + k * r),
|
53 |
+
1 - two_s * (i * i + k * k),
|
54 |
+
two_s * (j * k - i * r),
|
55 |
+
two_s * (i * k - j * r),
|
56 |
+
two_s * (j * k + i * r),
|
57 |
+
1 - two_s * (i * i + j * j),
|
58 |
+
),
|
59 |
+
-1,
|
60 |
+
)
|
61 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
62 |
+
|
63 |
+
def axis_angle_to_quaternion(axis_angle: FloatTensor) -> FloatTensor:
|
64 |
+
"""
|
65 |
+
Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#axis_angle_to_quaternion
|
66 |
+
"""
|
67 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
68 |
+
half_angles = angles * 0.5
|
69 |
+
eps = 1e-6
|
70 |
+
small_angles = angles.abs() < eps
|
71 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
72 |
+
sin_half_angles_over_angles[~small_angles] = (
|
73 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
74 |
+
)
|
75 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
76 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
77 |
+
sin_half_angles_over_angles[small_angles] = (
|
78 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
79 |
+
)
|
80 |
+
quaternions = torch.cat(
|
81 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
82 |
+
)
|
83 |
+
return quaternions
|
84 |
+
|
85 |
+
def axis_angle_to_matrix(axis_angle: Union[FloatTensor, ndarray]) -> Union[FloatTensor, ndarray]:
|
86 |
+
"""
|
87 |
+
Ref: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#axis_angle_to_matrix
|
88 |
+
"""
|
89 |
+
if isinstance(axis_angle, FloatTensor):
|
90 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
91 |
+
else:
|
92 |
+
res = np.pad(R.from_rotvec(axis_angle).as_matrix(), ((0, 0), (0, 1), (0, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0)))
|
93 |
+
assert res.ndim == 3
|
94 |
+
res[:, -1, -1] = 1
|
95 |
+
return res
|
96 |
+
|
97 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
98 |
+
"""
|
99 |
+
Returns torch.sqrt(torch.max(0, x))
|
100 |
+
but with a zero subgradient where x is 0.
|
101 |
+
"""
|
102 |
+
ret = torch.zeros_like(x)
|
103 |
+
positive_mask = x > 0
|
104 |
+
if torch.is_grad_enabled():
|
105 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
106 |
+
else:
|
107 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
108 |
+
return ret
|
109 |
+
|
110 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
111 |
+
"""
|
112 |
+
Convert a unit quaternion to a standard form: one in which the real
|
113 |
+
part is non negative.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
quaternions: Quaternions with real part first,
|
117 |
+
as tensor of shape (..., 4).
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Standardized quaternions as tensor of shape (..., 4).
|
121 |
+
"""
|
122 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
123 |
+
|
124 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
125 |
+
"""
|
126 |
+
Convert rotations given as rotation matrices to quaternions.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
133 |
+
"""
|
134 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
135 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
136 |
+
|
137 |
+
batch_dim = matrix.shape[:-2]
|
138 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
139 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
140 |
+
)
|
141 |
+
|
142 |
+
q_abs = _sqrt_positive_part(
|
143 |
+
torch.stack(
|
144 |
+
[
|
145 |
+
1.0 + m00 + m11 + m22,
|
146 |
+
1.0 + m00 - m11 - m22,
|
147 |
+
1.0 - m00 + m11 - m22,
|
148 |
+
1.0 - m00 - m11 + m22,
|
149 |
+
],
|
150 |
+
dim=-1,
|
151 |
+
)
|
152 |
+
)
|
153 |
+
|
154 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
155 |
+
quat_by_rijk = torch.stack(
|
156 |
+
[
|
157 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
158 |
+
# `int`.
|
159 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
160 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
161 |
+
# `int`.
|
162 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
163 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
164 |
+
# `int`.
|
165 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
166 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
167 |
+
# `int`.
|
168 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
169 |
+
],
|
170 |
+
dim=-2,
|
171 |
+
)
|
172 |
+
|
173 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
174 |
+
# the candidate won't be picked.
|
175 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
176 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
177 |
+
|
178 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
179 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
180 |
+
out = quat_candidates[
|
181 |
+
torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
182 |
+
].reshape(batch_dim + (4,))
|
183 |
+
return standardize_quaternion(out)
|
184 |
+
|
185 |
+
def linear_blend_skinning(
|
186 |
+
vertex: Union[FloatTensor, ndarray],
|
187 |
+
matrix_local: Union[FloatTensor, ndarray],
|
188 |
+
matrix: Union[FloatTensor, ndarray],
|
189 |
+
skin: Union[FloatTensor, ndarray],
|
190 |
+
pad: int=0,
|
191 |
+
value: float=0.,
|
192 |
+
) -> Union[FloatTensor, ndarray]:
|
193 |
+
'''
|
194 |
+
Args:
|
195 |
+
vertex: (B, N, 4-pad) or (N, 4-pad)
|
196 |
+
matrix_local: (B, J, 4, 4) or (J, 4, 4)
|
197 |
+
matrix: (B, J, 4, 4) or (J, 4, 4)
|
198 |
+
skin: (B, N, J) or (N, J), value of pseudo bones should be 0
|
199 |
+
Returns:
|
200 |
+
(B, N, 3) or (N, 3)
|
201 |
+
'''
|
202 |
+
assert vertex.shape[-1] + pad == 4
|
203 |
+
if isinstance(vertex, Tensor):
|
204 |
+
dims = vertex.dim()
|
205 |
+
elif isinstance(vertex, ndarray):
|
206 |
+
dims = vertex.ndim
|
207 |
+
else:
|
208 |
+
raise NotImplementedError()
|
209 |
+
if dims == 3: # Case: (B, N, 3+pad)
|
210 |
+
assert isinstance(vertex, Tensor)
|
211 |
+
J = matrix_local.shape[1]
|
212 |
+
# (B, J, 3+pad, N)
|
213 |
+
offset = (
|
214 |
+
matrix_local.inverse() @
|
215 |
+
torch.nn.functional.pad(vertex, (0, pad, 0, 0, 0, 0), value=value).unsqueeze(1).transpose(2, 3).repeat(1, J, 1, 1)
|
216 |
+
)
|
217 |
+
# (B, J, 4, N)
|
218 |
+
per_bone_matrix = matrix @ offset
|
219 |
+
# (B, J, 4, N)
|
220 |
+
weighted_per_bone_matrix = skin.transpose(1, 2).unsqueeze(2) * per_bone_matrix
|
221 |
+
# (B, 3, N)
|
222 |
+
g = weighted_per_bone_matrix.sum(dim=1)
|
223 |
+
# (B, 3, N)
|
224 |
+
final = g[:, 0:3, :] / (skin.transpose(1, 2).sum(dim=1) + 1e-8).unsqueeze(1)
|
225 |
+
return final.permute(0, 2, 1)
|
226 |
+
|
227 |
+
elif dims == 2: # Case: (N, 3+pad)
|
228 |
+
if isinstance(vertex, Tensor):
|
229 |
+
J = matrix_local.shape[0]
|
230 |
+
offset = (
|
231 |
+
matrix_local.inverse() @
|
232 |
+
torch.nn.functional.pad(vertex, (0, pad, 0, 0), value=value).unsqueeze(0).transpose(1, 2).repeat(J, 1, 1)
|
233 |
+
)
|
234 |
+
per_bone_matrix = matrix @ offset
|
235 |
+
weighted_per_bone_matrix = skin.transpose(0, 1).unsqueeze(1) * per_bone_matrix
|
236 |
+
g = weighted_per_bone_matrix.sum(dim=0)
|
237 |
+
final = g[0:3, :] / (skin.transpose(0, 1).sum(dim=0) + 1e-8).unsqueeze(0)
|
238 |
+
return final.permute(1, 0) # Output shape (N, 3)
|
239 |
+
else:
|
240 |
+
J = matrix_local.shape[0]
|
241 |
+
N = vertex.shape[0]
|
242 |
+
# (4, N)
|
243 |
+
padded = np.pad(vertex, ((0, 0), (0, pad)), 'constant', constant_values=(0, value)).T
|
244 |
+
# (J, 4, 4)
|
245 |
+
trans = matrix @ np.linalg.inv(matrix_local)
|
246 |
+
weighted_per_bone_matrix = []
|
247 |
+
# (J, N)
|
248 |
+
mask = (skin > 0).T
|
249 |
+
for i in range(J):
|
250 |
+
offset = np.zeros((4, N), dtype=np.float32)
|
251 |
+
offset[:, mask[i]] = (trans[i] @ padded[:, mask[i]]) * skin.T[i, mask[i]]
|
252 |
+
weighted_per_bone_matrix.append(offset)
|
253 |
+
weighted_per_bone_matrix = np.stack(weighted_per_bone_matrix)
|
254 |
+
g = np.sum(weighted_per_bone_matrix, axis=0)
|
255 |
+
final = g[:3, :] / (np.sum(skin, axis=1) + 1e-8)
|
256 |
+
return final.T
|
257 |
+
else:
|
258 |
+
assert 0, f'unsupported shape: {vertex.shape}'
|