mac9087 commited on
Commit
d706f08
·
verified ·
1 Parent(s): da0c0da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -410
app.py CHANGED
@@ -37,7 +37,7 @@ os.environ['HF_HOME'] = CACHE_DIR
37
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
38
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
39
 
40
- # Job track
41
  processing_jobs = {}
42
 
43
  # Model variables
@@ -47,7 +47,7 @@ depth_anything_processor = None
47
  model_loaded = False
48
  model_loading = False
49
 
50
- TIMEOUT_SECONDS = 300 # Increased timeout for better processing
51
  MAX_DIMENSION = 518
52
 
53
  class TimeoutError(Exception):
@@ -84,77 +84,50 @@ def process_with_timeout(function, args, timeout):
84
  def allowed_file(filename):
85
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
86
 
87
- def remove_background(img):
88
- """
89
- Remove background from image using OpenCV
90
- """
91
- # Store original mode
92
- original_mode = img.mode
93
 
94
- # Convert to RGBA if not already
95
- if img.mode != 'RGBA':
96
- img = img.convert('RGBA')
97
 
98
- # Convert to numpy array
99
- img_array = np.array(img)
 
 
100
 
101
- # Create a mask with alpha channel
102
- if img_array.shape[2] == 4:
103
- # If image already has alpha channel, use it
104
- alpha = img_array[:, :, 3]
105
- if np.all(alpha == 255): # If alpha is all 255, it's not transparent
106
- alpha = None
107
- else:
108
- alpha = None
109
 
110
- # If no alpha channel or all opaque, we need to create a mask
111
- if alpha is None:
112
- # Convert to RGB for processing
113
- img_rgb = cv2.cvtColor(img_array[:, :, :3], cv2.COLOR_RGB2BGR)
114
-
115
- # Create a blank mask
116
- mask = np.zeros(img_rgb.shape[:2], np.uint8)
117
-
118
- # Approximate background with GrabCut algorithm
119
- bgd_model = np.zeros((1, 65), np.float64)
120
- fgd_model = np.zeros((1, 65), np.float64)
121
-
122
- # Define rectangle for initial segmentation (use most of the image)
123
- h, w = img_rgb.shape[:2]
124
- margin = min(h, w) // 10
125
- rect = (margin, margin, w - 2*margin, h - 2*margin)
126
-
127
- try:
128
- # Apply GrabCut
129
- cv2.grabCut(img_rgb, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
130
-
131
- # Create binary mask
132
- mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
133
-
134
- # Refine the mask with morphological operations
135
- kernel = np.ones((5, 5), np.uint8)
136
- mask2 = cv2.morphologyEx(mask2, cv2.MORPH_CLOSE, kernel)
137
- mask2 = cv2.morphologyEx(mask2, cv2.MORPH_OPEN, kernel)
138
-
139
- # Create alpha channel
140
- alpha = mask2 * 255
141
- except Exception as e:
142
- print(f"GrabCut failed: {str(e)}. Using simple thresholding instead.")
143
- # Fallback to simpler method if GrabCut fails
144
- gray = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2GRAY)
145
- _, alpha = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
146
 
147
- # Create RGBA image
148
- result = np.zeros((img_array.shape[0], img_array.shape[1], 4), dtype=np.uint8)
149
- result[:, :, :3] = img_array[:, :, :3]
150
- result[:, :, 3] = alpha
 
151
 
152
- # Convert back to PIL Image
153
- return Image.fromarray(result)
154
-
 
 
 
 
 
155
 
156
- def preprocess_image(image_path, remove_bg=True):
157
  with Image.open(image_path) as img:
 
 
 
 
 
 
 
158
  img = img.convert("RGB")
159
 
160
  if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
@@ -166,42 +139,21 @@ def preprocess_image(image_path, remove_bg=True):
166
  new_width = int(img.width * (MAX_DIMENSION / img.height))
167
  img = img.resize((new_width, new_height), Image.LANCZOS)
168
 
169
- # Enhanced contrast and brightness adjustment
 
 
170
  img_array = np.array(img)
171
  if len(img_array.shape) == 3 and img_array.shape[2] == 3:
172
- # Enhance contrast with CLAHE
173
  lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
174
  l, a, b = cv2.split(lab)
175
- clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8, 8))
176
  cl = clahe.apply(l)
177
  enhanced_lab = cv2.merge((cl, a, b))
178
  img_array = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)
179
-
180
- # Additional brightness adjustment
181
- alpha = 1.1 # Contrast control (1.0-3.0)
182
- beta = 5 # Brightness control (0-100)
183
- img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=beta)
184
-
185
  img = Image.fromarray(img_array)
186
 
187
- # Background removal if requested
188
- if remove_bg:
189
- img = remove_background(img)
190
-
191
- # Save original image with alpha for texture
192
- processed_img = img
193
-
194
- # For model processing, we need to save an RGB copy
195
- model_img = Image.new('RGB', img.size, (255, 255, 255))
196
- model_img.paste(img, mask=img.split()[3]) # Use alpha as mask
197
-
198
- # Return both processed image (with alpha) and model image (RGB)
199
- return processed_img
200
-
201
  return img
202
 
203
-
204
-
205
  def load_models():
206
  global dpt_estimator, depth_anything_model, depth_anything_processor, model_loaded, model_loading
207
 
@@ -217,13 +169,11 @@ def load_models():
217
  model_loading = True
218
  print("Loading models...")
219
 
220
- # Authenticate with Hugging Face
221
  hf_token = os.environ.get('HF_TOKEN')
222
  if hf_token:
223
  login(token=hf_token)
224
  print("Authenticated with Hugging Face token")
225
 
226
- # DPT-Large
227
  dpt_model_name = "Intel/dpt-large"
228
  max_retries = 3
229
  retry_delay = 5
@@ -254,7 +204,6 @@ def load_models():
254
  print("DPT-Large loaded")
255
  gc.collect()
256
 
257
- # Depth Anything
258
  da_model_name = "depth-anything/Depth-Anything-V2-Small-hf"
259
  for attempt in range(max_retries):
260
  try:
@@ -300,9 +249,6 @@ def load_models():
300
  model_loading = False
301
 
302
  def fuse_depth_maps(dpt_depth, da_depth, detail_level='medium'):
303
- """
304
- Improved depth map fusion with better edge preservation and depth control
305
- """
306
  if isinstance(dpt_depth, Image.Image):
307
  dpt_depth = np.array(dpt_depth)
308
  if isinstance(da_depth, torch.Tensor):
@@ -315,170 +261,78 @@ def fuse_depth_maps(dpt_depth, da_depth, detail_level='medium'):
315
  if dpt_depth.shape != da_depth.shape:
316
  da_depth = cv2.resize(da_depth, (dpt_depth.shape[1], dpt_depth.shape[0]), interpolation=cv2.INTER_CUBIC)
317
 
318
- # Better normalization with more robust percentiles
319
- p_low_dpt, p_high_dpt = np.percentile(dpt_depth, [2, 98])
320
- p_low_da, p_high_da = np.percentile(da_depth, [2, 98])
321
  dpt_depth = np.clip((dpt_depth - p_low_dpt) / (p_high_dpt - p_low_dpt), 0, 1) if p_high_dpt > p_low_dpt else dpt_depth
322
  da_depth = np.clip((da_depth - p_low_da) / (p_high_da - p_low_da), 0, 1) if p_high_da > p_low_da else da_depth
323
 
324
- # Apply bilateral filter for edge-preserving smoothing
325
- dpt_depth_smooth = cv2.bilateralFilter((dpt_depth * 255).astype(np.uint8), 9, 75, 75) / 255.0
326
- da_depth_smooth = cv2.bilateralFilter((da_depth * 255).astype(np.uint8), 9, 75, 75) / 255.0
327
-
328
- # Detect edges more precisely using Canny with auto thresholds
329
- edges_dpt = cv2.Canny(
330
- (dpt_depth * 255).astype(np.uint8),
331
- int(np.mean(dpt_depth * 255) * 0.66),
332
- int(np.mean(dpt_depth * 255) * 1.33)
333
- )
334
- edges_da = cv2.Canny(
335
- (da_depth * 255).astype(np.uint8),
336
- int(np.mean(da_depth * 255) * 0.66),
337
- int(np.mean(da_depth * 255) * 1.33)
338
- )
339
-
340
- # Combine edge maps
341
- combined_edges = np.maximum(edges_dpt, edges_da)
342
- edge_mask = gaussian_filter(combined_edges / 255.0, sigma=1.0)
343
-
344
  if detail_level == 'high':
345
- # For high detail, we use more of the DA model at edges and more DPT for flat areas
346
  weight_da = 0.6
 
 
347
  dpt_weight = gaussian_filter(1 - edge_mask, sigma=1.0)
348
  da_weight = gaussian_filter(edge_mask, sigma=1.0)
349
-
350
- # Adaptive depth scaling - reduce extreme depth values
351
- depth_scale = np.ones_like(dpt_depth)
352
- depth_scale = np.where(da_depth > 0.8, 0.8, depth_scale) # Limit maximum depth
353
-
354
- fused_depth = (dpt_weight * dpt_depth_smooth +
355
- da_weight * da_depth_smooth * weight_da * depth_scale +
356
- (1 - weight_da) * dpt_depth_smooth * (1 - da_weight))
357
- elif detail_level == 'medium':
358
- # For medium detail, balanced approach
359
- weight_da = 0.45
360
- # More aggressive depth limitation
361
- depth_scale = np.ones_like(dpt_depth)
362
- depth_scale = np.where(da_depth > 0.75, 0.75 / da_depth, depth_scale) # Limit maximum depth
363
-
364
- fused_depth = ((1 - weight_da) * dpt_depth_smooth +
365
- weight_da * da_depth_smooth * depth_scale)
366
  else:
367
- # For low detail, simpler approach with more smoothing
368
- weight_da = 0.3
369
- fused_depth = (1 - weight_da) * gaussian_filter(dpt_depth_smooth, sigma=0.5) + weight_da * gaussian_filter(da_depth_smooth, sigma=0.5)
370
 
371
- # Final cleanup
372
  fused_depth = np.clip(fused_depth, 0, 1)
373
-
374
- # Apply depth compression to avoid extreme depth values
375
- fused_depth = np.power(fused_depth, 0.85) # Compress depth range
376
-
377
  return fused_depth
378
 
379
  def enhance_depth_map(depth_map, detail_level='medium'):
380
- """
381
- Enhanced depth map processing with better depth control
382
- """
383
  enhanced_depth = depth_map.copy().astype(np.float32)
384
-
385
- # More robust percentile clipping
386
- p_low, p_high = np.percentile(enhanced_depth, [2, 98])
387
  enhanced_depth = np.clip(enhanced_depth, p_low, p_high)
388
  enhanced_depth = (enhanced_depth - p_low) / (p_high - p_low) if p_high > p_low else enhanced_depth
389
 
390
  if detail_level == 'high':
391
- # Apply bilateral filter for edge-preserving smoothing
392
- enhanced_depth_smooth = cv2.bilateralFilter(
393
- (enhanced_depth * 255).astype(np.uint8), 7, 50, 50
394
- ).astype(np.float32) / 255.0
395
-
396
- # Enhance edges
397
- edges = cv2.Canny((enhanced_depth * 255).astype(np.uint8), 50, 150)
398
- edge_mask = gaussian_filter(edges / 255.0, sigma=1.0)
399
-
400
- # Detail enhancement through unsharp masking
401
- blurred = gaussian_filter(enhanced_depth, sigma=1.5)
402
- detail_mask = enhanced_depth - blurred
403
- enhanced_depth = enhanced_depth_smooth + 1.2 * detail_mask * (1 - edge_mask)
404
-
405
- # Compression for better depth control
406
- enhanced_depth = np.power(enhanced_depth, 0.85)
407
- elif detail_level == 'medium':
408
- # Medium detail processing
409
- enhanced_depth_smooth = cv2.bilateralFilter(
410
- (enhanced_depth * 255).astype(np.uint8), 5, 40, 40
411
- ).astype(np.float32) / 255.0
412
-
413
  blurred = gaussian_filter(enhanced_depth, sigma=1.0)
414
- detail_mask = enhanced_depth - blurred
415
- enhanced_depth = enhanced_depth_smooth + 0.7 * detail_mask
416
- enhanced_depth = np.power(enhanced_depth, 0.9) # Milder compression
 
 
 
 
 
 
 
 
417
  else:
418
- # Low detail - more smoothing
419
- enhanced_depth = gaussian_filter(enhanced_depth, sigma=0.8)
420
- enhanced_depth = np.power(enhanced_depth, 0.95) # Light compression
421
 
422
- # Final normalization
423
  enhanced_depth = np.clip(enhanced_depth, 0, 1)
424
  return enhanced_depth
425
 
426
  def depth_to_mesh(depth_map, image, resolution=100, detail_level='medium'):
427
- """
428
- Improved mesh creation with better vertex distribution and depth control
429
- """
430
- # Apply enhanced depth processing
431
  enhanced_depth = enhance_depth_map(depth_map, detail_level)
432
-
433
- # Get dimensions
434
  h, w = enhanced_depth.shape
435
-
436
- # Create grid coordinates
437
  x = np.linspace(0, w-1, resolution)
438
  y = np.linspace(0, h-1, resolution)
439
  x_grid, y_grid = np.meshgrid(x, y)
440
 
441
- # Use bicubic interpolation for smoother depth
442
  interp_func = interpolate.RectBivariateSpline(
443
  np.arange(h), np.arange(w), enhanced_depth, kx=3, ky=3
444
  )
445
  z_values = interp_func(y, x, grid=True)
446
 
447
- # Enhanced edge preservation for high detail
448
  if detail_level == 'high':
449
  dx = np.gradient(z_values, axis=1)
450
  dy = np.gradient(z_values, axis=0)
451
  gradient_magnitude = np.sqrt(dx**2 + dy**2)
452
-
453
- # Limit excessive depth at edges
454
- max_gradient = np.percentile(gradient_magnitude, 95)
455
- edge_factor = np.clip(gradient_magnitude / max_gradient, 0, 1)
456
- edge_depth_limit = np.clip(0.2 - edge_factor * 0.1, 0, 0.2)
457
-
458
- # Apply depth limiting at high-gradient areas
459
- z_values = z_values - edge_factor * edge_depth_limit
460
 
461
- # Better normalization for z values
462
- z_min, z_max = np.percentile(z_values, [2, 98])
463
- z_values = (z_values - z_min) / (z_max - z_min) if z_max > z_min else z_values
 
464
 
465
- # Adaptive depth scaling based on detail level
466
- if detail_level == 'high':
467
- z_scaling = 1.8 # Reduced from 2.5
468
- elif detail_level == 'medium':
469
- z_scaling = 1.5 # Reduced from 2.0
470
- else:
471
- z_scaling = 1.2 # Reduced from 1.5
472
-
473
- # Apply depth compression to avoid extreme values
474
- z_values = np.power(z_values, 0.85) * z_scaling
475
-
476
- # Create 3D coordinates
477
- x_grid = (x_grid / w - 0.5) * 2.0
478
- y_grid = (y_grid / h - 0.5) * 2.0
479
  vertices = np.vstack([x_grid.flatten(), -y_grid.flatten(), -z_values.flatten()]).T
480
 
481
- # Create faces with improved topology
482
  faces = []
483
  for i in range(resolution-1):
484
  for j in range(resolution-1):
@@ -486,18 +340,12 @@ def depth_to_mesh(depth_map, image, resolution=100, detail_level='medium'):
486
  p2 = i * resolution + (j + 1)
487
  p3 = (i + 1) * resolution + j
488
  p4 = (i + 1) * resolution + (j + 1)
489
-
490
- # Check face orientation for better topology
491
  v1 = vertices[p1]
492
  v2 = vertices[p2]
493
  v3 = vertices[p3]
494
  v4 = vertices[p4]
495
-
496
- # Calculate normals
497
  norm1 = np.cross(v2-v1, v4-v1)
498
  norm2 = np.cross(v4-v3, v1-v3)
499
-
500
- # Check if faces should be flipped
501
  if np.dot(norm1, norm2) >= 0:
502
  faces.append([p1, p2, p4])
503
  faces.append([p1, p4, p3])
@@ -505,40 +353,12 @@ def depth_to_mesh(depth_map, image, resolution=100, detail_level='medium'):
505
  faces.append([p1, p2, p3])
506
  faces.append([p2, p4, p3])
507
 
508
- # Check if we have valid faces before creating the mesh
509
- if len(faces) == 0:
510
- # Create a simple square mesh as fallback
511
- faces = [[0, 1, 2], [1, 3, 2]]
512
-
513
  faces = np.array(faces)
514
-
515
- # Ensure we have at least one vertex and face
516
- if len(vertices) == 0 or len(faces) == 0:
517
- # Create a minimal mesh to avoid errors
518
- vertices = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]])
519
- faces = np.array([[0, 1, 2], [1, 3, 2]])
520
-
521
- # Check for out-of-bounds indexes
522
- max_vertex_idx = len(vertices) - 1
523
- valid_faces = []
524
- for face in faces:
525
- if np.all(face <= max_vertex_idx):
526
- valid_faces.append(face)
527
-
528
- if len(valid_faces) == 0:
529
- # Create a minimal mesh to avoid errors
530
- vertices = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]])
531
- faces = np.array([[0, 1, 2], [1, 3, 2]])
532
- else:
533
- faces = np.array(valid_faces)
534
-
535
  mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
536
 
537
- # Apply vertex colors from image
538
  if image:
539
  img_array = np.array(image)
540
  vertex_colors = np.zeros((vertices.shape[0], 4), dtype=np.uint8)
541
-
542
  for i in range(resolution):
543
  for j in range(resolution):
544
  img_x = j * (img_array.shape[1] - 1) / (resolution - 1)
@@ -547,79 +367,30 @@ def depth_to_mesh(depth_map, image, resolution=100, detail_level='medium'):
547
  x1, y1 = min(x0 + 1, img_array.shape[1] - 1), min(y0 + 1, img_array.shape[0] - 1)
548
  wx = img_x - x0
549
  wy = img_y - y0
550
-
551
  vertex_idx = i * resolution + j
552
-
553
- # Skip if vertex index is out of range
554
- if vertex_idx >= len(vertices):
555
- continue
556
-
557
- # Handle RGBA images
558
  if len(img_array.shape) == 3 and img_array.shape[2] == 4:
559
- # Direct copy of RGBA values with bilinear interpolation
560
  for c in range(4):
561
- vertex_colors[vertex_idx, c] = int(
562
- (1-wx)*(1-wy)*img_array[y0, x0, c] +
563
- wx*(1-wy)*img_array[y0, x1, c] +
564
- (1-wx)*wy*img_array[y1, x0, c] +
565
- wx*wy*img_array[y1, x1, c]
566
- )
567
- # Handle RGB images - add full opacity
568
- elif len(img_array.shape) == 3 and img_array.shape[2] == 3:
569
- for c in range(3):
570
- vertex_colors[vertex_idx, c] = int(
571
- (1-wx)*(1-wy)*img_array[y0, x0, c] +
572
- wx*(1-wy)*img_array[y0, x1, c] +
573
- (1-wx)*wy*img_array[y1, x0, c] +
574
- wx*wy*img_array[y1, x1, c]
575
- )
576
- vertex_colors[vertex_idx, 3] = 255
577
- # Handle grayscale images - convert to RGB with full opacity
578
  else:
579
- gray = int(
580
- (1-wx)*(1-wy)*img_array[y0, x0] +
581
- wx*(1-wy)*img_array[y0, x1] +
582
- (1-wx)*wy*img_array[y1, x0] +
583
- wx*wy*img_array[y1, x1]
584
- )
585
- vertex_colors[vertex_idx, :3] = [gray, gray, gray]
586
  vertex_colors[vertex_idx, 3] = 255
587
-
588
  mesh.visual.vertex_colors = vertex_colors
589
 
590
- try:
591
- # Apply smoothing for non-high detail levels
592
- if detail_level != 'high':
593
- mesh = mesh.smoothed(method='laplacian', iterations=1)
594
-
595
- # Try to fix normals but catch any errors
596
- try:
597
- mesh.fix_normals()
598
- except Exception as e:
599
- print(f"Warning: Could not fix normals: {str(e)}")
600
- # Compute face normals manually if fix_normals fails
601
- mesh.face_normals = trimesh.geometry.triangles_normals(
602
- mesh.triangles
603
- )
604
- except Exception as e:
605
- print(f"Warning: Error in mesh post-processing: {str(e)}")
606
-
607
  return mesh
608
-
609
-
610
-
611
-
612
-
613
-
614
-
615
 
616
  @app.route('/health', methods=['GET'])
617
  def health_check():
618
  return jsonify({
619
  "status": "healthy",
620
  "model": "DPT-Large + Depth Anything",
621
- "device": "cpu",
622
- "version": "1.1.0" # Added version indicator
623
  }), 200
624
 
625
  @app.route('/progress/<job_id>', methods=['GET'])
@@ -671,14 +442,12 @@ def convert_image_to_3d():
671
  output_format = request.form.get('output_format', 'glb').lower()
672
  detail_level = request.form.get('detail_level', 'medium').lower()
673
  texture_quality = request.form.get('texture_quality', 'medium').lower()
674
- remove_bg = request.form.get('remove_background', 'true').lower() == 'true'
675
  except ValueError:
676
  return jsonify({"error": "Invalid parameter values"}), 400
677
 
678
  if output_format not in ['obj', 'glb']:
679
  return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
680
 
681
- # Adjust resolution based on detail level
682
  if detail_level == 'high':
683
  mesh_resolution = min(int(mesh_resolution * 1.5), 150)
684
  elif detail_level == 'low':
@@ -708,13 +477,9 @@ def convert_image_to_3d():
708
 
709
  try:
710
  processing_jobs[job_id]['progress'] = 5
711
- image = preprocess_image(filepath, remove_bg=remove_bg)
712
  processing_jobs[job_id]['progress'] = 10
713
 
714
- # Save the processed image for debugging if needed
715
- debug_img_path = os.path.join(output_dir, "processed_input.png")
716
- image.save(debug_img_path, format="PNG")
717
-
718
  try:
719
  dpt_model, da_model, da_processor = load_models()
720
  processing_jobs[job_id]['progress'] = 30
@@ -726,24 +491,11 @@ def convert_image_to_3d():
726
  try:
727
  def estimate_depth():
728
  with torch.no_grad():
729
- # Make sure image is in RGB format for the models
730
- rgb_image = image
731
- if rgb_image.mode == 'RGBA':
732
- # Convert RGBA to RGB for model processing
733
- rgb_image = Image.new('RGB', image.size, (255, 255, 255))
734
- rgb_image.paste(image, mask=image.split()[3]) # Use alpha channel as mask
735
-
736
- # DPT-Large
737
- dpt_result = dpt_model(rgb_image)
738
  dpt_depth = dpt_result["depth"]
739
- processing_jobs[job_id]['progress'] = 40
740
-
741
-
742
-
743
-
744
- # Depth Anything (if loaded)
745
  if da_model and da_processor:
746
- inputs = da_processor(images=rgb_image, return_tensors="pt") # Use RGB image here
747
  inputs = {k: v.to("cpu") for k, v in inputs.items()}
748
  outputs = da_model(**inputs)
749
  da_depth = outputs.predicted_depth.squeeze()
@@ -753,27 +505,15 @@ def convert_image_to_3d():
753
  mode='bicubic',
754
  align_corners=False
755
  ).squeeze()
756
- processing_jobs[job_id]['progress'] = 50
757
-
758
- # Improved fusion of depth maps
759
  fused_depth = fuse_depth_maps(dpt_depth, da_depth, detail_level)
760
  else:
761
- # Just use DPT with enhanced processing if Depth Anything is not available
762
  fused_depth = np.array(dpt_depth) if isinstance(dpt_depth, Image.Image) else dpt_depth
763
  if len(fused_depth.shape) > 2:
764
  fused_depth = np.mean(fused_depth, axis=2)
765
- # Apply more conservative normalization
766
- p_low, p_high = np.percentile(fused_depth, [2, 98])
767
  fused_depth = np.clip((fused_depth - p_low) / (p_high - p_low), 0, 1) if p_high > p_low else fused_depth
768
- # Apply compression to limit extreme depths
769
- fused_depth = np.power(fused_depth, 0.85)
770
-
771
- # Save depth map for debugging
772
- depth_debug_path = os.path.join(output_dir, "depth_map.png")
773
- cv2.imwrite(depth_debug_path, (fused_depth * 255).astype(np.uint8))
774
-
775
  return fused_depth
776
-
777
 
778
  fused_depth, error = process_with_timeout(estimate_depth, [], TIMEOUT_SECONDS)
779
 
@@ -822,7 +562,6 @@ def convert_image_to_3d():
822
 
823
  processing_jobs[job_id]['status'] = 'completed'
824
  processing_jobs[job_id]['progress'] = 100
825
- processing_jobs[job_id]['completed_at'] = time.time()
826
  print(f"Job {job_id} completed")
827
 
828
  except Exception as e:
@@ -890,58 +629,6 @@ def preview_model(job_id):
890
 
891
  return jsonify({"error": "File not found"}), 404
892
 
893
- @app.route('/debug/<job_id>', methods=['GET'])
894
- def debug_processing(job_id):
895
- """New endpoint to provide debug info about processing"""
896
- if job_id not in processing_jobs:
897
- return jsonify({"error": "Job not found"}), 404
898
-
899
- job = processing_jobs[job_id]
900
- output_dir = os.path.join(RESULTS_FOLDER, job_id)
901
-
902
- debug_info = {
903
- "job_status": job['status'],
904
- "progress": job['progress'],
905
- "created_at": job.get('created_at'),
906
- "completed_at": job.get('completed_at'),
907
- "processing_time": job.get('completed_at', time.time()) - job.get('created_at', time.time()) if job.get('created_at') else None,
908
- "error": job.get('error'),
909
- "output_format": job.get('output_format'),
910
- "available_files": []
911
- }
912
-
913
- # List available debug files
914
- if os.path.exists(output_dir):
915
- for file in os.listdir(output_dir):
916
- file_path = os.path.join(output_dir, file)
917
- if os.path.isfile(file_path):
918
- debug_info["available_files"].append({
919
- "filename": file,
920
- "size": os.path.getsize(file_path),
921
- "url": f"/files/{job_id}/{file}"
922
- })
923
-
924
- return jsonify(debug_info), 200
925
-
926
- @app.route('/files/<job_id>/<filename>', methods=['GET'])
927
- def get_job_file(job_id, filename):
928
- """Access debug files from processing"""
929
- if job_id not in processing_jobs:
930
- return jsonify({"error": "Job not found"}), 404
931
-
932
- file_path = os.path.join(RESULTS_FOLDER, job_id, filename)
933
- if not os.path.exists(file_path) or not os.path.isfile(file_path):
934
- return jsonify({"error": "File not found"}), 404
935
-
936
- # Determine MIME type
937
- mimetype = "application/octet-stream"
938
- if filename.endswith(".png"):
939
- mimetype = "image/png"
940
- elif filename.endswith(".jpg") or filename.endswith(".jpeg"):
941
- mimetype = "image/jpeg"
942
-
943
- return send_file(file_path, mimetype=mimetype)
944
-
945
  def cleanup_old_jobs():
946
  current_time = time.time()
947
  job_ids_to_remove = []
@@ -1002,35 +689,30 @@ def model_info(job_id):
1002
  "preview_url": job['preview_url'],
1003
  "model_stats": model_stats,
1004
  "created_at": job.get('created_at'),
1005
- "completed_at": job.get('completed_at'),
1006
- "processing_time": job.get('completed_at', 0) - job.get('created_at', 0) if job.get('completed_at') and job.get('created_at') else None
1007
  }), 200
1008
 
1009
  @app.route('/', methods=['GET'])
1010
  def index():
1011
  return jsonify({
1012
- "message": "Enhanced Image to 3D API (DPT-Large + Depth Anything)",
1013
- "version": "1.1.0",
1014
  "endpoints": [
1015
  "/convert",
1016
  "/progress/<job_id>",
1017
  "/download/<job_id>",
1018
  "/preview/<job_id>",
1019
- "/model-info/<job_id>",
1020
- "/debug/<job_id>", # New debug endpoint
1021
- "/health"
1022
  ],
1023
  "parameters": {
1024
  "mesh_resolution": "Integer (50-150)",
1025
  "output_format": "obj or glb",
1026
  "detail_level": "low, medium, or high",
1027
- "texture_quality": "low, medium, or high",
1028
- "remove_background": "true or false (default: true)"
1029
  },
1030
- "description": "Creates high-quality 3D models from 2D images with improved depth estimation and background removal."
1031
  }), 200
1032
 
1033
  if __name__ == '__main__':
1034
  cleanup_old_jobs()
1035
  port = int(os.environ.get('PORT', 7860))
1036
- app.run(host='0.0.0.0', port=port)
 
37
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
38
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
39
 
40
+ # Job tracking
41
  processing_jobs = {}
42
 
43
  # Model variables
 
47
  model_loaded = False
48
  model_loading = False
49
 
50
+ TIMEOUT_SECONDS = 240
51
  MAX_DIMENSION = 518
52
 
53
  class TimeoutError(Exception):
 
84
  def allowed_file(filename):
85
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
86
 
87
+ def remove_background(image):
88
+ """Remove background using OpenCV GrabCut algorithm with improved precision"""
89
+ img_array = np.array(image)
 
 
 
90
 
91
+ # Convert to RGB if image has alpha channel
92
+ if img_array.shape[2] == 4:
93
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
94
 
95
+ # Create mask for GrabCut
96
+ mask = np.zeros(img_array.shape[:2], np.uint8)
97
+ bgdModel = np.zeros((1, 65), np.float64)
98
+ fgdModel = np.zeros((1, 65), np.float64)
99
 
100
+ # Define a tighter rectangle for foreground, adjusting based on image content
101
+ height, width = img_array.shape[:2]
102
+ rect = (int(width * 0.1), int(height * 0.1), int(width * 0.8), int(height * 0.8))
 
 
 
 
 
103
 
104
+ # Run GrabCut with multiple iterations for better accuracy
105
+ cv2.grabCut(img_array, mask, rect, bgdModel, fgdModel, 10, cv2.GC_INIT_WITH_RECT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ # Refine mask using edge detection to preserve subject edges
108
+ mask2 = np.where((mask == cv2.GC_PR_FGD) | (mask == cv2.GC_FGD), 1, 0).astype('uint8')
109
+ edges = cv2.Canny(mask2 * 255, 50, 150)
110
+ mask2 = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)
111
+ mask2 = cv2.erode(mask2, np.ones((3, 3), np.uint8), iterations=1)
112
 
113
+ # Apply mask to image
114
+ result = img_array * mask2[:, :, np.newaxis]
115
+
116
+ # Create alpha channel
117
+ alpha = mask2 * 255
118
+ result = np.dstack((result, alpha))
119
+
120
+ return Image.fromarray(result, 'RGBA')
121
 
122
+ def preprocess_image(image_path):
123
  with Image.open(image_path) as img:
124
+ # Handle PNG transparency
125
+ if img.mode == 'RGBA':
126
+ # Create white background
127
+ background = Image.new('RGB', img.size, (255, 255, 255))
128
+ background.paste(img, mask=img.split()[3])
129
+ img = background
130
+
131
  img = img.convert("RGB")
132
 
133
  if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
 
139
  new_width = int(img.width * (MAX_DIMENSION / img.height))
140
  img = img.resize((new_width, new_height), Image.LANCZOS)
141
 
142
+ # Remove background
143
+ img = remove_background(img)
144
+
145
  img_array = np.array(img)
146
  if len(img_array.shape) == 3 and img_array.shape[2] == 3:
 
147
  lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
148
  l, a, b = cv2.split(lab)
149
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
150
  cl = clahe.apply(l)
151
  enhanced_lab = cv2.merge((cl, a, b))
152
  img_array = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)
 
 
 
 
 
 
153
  img = Image.fromarray(img_array)
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  return img
156
 
 
 
157
  def load_models():
158
  global dpt_estimator, depth_anything_model, depth_anything_processor, model_loaded, model_loading
159
 
 
169
  model_loading = True
170
  print("Loading models...")
171
 
 
172
  hf_token = os.environ.get('HF_TOKEN')
173
  if hf_token:
174
  login(token=hf_token)
175
  print("Authenticated with Hugging Face token")
176
 
 
177
  dpt_model_name = "Intel/dpt-large"
178
  max_retries = 3
179
  retry_delay = 5
 
204
  print("DPT-Large loaded")
205
  gc.collect()
206
 
 
207
  da_model_name = "depth-anything/Depth-Anything-V2-Small-hf"
208
  for attempt in range(max_retries):
209
  try:
 
249
  model_loading = False
250
 
251
  def fuse_depth_maps(dpt_depth, da_depth, detail_level='medium'):
 
 
 
252
  if isinstance(dpt_depth, Image.Image):
253
  dpt_depth = np.array(dpt_depth)
254
  if isinstance(da_depth, torch.Tensor):
 
261
  if dpt_depth.shape != da_depth.shape:
262
  da_depth = cv2.resize(da_depth, (dpt_depth.shape[1], dpt_depth.shape[0]), interpolation=cv2.INTER_CUBIC)
263
 
264
+ p_low_dpt, p_high_dpt = np.percentile(dpt_depth, [5, 95])
265
+ p_low_da, p_high_da = np.percentile(da_depth, [5, 95])
 
266
  dpt_depth = np.clip((dpt_depth - p_low_dpt) / (p_high_dpt - p_low_dpt), 0, 1) if p_high_dpt > p_low_dpt else dpt_depth
267
  da_depth = np.clip((da_depth - p_low_da) / (p_high_da - p_low_da), 0, 1) if p_high_da > p_low_da else da_depth
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  if detail_level == 'high':
 
270
  weight_da = 0.6
271
+ edges = cv2.Canny((da_depth * 255).astype(np.uint8), 50, 150)
272
+ edge_mask = (edges > 0).astype(np.float32)
273
  dpt_weight = gaussian_filter(1 - edge_mask, sigma=1.0)
274
  da_weight = gaussian_filter(edge_mask, sigma=1.0)
275
+ fused_depth = dpt_weight * dpt_depth + da_weight * da_depth * weight_da + (1 - weight_da) * dpt_depth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  else:
277
+ weight_da = 0.4 if detail_level == 'medium' else 0.2
278
+ fused_depth = (1 - weight_da) * dpt_depth + weight_da * da_depth
 
279
 
 
280
  fused_depth = np.clip(fused_depth, 0, 1)
 
 
 
 
281
  return fused_depth
282
 
283
  def enhance_depth_map(depth_map, detail_level='medium'):
 
 
 
284
  enhanced_depth = depth_map.copy().astype(np.float32)
285
+ p_low, p_high = np.percentile(enhanced_depth, [5, 95])
 
 
286
  enhanced_depth = np.clip(enhanced_depth, p_low, p_high)
287
  enhanced_depth = (enhanced_depth - p_low) / (p_high - p_low) if p_high > p_low else enhanced_depth
288
 
289
  if detail_level == 'high':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  blurred = gaussian_filter(enhanced_depth, sigma=1.0)
291
+ mask = enhanced_depth - blurred
292
+ enhanced_depth = enhanced_depth + 1.0 * mask
293
+ smooth1 = gaussian_filter(enhanced_depth, sigma=0.3)
294
+ smooth2 = gaussian_filter(enhanced_depth, sigma=1.5)
295
+ edge_mask = enhanced_depth - smooth2
296
+ enhanced_depth = smooth1 + 0.8 * edge_mask # Reduced enhancement
297
+ elif detail_level == 'medium':
298
+ blurred = gaussian_filter(enhanced_depth, sigma=0.7)
299
+ mask = enhanced_depth - blurred
300
+ enhanced_depth = enhanced_depth + 0.6 * mask
301
+ enhanced_depth = gaussian_filter(enhanced_depth, sigma=0.4)
302
  else:
303
+ enhanced_depth = gaussian_filter(enhanced_depth, sigma=0.5)
 
 
304
 
 
305
  enhanced_depth = np.clip(enhanced_depth, 0, 1)
306
  return enhanced_depth
307
 
308
  def depth_to_mesh(depth_map, image, resolution=100, detail_level='medium'):
 
 
 
 
309
  enhanced_depth = enhance_depth_map(depth_map, detail_level)
 
 
310
  h, w = enhanced_depth.shape
 
 
311
  x = np.linspace(0, w-1, resolution)
312
  y = np.linspace(0, h-1, resolution)
313
  x_grid, y_grid = np.meshgrid(x, y)
314
 
 
315
  interp_func = interpolate.RectBivariateSpline(
316
  np.arange(h), np.arange(w), enhanced_depth, kx=3, ky=3
317
  )
318
  z_values = interp_func(y, x, grid=True)
319
 
 
320
  if detail_level == 'high':
321
  dx = np.gradient(z_values, axis=1)
322
  dy = np.gradient(z_values, axis=0)
323
  gradient_magnitude = np.sqrt(dx**2 + dy**2)
324
+ edge_mask = np.clip(gradient_magnitude * 2, 0, 0.1)
325
+ z_values = z_values + edge_mask * (z_values - gaussian_filter(z_values, sigma=0.5))
 
 
 
 
 
 
326
 
327
+ z_min, z_max = np.percentile(z_values, [10, 90])
328
+ z_values = np.clip((z_values - z_min) / (z_max - z_min), 0, 1) if z_max > z_min else z_values
329
+ z_scaling = 1.5 if detail_level == 'high' else 1.2 if detail_level == 'medium' else 1.0
330
+ z_values = z_values * z_scaling
331
 
332
+ x_grid = (x_grid / w - 0.5) * 1.5
333
+ y_grid = (y_grid / h - 0.5) * 1.5
 
 
 
 
 
 
 
 
 
 
 
 
334
  vertices = np.vstack([x_grid.flatten(), -y_grid.flatten(), -z_values.flatten()]).T
335
 
 
336
  faces = []
337
  for i in range(resolution-1):
338
  for j in range(resolution-1):
 
340
  p2 = i * resolution + (j + 1)
341
  p3 = (i + 1) * resolution + j
342
  p4 = (i + 1) * resolution + (j + 1)
 
 
343
  v1 = vertices[p1]
344
  v2 = vertices[p2]
345
  v3 = vertices[p3]
346
  v4 = vertices[p4]
 
 
347
  norm1 = np.cross(v2-v1, v4-v1)
348
  norm2 = np.cross(v4-v3, v1-v3)
 
 
349
  if np.dot(norm1, norm2) >= 0:
350
  faces.append([p1, p2, p4])
351
  faces.append([p1, p4, p3])
 
353
  faces.append([p1, p2, p3])
354
  faces.append([p2, p4, p3])
355
 
 
 
 
 
 
356
  faces = np.array(faces)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
358
 
 
359
  if image:
360
  img_array = np.array(image)
361
  vertex_colors = np.zeros((vertices.shape[0], 4), dtype=np.uint8)
 
362
  for i in range(resolution):
363
  for j in range(resolution):
364
  img_x = j * (img_array.shape[1] - 1) / (resolution - 1)
 
367
  x1, y1 = min(x0 + 1, img_array.shape[1] - 1), min(y0 + 1, img_array.shape[0] - 1)
368
  wx = img_x - x0
369
  wy = img_y - y0
 
370
  vertex_idx = i * resolution + j
 
 
 
 
 
 
371
  if len(img_array.shape) == 3 and img_array.shape[2] == 4:
 
372
  for c in range(4):
373
+ vertex_colors[vertex_idx, c] = int((1-wx)*(1-wy)*img_array[y0, x0, c] +
374
+ wx*(1-wy)*img_array[y0, x1, c] +
375
+ (1-wx)*wy*img_array[y1, x0, c] +
376
+ wx*wy*img_array[y1, x1, c])
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  else:
378
+ r, g, b = img_array[y0, x0]
379
+ vertex_colors[vertex_idx, :3] = [r, g, b]
 
 
 
 
 
380
  vertex_colors[vertex_idx, 3] = 255
 
381
  mesh.visual.vertex_colors = vertex_colors
382
 
383
+ if detail_level != 'high':
384
+ mesh = mesh.smoothed(method='laplacian', iterations=1)
385
+ mesh.fix_normals()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  return mesh
 
 
 
 
 
 
 
387
 
388
  @app.route('/health', methods=['GET'])
389
  def health_check():
390
  return jsonify({
391
  "status": "healthy",
392
  "model": "DPT-Large + Depth Anything",
393
+ "device": "cpu"
 
394
  }), 200
395
 
396
  @app.route('/progress/<job_id>', methods=['GET'])
 
442
  output_format = request.form.get('output_format', 'glb').lower()
443
  detail_level = request.form.get('detail_level', 'medium').lower()
444
  texture_quality = request.form.get('texture_quality', 'medium').lower()
 
445
  except ValueError:
446
  return jsonify({"error": "Invalid parameter values"}), 400
447
 
448
  if output_format not in ['obj', 'glb']:
449
  return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
450
 
 
451
  if detail_level == 'high':
452
  mesh_resolution = min(int(mesh_resolution * 1.5), 150)
453
  elif detail_level == 'low':
 
477
 
478
  try:
479
  processing_jobs[job_id]['progress'] = 5
480
+ image = preprocess_image(filepath)
481
  processing_jobs[job_id]['progress'] = 10
482
 
 
 
 
 
483
  try:
484
  dpt_model, da_model, da_processor = load_models()
485
  processing_jobs[job_id]['progress'] = 30
 
491
  try:
492
  def estimate_depth():
493
  with torch.no_grad():
494
+ dpt_result = dpt_model(image)
 
 
 
 
 
 
 
 
495
  dpt_depth = dpt_result["depth"]
496
+
 
 
 
 
 
497
  if da_model and da_processor:
498
+ inputs = da_processor(images=image, return_tensors="pt")
499
  inputs = {k: v.to("cpu") for k, v in inputs.items()}
500
  outputs = da_model(**inputs)
501
  da_depth = outputs.predicted_depth.squeeze()
 
505
  mode='bicubic',
506
  align_corners=False
507
  ).squeeze()
 
 
 
508
  fused_depth = fuse_depth_maps(dpt_depth, da_depth, detail_level)
509
  else:
 
510
  fused_depth = np.array(dpt_depth) if isinstance(dpt_depth, Image.Image) else dpt_depth
511
  if len(fused_depth.shape) > 2:
512
  fused_depth = np.mean(fused_depth, axis=2)
513
+ p_low, p_high = np.percentile(fused_depth, [5, 95])
 
514
  fused_depth = np.clip((fused_depth - p_low) / (p_high - p_low), 0, 1) if p_high > p_low else fused_depth
515
+
 
 
 
 
 
 
516
  return fused_depth
 
517
 
518
  fused_depth, error = process_with_timeout(estimate_depth, [], TIMEOUT_SECONDS)
519
 
 
562
 
563
  processing_jobs[job_id]['status'] = 'completed'
564
  processing_jobs[job_id]['progress'] = 100
 
565
  print(f"Job {job_id} completed")
566
 
567
  except Exception as e:
 
629
 
630
  return jsonify({"error": "File not found"}), 404
631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  def cleanup_old_jobs():
633
  current_time = time.time()
634
  job_ids_to_remove = []
 
689
  "preview_url": job['preview_url'],
690
  "model_stats": model_stats,
691
  "created_at": job.get('created_at'),
692
+ "completed_at": job.get('completed_at')
 
693
  }), 200
694
 
695
  @app.route('/', methods=['GET'])
696
  def index():
697
  return jsonify({
698
+ "message": "Image to 3D API (DPT-Large + Depth Anything)",
 
699
  "endpoints": [
700
  "/convert",
701
  "/progress/<job_id>",
702
  "/download/<job_id>",
703
  "/preview/<job_id>",
704
+ "/model-info/<job_id>"
 
 
705
  ],
706
  "parameters": {
707
  "mesh_resolution": "Integer (50-150)",
708
  "output_format": "obj or glb",
709
  "detail_level": "low, medium, or high",
710
+ "texture_quality": "low, medium, or high"
 
711
  },
712
+ "description": "Creates high-quality 3D models from 2D images using DPT-Large and Depth Anything."
713
  }), 200
714
 
715
  if __name__ == '__main__':
716
  cleanup_old_jobs()
717
  port = int(os.environ.get('PORT', 7860))
718
+ app.run(host='0.0.0.0', port=port)