mac9087 commited on
Commit
cfa68ff
·
verified ·
1 Parent(s): 8c48b7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -539
app.py CHANGED
@@ -15,13 +15,13 @@ from huggingface_hub import snapshot_download
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
- from transformers import pipeline
19
- from scipy.ndimage import gaussian_filter, uniform_filter, median_filter
20
- from scipy import interpolate
21
  import cv2
 
 
 
22
 
23
  app = Flask(__name__)
24
- CORS(app) # Enable CORS for all routes
25
 
26
  # Configure directories
27
  UPLOAD_FOLDER = '/tmp/uploads'
@@ -29,12 +29,12 @@ RESULTS_FOLDER = '/tmp/results'
29
  CACHE_DIR = '/tmp/huggingface'
30
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
31
 
32
- # Create necessary directories
33
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
34
  os.makedirs(RESULTS_FOLDER, exist_ok=True)
35
  os.makedirs(CACHE_DIR, exist_ok=True)
36
 
37
- # Set Hugging Face cache environment variables
38
  os.environ['HF_HOME'] = CACHE_DIR
39
  os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
40
  os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
@@ -42,23 +42,23 @@ os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
42
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
43
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
44
 
45
- # Job tracking dictionary
46
  processing_jobs = {}
47
 
48
  # Global model variables
49
- depth_estimator = None
 
 
50
  model_loaded = False
51
  model_loading = False
52
 
53
- # Configuration for processing
54
- TIMEOUT_SECONDS = 240 # 4 minutes max for processing
55
- MAX_DIMENSION = 512 # Max image dimension to process
56
 
57
- # TimeoutError for handling timeouts
58
  class TimeoutError(Exception):
59
  pass
60
 
61
- # Thread-safe timeout implementation
62
  def process_with_timeout(function, args, timeout):
63
  result = [None]
64
  error = [None]
@@ -91,70 +91,71 @@ def process_with_timeout(function, args, timeout):
91
  def allowed_file(filename):
92
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
93
 
94
- # Enhanced image preprocessing with better detail preservation
95
  def preprocess_image(image_path):
96
  with Image.open(image_path) as img:
97
  img = img.convert("RGB")
98
 
99
- # Resize if the image is too large
100
  if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
101
- # Calculate new dimensions while preserving aspect ratio
102
  if img.width > img.height:
103
  new_width = MAX_DIMENSION
104
  new_height = int(img.height * (MAX_DIMENSION / img.width))
105
  else:
106
  new_height = MAX_DIMENSION
107
  new_width = int(img.width * (MAX_DIMENSION / img.height))
108
-
109
- # Use high-quality Lanczos resampling for better detail preservation
110
  img = img.resize((new_width, new_height), Image.LANCZOS)
111
 
112
- # Convert to numpy array for additional preprocessing
113
  img_array = np.array(img)
114
-
115
- # Optional: Apply adaptive histogram equalization for better contrast
116
- # This helps the depth model detect more details
117
  if len(img_array.shape) == 3 and img_array.shape[2] == 3:
118
- # Convert to LAB color space
119
  lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
120
  l, a, b = cv2.split(lab)
121
-
122
- # Apply CLAHE to L channel
123
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
124
  cl = clahe.apply(l)
125
-
126
- # Merge channels back
127
  enhanced_lab = cv2.merge((cl, a, b))
128
-
129
- # Convert back to RGB
130
  img_array = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)
131
-
132
- # Convert back to PIL Image
133
  img = Image.fromarray(img_array)
134
 
135
  return img
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def load_model():
138
- global depth_estimator, model_loaded, model_loading
139
 
140
  if model_loaded:
141
- return depth_estimator
142
 
143
  if model_loading:
144
- # Wait for model to load if it's already in progress
145
  while model_loading and not model_loaded:
146
  time.sleep(0.5)
147
- return depth_estimator
148
 
149
  try:
150
  model_loading = True
151
- print("Starting model loading...")
152
-
153
- # Using DPT-Large which provides better detail than DPT-Hybrid
154
- # Alternatively, consider "vinvino02/glpn-nyu" for different detail characteristics
155
- model_name = "Intel/dpt-large"
156
 
157
- # Download model with retry mechanism
158
  max_retries = 3
159
  retry_delay = 5
160
 
@@ -168,30 +169,19 @@ def load_model():
168
  break
169
  except Exception as e:
170
  if attempt < max_retries - 1:
171
- print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying in {retry_delay} seconds...")
172
  time.sleep(retry_delay)
173
  retry_delay *= 2
174
  else:
175
  raise
176
 
177
- # Initialize model with appropriate precision
178
- device = "cuda" if torch.cuda.is_available() else "cpu"
179
-
180
- # Load depth estimator pipeline
181
- depth_estimator = pipeline(
182
- "depth-estimation",
183
- model=model_name,
184
- device=device if device == "cuda" else -1,
185
- cache_dir=CACHE_DIR
186
- )
187
-
188
- # Optimize memory usage
189
- if device == "cuda":
190
- torch.cuda.empty_cache()
191
 
192
  model_loaded = True
193
- print(f"Model loaded successfully on {device}")
194
- return depth_estimator
195
 
196
  except Exception as e:
197
  print(f"Error loading model: {str(e)}")
@@ -200,227 +190,28 @@ def load_model():
200
  finally:
201
  model_loading = False
202
 
203
- # Enhanced depth processing function to improve detail quality
204
- def enhance_depth_map(depth_map, detail_level='medium'):
205
- """Apply sophisticated processing to enhance depth map details"""
206
- # Convert to numpy array if needed
207
- if isinstance(depth_map, Image.Image):
208
- depth_map = np.array(depth_map)
209
-
210
- # Make sure the depth map is 2D
211
- if len(depth_map.shape) > 2:
212
- depth_map = np.mean(depth_map, axis=2) if depth_map.shape[2] > 1 else depth_map[:,:,0]
213
-
214
- # Create a copy for processing
215
- enhanced_depth = depth_map.copy().astype(np.float32)
216
-
217
- # Remove outliers using percentile clipping (more stable than min/max)
218
- p_low, p_high = np.percentile(enhanced_depth, [1, 99])
219
- enhanced_depth = np.clip(enhanced_depth, p_low, p_high)
220
-
221
- # Normalize to 0-1 range for processing
222
- enhanced_depth = (enhanced_depth - p_low) / (p_high - p_low) if p_high > p_low else enhanced_depth
223
-
224
- # Apply different enhancement methods based on detail level
225
  if detail_level == 'high':
226
- # Apply unsharp masking for edge enhancement - simulating Hunyuan's detail technique
227
- # First apply gaussian blur
228
- blurred = gaussian_filter(enhanced_depth, sigma=1.5)
229
- # Create the unsharp mask
230
- mask = enhanced_depth - blurred
231
- # Apply the mask with strength factor
232
- enhanced_depth = enhanced_depth + 1.5 * mask
233
-
234
- # Apply bilateral filter to preserve edges while smoothing noise
235
- # Simulate using gaussian combinations
236
- smooth1 = gaussian_filter(enhanced_depth, sigma=0.5)
237
- smooth2 = gaussian_filter(enhanced_depth, sigma=2.0)
238
- edge_mask = enhanced_depth - smooth2
239
- enhanced_depth = smooth1 + 1.2 * edge_mask
240
-
241
  elif detail_level == 'medium':
242
- # Less aggressive but still effective enhancement
243
- # Apply mild unsharp masking
244
- blurred = gaussian_filter(enhanced_depth, sigma=1.0)
245
- mask = enhanced_depth - blurred
246
- enhanced_depth = enhanced_depth + 0.8 * mask
247
-
248
- # Apply mild smoothing to reduce noise but preserve edges
249
- enhanced_depth = gaussian_filter(enhanced_depth, sigma=0.5)
250
-
251
- else: # low
252
- # Just apply noise reduction without too much detail enhancement
253
- enhanced_depth = gaussian_filter(enhanced_depth, sigma=0.7)
254
-
255
- # Normalize again after processing
256
- enhanced_depth = np.clip(enhanced_depth, 0, 1)
257
-
258
- return enhanced_depth
259
-
260
- # Convert depth map to 3D mesh with significantly enhanced detail
261
- def depth_to_mesh(depth_map, image, resolution=100, detail_level='medium'):
262
- """Convert depth map to 3D mesh with highly improved detail preservation"""
263
- # First, enhance the depth map for better details
264
- enhanced_depth = enhance_depth_map(depth_map, detail_level)
265
-
266
- # Get dimensions of depth map
267
- h, w = enhanced_depth.shape
268
-
269
- # Create a higher resolution grid for better detail
270
- x = np.linspace(0, w-1, resolution)
271
- y = np.linspace(0, h-1, resolution)
272
- x_grid, y_grid = np.meshgrid(x, y)
273
-
274
- # Use bicubic interpolation for smoother surface with better details
275
- # Create interpolation function
276
- interp_func = interpolate.RectBivariateSpline(
277
- np.arange(h), np.arange(w), enhanced_depth, kx=3, ky=3
278
- )
279
-
280
- # Sample depth at grid points with the interpolation function
281
- z_values = interp_func(y, x, grid=True)
282
-
283
- # Apply a post-processing step to enhance small details even further
284
- if detail_level == 'high':
285
- # Calculate local gradients to detect edges
286
- dx = np.gradient(z_values, axis=1)
287
- dy = np.gradient(z_values, axis=0)
288
-
289
- # Enhance edges by increasing depth differences at high gradient areas
290
- gradient_magnitude = np.sqrt(dx**2 + dy**2)
291
- edge_mask = np.clip(gradient_magnitude * 5, 0, 0.2) # Scale and limit effect
292
-
293
- # Apply edge enhancement
294
- z_values = z_values + edge_mask * (z_values - gaussian_filter(z_values, sigma=1.0))
295
-
296
- # Normalize z-values with advanced scaling for better depth impression
297
- z_min, z_max = np.percentile(z_values, [2, 98]) # Remove outliers
298
- z_values = (z_values - z_min) / (z_max - z_min) if z_max > z_min else z_values
299
-
300
- # Apply depth scaling appropriate to the detail level
301
- if detail_level == 'high':
302
- z_scaling = 2.5 # More pronounced depth variations
303
- elif detail_level == 'medium':
304
- z_scaling = 2.0 # Standard depth
305
  else:
306
- z_scaling = 1.5 # More subtle depth variations
307
-
308
- z_values = z_values * z_scaling
309
-
310
- # Normalize x and y coordinates
311
- x_grid = (x_grid / w - 0.5) * 2.0 # Map to -1 to 1
312
- y_grid = (y_grid / h - 0.5) * 2.0 # Map to -1 to 1
313
-
314
- # Create vertices
315
- vertices = np.vstack([x_grid.flatten(), -y_grid.flatten(), -z_values.flatten()]).T
316
-
317
- # Create faces (triangles) with optimized winding for better normals
318
- faces = []
319
- for i in range(resolution-1):
320
- for j in range(resolution-1):
321
- p1 = i * resolution + j
322
- p2 = i * resolution + (j + 1)
323
- p3 = (i + 1) * resolution + j
324
- p4 = (i + 1) * resolution + (j + 1)
325
-
326
- # Calculate normals to ensure consistent orientation
327
- v1 = vertices[p1]
328
- v2 = vertices[p2]
329
- v3 = vertices[p3]
330
- v4 = vertices[p4]
331
-
332
- # Calculate normals for both possible triangulations
333
- # and choose the one that's more consistent
334
- norm1 = np.cross(v2-v1, v4-v1)
335
- norm2 = np.cross(v4-v3, v1-v3)
336
-
337
- if np.dot(norm1, norm2) >= 0:
338
- # Standard triangulation
339
- faces.append([p1, p2, p4])
340
- faces.append([p1, p4, p3])
341
- else:
342
- # Alternative triangulation for smoother surface
343
- faces.append([p1, p2, p3])
344
- faces.append([p2, p4, p3])
345
-
346
- faces = np.array(faces)
347
-
348
- # Create mesh
349
- mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
350
-
351
- # Apply advanced texturing if image is provided
352
- if image:
353
- # Convert to numpy array if needed
354
- if isinstance(image, Image.Image):
355
- img_array = np.array(image)
356
- else:
357
- img_array = image
358
-
359
- # Create vertex colors with improved sampling
360
- if resolution <= img_array.shape[0] and resolution <= img_array.shape[1]:
361
- # Create vertex colors by sampling the image with bilinear interpolation
362
- vertex_colors = np.zeros((vertices.shape[0], 4), dtype=np.uint8)
363
-
364
- # Get normalized coordinates for sampling
365
- for i in range(resolution):
366
- for j in range(resolution):
367
- # Calculate exact image coordinates with proper scaling
368
- img_x = j * (img_array.shape[1] - 1) / (resolution - 1)
369
- img_y = i * (img_array.shape[0] - 1) / (resolution - 1)
370
-
371
- # Bilinear interpolation for smooth color transitions
372
- x0, y0 = int(img_x), int(img_y)
373
- x1, y1 = min(x0 + 1, img_array.shape[1] - 1), min(y0 + 1, img_array.shape[0] - 1)
374
-
375
- # Calculate interpolation weights
376
- wx = img_x - x0
377
- wy = img_y - y0
378
-
379
- vertex_idx = i * resolution + j
380
-
381
- if len(img_array.shape) == 3 and img_array.shape[2] == 3: # RGB
382
- # Perform bilinear interpolation for each color channel
383
- r = int((1-wx)*(1-wy)*img_array[y0, x0, 0] + wx*(1-wy)*img_array[y0, x1, 0] +
384
- (1-wx)*wy*img_array[y1, x0, 0] + wx*wy*img_array[y1, x1, 0])
385
- g = int((1-wx)*(1-wy)*img_array[y0, x0, 1] + wx*(1-wy)*img_array[y0, x1, 1] +
386
- (1-wx)*wy*img_array[y1, x0, 1] + wx*wy*img_array[y1, x1, 1])
387
- b = int((1-wx)*(1-wy)*img_array[y0, x0, 2] + wx*(1-wy)*img_array[y0, x1, 2] +
388
- (1-wx)*wy*img_array[y1, x0, 2] + wx*wy*img_array[y1, x1, 2])
389
-
390
- vertex_colors[vertex_idx, :3] = [r, g, b]
391
- vertex_colors[vertex_idx, 3] = 255 # Alpha
392
- elif len(img_array.shape) == 3 and img_array.shape[2] == 4: # RGBA
393
- for c in range(4): # For each RGBA channel
394
- vertex_colors[vertex_idx, c] = int((1-wx)*(1-wy)*img_array[y0, x0, c] +
395
- wx*(1-wy)*img_array[y0, x1, c] +
396
- (1-wx)*wy*img_array[y1, x0, c] +
397
- wx*wy*img_array[y1, x1, c])
398
- else:
399
- # Handle grayscale with bilinear interpolation
400
- gray = int((1-wx)*(1-wy)*img_array[y0, x0] + wx*(1-wy)*img_array[y0, x1] +
401
- (1-wx)*wy*img_array[y1, x0] + wx*wy*img_array[y1, x1])
402
- vertex_colors[vertex_idx, :3] = [gray, gray, gray]
403
- vertex_colors[vertex_idx, 3] = 255
404
-
405
- mesh.visual.vertex_colors = vertex_colors
406
 
407
- # Apply smoothing to get rid of staircase artifacts
408
- if detail_level != 'high':
409
- # For medium and low detail, apply Laplacian smoothing
410
- # but preserve the overall shape
411
- mesh = mesh.smoothed(method='laplacian', iterations=1)
412
 
413
- # Calculate and fix normals for better rendering
414
  mesh.fix_normals()
415
-
416
  return mesh
417
 
418
  @app.route('/health', methods=['GET'])
419
  def health_check():
420
  return jsonify({
421
- "status": "healthy",
422
- "model": "Enhanced Depth-Based 3D Model Generator (DPT-Large)",
423
- "device": "cuda" if torch.cuda.is_available() else "cpu"
424
  }), 200
425
 
426
  @app.route('/progress/<job_id>', methods=['GET'])
@@ -431,30 +222,23 @@ def progress(job_id):
431
  return
432
 
433
  job = processing_jobs[job_id]
434
-
435
- # Send initial progress
436
  yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
437
 
438
- # Wait for job to complete or update
439
  last_progress = job['progress']
440
  check_count = 0
441
  while job['status'] == 'processing':
442
  if job['progress'] != last_progress:
443
  yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
444
  last_progress = job['progress']
445
-
446
  time.sleep(0.5)
447
  check_count += 1
448
-
449
- # If client hasn't received updates for a while, check if job is still running
450
- if check_count > 60: # 30 seconds with no updates
451
  if 'thread_alive' in job and not job['thread_alive']():
452
  job['status'] = 'error'
453
  job['error'] = 'Processing thread died unexpectedly'
454
  break
455
  check_count = 0
456
 
457
- # Send final status
458
  if job['status'] == 'completed':
459
  yield f"data: {json.dumps({'status': 'completed', 'progress': 100, 'result_url': job['result_url'], 'preview_url': job['preview_url']})}\n\n"
460
  else:
@@ -464,7 +248,6 @@ def progress(job_id):
464
 
465
  @app.route('/convert', methods=['POST'])
466
  def convert_image_to_3d():
467
- # Check if image is in the request
468
  if 'image' not in request.files:
469
  return jsonify({"error": "No image provided"}), 400
470
 
@@ -473,38 +256,26 @@ def convert_image_to_3d():
473
  return jsonify({"error": "No image selected"}), 400
474
 
475
  if not allowed_file(file.filename):
476
- return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
477
 
478
- # Get optional parameters with defaults
479
  try:
480
- mesh_resolution = min(int(request.form.get('mesh_resolution', 100)), 200) # Limit max resolution
481
- output_format = request.form.get('output_format', 'obj').lower()
482
- detail_level = request.form.get('detail_level', 'medium').lower() # Parameter for detail level
483
- texture_quality = request.form.get('texture_quality', 'medium').lower() # New parameter for texture quality
484
  except ValueError:
485
  return jsonify({"error": "Invalid parameter values"}), 400
486
 
487
- # Validate output format
488
  if output_format not in ['obj', 'glb']:
489
- return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
490
-
491
- # Adjust mesh resolution based on detail level
492
- if detail_level == 'high':
493
- mesh_resolution = min(int(mesh_resolution * 1.5), 200)
494
- elif detail_level == 'low':
495
- mesh_resolution = max(int(mesh_resolution * 0.7), 50)
496
 
497
- # Create a job ID
498
  job_id = str(uuid.uuid4())
499
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
500
  os.makedirs(output_dir, exist_ok=True)
501
 
502
- # Save the uploaded file
503
  filename = secure_filename(file.filename)
504
  filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}")
505
  file.save(filepath)
506
 
507
- # Initialize job tracking
508
  processing_jobs[job_id] = {
509
  'status': 'processing',
510
  'progress': 0,
@@ -515,44 +286,40 @@ def convert_image_to_3d():
515
  'created_at': time.time()
516
  }
517
 
518
- # Start processing in a separate thread
519
  def process_image():
520
  thread = threading.current_thread()
521
  processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
522
 
523
  try:
524
- # Preprocess image with enhanced detail preservation
525
  processing_jobs[job_id]['progress'] = 5
526
  image = preprocess_image(filepath)
527
  processing_jobs[job_id]['progress'] = 10
528
 
529
- # Load model
 
 
 
 
 
530
  try:
531
- model = load_model()
532
- processing_jobs[job_id]['progress'] = 30
533
  except Exception as e:
534
  processing_jobs[job_id]['status'] = 'error'
535
  processing_jobs[job_id]['error'] = f"Error loading model: {str(e)}"
536
  return
537
 
538
- # Process image with thread-safe timeout
539
  try:
540
- def estimate_depth():
541
- # Get depth map
542
- result = model(image)
543
- depth_map = result["depth"]
544
-
545
- # Convert to numpy array if needed
546
- if isinstance(depth_map, torch.Tensor):
547
- depth_map = depth_map.cpu().numpy()
548
- elif hasattr(depth_map, 'numpy'):
549
- depth_map = depth_map.numpy()
550
- elif isinstance(depth_map, Image.Image):
551
- depth_map = np.array(depth_map)
552
-
553
- return depth_map
554
 
555
- depth_map, error = process_with_timeout(estimate_depth, [], TIMEOUT_SECONDS)
556
 
557
  if error:
558
  if isinstance(error, TimeoutError):
@@ -561,12 +328,11 @@ def convert_image_to_3d():
561
  return
562
  else:
563
  raise error
564
-
565
- processing_jobs[job_id]['progress'] = 60
566
 
567
- # Create mesh from depth map with enhanced detail handling
568
- mesh_resolution_int = int(mesh_resolution)
569
- mesh = depth_to_mesh(depth_map, image, resolution=mesh_resolution_int, detail_level=detail_level)
 
570
  processing_jobs[job_id]['progress'] = 80
571
 
572
  except Exception as e:
@@ -577,50 +343,39 @@ def convert_image_to_3d():
577
  print(error_details)
578
  return
579
 
580
- # Export based on requested format with enhanced quality settings
581
  try:
582
  if output_format == 'obj':
583
  obj_path = os.path.join(output_dir, "model.obj")
584
-
585
- # Export with normal and texture coordinates
586
  mesh.export(
587
- obj_path,
588
  file_type='obj',
589
  include_normals=True,
590
  include_texture=True
591
  )
592
-
593
- # Create a zip file with OBJ and MTL
594
  zip_path = os.path.join(output_dir, "model.zip")
595
  with zipfile.ZipFile(zip_path, 'w') as zipf:
596
  zipf.write(obj_path, arcname="model.obj")
597
  mtl_path = os.path.join(output_dir, "model.mtl")
598
  if os.path.exists(mtl_path):
599
  zipf.write(mtl_path, arcname="model.mtl")
600
-
601
- # Include texture file if it exists
602
  texture_path = os.path.join(output_dir, "model.png")
603
  if os.path.exists(texture_path):
604
  zipf.write(texture_path, arcname="model.png")
605
 
606
  processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
607
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
608
-
609
  elif output_format == 'glb':
610
- # Export as GLB with enhanced settings
611
  glb_path = os.path.join(output_dir, "model.glb")
612
- mesh.export(
613
- glb_path,
614
- file_type='glb'
615
- )
616
-
617
  processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
618
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
619
 
620
- # Update job status
621
  processing_jobs[job_id]['status'] = 'completed'
622
  processing_jobs[job_id]['progress'] = 100
623
  print(f"Job {job_id} completed successfully")
 
624
  except Exception as e:
625
  error_details = traceback.format_exc()
626
  processing_jobs[job_id]['status'] = 'error'
@@ -628,51 +383,39 @@ def convert_image_to_3d():
628
  print(f"Error exporting model for job {job_id}: {str(e)}")
629
  print(error_details)
630
 
631
- # Clean up temporary file
632
  if os.path.exists(filepath):
633
  os.remove(filepath)
634
 
635
- # Force garbage collection to free memory
636
  gc.collect()
637
- if torch.cuda.is_available():
638
- torch.cuda.empty_cache()
639
 
640
  except Exception as e:
641
- # Handle errors
642
  error_details = traceback.format_exc()
643
  processing_jobs[job_id]['status'] = 'error'
644
  processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}"
645
  print(f"Error processing job {job_id}: {str(e)}")
646
  print(error_details)
647
-
648
- # Clean up on error
649
  if os.path.exists(filepath):
650
  os.remove(filepath)
651
 
652
- # Start processing thread
653
  processing_thread = threading.Thread(target=process_image)
654
  processing_thread.daemon = True
655
  processing_thread.start()
656
 
657
- # Return job ID immediately
658
- return jsonify({"job_id": job_id}), 202 # 202 Accepted
659
 
660
  @app.route('/download/<job_id>', methods=['GET'])
661
  def download_model(job_id):
662
  if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
663
  return jsonify({"error": "Model not found or processing not complete"}), 404
664
 
665
- # Get the output directory for this job
666
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
667
-
668
- # Determine file format from the job data
669
- output_format = processing_jobs[job_id].get('output_format', 'obj')
670
 
671
  if output_format == 'obj':
672
  zip_path = os.path.join(output_dir, "model.zip")
673
  if os.path.exists(zip_path):
674
  return send_file(zip_path, as_attachment=True, download_name="model.zip")
675
- else: # glb
676
  glb_path = os.path.join(output_dir, "model.glb")
677
  if os.path.exists(glb_path):
678
  return send_file(glb_path, as_attachment=True, download_name="model.glb")
@@ -684,35 +427,30 @@ def preview_model(job_id):
684
  if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
685
  return jsonify({"error": "Model not found or processing not complete"}), 404
686
 
687
- # Get the output directory for this job
688
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
689
- output_format = processing_jobs[job_id].get('output_format', 'obj')
690
 
691
  if output_format == 'obj':
692
  obj_path = os.path.join(output_dir, "model.obj")
693
  if os.path.exists(obj_path):
694
  return send_file(obj_path, mimetype='model/obj')
695
- else: # glb
696
  glb_path = os.path.join(output_dir, "model.glb")
697
  if os.path.exists(glb_path):
698
  return send_file(glb_path, mimetype='model/gltf-binary')
699
 
700
  return jsonify({"error": "Model file not found"}), 404
701
 
702
- # Cleanup old jobs periodically
703
  def cleanup_old_jobs():
704
  current_time = time.time()
705
  job_ids_to_remove = []
706
 
707
  for job_id, job_data in processing_jobs.items():
708
- # Remove completed jobs after 1 hour
709
  if job_data['status'] == 'completed' and (current_time - job_data.get('created_at', 0)) > 3600:
710
  job_ids_to_remove.append(job_id)
711
- # Remove error jobs after 30 minutes
712
  elif job_data['status'] == 'error' and (current_time - job_data.get('created_at', 0)) > 1800:
713
  job_ids_to_remove.append(job_id)
714
 
715
- # Remove the jobs
716
  for job_id in job_ids_to_remove:
717
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
718
  try:
@@ -721,15 +459,11 @@ def cleanup_old_jobs():
721
  shutil.rmtree(output_dir)
722
  except Exception as e:
723
  print(f"Error cleaning up job {job_id}: {str(e)}")
724
-
725
- # Remove from tracking dictionary
726
  if job_id in processing_jobs:
727
  del processing_jobs[job_id]
728
 
729
- # Schedule the next cleanup
730
- threading.Timer(300, cleanup_old_jobs).start() # Run every 5 minutes
731
 
732
- # New endpoint to get detailed information about a model
733
  @app.route('/model-info/<job_id>', methods=['GET'])
734
  def model_info(job_id):
735
  if job_id not in processing_jobs:
@@ -744,27 +478,21 @@ def model_info(job_id):
744
  "error": job.get('error')
745
  }), 200
746
 
747
- # For completed jobs, include information about the model
748
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
749
  model_stats = {}
750
 
751
- # Get file size
752
  if job['output_format'] == 'obj':
753
  obj_path = os.path.join(output_dir, "model.obj")
754
  zip_path = os.path.join(output_dir, "model.zip")
755
-
756
  if os.path.exists(obj_path):
757
  model_stats['obj_size'] = os.path.getsize(obj_path)
758
-
759
  if os.path.exists(zip_path):
760
  model_stats['package_size'] = os.path.getsize(zip_path)
761
-
762
- else: # glb
763
  glb_path = os.path.join(output_dir, "model.glb")
764
  if os.path.exists(glb_path):
765
  model_stats['model_size'] = os.path.getsize(glb_path)
766
 
767
- # Return detailed info
768
  return jsonify({
769
  "status": job['status'],
770
  "model_format": job['output_format'],
@@ -778,185 +506,23 @@ def model_info(job_id):
778
  @app.route('/', methods=['GET'])
779
  def index():
780
  return jsonify({
781
- "message": "Enhanced Image to 3D API (DPT-Large Model)",
782
  "endpoints": [
783
- "/convert",
784
- "/progress/<job_id>",
785
- "/download/<job_id>",
786
  "/preview/<job_id>",
787
  "/model-info/<job_id>"
788
  ],
789
  "parameters": {
790
- "mesh_resolution": "Integer (50-200), controls mesh density",
791
  "output_format": "obj or glb",
792
- "detail_level": "low, medium, or high - controls the level of detail in the final model",
793
- "texture_quality": "low, medium, or high - controls the quality of textures"
794
  },
795
- "description": "This API creates high-quality 3D models from 2D images with enhanced detail finishing similar to Hunyuan model"
796
  }), 200
797
 
798
- # Example endpoint showing how to compare different detail levels
799
- @app.route('/detail-comparison', methods=['POST'])
800
- def compare_detail_levels():
801
- # Check if image is in the request
802
- if 'image' not in request.files:
803
- return jsonify({"error": "No image provided"}), 400
804
-
805
- file = request.files['image']
806
- if file.filename == '':
807
- return jsonify({"error": "No image selected"}), 400
808
-
809
- if not allowed_file(file.filename):
810
- return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
811
-
812
- # Create a job ID
813
- job_id = str(uuid.uuid4())
814
- output_dir = os.path.join(RESULTS_FOLDER, job_id)
815
- os.makedirs(output_dir, exist_ok=True)
816
-
817
- # Save the uploaded file
818
- filename = secure_filename(file.filename)
819
- filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}")
820
- file.save(filepath)
821
-
822
- # Initialize job tracking
823
- processing_jobs[job_id] = {
824
- 'status': 'processing',
825
- 'progress': 0,
826
- 'result_url': None,
827
- 'preview_url': None,
828
- 'error': None,
829
- 'output_format': 'glb', # Use GLB for comparison
830
- 'created_at': time.time(),
831
- 'comparison': True
832
- }
833
-
834
- # Process in separate thread to create 3 different detail levels
835
- def process_comparison():
836
- thread = threading.current_thread()
837
- processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
838
-
839
- try:
840
- # Preprocess image
841
- image = preprocess_image(filepath)
842
- processing_jobs[job_id]['progress'] = 10
843
-
844
- # Load model
845
- try:
846
- model = load_model()
847
- processing_jobs[job_id]['progress'] = 20
848
- except Exception as e:
849
- processing_jobs[job_id]['status'] = 'error'
850
- processing_jobs[job_id]['error'] = f"Error loading model: {str(e)}"
851
- return
852
-
853
- # Process image to get depth map
854
- try:
855
- depth_map = model(image)["depth"]
856
- if isinstance(depth_map, torch.Tensor):
857
- depth_map = depth_map.cpu().numpy()
858
- elif hasattr(depth_map, 'numpy'):
859
- depth_map = depth_map.numpy()
860
- elif isinstance(depth_map, Image.Image):
861
- depth_map = np.array(depth_map)
862
-
863
- processing_jobs[job_id]['progress'] = 40
864
- except Exception as e:
865
- processing_jobs[job_id]['status'] = 'error'
866
- processing_jobs[job_id]['error'] = f"Error estimating depth: {str(e)}"
867
- return
868
-
869
- # Create meshes at different detail levels
870
- result_urls = {}
871
-
872
- for detail_level in ['low', 'medium', 'high']:
873
- try:
874
- # Update progress
875
- if detail_level == 'low':
876
- processing_jobs[job_id]['progress'] = 50
877
- elif detail_level == 'medium':
878
- processing_jobs[job_id]['progress'] = 70
879
- else:
880
- processing_jobs[job_id]['progress'] = 90
881
-
882
- # Create mesh with appropriate detail level
883
- mesh_resolution = 100 # Fixed resolution for fair comparison
884
- if detail_level == 'high':
885
- mesh_resolution = 150
886
- elif detail_level == 'low':
887
- mesh_resolution = 80
888
-
889
- mesh = depth_to_mesh(depth_map, image,
890
- resolution=mesh_resolution,
891
- detail_level=detail_level)
892
-
893
- # Export as GLB
894
- model_path = os.path.join(output_dir, f"model_{detail_level}.glb")
895
- mesh.export(model_path, file_type='glb')
896
-
897
- # Add to result URLs
898
- result_urls[detail_level] = f"/compare-download/{job_id}/{detail_level}"
899
-
900
- except Exception as e:
901
- print(f"Error processing {detail_level} detail level: {str(e)}")
902
- # Continue with other detail levels even if one fails
903
-
904
- # Update job status
905
- processing_jobs[job_id]['status'] = 'completed'
906
- processing_jobs[job_id]['progress'] = 100
907
- processing_jobs[job_id]['result_urls'] = result_urls
908
- processing_jobs[job_id]['completed_at'] = time.time()
909
-
910
- # Clean up temporary file
911
- if os.path.exists(filepath):
912
- os.remove(filepath)
913
-
914
- # Force garbage collection
915
- gc.collect()
916
- if torch.cuda.is_available():
917
- torch.cuda.empty_cache()
918
-
919
- except Exception as e:
920
- # Handle errors
921
- processing_jobs[job_id]['status'] = 'error'
922
- processing_jobs[job_id]['error'] = f"Error during processing: {str(e)}"
923
-
924
- # Clean up on error
925
- if os.path.exists(filepath):
926
- os.remove(filepath)
927
-
928
- # Start processing thread
929
- processing_thread = threading.Thread(target=process_comparison)
930
- processing_thread.daemon = True
931
- processing_thread.start()
932
-
933
- # Return job ID immediately
934
- return jsonify({"job_id": job_id, "check_progress_at": f"/progress/{job_id}"}), 202
935
-
936
- @app.route('/compare-download/<job_id>/<detail_level>', methods=['GET'])
937
- def download_comparison_model(job_id, detail_level):
938
- if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
939
- return jsonify({"error": "Model not found or processing not complete"}), 404
940
-
941
- if 'comparison' not in processing_jobs[job_id] or not processing_jobs[job_id]['comparison']:
942
- return jsonify({"error": "This is not a comparison job"}), 400
943
-
944
- if detail_level not in ['low', 'medium', 'high']:
945
- return jsonify({"error": "Invalid detail level"}), 400
946
-
947
- # Get the output directory for this job
948
- output_dir = os.path.join(RESULTS_FOLDER, job_id)
949
- model_path = os.path.join(output_dir, f"model_{detail_level}.glb")
950
-
951
- if os.path.exists(model_path):
952
- return send_file(model_path, as_attachment=True, download_name=f"model_{detail_level}.glb")
953
-
954
- return jsonify({"error": "File not found"}), 404
955
-
956
  if __name__ == '__main__':
957
- # Start the cleanup thread
958
  cleanup_old_jobs()
959
-
960
- # Use port 7860 which is standard for Hugging Face Spaces
961
  port = int(os.environ.get('PORT', 7860))
962
  app.run(host='0.0.0.0', port=port)
 
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
 
 
 
18
  import cv2
19
+ from transformers import AutoModel, AutoProcessor # For TripoSR
20
+ from u2net import U2NET # For background removal; install from https://github.com/xuebinqin/U-2-Net
21
+ import torchvision.transforms as T
22
 
23
  app = Flask(__name__)
24
+ CORS(app)
25
 
26
  # Configure directories
27
  UPLOAD_FOLDER = '/tmp/uploads'
 
29
  CACHE_DIR = '/tmp/huggingface'
30
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
31
 
32
+ # Create directories
33
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
34
  os.makedirs(RESULTS_FOLDER, exist_ok=True)
35
  os.makedirs(CACHE_DIR, exist_ok=True)
36
 
37
+ # Set Hugging Face cache
38
  os.environ['HF_HOME'] = CACHE_DIR
39
  os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
40
  os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
 
42
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
43
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
44
 
45
+ # Job tracking
46
  processing_jobs = {}
47
 
48
  # Global model variables
49
+ u2net_model = None
50
+ triposr_model = None
51
+ triposr_processor = None
52
  model_loaded = False
53
  model_loading = False
54
 
55
+ # Configuration
56
+ TIMEOUT_SECONDS = 240 # 4 minutes max
57
+ MAX_DIMENSION = 512 # Max image dimension
58
 
 
59
  class TimeoutError(Exception):
60
  pass
61
 
 
62
  def process_with_timeout(function, args, timeout):
63
  result = [None]
64
  error = [None]
 
91
  def allowed_file(filename):
92
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
93
 
 
94
  def preprocess_image(image_path):
95
  with Image.open(image_path) as img:
96
  img = img.convert("RGB")
97
 
98
+ # Resize if too large
99
  if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
 
100
  if img.width > img.height:
101
  new_width = MAX_DIMENSION
102
  new_height = int(img.height * (MAX_DIMENSION / img.width))
103
  else:
104
  new_height = MAX_DIMENSION
105
  new_width = int(img.width * (MAX_DIMENSION / img.height))
 
 
106
  img = img.resize((new_width, new_height), Image.LANCZOS)
107
 
108
+ # Apply adaptive histogram equalization
109
  img_array = np.array(img)
 
 
 
110
  if len(img_array.shape) == 3 and img_array.shape[2] == 3:
 
111
  lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
112
  l, a, b = cv2.split(lab)
 
 
113
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
114
  cl = clahe.apply(l)
 
 
115
  enhanced_lab = cv2.merge((cl, a, b))
 
 
116
  img_array = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)
 
 
117
  img = Image.fromarray(img_array)
118
 
119
  return img
120
 
121
+ def remove_background(image):
122
+ global u2net_model
123
+ if u2net_model is None:
124
+ u2net_model = U2NET()
125
+ u2net_model.load_state_dict(torch.load('u2net.pth', map_location='cpu'))
126
+ u2net_model.eval()
127
+ u2net_model.to('cpu')
128
+
129
+ img_array = np.array(image)
130
+ img_tensor = T.ToTensor()(image.resize((320, 320))).unsqueeze(0)
131
+
132
+ with torch.no_grad():
133
+ d1, *_ = u2net_model(img_tensor)
134
+ pred = d1[:, 0, :, :]
135
+ pred = (pred - pred.min()) / (pred.max() - pred.min())
136
+ mask = (pred > 0.5).float().squeeze().numpy()
137
+
138
+ mask_img = Image.fromarray((mask * 255).astype('uint8')).resize(image.size)
139
+ mask_array = np.array(mask_img)[:, :, np.newaxis] / 255
140
+ result = img_array * mask_array + (1 - mask_array) * 255 # White background
141
+ return Image.fromarray(result.astype('uint8'))
142
+
143
  def load_model():
144
+ global triposr_model, triposr_processor, model_loaded, model_loading
145
 
146
  if model_loaded:
147
+ return triposr_model, triposr_processor
148
 
149
  if model_loading:
 
150
  while model_loading and not model_loaded:
151
  time.sleep(0.5)
152
+ return triposr_model, triposr_processor
153
 
154
  try:
155
  model_loading = True
156
+ print("Loading TripoSR model...")
 
 
 
 
157
 
158
+ model_name = "stabilityai/TripoSR"
159
  max_retries = 3
160
  retry_delay = 5
161
 
 
169
  break
170
  except Exception as e:
171
  if attempt < max_retries - 1:
172
+ print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying...")
173
  time.sleep(retry_delay)
174
  retry_delay *= 2
175
  else:
176
  raise
177
 
178
+ triposr_processor = AutoProcessor.from_pretrained(model_name, cache_dir=CACHE_DIR)
179
+ triposr_model = AutoModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
180
+ triposr_model.to('cpu')
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  model_loaded = True
183
+ print("TripoSR model loaded successfully on CPU")
184
+ return triposr_model, triposr_processor
185
 
186
  except Exception as e:
187
  print(f"Error loading model: {str(e)}")
 
190
  finally:
191
  model_loading = False
192
 
193
+ def optimize_mesh(mesh, detail_level='medium'):
194
+ # Simplify mesh based on detail level
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  if detail_level == 'high':
196
+ target_faces = 50000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  elif detail_level == 'medium':
198
+ target_faces = 30000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  else:
200
+ target_faces = 15000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ if len(mesh.faces) > target_faces:
203
+ mesh = mesh.simplify_quadric_decimation(target_faces)
 
 
 
204
 
205
+ # Fix normals
206
  mesh.fix_normals()
 
207
  return mesh
208
 
209
  @app.route('/health', methods=['GET'])
210
  def health_check():
211
  return jsonify({
212
+ "status": "healthy",
213
+ "model": "TripoSR 3D Model Generator",
214
+ "device": "cpu"
215
  }), 200
216
 
217
  @app.route('/progress/<job_id>', methods=['GET'])
 
222
  return
223
 
224
  job = processing_jobs[job_id]
 
 
225
  yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
226
 
 
227
  last_progress = job['progress']
228
  check_count = 0
229
  while job['status'] == 'processing':
230
  if job['progress'] != last_progress:
231
  yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
232
  last_progress = job['progress']
 
233
  time.sleep(0.5)
234
  check_count += 1
235
+ if check_count > 60:
 
 
236
  if 'thread_alive' in job and not job['thread_alive']():
237
  job['status'] = 'error'
238
  job['error'] = 'Processing thread died unexpectedly'
239
  break
240
  check_count = 0
241
 
 
242
  if job['status'] == 'completed':
243
  yield f"data: {json.dumps({'status': 'completed', 'progress': 100, 'result_url': job['result_url'], 'preview_url': job['preview_url']})}\n\n"
244
  else:
 
248
 
249
  @app.route('/convert', methods=['POST'])
250
  def convert_image_to_3d():
 
251
  if 'image' not in request.files:
252
  return jsonify({"error": "No image provided"}), 400
253
 
 
256
  return jsonify({"error": "No image selected"}), 400
257
 
258
  if not allowed_file(file.filename):
259
+ return jsonify({"error": f"File type not allowed: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
260
 
 
261
  try:
262
+ output_format = request.form.get('output_format', 'glb').lower()
263
+ detail_level = request.form.get('detail_level', 'medium').lower()
264
+ texture_quality = request.form.get('texture_quality', 'medium').lower()
 
265
  except ValueError:
266
  return jsonify({"error": "Invalid parameter values"}), 400
267
 
 
268
  if output_format not in ['obj', 'glb']:
269
+ return jsonify({"error": "Unsupported output format: 'obj' or 'glb'"}), 400
 
 
 
 
 
 
270
 
 
271
  job_id = str(uuid.uuid4())
272
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
273
  os.makedirs(output_dir, exist_ok=True)
274
 
 
275
  filename = secure_filename(file.filename)
276
  filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}")
277
  file.save(filepath)
278
 
 
279
  processing_jobs[job_id] = {
280
  'status': 'processing',
281
  'progress': 0,
 
286
  'created_at': time.time()
287
  }
288
 
 
289
  def process_image():
290
  thread = threading.current_thread()
291
  processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
292
 
293
  try:
294
+ # Preprocess image
295
  processing_jobs[job_id]['progress'] = 5
296
  image = preprocess_image(filepath)
297
  processing_jobs[job_id]['progress'] = 10
298
 
299
+ # Remove background
300
+ processing_jobs[job_id]['progress'] = 20
301
+ clean_image = remove_background(image)
302
+ processing_jobs[job_id]['progress'] = 30
303
+
304
+ # Load TripoSR model
305
  try:
306
+ model, processor = load_model()
307
+ processing_jobs[job_id]['progress'] = 40
308
  except Exception as e:
309
  processing_jobs[job_id]['status'] = 'error'
310
  processing_jobs[job_id]['error'] = f"Error loading model: {str(e)}"
311
  return
312
 
313
+ # Generate 3D model
314
  try:
315
+ def generate_3d():
316
+ inputs = processor(images=clean_image, return_tensors="pt").to('cpu')
317
+ with torch.no_grad():
318
+ outputs = model(**inputs)
319
+ mesh = outputs.mesh # TripoSR outputs a trimesh object
320
+ return mesh
 
 
 
 
 
 
 
 
321
 
322
+ mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
323
 
324
  if error:
325
  if isinstance(error, TimeoutError):
 
328
  return
329
  else:
330
  raise error
 
 
331
 
332
+ processing_jobs[job_id]['progress'] = 70
333
+
334
+ # Optimize mesh
335
+ mesh = optimize_mesh(mesh, detail_level)
336
  processing_jobs[job_id]['progress'] = 80
337
 
338
  except Exception as e:
 
343
  print(error_details)
344
  return
345
 
346
+ # Export model
347
  try:
348
  if output_format == 'obj':
349
  obj_path = os.path.join(output_dir, "model.obj")
 
 
350
  mesh.export(
351
+ obj_path,
352
  file_type='obj',
353
  include_normals=True,
354
  include_texture=True
355
  )
 
 
356
  zip_path = os.path.join(output_dir, "model.zip")
357
  with zipfile.ZipFile(zip_path, 'w') as zipf:
358
  zipf.write(obj_path, arcname="model.obj")
359
  mtl_path = os.path.join(output_dir, "model.mtl")
360
  if os.path.exists(mtl_path):
361
  zipf.write(mtl_path, arcname="model.mtl")
 
 
362
  texture_path = os.path.join(output_dir, "model.png")
363
  if os.path.exists(texture_path):
364
  zipf.write(texture_path, arcname="model.png")
365
 
366
  processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
367
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
368
+
369
  elif output_format == 'glb':
 
370
  glb_path = os.path.join(output_dir, "model.glb")
371
+ mesh.export(glb_path, file_type='glb')
 
 
 
 
372
  processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
373
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
374
 
 
375
  processing_jobs[job_id]['status'] = 'completed'
376
  processing_jobs[job_id]['progress'] = 100
377
  print(f"Job {job_id} completed successfully")
378
+
379
  except Exception as e:
380
  error_details = traceback.format_exc()
381
  processing_jobs[job_id]['status'] = 'error'
 
383
  print(f"Error exporting model for job {job_id}: {str(e)}")
384
  print(error_details)
385
 
 
386
  if os.path.exists(filepath):
387
  os.remove(filepath)
388
 
 
389
  gc.collect()
 
 
390
 
391
  except Exception as e:
 
392
  error_details = traceback.format_exc()
393
  processing_jobs[job_id]['status'] = 'error'
394
  processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}"
395
  print(f"Error processing job {job_id}: {str(e)}")
396
  print(error_details)
 
 
397
  if os.path.exists(filepath):
398
  os.remove(filepath)
399
 
 
400
  processing_thread = threading.Thread(target=process_image)
401
  processing_thread.daemon = True
402
  processing_thread.start()
403
 
404
+ return jsonify({"job_id": job_id}), 202
 
405
 
406
  @app.route('/download/<job_id>', methods=['GET'])
407
  def download_model(job_id):
408
  if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
409
  return jsonify({"error": "Model not found or processing not complete"}), 404
410
 
 
411
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
412
+ output_format = processing_jobs[job_id].get('output_format', 'glb')
 
 
413
 
414
  if output_format == 'obj':
415
  zip_path = os.path.join(output_dir, "model.zip")
416
  if os.path.exists(zip_path):
417
  return send_file(zip_path, as_attachment=True, download_name="model.zip")
418
+ else:
419
  glb_path = os.path.join(output_dir, "model.glb")
420
  if os.path.exists(glb_path):
421
  return send_file(glb_path, as_attachment=True, download_name="model.glb")
 
427
  if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
428
  return jsonify({"error": "Model not found or processing not complete"}), 404
429
 
 
430
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
431
+ output_format = processing_jobs[job_id].get('output_format', 'glb')
432
 
433
  if output_format == 'obj':
434
  obj_path = os.path.join(output_dir, "model.obj")
435
  if os.path.exists(obj_path):
436
  return send_file(obj_path, mimetype='model/obj')
437
+ else:
438
  glb_path = os.path.join(output_dir, "model.glb")
439
  if os.path.exists(glb_path):
440
  return send_file(glb_path, mimetype='model/gltf-binary')
441
 
442
  return jsonify({"error": "Model file not found"}), 404
443
 
 
444
  def cleanup_old_jobs():
445
  current_time = time.time()
446
  job_ids_to_remove = []
447
 
448
  for job_id, job_data in processing_jobs.items():
 
449
  if job_data['status'] == 'completed' and (current_time - job_data.get('created_at', 0)) > 3600:
450
  job_ids_to_remove.append(job_id)
 
451
  elif job_data['status'] == 'error' and (current_time - job_data.get('created_at', 0)) > 1800:
452
  job_ids_to_remove.append(job_id)
453
 
 
454
  for job_id in job_ids_to_remove:
455
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
456
  try:
 
459
  shutil.rmtree(output_dir)
460
  except Exception as e:
461
  print(f"Error cleaning up job {job_id}: {str(e)}")
 
 
462
  if job_id in processing_jobs:
463
  del processing_jobs[job_id]
464
 
465
+ threading.Timer(300, cleanup_old_jobs).start()
 
466
 
 
467
  @app.route('/model-info/<job_id>', methods=['GET'])
468
  def model_info(job_id):
469
  if job_id not in processing_jobs:
 
478
  "error": job.get('error')
479
  }), 200
480
 
 
481
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
482
  model_stats = {}
483
 
 
484
  if job['output_format'] == 'obj':
485
  obj_path = os.path.join(output_dir, "model.obj")
486
  zip_path = os.path.join(output_dir, "model.zip")
 
487
  if os.path.exists(obj_path):
488
  model_stats['obj_size'] = os.path.getsize(obj_path)
 
489
  if os.path.exists(zip_path):
490
  model_stats['package_size'] = os.path.getsize(zip_path)
491
+ else:
 
492
  glb_path = os.path.join(output_dir, "model.glb")
493
  if os.path.exists(glb_path):
494
  model_stats['model_size'] = os.path.getsize(glb_path)
495
 
 
496
  return jsonify({
497
  "status": job['status'],
498
  "model_format": job['output_format'],
 
506
  @app.route('/', methods=['GET'])
507
  def index():
508
  return jsonify({
509
+ "message": "TripoSR Image to 3D API",
510
  "endpoints": [
511
+ "/convert",
512
+ "/progress/<job_id>",
513
+ "/download/<job_id>",
514
  "/preview/<job_id>",
515
  "/model-info/<job_id>"
516
  ],
517
  "parameters": {
 
518
  "output_format": "obj or glb",
519
+ "detail_level": "low, medium, or high - controls mesh density",
520
+ "texture_quality": "low, medium, or high - controls texture quality"
521
  },
522
+ "description": "Creates full 3D models from 2D images with background removal"
523
  }), 200
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  if __name__ == '__main__':
 
526
  cleanup_old_jobs()
 
 
527
  port = int(os.environ.get('PORT', 7860))
528
  app.run(host='0.0.0.0', port=port)