mac9087 commited on
Commit
6a5c502
·
verified ·
1 Parent(s): 03bce9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -36
app.py CHANGED
@@ -14,7 +14,9 @@ from huggingface_hub import snapshot_download
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
- from diffusers import DiffusionPipeline
 
 
18
 
19
  # Force CPU usage
20
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
@@ -49,13 +51,13 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
49
  processing_jobs = {}
50
 
51
  # Global model variables
52
- triposg_pipeline = None
53
  model_loaded = False
54
  model_loading = False
55
 
56
  # Configuration for processing
57
- TIMEOUT_SECONDS = 240 # 4 minutes max for TripoSG on CPU
58
- MAX_DIMENSION = 512 # TripoSG expects 512x512
59
 
60
  # TimeoutError for handling timeouts
61
  class TimeoutError(Exception):
@@ -94,30 +96,30 @@ def process_with_timeout(function, args, timeout):
94
  def allowed_file(filename):
95
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
96
 
97
- # Image preprocessing for TripoSG (512x512, no background removal)
98
  def preprocess_image(image_path):
99
  with Image.open(image_path) as img:
100
  img = img.convert("RGB")
101
- # TripoSG requires 512x512
102
  img = img.resize((512, 512), Image.LANCZOS)
103
  return img
104
 
105
  def load_model():
106
- global triposg_pipeline, model_loaded, model_loading
107
 
108
  if model_loaded:
109
- return triposg_pipeline
110
 
111
  if model_loading:
112
  while model_loading and not model_loaded:
113
  time.sleep(0.5)
114
- return triposg_pipeline
115
 
116
  try:
117
  model_loading = True
118
  print("Starting model loading...")
119
 
120
- model_name = "VAST-AI/TripoSG"
121
 
122
  # Download model with retry mechanism
123
  max_retries = 3
@@ -138,18 +140,18 @@ def load_model():
138
  else:
139
  raise
140
 
141
- # Load TripoSG pipeline
142
- triposg_pipeline = DiffusionPipeline.from_pretrained(
143
- model_name,
 
144
  cache_dir=CACHE_DIR,
145
- torch_dtype=torch.float32, # Use float32 for CPU
146
- custom_pipeline="VAST-AI/TripoSG",
147
  )
148
- triposg_pipeline.to("cpu") # Explicitly move to CPU
149
 
150
  model_loaded = True
151
  print("Model loaded successfully on CPU")
152
- return triposg_pipeline
153
 
154
  except Exception as e:
155
  print(f"Error loading model: {str(e)}")
@@ -158,11 +160,76 @@ def load_model():
158
  finally:
159
  model_loading = False
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  @app.route('/health', methods=['GET'])
162
  def health_check():
163
  return jsonify({
164
  "status": "healthy",
165
- "model": "TripoSG",
166
  "device": "cpu"
167
  }), 200
168
 
@@ -259,23 +326,14 @@ def convert_image_to_3d():
259
 
260
  try:
261
  def generate_3d():
262
- # Adjust settings based on detail level
263
- num_steps = {'low': 20, 'medium': 50, 'high': 75}
264
- faces = {'low': 3000, 'medium': 5000, 'high': 8000}
265
 
266
- # Convert image to tensor
267
- img_array = np.array(image) / 255.0
268
- img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).float()
269
 
270
- # Generate mesh with TripoSG
271
- with torch.no_grad():
272
- mesh = pipeline(
273
- img_tensor.unsqueeze(0),
274
- num_inference_steps=num_steps[detail_level],
275
- num_faces=faces[detail_level],
276
- guidance_scale=7.5,
277
- seed=12345
278
- ).meshes[0]
279
  return mesh
280
 
281
  mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
@@ -414,7 +472,7 @@ def model_info(job_id):
414
  @app.route('/', methods=['GET'])
415
  def index():
416
  return jsonify({
417
- "message": "Image to 3D API (TripoSG)",
418
  "endpoints": [
419
  "/convert",
420
  "/progress/<job_id>",
@@ -424,9 +482,9 @@ def index():
424
  ],
425
  "parameters": {
426
  "output_format": "glb",
427
- "detail_level": "low, medium, or high - controls inference steps and mesh faces"
428
  },
429
- "description": "This API creates full 3D models from 2D images using TripoSG. Images should have transparent backgrounds."
430
  }), 200
431
 
432
  if __name__ == '__main__':
 
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
+ from transformers import pipeline
18
+ from scipy.ndimage import gaussian_filter
19
+ import open3d as o3d
20
 
21
  # Force CPU usage
22
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
51
  processing_jobs = {}
52
 
53
  # Global model variables
54
+ depth_pipeline = None
55
  model_loaded = False
56
  model_loading = False
57
 
58
  # Configuration for processing
59
+ TIMEOUT_SECONDS = 240 # 4 minutes max for Depth-Anything on CPU
60
+ MAX_DIMENSION = 512 # Depth-Anything expects 512x512
61
 
62
  # TimeoutError for handling timeouts
63
  class TimeoutError(Exception):
 
96
  def allowed_file(filename):
97
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
98
 
99
+ # Image preprocessing for Depth-Anything (512x512, no background removal)
100
  def preprocess_image(image_path):
101
  with Image.open(image_path) as img:
102
  img = img.convert("RGB")
103
+ # Depth-Anything requires 512x512
104
  img = img.resize((512, 512), Image.LANCZOS)
105
  return img
106
 
107
  def load_model():
108
+ global depth_pipeline, model_loaded, model_loading
109
 
110
  if model_loaded:
111
+ return depth_pipeline
112
 
113
  if model_loading:
114
  while model_loading and not model_loaded:
115
  time.sleep(0.5)
116
+ return depth_pipeline
117
 
118
  try:
119
  model_loading = True
120
  print("Starting model loading...")
121
 
122
+ model_name = "LiheYoung/depth-anything-small-hf"
123
 
124
  # Download model with retry mechanism
125
  max_retries = 3
 
140
  else:
141
  raise
142
 
143
+ # Load Depth-Anything pipeline
144
+ depth_pipeline = pipeline(
145
+ "depth-estimation",
146
+ model=model_name,
147
  cache_dir=CACHE_DIR,
148
+ device=-1, # Force CPU
149
+ torch_dtype=torch.float32,
150
  )
 
151
 
152
  model_loaded = True
153
  print("Model loaded successfully on CPU")
154
+ return depth_pipeline
155
 
156
  except Exception as e:
157
  print(f"Error loading model: {str(e)}")
 
160
  finally:
161
  model_loading = False
162
 
163
+ def depth_to_point_cloud(depth_map, image, detail_level):
164
+ # Parameters based on detail level
165
+ downsample_factors = {'low': 4, 'medium': 2, 'high': 1}
166
+ downsample = downsample_factors[detail_level]
167
+
168
+ # Convert image and depth to numpy
169
+ img_array = np.array(image)
170
+ depth_array = np.array(depth_map)
171
+
172
+ # Downsample for performance
173
+ if downsample > 1:
174
+ depth_array = depth_array[::downsample, ::downsample]
175
+ img_array = img_array[::downsample, ::downsample]
176
+
177
+ # Normalize depth
178
+ depth_array = gaussian_filter(depth_array, sigma=1)
179
+ depth_array = (depth_array - depth_array.min()) / (depth_array.max() - depth_array.min() + 1e-8)
180
+
181
+ # Create point cloud
182
+ h, w = depth_array.shape
183
+ x, y = np.meshgrid(np.arange(w), np.arange(h))
184
+
185
+ # Simple camera intrinsics (assumed focal length)
186
+ fx = fy = w * 0.5
187
+ cx, cy = w / 2, h / 2
188
+
189
+ # Convert to 3D coordinates
190
+ z = depth_array
191
+ x = (x - cx) * z / fx
192
+ y = (y - cy) * z / fy
193
+
194
+ points = np.stack([x, y, z], axis=-1).reshape(-1, 3)
195
+ colors = img_array.reshape(-1, 3) / 255.0
196
+
197
+ # Filter out invalid points (e.g., background)
198
+ mask = (z.reshape(-1) > 0.1) & (z.reshape(-1) < 0.9)
199
+ points = points[mask]
200
+ colors = colors[mask]
201
+
202
+ # Create Open3D point cloud
203
+ pcd = o3d.geometry.PointCloud()
204
+ pcd.points = o3d.utility.Vector3dVector(points)
205
+ pcd.colors = o3d.utility.Vector3dVector(colors)
206
+
207
+ # Estimate normals
208
+ pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))
209
+
210
+ # Poisson surface reconstruction
211
+ mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
212
+ pcd, depth=8 if detail_level == 'high' else 6
213
+ )
214
+
215
+ # Convert to trimesh
216
+ vertices = np.asarray(mesh.vertices)
217
+ faces = np.asarray(mesh.triangles)
218
+ vertex_colors = np.asarray(mesh.vertex_colors)
219
+
220
+ trimesh_mesh = trimesh.Trimesh(
221
+ vertices=vertices,
222
+ faces=faces,
223
+ vertex_colors=vertex_colors
224
+ )
225
+
226
+ return trimesh_mesh
227
+
228
  @app.route('/health', methods=['GET'])
229
  def health_check():
230
  return jsonify({
231
  "status": "healthy",
232
+ "model": "Depth-Anything",
233
  "device": "cpu"
234
  }), 200
235
 
 
326
 
327
  try:
328
  def generate_3d():
329
+ # Generate depth map
330
+ with torch.no_grad():
331
+ depth_output = pipeline(image)
332
 
333
+ depth_map = depth_output["depth"]
 
 
334
 
335
+ # Convert depth to mesh
336
+ mesh = depth_to_point_cloud(depth_map, image, detail_level)
 
 
 
 
 
 
 
337
  return mesh
338
 
339
  mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
 
472
  @app.route('/', methods=['GET'])
473
  def index():
474
  return jsonify({
475
+ "message": "Image to 3D API (Depth-Anything)",
476
  "endpoints": [
477
  "/convert",
478
  "/progress/<job_id>",
 
482
  ],
483
  "parameters": {
484
  "output_format": "glb",
485
+ "detail_level": "low, medium, or high - controls point cloud density"
486
  },
487
+ "description": "This API creates 3D models from 2D images using Depth-Anything depth estimation. Images should have transparent backgrounds."
488
  }), 200
489
 
490
  if __name__ == '__main__':