File size: 9,444 Bytes
498b8bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd9d8fc
 
498b8bb
 
 
 
fd9d8fc
 
498b8bb
fd9d8fc
498b8bb
 
fd9d8fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498b8bb
 
fd9d8fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498b8bb
 
 
 
fd9d8fc
 
 
 
498b8bb
 
 
 
fd9d8fc
498b8bb
 
 
 
fd9d8fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498b8bb
 
 
fd9d8fc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import gradio as gr
import torch
import os
import sys
import tempfile
import shutil
import subprocess
# from huggingface_hub import HfApi, snapshot_download # For future model management if needed
# import spaces # For @spaces.GPU decorator if you add it

# --- Configuration ---
# Path to the cloned UniRig repository directory within the Space
UNIRIG_REPO_DIR = os.path.join(os.path.dirname(__file__), "UniRig")

if not os.path.isdir(UNIRIG_REPO_DIR):
    print(f"ERROR: UniRig repository not found at {UNIRIG_REPO_DIR}. Please clone it there.")
    # Consider raising an error or displaying it in the UI if UniRig is critical for startup

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
if DEVICE.type == 'cuda':
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("Warning: CUDA not available or not detected by PyTorch. UniRig performance will be severely impacted.")

def run_unirig_command(command_args, step_name):
    """Helper function to run UniRig commands using subprocess."""
    python_exe = sys.executable
    # Ensure the command starts with the python executable and '-m' for module execution
    cmd = [python_exe, "-m"] + command_args
    
    print(f"Running {step_name}: {' '.join(cmd)}")
    
    process_env = os.environ.copy()
    
    # Explicitly add UNIRIG_REPO_DIR to PYTHONPATH for the subprocess.
    # This ensures that Python can find the 'unirig' package located within UNIRIG_REPO_DIR.
    # UNIRIG_REPO_DIR itself is the directory containing the 'unirig' package folder.
    existing_pythonpath = process_env.get('PYTHONPATH', '')
    process_env["PYTHONPATH"] = f"{UNIRIG_REPO_DIR}{os.pathsep}{existing_pythonpath}"
    print(f"Set PYTHONPATH for subprocess: {process_env['PYTHONPATH']}")


    try:
        # Execute the command from the UniRig directory for Hydra to find configs
        result = subprocess.run(cmd, cwd=UNIRIG_REPO_DIR, capture_output=True, text=True, check=True, env=process_env)
        print(f"{step_name} STDOUT:\n{result.stdout}")
        if result.stderr:
            print(f"{step_name} STDERR (non-fatal or warnings):\n{result.stderr}")
    except subprocess.CalledProcessError as e:
        print(f"ERROR during {step_name}:")
        print(f"Command: {' '.join(e.cmd)}")
        print(f"Return code: {e.returncode}")
        print(f"Stdout: {e.stdout}")
        print(f"Stderr: {e.stderr}")
        # Provide a more user-friendly error, potentially masking long tracebacks
        error_summary = e.stderr.splitlines()[-5:] # Last 5 lines of stderr
        raise gr.Error(f"Error in UniRig {step_name}. Details: {' '.join(error_summary)}")
    except FileNotFoundError:
        print(f"ERROR: Could not find executable or script for {step_name}. Is UniRig cloned correctly in {UNIRIG_REPO_DIR} and Python environment setup?")
        raise gr.Error(f"Setup error for UniRig {step_name}. Check server logs and UniRig directory structure.")
    except Exception as e_general:
        print(f"An unexpected Python exception occurred in run_unirig_command for {step_name}: {e_general}")
        raise gr.Error(f"Unexpected Python error during {step_name}: {str(e_general)[:500]}")


# If you are using @spaces.GPU, you would import it:
# import spaces
# @spaces.GPU # You can specify type like @spaces.GPU(type="t4") or count
def rig_glb_mesh_multistep(input_glb_file_obj):
    """
    Takes an input GLB file object (from gr.File with type="filepath"),
    rigs it using the new UniRig multi-step process,
    and returns the path to the final rigged GLB file.
    """
    if not os.path.isdir(UNIRIG_REPO_DIR):
         raise gr.Error(f"UniRig repository not found at {UNIRIG_REPO_DIR}. Cannot proceed. Please check Space setup.")

    if input_glb_file_obj is None:
        # This case should ideally be handled by Gradio's input validation if `allow_none=False` (default)
        raise gr.Error("No input file provided. Please upload a .glb mesh.")

    # When type="filepath", input_glb_file_obj is the path string directly
    input_glb_path = input_glb_file_obj
    print(f"Input GLB path received: {input_glb_path}")

    # Create a dedicated temporary directory for all intermediate and final files
    processing_temp_dir = tempfile.mkdtemp(prefix="unirig_processing_")
    print(f"Using temporary processing directory: {processing_temp_dir}")

    try:
        base_name = os.path.splitext(os.path.basename(input_glb_path))[0]
        
        # Step 1: Skeleton Prediction
        temp_skeleton_path = os.path.join(processing_temp_dir, f"{base_name}_skeleton.fbx")
        print("Step 1: Predicting Skeleton...")
        run_unirig_command([
            "unirig.predict_skeleton",
            f"input.path={os.path.abspath(input_glb_path)}", # Use absolute path for robustness
            f"output.path={os.path.abspath(temp_skeleton_path)}",
            # f"device={str(DEVICE)}" # If UniRig's script accepts this override and handles it
        ], "Skeleton Prediction")
        if not os.path.exists(temp_skeleton_path):
            raise gr.Error("Skeleton prediction failed to produce an output file. Check logs for UniRig errors.")

        # Step 2: Skinning Weight Prediction
        temp_skin_path = os.path.join(processing_temp_dir, f"{base_name}_skin.fbx")
        print("Step 2: Predicting Skinning Weights...")
        run_unirig_command([
            "unirig.predict_skin",
            f"input.skeleton_path={os.path.abspath(temp_skeleton_path)}",
            f"input.source_mesh_path={os.path.abspath(input_glb_path)}",
            f"output.path={os.path.abspath(temp_skin_path)}",
        ], "Skinning Prediction")
        if not os.path.exists(temp_skin_path):
            raise gr.Error("Skinning prediction failed to produce an output file. Check logs for UniRig errors.")

        # Step 3: Merge Skeleton/Skin with Original Mesh
        final_rigged_glb_path = os.path.join(processing_temp_dir, f"{base_name}_rigged_final.glb")
        print("Step 3: Merging Results...")
        run_unirig_command([
            "unirig.merge_skeleton_skin",
            f"input.source_rig_path={os.path.abspath(temp_skin_path)}",
            f"input.target_mesh_path={os.path.abspath(input_glb_path)}",
            f"output.path={os.path.abspath(final_rigged_glb_path)}",
        ], "Merging")
        if not os.path.exists(final_rigged_glb_path):
            raise gr.Error("Merging process failed to produce the final rigged GLB file. Check logs for UniRig errors.")

        # final_rigged_glb_path is in processing_temp_dir.
        # Gradio's gr.Model3D output component will handle serving this file.
        return final_rigged_glb_path

    except gr.Error: # Re-raise Gradio errors directly
        if os.path.exists(processing_temp_dir): # Clean up on known Gradio error
            shutil.rmtree(processing_temp_dir)
            print(f"Cleaned up temporary directory: {processing_temp_dir}")
        raise
    except Exception as e:
        print(f"An unexpected error occurred in rig_glb_mesh_multistep: {e}")
        if os.path.exists(processing_temp_dir): # Clean up on unexpected error
            shutil.rmtree(processing_temp_dir)
            print(f"Cleaned up temporary directory: {processing_temp_dir}")
        raise gr.Error(f"An unexpected error occurred during processing: {str(e)[:500]}")


# --- Gradio Interface ---
theme = gr.themes.Soft(
    primary_hue=gr.themes.colors.sky,
    secondary_hue=gr.themes.colors.blue,
    neutral_hue=gr.themes.colors.slate,
    font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
)

# Ensure UNIRIG_REPO_DIR check happens before interface is built if it's critical
if not os.path.isdir(UNIRIG_REPO_DIR) and __name__ == "__main__": # Check only if running as main script
    print(f"CRITICAL STARTUP ERROR: UniRig repository not found at {UNIRIG_REPO_DIR}. The application will not work.")

# Define the interface
# Note: The @spaces.GPU decorator would go above the function `rig_glb_mesh_multistep`
iface = gr.Interface(
    fn=rig_glb_mesh_multistep, 
    inputs=gr.File(
        label="Upload .glb Mesh File",
        type="filepath"  # Corrected type for Gradio 4.x / 5.x
    ),
    outputs=gr.Model3D(
        label="Rigged 3D Model (.glb)",
        clear_color=[0.8, 0.8, 0.8, 1.0],
    ),
    title="UniRig Auto-Rigger (Python 3.11 / PyTorch 2.3+)",
    description=(
        "Upload a 3D mesh in `.glb` format. This application uses the latest UniRig to automatically rig the mesh.\n"
        "The process involves: 1. Skeleton Prediction, 2. Skinning Weight Prediction, 3. Merging.\n"
        "This may take several minutes. Ensure your GLB has clean geometry.\n"
        f"Running on: {str(DEVICE).upper()}. UniRig repo expected at: '{os.path.basename(UNIRIG_REPO_DIR)}'.\n"
        f"UniRig Source: https://github.com/VAST-AI-Research/UniRig"
    ),
    cache_examples=False,
    theme=theme
    # allow_flagging="never" # Removed as it's deprecated in Gradio 4.x and default behavior is usually no flagging.
                           # If specific flagging control is needed, use `flagging_options` or similar.
)

if __name__ == "__main__":
    if not os.path.isdir(UNIRIG_REPO_DIR):
        print(f"CRITICAL: UniRig repository not found at {UNIRIG_REPO_DIR}. Ensure it's cloned in the Space's root.")
    
    iface.launch()