mac9087 commited on
Commit
220d6cc
·
verified ·
1 Parent(s): f3a3457

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +334 -111
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import torch
4
  import time
@@ -9,26 +8,22 @@ from flask import Flask, request, jsonify, send_file, Response, stream_with_cont
9
  from werkzeug.utils import secure_filename
10
  from PIL import Image
11
  import io
 
12
  import uuid
13
  import traceback
14
  from huggingface_hub import snapshot_download
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
 
 
 
18
  import cv2
19
- try:
20
- from lgm.models import LGM
21
- except ImportError:
22
- LGM = None
23
-
24
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
25
- torch.set_default_device("cpu")
26
- torch.cuda.is_available = lambda: False
27
- torch.cuda.device_count = lambda: 0
28
 
29
  app = Flask(__name__)
30
  CORS(app)
31
 
 
32
  UPLOAD_FOLDER = '/tmp/uploads'
33
  RESULTS_FOLDER = '/tmp/results'
34
  CACHE_DIR = '/tmp/huggingface'
@@ -45,13 +40,18 @@ os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
45
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
46
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
47
 
 
48
  processing_jobs = {}
49
- lgm_model = None
 
 
 
 
50
  model_loaded = False
51
  model_loading = False
52
 
53
- TIMEOUT_SECONDS = 300
54
- MAX_DIMENSION = 512 # LGM uses 512x512 inputs
55
 
56
  class TimeoutError(Exception):
57
  pass
@@ -88,108 +88,267 @@ def allowed_file(filename):
88
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
89
 
90
  def preprocess_image(image_path):
91
- try:
92
- img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
93
- if img.shape[2] == 4: # RGBA
94
- alpha = img[:, :, 3]
95
- rgb = img[:, :, :3]
96
- white_bg = np.ones_like(rgb) * 255
97
- mask = alpha[:, :, np.newaxis] / 255.0
98
- img = rgb * mask + white_bg * (1 - mask)
99
- img = img.astype(np.uint8)
100
- img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LANCZOS4)
101
- return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
102
- except Exception as e:
103
- raise Exception(f"Error preprocessing image: {str(e)}")
 
 
 
 
 
 
 
 
 
 
104
 
105
- def load_model():
106
- global lgm_model, model_loaded, model_loading
107
 
108
  if model_loaded:
109
- return lgm_model
110
 
111
  if model_loading:
112
  while model_loading and not model_loaded:
113
  time.sleep(0.5)
114
- return lgm_model
115
 
116
  try:
117
  model_loading = True
118
- print("Loading LGM...")
119
-
120
- model_name = "large-gaussian-model/lgm"
121
 
 
 
122
  max_retries = 3
123
  retry_delay = 5
124
  for attempt in range(max_retries):
125
  try:
126
  snapshot_download(
127
- repo_id=model_name,
128
  cache_dir=CACHE_DIR,
129
  resume_download=True,
130
  )
131
  break
132
  except Exception as e:
133
  if attempt < max_retries - 1:
134
- print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying...")
135
  time.sleep(retry_delay)
136
  retry_delay *= 2
137
  else:
138
  raise
139
 
140
- if LGM is None:
141
- raise ImportError("LGM module not available. Ensure lgm is installed from https://github.com/baaivision/LGM.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- lgm_model = LGM.from_pretrained(
144
- model_name,
145
- cache_dir=CACHE_DIR,
146
- device="cpu",
147
  )
 
 
 
 
148
 
149
  model_loaded = True
150
- print("LGM loaded successfully on CPU")
151
- return lgm_model
152
 
153
  except Exception as e:
154
- print(f"Error loading model: {str(e)}")
155
  print(traceback.format_exc())
156
  raise
157
  finally:
158
  model_loading = False
159
 
160
- def generate_3d_model(image, detail_level):
161
- try:
162
- resolution = {'low': 256, 'medium': 512, 'high': 1024}
163
- res = resolution[detail_level]
164
-
165
- with torch.no_grad():
166
- mesh_data = lgm_model.generate_mesh(
167
- image,
168
- resolution=res,
169
- device="cpu"
170
- )
171
-
172
- vertices = np.array(mesh_data['vertices'])
173
- faces = np.array(mesh_data['faces'])
174
- vertex_colors = np.array(mesh_data['vertex_colors']) if 'vertex_colors' in mesh_data else None
175
-
176
- trimesh_mesh = trimesh.Trimesh(
177
- vertices=vertices,
178
- faces=faces,
179
- vertex_colors=vertex_colors
180
- )
181
-
182
- trimesh_mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]))
183
-
184
- return trimesh_mesh
185
- except Exception as e:
186
- raise Exception(f"Error generating 3D model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  @app.route('/health', methods=['GET'])
189
  def health_check():
190
  return jsonify({
191
  "status": "healthy",
192
- "model": "LGM",
193
  "device": "cpu"
194
  }), 200
195
 
@@ -201,19 +360,16 @@ def progress(job_id):
201
  return
202
 
203
  job = processing_jobs[job_id]
204
-
205
  yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
206
 
207
- ’T last_progress = job['progress']
208
  check_count = 0
209
  while job['status'] == 'processing':
210
  if job['progress'] != last_progress:
211
  yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
212
  last_progress = job['progress']
213
-
214
  time.sleep(0.5)
215
  check_count += 1
216
-
217
  if check_count > 60:
218
  if 'thread_alive' in job and not job['thread_alive']():
219
  job['status'] = 'error'
@@ -241,13 +397,20 @@ def convert_image_to_3d():
241
  return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
242
 
243
  try:
 
244
  output_format = request.form.get('output_format', 'glb').lower()
245
  detail_level = request.form.get('detail_level', 'medium').lower()
 
246
  except ValueError:
247
  return jsonify({"error": "Invalid parameter values"}), 400
248
 
249
- if output_format not in ['glb', 'obj']:
250
- return jsonify({"error": "Supported formats: glb, obj"}), 400
 
 
 
 
 
251
 
252
  job_id = str(uuid.uuid4())
253
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
@@ -277,18 +440,37 @@ def convert_image_to_3d():
277
  processing_jobs[job_id]['progress'] = 10
278
 
279
  try:
280
- model = load_model()
281
  processing_jobs[job_id]['progress'] = 30
282
  except Exception as e:
283
  processing_jobs[job_id]['status'] = 'error'
284
- processing_jobs[job_id]['error'] = f"Error loading model: {str(e)}"
285
  return
286
 
287
  try:
288
- def generate_3d():
289
- return generate_3d_model(image, detail_level)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
- mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
292
 
293
  if error:
294
  if isinstance(error, TimeoutError):
@@ -298,18 +480,45 @@ def convert_image_to_3d():
298
  else:
299
  raise error
300
 
 
 
 
301
  processing_jobs[job_id]['progress'] = 80
302
 
303
- file_path = os.path.join(output_dir, f"model.{output_format}")
304
- mesh.export(file_path, file_type=output_format)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
- processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
307
- processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
 
 
 
 
 
 
308
 
309
  processing_jobs[job_id]['status'] = 'completed'
310
  processing_jobs[job_id]['progress'] = 100
311
  print(f"Job {job_id} completed")
312
-
313
  except Exception as e:
314
  error_details = traceback.format_exc()
315
  processing_jobs[job_id]['status'] = 'error'
@@ -320,7 +529,6 @@ def convert_image_to_3d():
320
 
321
  if os.path.exists(filepath):
322
  os.remove(filepath)
323
-
324
  gc.collect()
325
 
326
  except Exception as e:
@@ -329,7 +537,6 @@ def convert_image_to_3d():
329
  processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}"
330
  print(f"Error processing job {job_id}: {str(e)}")
331
  print(error_details)
332
-
333
  if os.path.exists(filepath):
334
  os.remove(filepath)
335
 
@@ -345,11 +552,16 @@ def download_model(job_id):
345
  return jsonify({"error": "Model not found or processing not complete"}), 404
346
 
347
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
348
- output_format = processing_jobs[job_id]['output_format']
349
- file_path = os.path.join(output_dir, f"model.{output_format}")
350
 
351
- if os.path.exists(file_path):
352
- return send_file(file_path, as_attachment=True, download_name=f"model.{output_format}")
 
 
 
 
 
 
353
 
354
  return jsonify({"error": "File not found"}), 404
355
 
@@ -359,14 +571,16 @@ def preview_model(job_id):
359
  return jsonify({"error": "Model not found or processing not complete"}), 404
360
 
361
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
362
- output_format = processing_jobs[job_id]['output_format']
363
- file_path = os.path.join(output_dir, f"model.{output_format}")
364
 
365
- if os.path.exists(file_path):
366
- if output_format == 'glb':
367
- return send_file(file_path, mimetype='model/gltf-binary')
368
- else:
369
- return send_file(file_path, mimetype='text/plain')
 
 
 
370
 
371
  return jsonify({"error": "File not found"}), 404
372
 
@@ -409,16 +623,23 @@ def model_info(job_id):
409
  }), 200
410
 
411
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
412
- output_format = job['output_format']
413
  model_stats = {}
414
 
415
- file_path = os.path.join(output_dir, f"model.{output_format}")
416
- if os.path.exists(file_path):
417
- model_stats['model_size'] = os.path.getsize(file_path)
 
 
 
 
 
 
 
 
418
 
419
  return jsonify({
420
  "status": job['status'],
421
- "model_format": output_format,
422
  "download_url": job['result_url'],
423
  "preview_url": job['preview_url'],
424
  "model_stats": model_stats,
@@ -429,7 +650,7 @@ def model_info(job_id):
429
  @app.route('/', methods=['GET'])
430
  def index():
431
  return jsonify({
432
- "message": "Image to 3D API (LGM)",
433
  "endpoints": [
434
  "/convert",
435
  "/progress/<job_id>",
@@ -438,13 +659,15 @@ def index():
438
  "/model-info/<job_id>"
439
  ],
440
  "parameters": {
441
- "output_format": "glb or obj",
442
- "detail_level": "low, medium, or high"
 
 
443
  },
444
- "description": "Creates 3D models from 2D images using LGM."
445
  }), 200
446
 
447
  if __name__ == '__main__':
448
  cleanup_old_jobs()
449
  port = int(os.environ.get('PORT', 7860))
450
- app.run(host='0.0.0.0', port=port)
 
 
1
  import os
2
  import torch
3
  import time
 
8
  from werkzeug.utils import secure_filename
9
  from PIL import Image
10
  import io
11
+ import zipfile
12
  import uuid
13
  import traceback
14
  from huggingface_hub import snapshot_download
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
+ from transformers import pipeline, AutoImageProcessor, AutoModelForDepthEstimation
19
+ from scipy.ndimage import gaussian_filter
20
+ from scipy import interpolate
21
  import cv2
 
 
 
 
 
 
 
 
 
22
 
23
  app = Flask(__name__)
24
  CORS(app)
25
 
26
+ # Configure directories
27
  UPLOAD_FOLDER = '/tmp/uploads'
28
  RESULTS_FOLDER = '/tmp/results'
29
  CACHE_DIR = '/tmp/huggingface'
 
40
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
41
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
42
 
43
+ # Job tracking
44
  processing_jobs = {}
45
+
46
+ # Model variables
47
+ dpt_estimator = None
48
+ depth_anything_model = None
49
+ depth_anything_processor = None
50
  model_loaded = False
51
  model_loading = False
52
 
53
+ TIMEOUT_SECONDS = 240
54
+ MAX_DIMENSION = 518 # Depth Anything uses 518x518
55
 
56
  class TimeoutError(Exception):
57
  pass
 
88
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
89
 
90
  def preprocess_image(image_path):
91
+ with Image.open(image_path) as img:
92
+ img = img.convert("RGB")
93
+
94
+ if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
95
+ if img.width > img.height:
96
+ new_width = MAX_DIMENSION
97
+ new_height = int(img.height * (MAX_DIMENSION / img.width))
98
+ else:
99
+ new_height = MAX_DIMENSION
100
+ new_width = int(img.width * (MAX_DIMENSION / img.height))
101
+ img = img.resize((new_width, new_height), Image.LANCZOS)
102
+
103
+ img_array = np.array(img)
104
+ if len(img_array.shape) == 3 and img_array.shape[2] == 3:
105
+ lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
106
+ l, a, b = cv2.split(lab)
107
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
108
+ cl = clahe.apply(l)
109
+ enhanced_lab = cv2.merge((cl, a, b))
110
+ img_array = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)
111
+ img = Image.fromarray(img_array)
112
+
113
+ return img
114
 
115
+ def load_models():
116
+ global dpt_estimator, depth_anything_model, depth_anything_processor, model_loaded, model_loading
117
 
118
  if model_loaded:
119
+ return dpt_estimator, depth_anything_model, depth_anything_processor
120
 
121
  if model_loading:
122
  while model_loading and not model_loaded:
123
  time.sleep(0.5)
124
+ return dpt_estimator, depth_anything_model, depth_anything_processor
125
 
126
  try:
127
  model_loading = True
128
+ print("Loading models...")
 
 
129
 
130
+ # DPT-Large
131
+ dpt_model_name = "Intel/dpt-large"
132
  max_retries = 3
133
  retry_delay = 5
134
  for attempt in range(max_retries):
135
  try:
136
  snapshot_download(
137
+ repo_id=dpt_model_name,
138
  cache_dir=CACHE_DIR,
139
  resume_download=True,
140
  )
141
  break
142
  except Exception as e:
143
  if attempt < max_retries - 1:
144
+ print(f"DPT download attempt {attempt+1} failed: {str(e)}. Retrying...")
145
  time.sleep(retry_delay)
146
  retry_delay *= 2
147
  else:
148
  raise
149
 
150
+ dpt_estimator = pipeline(
151
+ "depth-estimation",
152
+ model=dpt_model_name,
153
+ device=-1, # CPU
154
+ cache_dir=CACHE_DIR
155
+ )
156
+ print("DPT-Large loaded")
157
+ gc.collect()
158
+
159
+ # Depth Anything
160
+ da_model_name = "LiheYoung/depth-anything-v2-small"
161
+ for attempt in range(max_retries):
162
+ try:
163
+ snapshot_download(
164
+ repo_id=da_model_name,
165
+ cache_dir=CACHE_DIR,
166
+ resume_download=True,
167
+ )
168
+ break
169
+ except Exception as e:
170
+ if attempt < max_retries - 1:
171
+ print(f"Depth Anything download attempt {attempt+1} failed: {str(e)}. Retrying...")
172
+ time.sleep(retry_delay)
173
+ retry_delay *= 2
174
+ else:
175
+ raise
176
 
177
+ depth_anything_processor = AutoImageProcessor.from_pretrained(
178
+ da_model_name,
179
+ cache_dir=CACHE_DIR
 
180
  )
181
+ depth_anything_model = AutoModelForDepthEstimation.from_pretrained(
182
+ da_model_name,
183
+ cache_dir=CACHE_DIR
184
+ ).to("cpu")
185
 
186
  model_loaded = True
187
+ print("Depth Anything loaded")
188
+ return dpt_estimator, depth_anything_model, depth_anything_processor
189
 
190
  except Exception as e:
191
+ print(f"Error loading models: {str(e)}")
192
  print(traceback.format_exc())
193
  raise
194
  finally:
195
  model_loading = False
196
 
197
+ def fuse_depth_maps(dpt_depth, da_depth, detail_level='medium'):
198
+ if isinstance(dpt_depth, Image.Image):
199
+ dpt_depth = np.array(dpt_depth)
200
+ if isinstance(da_depth, torch.Tensor):
201
+ da_depth = da_depth.cpu().numpy()
202
+ if len(dpt_depth.shape) > 2:
203
+ dpt_depth = np.mean(dpt_depth, axis=2)
204
+ if len(da_depth.shape) > 2:
205
+ da_depth = np.mean(da_depth, axis=2)
206
+
207
+ # Resize to match
208
+ if dpt_depth.shape != da_depth.shape:
209
+ da_depth = cv2.resize(da_depth, (dpt_depth.shape[1], dpt_depth.shape[0]), interpolation=cv2.INTER_CUBIC)
210
+
211
+ # Normalize
212
+ p_low_dpt, p_high_dpt = np.percentile(dpt_depth, [1, 99])
213
+ p_low_da, p_high_da = np.percentile(da_depth, [1, 99])
214
+ dpt_depth = np.clip((dpt_depth - p_low_dpt) / (p_high_dpt - p_low_dpt), 0, 1) if p_high_dpt > p_low_dpt else dpt_depth
215
+ da_depth = np.clip((da_depth - p_low_da) / (p_high_da - p_low_da), 0, 1) if p_high_da > p_low_da else da_depth
216
+
217
+ # Edge-aware fusion
218
+ if detail_level == 'high':
219
+ weight_da = 0.7 # Favor Depth Anything for details
220
+ edges = cv2.Canny((da_depth * 255).astype(np.uint8), 50, 150)
221
+ edge_mask = (edges > 0).astype(np.float32)
222
+ dpt_weight = gaussian_filter(1 - edge_mask, sigma=1.0)
223
+ da_weight = gaussian_filter(edge_mask, sigma=1.0)
224
+ fused_depth = dpt_weight * dpt_depth + da_weight * da_depth * weight_da + (1 - weight_da) * dpt_depth
225
+ else:
226
+ weight_da = 0.5 if detail_level == 'medium' else 0.3
227
+ fused_depth = (1 - weight_da) * dpt_depth + weight_da * da_depth
228
+
229
+ fused_depth = np.clip(fused_depth, 0, 1)
230
+ return fused_depth
231
+
232
+ def enhance_depth_map(depth_map, detail_level='medium'):
233
+ enhanced_depth = depth_map.copy().astype(np.float32)
234
+ p_low, p_high = np.percentile(enhanced_depth, [1, 99])
235
+ enhanced_depth = np.clip(enhanced_depth, p_low, p_high)
236
+ enhanced_depth = (enhanced_depth - p_low) / (p_high - p_low) if p_high > p_low else enhanced_depth
237
+
238
+ if detail_level == 'high':
239
+ blurred = gaussian_filter(enhanced_depth, sigma=1.5)
240
+ mask = enhanced_depth - blurred
241
+ enhanced_depth = enhanced_depth + 1.5 * mask
242
+ smooth1 = gaussian_filter(enhanced_depth, sigma=0.5)
243
+ smooth2 = gaussian_filter(enhanced_depth, sigma=2.0)
244
+ edge_mask = enhanced_depth - smooth2
245
+ enhanced_depth = smooth1 + 1.2 * edge_mask
246
+ elif detail_level == 'medium':
247
+ blurred = gaussian_filter(enhanced_depth, sigma=1.0)
248
+ mask = enhanced_depth - blurred
249
+ enhanced_depth = enhanced_depth + 0.8 * mask
250
+ enhanced_depth = gaussian_filter(enhanced_depth, sigma=0.5)
251
+ else:
252
+ enhanced_depth = gaussian_filter(enhanced_depth, sigma=0.7)
253
+
254
+ enhanced_depth = np.clip(enhanced_depth, 0, 1)
255
+ return enhanced_depth
256
+
257
+ def depth_to_mesh(depth_map, image, resolution=100, detail_level='medium'):
258
+ enhanced_depth = enhance_depth_map(depth_map, detail_level)
259
+ h, w = enhanced_depth.shape
260
+ x = np.linspace(0, w-1, resolution)
261
+ y = np.linspace(0, h-1, resolution)
262
+ x_grid, y_grid = np.meshgrid(x, y)
263
+
264
+ interp_func = interpolate.RectBivariateSpline(
265
+ np.arange(h), np.arange(w), enhanced_depth, kx=3, ky=3
266
+ )
267
+ z_values = interp_func(y, x, grid=True)
268
+
269
+ if detail_level == 'high':
270
+ dx = np.gradient(z_values, axis=1)
271
+ dy = np.gradient(z_values, axis=0)
272
+ gradient_magnitude = np.sqrt(dx**2 + dy**2)
273
+ edge_mask = np.clip(gradient_magnitude * 5, 0, 0.2)
274
+ z_values = z_values + edge_mask * (z_values - gaussian_filter(z_values, sigma=1.0))
275
+
276
+ z_min, z_max = np.percentile(z_values, [2, 98])
277
+ z_values = (z_values - z_min) / (z_max - z_min) if z_max > z_min else z_values
278
+ z_scaling = 2.5 if detail_level == 'high' else 2.0 if detail_level == 'medium' else 1.5
279
+ z_values = z_values * z_scaling
280
+
281
+ x_grid = (x_grid / w - 0.5) * 2.0
282
+ y_grid = (y_grid / h - 0.5) * 2.0
283
+ vertices = np.vstack([x_grid.flatten(), -y_grid.flatten(), -z_values.flatten()]).T
284
+
285
+ faces = []
286
+ for i in range(resolution-1):
287
+ for j in range(resolution-1):
288
+ p1 = i * resolution + j
289
+ p2 = i * resolution + (j + 1)
290
+ p3 = (i + 1) * resolution + j
291
+ p4 = (i + 1) * resolution + (j + 1)
292
+ v1 = vertices[p1]
293
+ v2 = vertices[p2]
294
+ v3 = vertices[p3]
295
+ v4 = vertices[p4]
296
+ norm1 = np.cross(v2-v1, v4-v1)
297
+ norm2 = np.cross(v4-v3, v1-v3)
298
+ if np.dot(norm1, norm2) >= 0:
299
+ faces.append([p1, p2, p4])
300
+ faces.append([p1, p4, p3])
301
+ else:
302
+ faces.append([p1, p2, p3])
303
+ faces.append([p2, p4, p3])
304
+
305
+ faces = np.array(faces)
306
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
307
+
308
+ if image:
309
+ img_array = np.array(image)
310
+ vertex_colors = np.zeros((vertices.shape[0], 4), dtype=np.uint8)
311
+ for i in range(resolution):
312
+ for j in range(resolution):
313
+ img_x = j * (img_array.shape[1] - 1) / (resolution - 1)
314
+ img_y = i * (img_array.shape[0] - 1) / (resolution - 1)
315
+ x0, y0 = int(img_x), int(img_y)
316
+ x1, y1 = min(x0 + 1, img_array.shape[1] - 1), min(y0 + 1, img_array.shape[0] - 1)
317
+ wx = img_x - x0
318
+ wy = img_y - y0
319
+ vertex_idx = i * resolution + j
320
+ if len(img_array.shape) == 3 and img_array.shape[2] == 3:
321
+ r = int((1-wx)*(1-wy)*img_array[y0, x0, 0] + wx*(1-wy)*img_array[y0, x1, 0] +
322
+ (1-wx)*wy*img_array[y1, x0, 0] + wx*wy*img_array[y1, x1, 0])
323
+ g = int((1-wx)*(1-wy)*img_array[y0, x0, 1] + wx*(1-wy)*img_array[y0, x1, 1] +
324
+ (1-wx)*wy*img_array[y1, x0, 1] + wx*wy*img_array[y1, x1, 1])
325
+ b = int((1-wx)*(1-wy)*img_array[y0, x0, 2] + wx*(1-wy)*img_array[y0, x1, 2] +
326
+ (1-wx)*wy*img_array[y1, x0, 2] + wx*wy*img_array[y1, x1, 2])
327
+ vertex_colors[vertex_idx, :3] = [r, g, b]
328
+ vertex_colors[vertex_idx, 3] = 255
329
+ elif len(img_array.shape) == 3 and img_array.shape[2] == 4:
330
+ for c in range(4):
331
+ vertex_colors[vertex_idx, c] = int((1-wx)*(1-wy)*img_array[y0, x0, c] +
332
+ wx*(1-wy)*img_array[y0, x1, c] +
333
+ (1-wx)*wy*img_array[y1, x0, c] +
334
+ wx*wy*img_array[y1, x1, c])
335
+ else:
336
+ gray = int((1-wx)*(1-wy)*img_array[y0, x0] + wx*(1-wy)*img_array[y0, x1] +
337
+ (1-wx)*wy*img_array[y1, x0] + wx*wy*img_array[y1, x1])
338
+ vertex_colors[vertex_idx, :3] = [gray, gray, gray]
339
+ vertex_colors[vertex_idx, 3] = 255
340
+ mesh.visual.vertex_colors = vertex_colors
341
+
342
+ if detail_level != 'high':
343
+ mesh = mesh.smoothed(method='laplacian', iterations=1)
344
+ mesh.fix_normals()
345
+ return mesh
346
 
347
  @app.route('/health', methods=['GET'])
348
  def health_check():
349
  return jsonify({
350
  "status": "healthy",
351
+ "model": "DPT-Large + Depth Anything",
352
  "device": "cpu"
353
  }), 200
354
 
 
360
  return
361
 
362
  job = processing_jobs[job_id]
 
363
  yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
364
 
365
+ last_progress = job['progress']
366
  check_count = 0
367
  while job['status'] == 'processing':
368
  if job['progress'] != last_progress:
369
  yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
370
  last_progress = job['progress']
 
371
  time.sleep(0.5)
372
  check_count += 1
 
373
  if check_count > 60:
374
  if 'thread_alive' in job and not job['thread_alive']():
375
  job['status'] = 'error'
 
397
  return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
398
 
399
  try:
400
+ mesh_resolution = min(int(request.form.get('mesh_resolution', 100)), 150)
401
  output_format = request.form.get('output_format', 'glb').lower()
402
  detail_level = request.form.get('detail_level', 'medium').lower()
403
+ texture_quality = request.form.get('texture_quality', 'medium').lower()
404
  except ValueError:
405
  return jsonify({"error": "Invalid parameter values"}), 400
406
 
407
+ if output_format not in ['obj', 'glb']:
408
+ return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
409
+
410
+ if detail_level == 'high':
411
+ mesh_resolution = min(int(mesh_resolution * 1.5), 150)
412
+ elif detail_level == 'low':
413
+ mesh_resolution = max(int(mesh_resolution * 0.7), 50)
414
 
415
  job_id = str(uuid.uuid4())
416
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
 
440
  processing_jobs[job_id]['progress'] = 10
441
 
442
  try:
443
+ dpt_model, da_model, da_processor = load_models()
444
  processing_jobs[job_id]['progress'] = 30
445
  except Exception as e:
446
  processing_jobs[job_id]['status'] = 'error'
447
+ processing_jobs[job_id]['error'] = f"Error loading models: {str(e)}"
448
  return
449
 
450
  try:
451
+ def estimate_depth():
452
+ with torch.no_grad():
453
+ # DPT-Large
454
+ dpt_result = dpt_model(image)
455
+ dpt_depth = dpt_result["depth"]
456
+
457
+ # Depth Anything
458
+ inputs = da_processor(images=image, return_tensors="pt")
459
+ inputs = {k: v.to("cpu") for k, v in inputs.items()}
460
+ outputs = da_model(**inputs)
461
+ da_depth = outputs.predicted_depth.squeeze()
462
+ da_depth = torch.nn.functional.interpolate(
463
+ da_depth.unsqueeze(0).unsqueeze(0),
464
+ size=(image.height, image.width),
465
+ mode='bicubic',
466
+ align_corners=False
467
+ ).squeeze()
468
+
469
+ # Fuse depth maps
470
+ fused_depth = fuse_depth_maps(dpt_depth, da_depth, detail_level)
471
+ return fused_depth
472
 
473
+ fused_depth, error = process_with_timeout(estimate_depth, [], TIMEOUT_SECONDS)
474
 
475
  if error:
476
  if isinstance(error, TimeoutError):
 
480
  else:
481
  raise error
482
 
483
+ processing_jobs[job_id]['progress'] = 60
484
+ mesh_resolution_int = int(mesh_resolution)
485
+ mesh = depth_to_mesh(fused_depth, image, resolution=mesh_resolution_int, detail_level=detail_level)
486
  processing_jobs[job_id]['progress'] = 80
487
 
488
+ if output_format == 'obj':
489
+ obj_path = os.path.join(output_dir, "model.obj")
490
+ mesh.export(
491
+ obj_path,
492
+ file_type='obj',
493
+ include_normals=True,
494
+ include_texture=True
495
+ )
496
+ zip_path = os.path.join(output_dir, "model.zip")
497
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
498
+ zipf.write(obj_path, arcname="model.obj")
499
+ mtl_path = os.path.join(output_dir, "model.mtl")
500
+ if os.path.exists(mtl_path):
501
+ zipf.write(mtl_path, arcname="model.mtl")
502
+ texture_path = os.path.join(output_dir, "model.png")
503
+ if os.path.exists(texture_path):
504
+ zipf.write(texture_path, arcname="model.png")
505
+
506
+ processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
507
+ processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
508
 
509
+ elif output_format == 'glb':
510
+ glb_path = os.path.join(output_dir, "model.glb")
511
+ mesh.export(
512
+ glb_path,
513
+ file_type='glb'
514
+ )
515
+ processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
516
+ processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
517
 
518
  processing_jobs[job_id]['status'] = 'completed'
519
  processing_jobs[job_id]['progress'] = 100
520
  print(f"Job {job_id} completed")
521
+
522
  except Exception as e:
523
  error_details = traceback.format_exc()
524
  processing_jobs[job_id]['status'] = 'error'
 
529
 
530
  if os.path.exists(filepath):
531
  os.remove(filepath)
 
532
  gc.collect()
533
 
534
  except Exception as e:
 
537
  processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}"
538
  print(f"Error processing job {job_id}: {str(e)}")
539
  print(error_details)
 
540
  if os.path.exists(filepath):
541
  os.remove(filepath)
542
 
 
552
  return jsonify({"error": "Model not found or processing not complete"}), 404
553
 
554
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
555
+ output_format = processing_jobs[job_id].get('output_format', 'glb')
 
556
 
557
+ if output_format == 'obj':
558
+ zip_path = os.path.join(output_dir, "model.zip")
559
+ if os.path.exists(zip_path):
560
+ return send_file(zip_path, as_attachment=True, download_name="model.zip")
561
+ else:
562
+ glb_path = os.path.join(output_dir, "model.glb")
563
+ if os.path.exists(glb_path):
564
+ return send_file(glb_path, as_attachment=True, download_name="model.glb")
565
 
566
  return jsonify({"error": "File not found"}), 404
567
 
 
571
  return jsonify({"error": "Model not found or processing not complete"}), 404
572
 
573
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
574
+ output_format = processing_jobs[job_id].get('output_format', 'glb')
 
575
 
576
+ if output_format == 'obj':
577
+ obj_path = os.path.join(output_dir, "model.obj")
578
+ if os.path.exists(obj_path):
579
+ return send_file(obj_path, mimetype='model/obj')
580
+ else:
581
+ glb_path = os.path.join(output_dir, "model.glb")
582
+ if os.path.exists(glb_path):
583
+ return send_file(glb_path, mimetype='model/gltf-binary')
584
 
585
  return jsonify({"error": "File not found"}), 404
586
 
 
623
  }), 200
624
 
625
  output_dir = os.path.join(RESULTS_FOLDER, job_id)
 
626
  model_stats = {}
627
 
628
+ if job['output_format'] == 'obj':
629
+ obj_path = os.path.join(output_dir, "model.obj")
630
+ zip_path = os.path.join(output_dir, "model.zip")
631
+ if os.path.exists(obj_path):
632
+ model_stats['obj_size'] = os.path.getsize(obj_path)
633
+ if os.path.exists(zip_path):
634
+ model_stats['package_size'] = os.path.getsize(zip_path)
635
+ else:
636
+ glb_path = os.path.join(output_dir, "model.glb")
637
+ if os.path.exists(glb_path):
638
+ model_stats['model_size'] = os.path.getsize(glb_path)
639
 
640
  return jsonify({
641
  "status": job['status'],
642
+ "model_format": job['output_format'],
643
  "download_url": job['result_url'],
644
  "preview_url": job['preview_url'],
645
  "model_stats": model_stats,
 
650
  @app.route('/', methods=['GET'])
651
  def index():
652
  return jsonify({
653
+ "message": "Image to 3D API (DPT-Large + Depth Anything)",
654
  "endpoints": [
655
  "/convert",
656
  "/progress/<job_id>",
 
659
  "/model-info/<job_id>"
660
  ],
661
  "parameters": {
662
+ "mesh_resolution": "Integer (50-150)",
663
+ "output_format": "obj or glb",
664
+ "detail_level": "low, medium, or high",
665
+ "texture_quality": "low, medium, or high"
666
  },
667
+ "description": "Creates high-quality 3D models from 2D images using DPT-Large and Depth Anything."
668
  }), 200
669
 
670
  if __name__ == '__main__':
671
  cleanup_old_jobs()
672
  port = int(os.environ.get('PORT', 7860))
673
+ app.run(host='0.0.0.0', port=port)