File size: 8,094 Bytes
da47f4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import torch
import trimesh
import os
import sys
import tempfile
import shutil

# Add UniRig source directory to Python path
# Assuming UniRig files are in a subdirectory named 'UniRig_src'
sys.path.append(os.path.join(os.path.dirname(__file__), 'UniRig_src'))

# Conditional import for AutoRigger and setup_source_mesh
# This helps in providing a clearer error if UniRig_src is not found
try:
    from autorig import AutoRigger
    from utils import setup_source_mesh
except ImportError as e:
    print("Error importing from UniRig_src. Make sure the UniRig source files are in the 'UniRig_src' directory.")
    print(f"Details: {e}")
    # Define dummy functions if import fails, so Gradio can still load with an error message
    def AutoRigger(*args, **kwargs):
        raise RuntimeError("UniRig AutoRigger could not be loaded. Check UniRig_src setup.")
    def setup_source_mesh(mesh, *args, **kwargs):
        raise RuntimeError("UniRig setup_source_mesh could not be loaded. Check UniRig_src setup.")

# --- Configuration ---
# Define paths to the UniRig model files
# These files should be placed in the 'model_files' directory in your Hugging Face Space
MODEL_DIR = os.path.join(os.path.dirname(__file__), "model_files")
SMPL_SKELETON_PATH = os.path.join(MODEL_DIR, "smpl_skeleton.pkl")
SKIN_KPS_PREDICTOR_PATH = os.path.join(MODEL_DIR, "skin_kps_predictor.pkl")

# Check if model files exist
if not os.path.exists(SMPL_SKELETON_PATH) or not os.path.exists(SKIN_KPS_PREDICTOR_PATH):
    print(f"Warning: Model files not found at {MODEL_DIR}. Please ensure smpl_skeleton.pkl and skin_kps_predictor.pkl are present.")

# Determine processing device (CUDA if available, otherwise CPU)
# ZeroGPU on Hugging Face Spaces should provide CUDA
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
if DEVICE.type == 'cuda':
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("CUDA not available, UniRig performance will be significantly slower on CPU.")


# --- Core Rigging Function ---
def rig_glb_mesh(input_glb_file):
    """
    Takes an input GLB file, rigs it using UniRig, and returns the path to the rigged GLB file.
    """
    if input_glb_file is None:
        raise gr.Error("No input file provided. Please upload a .glb mesh.")

    input_glb_path = input_glb_file.name  # Get the path of the uploaded file

    # Ensure UniRig components are loaded (they might be dummy if import failed)
    if not callable(getattr(AutoRigger, '__init__', None)) or not callable(setup_source_mesh):
         raise gr.Error("UniRig components are not correctly loaded. Please check the server logs and UniRig_src setup.")

    try:
        # Create a temporary directory for output
        temp_dir = tempfile.mkdtemp()
        output_glb_filename = "rigged_output.glb"
        output_glb_path = os.path.join(temp_dir, output_glb_filename)

        # 1. Load the mesh using trimesh
        print(f"Loading mesh from: {input_glb_path}")
        mesh = trimesh.load_mesh(input_glb_path, force='mesh', process=False)
        
        if not isinstance(mesh, trimesh.Trimesh):
            # If it's a Scene object, try to get a single geometry
            if isinstance(mesh, trimesh.Scene):
                if len(mesh.geometry) == 0:
                    raise gr.Error("Input GLB file contains no mesh geometry.")
                # Concatenate all meshes in the scene into a single mesh
                # This is a common approach, but might not be ideal for all GLB files
                print(f"Input is a scene with {len(mesh.geometry)} geometries. Attempting to merge.")
                mesh = trimesh.util.concatenate(list(mesh.geometry.values()))
                if not isinstance(mesh, trimesh.Trimesh):
                     raise gr.Error(f"Could not extract a valid mesh from the GLB scene. Found type: {type(mesh)}")
            else:
                raise gr.Error(f"Failed to load a valid mesh from the input file. Loaded type: {type(mesh)}")

        print("Mesh loaded successfully.")

        # 2. Preprocess the mesh (as per UniRig's example)
        # This step is crucial for UniRig to work correctly.
        # It involves canonicalization and remeshing.
        print("Preprocessing mesh...")
        mesh = setup_source_mesh(mesh, device=DEVICE)
        print("Mesh preprocessing complete.")

        # 3. Initialize the AutoRigger
        # Ensure model files are accessible
        if not os.path.exists(SMPL_SKELETON_PATH) or not os.path.exists(SKIN_KPS_PREDICTOR_PATH):
            raise gr.Error(f"UniRig model files not found. Searched in {MODEL_DIR}. Please check your Space's file structure.")
        
        print("Initializing AutoRigger...")
        autorigger = AutoRigger(SMPL_SKELETON_PATH, SKIN_KPS_PREDICTOR_PATH, device=DEVICE)
        print("AutoRigger initialized.")

        # 4. Perform rigging
        print("Starting rigging process...")
        # The `rig` method might require specific verts, faces, and normals if not handled by `setup_source_mesh`
        # Assuming `setup_source_mesh` prepares it adequately.
        output_dict = autorigger.rig(mesh)
        print("Rigging process complete.")

        # 5. Extract the rigged mesh
        rigged_mesh = output_dict['rigged_mesh'] # This should be a trimesh.Trimesh object
        print("Rigged mesh extracted.")

        # 6. Export the rigged mesh to GLB format
        print(f"Exporting rigged mesh to: {output_glb_path}")
        rigged_mesh.export(output_glb_path)
        print("Export complete.")

        return output_glb_path

    except Exception as e:
        print(f"Error during rigging: {e}")
        # Clean up temp dir in case of error
        if 'temp_dir' in locals() and os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
        # Re-raise as Gradio error to display to user
        raise gr.Error(f"An error occurred during processing: {str(e)}")
    # No finally block for shutil.rmtree(temp_dir) here,
    # because Gradio needs the file path to serve it.
    # Gradio handles cleanup of temporary files created by gr.File.

# --- Gradio Interface ---
# Define a custom theme (Blue and Charcoal Gray)
# Using Soft theme with sky blue and slate gray
theme = gr.themes.Soft(
    primary_hue=gr.themes.colors.sky,    # A nice blue
    secondary_hue=gr.themes.colors.blue, # Can be same as primary or a complementary blue
    neutral_hue=gr.themes.colors.slate,  # Charcoal gray
    font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
).set(
    # Further fine-tuning if needed
    # button_primary_background_fill="*primary_500",
    # button_primary_text_color="white",
)

# Interface definition
iface = gr.Interface(
    fn=rig_glb_mesh,
    inputs=gr.File(label="Upload .glb Mesh File", type="file"), # 'file' gives a NamedTemporaryFile object
    outputs=gr.Model3D(label="Rigged 3D Model (.glb)", clear_color=[0.8, 0.8, 0.8, 1.0]), # Model3D can display .glb
    title="UniRig Auto-Rigger for 3D Meshes",
    description=(
        "Upload a 3D mesh in `.glb` format. This application uses UniRig to automatically rig the mesh.\n"
        "The process may take a few minutes, especially for complex meshes. Ensure your GLB has clean geometry.\n"
        f"Running on: {str(DEVICE).upper()}. Model files expected in '{MODEL_DIR}'.\n"
        f"UniRig Source: https://github.com/VAST-AI-Research/UniRig"
    ),
    examples=[
        # Add paths to example GLB files if you include them in your Space
        # e.g., [os.path.join(os.path.dirname(__file__), "examples/sample_mesh.glb")]
    ],
    cache_examples=False, # Set to True if you have static examples and want to pre-process them
    theme=theme,
    allow_flagging="never"
)

if __name__ == "__main__":
    if not os.path.exists(os.path.join(os.path.dirname(__file__), 'UniRig_src')):
        print("CRITICAL: 'UniRig_src' directory not found. Please ensure UniRig source files are correctly placed.")
    iface.launch()