import gradio as gr import os import sys import subprocess from huggingface_hub import snapshot_download, HfFolder import random # Import random for seed generation # --- Repo Setup --- DEFAULT_REPO_DIR = "./TripoSG-repo" # Directory to clone into if not using local path REPO_GIT_URL = "github.com/VAST-AI-Research/TripoSG.git" # Base URL without schema/token BRANCH = "scribble" code_source_path = None # Option 1: Use local path if TRIPOSG_CODE_PATH env var is set local_code_path = os.environ.get("TRIPOSG_CODE_PATH") if local_code_path: print(f"Attempting to use local code path specified by TRIPOSG_CODE_PATH: {local_code_path}") # Basic check: does it exist and seem like a git repo (has .git)? if os.path.isdir(local_code_path) and os.path.isdir(os.path.join(local_code_path, ".git")): code_source_path = os.path.abspath(local_code_path) print(f"Using local TripoSG code directory: {code_source_path}") # You might want to add a check here to verify the branch is correct, e.g.: # try: # current_branch = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=code_source_path, check=True, capture_output=True, text=True).stdout.strip() # if current_branch != BRANCH: # print(f"Warning: Local repo is on branch '{current_branch}', expected '{BRANCH}'. Attempting checkout...") # subprocess.run(["git", "checkout", BRANCH], cwd=code_source_path, check=True) # except Exception as e: # print(f"Warning: Could not verify or checkout branch '{BRANCH}' in {code_source_path}: {e}") else: print(f"Warning: TRIPOSG_CODE_PATH '{local_code_path}' not found or not a valid git repository directory. Falling back to cloning.") # Option 2: Clone from GitHub (if local path not used or invalid) if not code_source_path: repo_url_to_clone = f"https://{REPO_GIT_URL}" github_token = os.environ.get("GITHUB_TOKEN") if github_token: print("Using GITHUB_TOKEN for repository cloning.") repo_url_to_clone = f"https://{github_token}@{REPO_GIT_URL}" else: print("No GITHUB_TOKEN found. Using public HTTPS for cloning.") repo_target_dir = os.path.abspath(DEFAULT_REPO_DIR) if not os.path.exists(repo_target_dir): print(f"Cloning TripoSG repository ({BRANCH} branch) into {repo_target_dir}...") try: subprocess.run(["git", "clone", "--branch", BRANCH, "--depth", "1", repo_url_to_clone, repo_target_dir], check=True) code_source_path = repo_target_dir print("Repository cloned successfully.") except subprocess.CalledProcessError as e: print(f"Error cloning repository: {e}") print("Please ensure the URL is correct, the branch '{BRANCH}' exists, and you have access rights (or provide a GITHUB_TOKEN).") sys.exit(1) except Exception as e: print(f"An unexpected error occurred during cloning: {e}") sys.exit(1) else: print(f"Directory {repo_target_dir} already exists. Assuming it contains the correct code/branch.") # Optional: Add checks here like git pull or verifying the branch code_source_path = repo_target_dir if not code_source_path: print("Error: Could not determine TripoSG code source path.") sys.exit(1) # Add repo to Python path sys.path.insert(0, code_source_path) # Use the determined absolute path print(f"Added {code_source_path} to sys.path") # --- End Repo Setup --- # --- ZeroGPU Setup --- DISABLE_ZEROGPU = os.environ.get("DISABLE_ZEROGPU", "false").lower() in ("true", "1", "t") ENABLE_ZEROGPU = not DISABLE_ZEROGPU print(f"ZeroGPU Enabled: {ENABLE_ZEROGPU}") # --- End ZeroGPU Setup --- if ENABLE_ZEROGPU: import spaces # Import spaces for ZeroGPU from PIL import Image import numpy as np import torch from triposg.pipelines.pipeline_triposg_scribble import TripoSGScribblePipeline import tempfile # --- Weight Loading Logic --- HF_TOKEN = os.environ.get("HF_TOKEN") if HF_TOKEN: HfFolder.save_token(HF_TOKEN) HUGGING_FACE_REPO_ID = "VAST-AI/TripoSG-scribble" DEFAULT_CACHE_PATH = "./pretrained_weights/TripoSG-scribble" # Option 1: Use local path if WEIGHTS_PATH env var is set local_weights_path = os.environ.get("WEIGHTS_PATH") model_load_path = None if local_weights_path: print(f"Attempting to load weights from local path specified by WEIGHTS_PATH: {local_weights_path}") if os.path.isdir(local_weights_path): model_load_path = local_weights_path print(f"Using local weights directory: {model_load_path}") else: print(f"Warning: WEIGHTS_PATH '{local_weights_path}' not found or not a directory. Falling back to Hugging Face download.") # Option 2: Download from Hugging Face (if local path not used or invalid) if not model_load_path: hf_token = os.environ.get("HF_TOKEN") print(f"Attempting to download weights from Hugging Face repo: {HUGGING_FACE_REPO_ID}") if hf_token: print("Using Hugging Face token for download.") auth_token = hf_token else: print("No Hugging Face token found. Attempting public download.") auth_token = None try: model_load_path = snapshot_download( repo_id=HUGGING_FACE_REPO_ID, local_dir=DEFAULT_CACHE_PATH, local_dir_use_symlinks=False, # Recommended for Spaces token=auth_token, # revision="main" # Specify branch/commit if needed ) print(f"Weights downloaded/cached to: {model_load_path}") except Exception as e: print(f"Error downloading weights from Hugging Face: {e}") print("Please ensure the repository exists and is accessible, or provide a valid WEIGHTS_PATH.") sys.exit(1) # Exit if weights cannot be loaded # Load the pipeline using the determined path print(f"Loading pipeline from: {model_load_path}") pipe = TripoSGScribblePipeline.from_pretrained(model_load_path) pipe.to(dtype=torch.float16, device="cuda") print("Pipeline loaded.") # --- End Weight Loading Logic --- # Create a white background image and a transparent layer for drawing canvas_width, canvas_height = 512, 512 initial_background = Image.new("RGB", (canvas_width, canvas_height), color="white") initial_layer = Image.new("RGBA", (canvas_width, canvas_height), color=(0, 0, 0, 0)) # Transparent layer # Prepare the initial value dictionary for ImageEditor initial_value = { "background": initial_background, "layers": [initial_layer], # Add the transparent layer "composite": None } # --- ZeroGPU Setup --- # ... existing ZeroGPU setup ... MAX_SEED = np.iinfo(np.int32).max def get_random_seed(): return random.randint(0, MAX_SEED) # Apply decorator conditionally @spaces.GPU() if ENABLE_ZEROGPU else lambda func: func def generate_3d(scribble_image_dict, prompt, scribble_confidence, prompt_confidence, seed): # Added text_confidence parameter print("Generating 3D model...") # Extract the composite image from the ImageEditor dictionary if scribble_image_dict is None or scribble_image_dict.get("composite") is None: print("No scribble image provided.") return None # Return None if no image is provided # --- Seed Handling --- current_seed = int(seed) print(f"Using seed: {current_seed}") # --- End Seed Handling --- # Get the composite image which includes the drawing # The composite might be RGBA if a layer was involved, ensure RGB for processing image = Image.fromarray(scribble_image_dict["composite"]).convert("RGB") # Preprocess the image: invert colors (black on white -> white on black) image_np = np.array(image) processed_image_np = 255 - image_np processed_image = Image.fromarray(processed_image_np) print("Image preprocessed.") # Define fixed parameters # attn_scale_text = 1.0 # Replaced by text_confidence input # Set the generator with the provided seed generator = torch.Generator(device='cuda').manual_seed(current_seed) # Run the pipeline print("Running pipeline...") out = pipe( processed_image, prompt=prompt, num_tokens=512, # Default value from example guidance_scale=0, # Default value from example num_inference_steps=16, # Default value from example attention_kwargs={ "cross_attention_scale": prompt_confidence, # Use input parameter "cross_attention_2_scale": scribble_confidence }, generator=generator, use_flash_decoder=False, dense_octree_depth=8, hierarchical_octree_depth=8 ) print("Pipeline finished.") # Save the output mesh to a temporary file if out.meshes and len(out.meshes) > 0: # Create a temporary file with .glb extension with tempfile.NamedTemporaryFile(suffix=".glb", delete=False) as tmpfile: output_path = tmpfile.name out.meshes[0].export(output_path) print(f"Mesh saved to temporary file: {output_path}") return output_path else: print("Pipeline did not generate any meshes.") return None # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# TripoSG Scribble!!") gr.Markdown("3D model generation with simple scribble and text prompt.") # Updated guidance with gr.Row(): with gr.Column(scale=1): image_input = gr.ImageEditor( label="Scribble Input (Draw Black on White)", value=initial_value, image_mode="RGB", brush=gr.Brush(default_color="#000000", color_mode="fixed", default_size=5), # Fixed small brush size interactive=True, eraser=gr.Brush(default_color="#FFFFFF", color_mode="fixed", default_size=20) # Fixed small eraser size ) prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., a cute cat wearing a hat") confidence_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Scribble Confidence") prompt_confidence_input = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Prompt Confidence") # Added slider seed_input = gr.Number(label="Seed", value=0, precision=0) # Added Seed input back with gr.Row(): submit_button = gr.Button("Generate 3D Model", variant="primary", scale=1) lucky_button = gr.Button("I'm Feeling Lucky", scale=1) with gr.Column(scale=1): model_output = gr.Model3D(label="Generated 3D Model", interactive=False) # Define the inputs for the main generation function gen_inputs = [image_input, prompt_input, confidence_input, prompt_confidence_input, seed_input] # Added text_confidence_input submit_button.click( fn=generate_3d, inputs=gen_inputs, # Include seed_input and text_confidence_input outputs=model_output ) # Define inputs for the lucky button (same as main button for the final call) lucky_gen_inputs = [image_input, prompt_input, confidence_input, prompt_confidence_input, seed_input] # Added text_confidence_input lucky_button.click( fn=get_random_seed, # First, get a random seed inputs=[], outputs=[seed_input] # Update the seed input field ).then( fn=generate_3d, # Then, generate the model inputs=lucky_gen_inputs, # Use the updated seed from the input field outputs=model_output ) # Launch with queue enabled if using ZeroGPU print("Launching Gradio interface...") demo.launch(share=False, server_name="0.0.0.0") print("Gradio interface launched.")