dkatz2391 commited on
Commit
64a7809
·
verified ·
1 Parent(s): ea9eafd

Upload trellis_fastAPI_integration.py

Browse files
Files changed (1) hide show
  1. trellis_fastAPI_integration.py +143 -0
trellis_fastAPI_integration.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # trellis_fastAPI_integration.py
2
+ # Version: 1.0.0
3
+
4
+ # a.1 Imports and Initial Setup
5
+ import os
6
+ import shutil
7
+ import threading
8
+ import uvicorn
9
+ import logging
10
+ import numpy as np
11
+ import torch
12
+
13
+ from fastapi import FastAPI, HTTPException
14
+ from pydantic import BaseModel
15
+ from easydict import EasyDict as edict # Assuming EasyDict might be needed if state used
16
+
17
+ # Assuming these are available or installed correctly in the environment
18
+ from trellis.utils import postprocessing_utils
19
+ # We get the pipeline object passed in, so no direct import needed here
20
+
21
+ # Set up logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # FastAPI app
26
+ api_app = FastAPI()
27
+
28
+ # --- Temporary Directory --- (Consistent with appTrellis.py)
29
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
30
+ os.makedirs(TMP_DIR, exist_ok=True)
31
+
32
+ # b.1 Request/Response Models
33
+ class GenerateRequest(BaseModel):
34
+ prompt: str
35
+ seed: int = 0 # Default seed
36
+ mesh_simplify: float = 0.95 # Default simplify factor
37
+ texture_size: int = 1024 # Default texture size
38
+ # Add other generation parameters if needed (e.g., guidance, steps)
39
+
40
+ # c.1 API Endpoint for Synchronous Generation
41
+ @api_app.post("/generate-sync")
42
+ async def generate_sync_api(request_data: GenerateRequest):
43
+ """API endpoint to synchronously generate a model and return the GLB path."""
44
+ # Access the pipeline object stored in app state
45
+ pipeline = api_app.state.pipeline
46
+ if pipeline is None:
47
+ logger.error("API Error: Pipeline not initialized or passed correctly")
48
+ raise HTTPException(status_code=503, detail="Pipeline not ready")
49
+
50
+ prompt = request_data.prompt
51
+ seed = request_data.seed
52
+ mesh_simplify = request_data.mesh_simplify
53
+ texture_size = request_data.texture_size
54
+ # Extract other params if added to GenerateRequest
55
+ ss_sampling_steps = 25 # Example default
56
+ ss_guidance_strength = 7.5 # Example default
57
+ slat_sampling_steps = 25 # Example default
58
+ slat_guidance_strength = 7.5 # Example default
59
+
60
+ logger.info(f"API /generate-sync received prompt: {prompt}")
61
+ user_dir = None # Define user_dir outside try for cleanup
62
+
63
+ try:
64
+ # --- Determine a unique temporary directory for this API call ---
65
+ # Using a simpler random hash name for the API call directory
66
+ api_call_hash = f"api_sync_{np.random.randint(100000)}"
67
+ user_dir = os.path.join(TMP_DIR, api_call_hash)
68
+ os.makedirs(user_dir, exist_ok=True)
69
+ logger.info(f"API using temp dir: {user_dir}")
70
+
71
+ # --- Stage 1: Run the text-to-3D pipeline ---
72
+ logger.info("API running pipeline...")
73
+ # Ensure pipeline is run with appropriate parameters
74
+ outputs = pipeline.run(
75
+ prompt,
76
+ seed=seed,
77
+ formats=["gaussian", "mesh"],
78
+ sparse_structure_sampler_params={
79
+ "steps": ss_sampling_steps,
80
+ "cfg_strength": ss_guidance_strength,
81
+ },
82
+ slat_sampler_params={
83
+ "steps": slat_sampling_steps,
84
+ "cfg_strength": slat_guidance_strength,
85
+ },
86
+ )
87
+ gs = outputs['gaussian'][0] # Get the Gaussian representation
88
+ mesh = outputs['mesh'][0] # Get the Mesh representation
89
+ logger.info("API pipeline finished.")
90
+ torch.cuda.empty_cache()
91
+
92
+ # --- Stage 2: Extract GLB ---
93
+ logger.info("API extracting GLB...")
94
+ # Use the postprocessing utility
95
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
96
+ glb_path = os.path.join(user_dir, 'generated_sync.glb')
97
+ glb.export(glb_path)
98
+ logger.info(f"API GLB exported to: {glb_path}")
99
+ torch.cuda.empty_cache()
100
+
101
+ # Return the absolute path within the container
102
+ # This path needs to be accessible via the /file= route from outside
103
+ absolute_glb_path = os.path.abspath(glb_path)
104
+ logger.info(f"API returning absolute path: {absolute_glb_path}")
105
+ return {"status": "success", "glb_path": absolute_glb_path}
106
+
107
+ except Exception as e:
108
+ logger.error(f"API /generate-sync error: {str(e)}", exc_info=True)
109
+ # Clean up temp dir on error if it exists and was created
110
+ if user_dir and os.path.exists(user_dir):
111
+ try:
112
+ shutil.rmtree(user_dir)
113
+ logger.info(f"API cleaned up failed directory: {user_dir}")
114
+ except Exception as cleanup_e:
115
+ logger.error(f"API Error cleaning up dir {user_dir}: {cleanup_e}")
116
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
117
+ # Note: We don't automatically clean up the user_dir on success,
118
+ # as the file needs to be accessible for download by the calling server.
119
+ # A separate cleanup mechanism might be needed eventually.
120
+
121
+ # d.1 API Server Setup Functions
122
+ def run_api():
123
+ """Run the FastAPI server."""
124
+ # Ensure pipeline is available in app state before starting
125
+ if not hasattr(api_app.state, 'pipeline') or api_app.state.pipeline is None:
126
+ logger.error("Cannot start API server: Pipeline object not found in app state.")
127
+ return
128
+ # Run on port 8000 - ensure this doesn't conflict if Gradio also tries this port
129
+ uvicorn.run(api_app, host="0.0.0.0", port=8000)
130
+
131
+ def start_api_thread(pipeline_object):
132
+ """Start the API server in a background thread
133
+
134
+ Args:
135
+ pipeline_object: The initialized TrellisTextTo3DPipeline object
136
+ """
137
+ # Store the passed pipeline object in the app's state
138
+ api_app.state.pipeline = pipeline_object
139
+
140
+ api_thread = threading.Thread(target=run_api, daemon=True)
141
+ api_thread.start()
142
+ logger.info("Started Trellis FastAPI integration server thread on port 8000")
143
+ return api_thread