mac9087 commited on
Commit
04ac060
·
verified ·
1 Parent(s): 3b864a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -19
app.py CHANGED
@@ -16,6 +16,7 @@ from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
  from transformers import pipeline
 
19
 
20
  app = Flask(__name__)
21
  CORS(app) # Enable CORS for all routes
@@ -48,7 +49,7 @@ model_loaded = False
48
  model_loading = False
49
 
50
  # Configuration for processing
51
- TIMEOUT_SECONDS = 180 # 3 minutes max for processing
52
  MAX_DIMENSION = 512 # Max image dimension to process
53
 
54
  # TimeoutError for handling timeouts
@@ -121,8 +122,8 @@ def load_model():
121
  model_loading = True
122
  print("Starting model loading...")
123
 
124
- # Using DPT-Hybrid which is smaller than other depth estimation models
125
- model_name = "Intel/dpt-hybrid-midas"
126
 
127
  # Download model with retry mechanism
128
  max_retries = 3
@@ -170,20 +171,23 @@ def load_model():
170
  finally:
171
  model_loading = False
172
 
173
- # Convert depth map to 3D mesh
174
  def depth_to_mesh(depth_map, image, resolution=100):
175
- """Convert depth map to 3D mesh"""
176
  # Convert depth_map to numpy array if it's a PIL Image
177
  if isinstance(depth_map, Image.Image):
178
  depth_map = np.array(depth_map)
179
 
180
  # Make sure the depth map is 2D
181
  if len(depth_map.shape) > 2:
182
- # If it's a 3D array (like RGB), convert to grayscale
183
  depth_map = np.mean(depth_map, axis=2) if depth_map.shape[2] > 1 else depth_map[:,:,0]
184
 
 
 
 
 
185
  # Get dimensions
186
- h, w = depth_map.shape
187
 
188
  # Create a grid of points
189
  x = np.linspace(0, w-1, resolution)
@@ -193,11 +197,24 @@ def depth_to_mesh(depth_map, image, resolution=100):
193
  # Sample depth at grid points
194
  x_indices = x_grid.astype(int)
195
  y_indices = y_grid.astype(int)
196
- z_values = depth_map[y_indices, x_indices]
197
-
198
- # Normalize depth values to suitable range
199
- z_min, z_max = z_values.min(), z_values.max()
200
- z_values = (z_values - z_min) / (z_max - z_min) * 2.0 # Map to 0-2 range
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  # Normalize x and y coordinates
203
  x_grid = (x_grid / w - 0.5) * 2.0 # Map to -1 to 1
@@ -215,6 +232,7 @@ def depth_to_mesh(depth_map, image, resolution=100):
215
  p3 = (i + 1) * resolution + j
216
  p4 = (i + 1) * resolution + (j + 1)
217
 
 
218
  faces.append([p1, p2, p4])
219
  faces.append([p1, p4, p3])
220
 
@@ -223,10 +241,37 @@ def depth_to_mesh(depth_map, image, resolution=100):
223
  # Create mesh
224
  mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
225
 
226
- # Optional: Apply texture from original image
227
  if image:
228
- # This is simplified - proper UV mapping would be needed for accurate texturing
229
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  return mesh
232
 
@@ -234,7 +279,7 @@ def depth_to_mesh(depth_map, image, resolution=100):
234
  def health_check():
235
  return jsonify({
236
  "status": "healthy",
237
- "model": "Depth-Based 3D Model Generator",
238
  "device": "cuda" if torch.cuda.is_available() else "cpu"
239
  }), 200
240
 
@@ -294,6 +339,7 @@ def convert_image_to_3d():
294
  try:
295
  mesh_resolution = min(int(request.form.get('mesh_resolution', 100)), 200) # Limit max resolution
296
  output_format = request.form.get('output_format', 'obj').lower()
 
297
  except ValueError:
298
  return jsonify({"error": "Invalid parameter values"}), 400
299
 
@@ -301,6 +347,12 @@ def convert_image_to_3d():
301
  if output_format not in ['obj', 'glb']:
302
  return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
303
 
 
 
 
 
 
 
304
  # Create a job ID
305
  job_id = str(uuid.uuid4())
306
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
@@ -372,7 +424,8 @@ def convert_image_to_3d():
372
  processing_jobs[job_id]['progress'] = 60
373
 
374
  # Create mesh from depth map
375
- mesh = depth_to_mesh(depth_map, image, resolution=mesh_resolution)
 
376
  processing_jobs[job_id]['progress'] = 80
377
 
378
  except Exception as e:
@@ -523,8 +576,13 @@ def cleanup_old_jobs():
523
  @app.route('/', methods=['GET'])
524
  def index():
525
  return jsonify({
526
- "message": "Image to 3D API is running",
527
- "endpoints": ["/convert", "/progress/<job_id>", "/download/<job_id>", "/preview/<job_id>"]
 
 
 
 
 
528
  }), 200
529
 
530
  if __name__ == '__main__':
 
16
  import numpy as np
17
  import trimesh
18
  from transformers import pipeline
19
+ from scipy.ndimage import gaussian_filter, uniform_filter
20
 
21
  app = Flask(__name__)
22
  CORS(app) # Enable CORS for all routes
 
49
  model_loading = False
50
 
51
  # Configuration for processing
52
+ TIMEOUT_SECONDS = 240 # 4 minutes max for processing (increased for larger model)
53
  MAX_DIMENSION = 512 # Max image dimension to process
54
 
55
  # TimeoutError for handling timeouts
 
122
  model_loading = True
123
  print("Starting model loading...")
124
 
125
+ # Using DPT-Large which provides better detail than DPT-Hybrid
126
+ model_name = "Intel/dpt-large"
127
 
128
  # Download model with retry mechanism
129
  max_retries = 3
 
171
  finally:
172
  model_loading = False
173
 
174
+ # Convert depth map to 3D mesh with enhanced detail
175
  def depth_to_mesh(depth_map, image, resolution=100):
176
+ """Convert depth map to 3D mesh with improved detail preservation"""
177
  # Convert depth_map to numpy array if it's a PIL Image
178
  if isinstance(depth_map, Image.Image):
179
  depth_map = np.array(depth_map)
180
 
181
  # Make sure the depth map is 2D
182
  if len(depth_map.shape) > 2:
 
183
  depth_map = np.mean(depth_map, axis=2) if depth_map.shape[2] > 1 else depth_map[:,:,0]
184
 
185
+ # Apply bilateral filter to smooth the depth map while preserving edges
186
+ # First, apply a slight gaussian filter to remove noise
187
+ depth_map_smooth = gaussian_filter(depth_map, sigma=1.0)
188
+
189
  # Get dimensions
190
+ h, w = depth_map_smooth.shape
191
 
192
  # Create a grid of points
193
  x = np.linspace(0, w-1, resolution)
 
197
  # Sample depth at grid points
198
  x_indices = x_grid.astype(int)
199
  y_indices = y_grid.astype(int)
200
+ z_values = depth_map_smooth[y_indices, x_indices]
201
+
202
+ # Normalize depth values with better scaling
203
+ z_min, z_max = np.percentile(z_values, [2, 98]) # Removes outliers
204
+ z_values = (z_values - z_min) / (z_max - z_min) if z_max > z_min else z_values
205
+ z_values = z_values * 2.0 # Scale depth
206
+
207
+ # Apply a local contrast enhancement to bring out details
208
+ # Simple adaptive normalization
209
+ window_size = resolution // 10
210
+ if window_size > 0:
211
+ local_mean = uniform_filter(z_values, size=window_size)
212
+ local_var = uniform_filter(z_values**2, size=window_size) - local_mean**2
213
+ local_std = np.sqrt(np.maximum(local_var, 0))
214
+
215
+ # Enhance local contrast
216
+ enhanced_z = (z_values - local_mean) / (local_std + 0.01) * 0.5 + z_values
217
+ z_values = np.clip(enhanced_z, 0, None) # Keep values positive
218
 
219
  # Normalize x and y coordinates
220
  x_grid = (x_grid / w - 0.5) * 2.0 # Map to -1 to 1
 
232
  p3 = (i + 1) * resolution + j
233
  p4 = (i + 1) * resolution + (j + 1)
234
 
235
+ # Create two triangles for each grid cell
236
  faces.append([p1, p2, p4])
237
  faces.append([p1, p4, p3])
238
 
 
241
  # Create mesh
242
  mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
243
 
244
+ # Apply texturing if image is provided
245
  if image:
246
+ # Convert to numpy array if needed
247
+ if isinstance(image, Image.Image):
248
+ img_array = np.array(image)
249
+ else:
250
+ img_array = image
251
+
252
+ # Create simple texture by sampling the original image
253
+ if resolution <= img_array.shape[0] and resolution <= img_array.shape[1]:
254
+ # Create vertex colors by sampling the image
255
+ vertex_colors = np.zeros((vertices.shape[0], 4), dtype=np.uint8)
256
+
257
+ for i in range(resolution):
258
+ for j in range(resolution):
259
+ img_x = min(int(j * img_array.shape[1] / resolution), img_array.shape[1]-1)
260
+ img_y = min(int(i * img_array.shape[0] / resolution), img_array.shape[0]-1)
261
+
262
+ vertex_idx = i * resolution + j
263
+ if len(img_array.shape) == 3 and img_array.shape[2] == 3: # RGB
264
+ vertex_colors[vertex_idx, :3] = img_array[img_y, img_x, :]
265
+ vertex_colors[vertex_idx, 3] = 255 # Alpha
266
+ elif len(img_array.shape) == 3 and img_array.shape[2] == 4: # RGBA
267
+ vertex_colors[vertex_idx, :] = img_array[img_y, img_x, :]
268
+ else:
269
+ # Handle grayscale or other formats
270
+ gray_value = img_array[img_y, img_x]
271
+ vertex_colors[vertex_idx, :3] = [gray_value, gray_value, gray_value]
272
+ vertex_colors[vertex_idx, 3] = 255
273
+
274
+ mesh.visual.vertex_colors = vertex_colors
275
 
276
  return mesh
277
 
 
279
  def health_check():
280
  return jsonify({
281
  "status": "healthy",
282
+ "model": "Depth-Based 3D Model Generator (DPT-Large)",
283
  "device": "cuda" if torch.cuda.is_available() else "cpu"
284
  }), 200
285
 
 
339
  try:
340
  mesh_resolution = min(int(request.form.get('mesh_resolution', 100)), 200) # Limit max resolution
341
  output_format = request.form.get('output_format', 'obj').lower()
342
+ detail_level = request.form.get('detail_level', 'medium').lower() # New parameter for detail level
343
  except ValueError:
344
  return jsonify({"error": "Invalid parameter values"}), 400
345
 
 
347
  if output_format not in ['obj', 'glb']:
348
  return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
349
 
350
+ # Adjust mesh resolution based on detail level
351
+ if detail_level == 'high':
352
+ mesh_resolution = min(mesh_resolution * 1.5, 200)
353
+ elif detail_level == 'low':
354
+ mesh_resolution = max(int(mesh_resolution * 0.7), 50)
355
+
356
  # Create a job ID
357
  job_id = str(uuid.uuid4())
358
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
 
424
  processing_jobs[job_id]['progress'] = 60
425
 
426
  # Create mesh from depth map
427
+ mesh_resolution_int = int(mesh_resolution)
428
+ mesh = depth_to_mesh(depth_map, image, resolution=mesh_resolution_int)
429
  processing_jobs[job_id]['progress'] = 80
430
 
431
  except Exception as e:
 
576
  @app.route('/', methods=['GET'])
577
  def index():
578
  return jsonify({
579
+ "message": "Image to 3D API is running (DPT-Large Model)",
580
+ "endpoints": ["/convert", "/progress/<job_id>", "/download/<job_id>", "/preview/<job_id>"],
581
+ "parameters": {
582
+ "mesh_resolution": "Integer (50-200), controls mesh density",
583
+ "output_format": "obj or glb",
584
+ "detail_level": "low, medium, or high"
585
+ }
586
  }), 200
587
 
588
  if __name__ == '__main__':