dkatz2391 commited on
Commit
df3aef9
·
verified ·
1 Parent(s): ca8ba1c

Delete trellis_fastAPI_integration.py

Browse files
Files changed (1) hide show
  1. trellis_fastAPI_integration.py +0 -168
trellis_fastAPI_integration.py DELETED
@@ -1,168 +0,0 @@
1
- # trellis_fastAPI_integration.py
2
- # Version: 1.0.0
3
-
4
- # a.1 Imports and Initial Setup
5
- import os
6
- import shutil
7
- import threading
8
- import uvicorn
9
- import logging
10
- import numpy as np
11
- import torch
12
-
13
- from fastapi import FastAPI, HTTPException
14
- from pydantic import BaseModel
15
- from easydict import EasyDict as edict # Assuming EasyDict might be needed if state used
16
-
17
- # Assuming these are available or installed correctly in the environment
18
- from trellis.utils import postprocessing_utils
19
- # We get the pipeline object passed in, so no direct import needed here
20
-
21
- # Set up logging
22
- logging.basicConfig(level=logging.INFO)
23
- logger = logging.getLogger(__name__)
24
-
25
- # FastAPI app
26
- api_app = FastAPI()
27
-
28
- # --- Temporary Directory --- (Consistent with appTrellis.py)
29
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
30
- os.makedirs(TMP_DIR, exist_ok=True)
31
-
32
- # b.1 Request/Response Models
33
- class GenerateRequest(BaseModel):
34
- prompt: str
35
- seed: int = 0 # Default seed
36
- mesh_simplify: float = 0.95 # Default simplify factor
37
- texture_size: int = 1024 # Default texture size
38
- # Add other generation parameters if needed (e.g., guidance, steps)
39
-
40
- # c.1 API Endpoint for Synchronous Generation
41
- @api_app.post("/generate-sync")
42
- async def generate_sync_api(request_data: GenerateRequest):
43
- """API endpoint to synchronously generate a model and return the GLB path."""
44
- logger.info("API /generate-sync endpoint hit.") # Log when endpoint is called
45
- # Access the pipeline object stored in app state
46
- pipeline = api_app.state.pipeline
47
- if pipeline is None:
48
- logger.error("API Error: Pipeline not initialized or passed correctly")
49
- raise HTTPException(status_code=503, detail="Pipeline not ready")
50
-
51
- prompt = request_data.prompt
52
- seed = request_data.seed
53
- mesh_simplify = request_data.mesh_simplify
54
- texture_size = request_data.texture_size
55
- # Extract other params if added to GenerateRequest
56
- ss_sampling_steps = 25 # Example default
57
- ss_guidance_strength = 7.5 # Example default
58
- slat_sampling_steps = 25 # Example default
59
- slat_guidance_strength = 7.5 # Example default
60
-
61
- logger.info(f"API /generate-sync received prompt: {prompt}")
62
- user_dir = None # Define user_dir outside try for cleanup
63
-
64
- try:
65
- # --- Determine a unique temporary directory for this API call ---
66
- # Using a simpler random hash name for the API call directory
67
- api_call_hash = f"api_sync_{np.random.randint(100000)}"
68
- user_dir = os.path.join(TMP_DIR, api_call_hash)
69
- os.makedirs(user_dir, exist_ok=True)
70
- logger.info(f"API using temp dir: {user_dir}")
71
-
72
- # --- Stage 1: Run the text-to-3D pipeline ---
73
- logger.info("API running pipeline...")
74
- # Ensure pipeline is run with appropriate parameters
75
- outputs = pipeline.run(
76
- prompt,
77
- seed=seed,
78
- formats=["gaussian", "mesh"],
79
- sparse_structure_sampler_params={
80
- "steps": ss_sampling_steps,
81
- "cfg_strength": ss_guidance_strength,
82
- },
83
- slat_sampler_params={
84
- "steps": slat_sampling_steps,
85
- "cfg_strength": slat_guidance_strength,
86
- },
87
- )
88
- gs = outputs['gaussian'][0] # Get the Gaussian representation
89
- mesh = outputs['mesh'][0] # Get the Mesh representation
90
- logger.info("API pipeline finished.")
91
- torch.cuda.empty_cache()
92
-
93
- # --- Stage 2: Extract GLB ---
94
- logger.info("API extracting GLB...")
95
- # Use the postprocessing utility
96
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
97
- glb_path = os.path.join(user_dir, 'generated_sync.glb')
98
- glb.export(glb_path)
99
- logger.info(f"API GLB exported to: {glb_path}")
100
- torch.cuda.empty_cache()
101
-
102
- # Return the absolute path within the container
103
- # This path needs to be accessible via the /file= route from outside
104
- absolute_glb_path = os.path.abspath(glb_path)
105
- logger.info(f"API returning absolute path: {absolute_glb_path}")
106
- return {"status": "success", "glb_path": absolute_glb_path}
107
-
108
- except Exception as e:
109
- logger.error(f"API /generate-sync error: {str(e)}", exc_info=True)
110
- # Clean up temp dir on error if it exists and was created
111
- if user_dir and os.path.exists(user_dir):
112
- try:
113
- shutil.rmtree(user_dir)
114
- logger.info(f"API cleaned up failed directory: {user_dir}")
115
- except Exception as cleanup_e:
116
- logger.error(f"API Error cleaning up dir {user_dir}: {cleanup_e}")
117
- raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
118
- # Note: We don't automatically clean up the user_dir on success,
119
- # as the file needs to be accessible for download by the calling server.
120
- # A separate cleanup mechanism might be needed eventually.
121
-
122
- # d.1 API Server Setup Functions
123
- def run_api():
124
- """Run the FastAPI server."""
125
- logger.info("FastAPI Integration: run_api function called.")
126
- # Ensure pipeline is available in app state before starting
127
- if not hasattr(api_app.state, 'pipeline') or api_app.state.pipeline is None:
128
- logger.error("FastAPI Integration: Cannot start API server - Pipeline object not found in app state.")
129
- return
130
- logger.info("FastAPI Integration: Pipeline object found in state. Attempting to start Uvicorn...")
131
- # Run on port 8000 - ensure this doesn't conflict if Gradio also tries this port
132
- try:
133
- uvicorn.run(api_app, host="0.0.0.0", port=8000)
134
- logger.info("FastAPI Integration: Uvicorn server stopped.") # Logged when server exits cleanly
135
- except Exception as e:
136
- logger.error(f"FastAPI Integration: Uvicorn server failed to run or crashed: {e}", exc_info=True)
137
-
138
- def start_api_thread(pipeline_object):
139
- """Start the API server in a background thread
140
-
141
- Args:
142
- pipeline_object: The initialized TrellisTextTo3DPipeline object
143
- """
144
- logger.info("FastAPI Integration: start_api_thread called.")
145
- # Store the passed pipeline object in the app's state
146
- if pipeline_object is None:
147
- logger.error("FastAPI Integration: start_api_thread received a None pipeline_object. Aborting thread start.")
148
- return None
149
- try:
150
- api_app.state.pipeline = pipeline_object
151
- logger.info("FastAPI Integration: Pipeline object successfully stored in app state.")
152
- except Exception as e:
153
- logger.error(f"FastAPI Integration: Failed to store pipeline object in app state: {e}", exc_info=True)
154
- return None
155
-
156
- logger.info("FastAPI Integration: Creating API thread...")
157
- api_thread = threading.Thread(target=run_api, daemon=True)
158
-
159
- logger.info("FastAPI Integration: Attempting to start API thread...")
160
- try:
161
- api_thread.start()
162
- logger.info("FastAPI Integration: API thread started (start() method called).")
163
- except Exception as e:
164
- logger.error(f"FastAPI Integration: Failed to start API thread: {e}", exc_info=True)
165
- return None # Indicate thread failed to start
166
-
167
- logger.info("Started Trellis FastAPI integration server thread function finished.") # Confirms this function completed
168
- return api_thread