mac9087 commited on
Commit
64188d6
·
verified ·
1 Parent(s): ffe4279

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -156
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import torch
4
  import time
@@ -15,15 +14,12 @@ from huggingface_hub import snapshot_download
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
- from transformers import pipeline
19
- from scipy.ndimage import gaussian_filter
20
- import open3d as o3d
21
  import cv2
22
 
23
  # Force CPU usage
24
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
25
  torch.set_default_device("cpu")
26
- # Patch PyTorch to disable CUDA initialization
27
  torch.cuda.is_available = lambda: False
28
  torch.cuda.device_count = lambda: 0
29
 
@@ -36,12 +32,12 @@ RESULTS_FOLDER = '/tmp/results'
36
  CACHE_DIR = '/tmp/huggingface'
37
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
38
 
39
- # Create necessary directories
40
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
41
  os.makedirs(RESULTS_FOLDER, exist_ok=True)
42
  os.makedirs(CACHE_DIR, exist_ok=True)
43
 
44
- # Set Hugging Face cache environment variables
45
  os.environ['HF_HOME'] = CACHE_DIR
46
  os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
47
  os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
@@ -49,23 +45,21 @@ os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
49
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
50
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
51
 
52
- # Job tracking dictionary
53
  processing_jobs = {}
54
 
55
- # Global model variables
56
- depth_pipeline = None
57
  model_loaded = False
58
  model_loading = False
59
 
60
- # Configuration for processing
61
- TIMEOUT_SECONDS = 240 # 4 minutes max for Depth-Anything on CPU
62
- MAX_DIMENSION = 512 # Depth-Anything expects 512x512
63
 
64
- # TimeoutError for handling timeouts
65
  class TimeoutError(Exception):
66
  pass
67
 
68
- # Thread-safe timeout implementation
69
  def process_with_timeout(function, args, timeout):
70
  result = [None]
71
  error = [None]
@@ -81,7 +75,6 @@ def process_with_timeout(function, args, timeout):
81
  thread = threading.Thread(target=target)
82
  thread.daemon = True
83
  thread.start()
84
-
85
  thread.join(timeout)
86
 
87
  if not completed[0]:
@@ -98,76 +91,44 @@ def process_with_timeout(function, args, timeout):
98
  def allowed_file(filename):
99
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
100
 
101
- # Image preprocessing: Remove background using cv2
102
  def preprocess_image(image_path):
103
  try:
104
- # Load image
105
  with Image.open(image_path) as img:
106
- # Convert to RGB or handle transparency
107
  if img.mode == 'RGBA':
108
- # Use alpha channel as initial mask
109
- img_array = np.array(img)
110
- alpha = img_array[:, :, 3]
111
- img_rgb = img_array[:, :, :3]
112
- else:
113
- img_rgb = np.array(img.convert('RGB'))
114
- alpha = None
115
-
116
- # Resize to 512x512
117
- img_rgb = cv2.resize(img_rgb, (512, 512), interpolation=cv2.INTER_LANCZOS4)
118
-
119
- # Convert to grayscale
120
- gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
121
-
122
- # Adaptive thresholding for initial mask
123
- thresh = cv2.adaptiveThreshold(
124
- gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2
125
- )
126
-
127
- # If alpha channel exists, combine with threshold
128
- if alpha is not None:
129
- alpha_resized = cv2.resize(alpha, (512, 512), interpolation=cv2.INTER_LANCZOS4)
130
- thresh = cv2.bitwise_and(thresh, alpha_resized)
131
-
132
- # Refine with GrabCut
133
- mask = np.zeros((512, 512), np.uint8)
134
- mask[thresh == 255] = cv2.GC_PR_FGD # Probable foreground
135
- mask[thresh == 0] = cv2.GC_PR_BGD # Probable background
136
-
137
- bgdModel = np.zeros((1, 65), np.float64)
138
- fgdModel = np.zeros((1, 65), np.float64)
139
-
140
- rect = (10, 10, 492, 492) # ROI for GrabCut
141
- cv2.grabCut(img_rgb, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_MASK)
142
-
143
- # Create final mask (foreground = 1, background = 0)
144
- mask2 = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype('uint8')
145
 
146
- # Apply mask to image
147
- img_foreground = cv2.bitwise_and(img_rgb, img_rgb, mask=mask2)
 
 
 
148
 
149
- return Image.fromarray(img_foreground)
150
  except Exception as e:
151
  raise Exception(f"Error preprocessing image: {str(e)}")
152
 
153
  def load_model():
154
- global depth_pipeline, model_loaded, model_loading
155
 
156
  if model_loaded:
157
- return depth_pipeline
158
 
159
  if model_loading:
160
  while model_loading and not model_loaded:
161
  time.sleep(0.5)
162
- return depth_pipeline
163
 
164
  try:
165
  model_loading = True
166
- print("Starting model loading...")
167
 
168
- model_name = "LiheYoung/depth-anything-small-hf"
169
 
170
- # Download model with retry mechanism
171
  max_retries = 3
172
  retry_delay = 5
173
  for attempt in range(max_retries):
@@ -180,24 +141,23 @@ def load_model():
180
  break
181
  except Exception as e:
182
  if attempt < max_retries - 1:
183
- print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying in {retry_delay} seconds...")
184
  time.sleep(retry_delay)
185
  retry_delay *= 2
186
  else:
187
  raise
188
 
189
- # Load Depth-Anything pipeline
190
- depth_pipeline = pipeline(
191
- "depth-estimation",
192
- model=model_name,
193
  cache_dir=CACHE_DIR,
194
- device=-1, # Force CPU
195
  torch_dtype=torch.float32,
196
  )
 
197
 
198
  model_loaded = True
199
- print("Model loaded successfully on CPU")
200
- return depth_pipeline
201
 
202
  except Exception as e:
203
  print(f"Error loading model: {str(e)}")
@@ -206,79 +166,42 @@ def load_model():
206
  finally:
207
  model_loading = False
208
 
209
- def depth_to_point_cloud(depth_map, image, detail_level):
210
- # Parameters based on detail level
211
- downsample_factors = {'low': 4, 'medium': 2, 'high': 1}
212
- downsample = downsample_factors[detail_level]
213
-
214
- # Convert image and depth to numpy
215
- img_array = np.array(image)
216
- depth_array = np.array(depth_map)
217
-
218
- # Downsample for performance
219
- if downsample > 1:
220
- depth_array = depth_array[::downsample, ::downsample]
221
- img_array = img_array[::downsample, ::downsample]
222
-
223
- # Normalize depth
224
- depth_array = gaussian_filter(depth_array, sigma=1)
225
- depth_array = (depth_array - depth_array.min()) / (depth_array.max() - depth_array.min() + 1e-8)
226
-
227
- # Create point cloud
228
- h, w = depth_array.shape
229
- x, y = np.meshgrid(np.arange(w), np.arange(h))
230
-
231
- # Camera intrinsics (assumed focal length)
232
- fx = fy = w * 0.5
233
- cx, cy = w / 2, h / 2
234
-
235
- # Convert to 3D coordinates (Z-up for Unity)
236
- z = depth_array
237
- x = (x - cx) * z / fx
238
- y = -(y - cy) * z / fy # Flip y-axis to correct orientation
239
-
240
- points = np.stack([x, y, z], axis=-1).reshape(-1, 3)
241
- colors = img_array.reshape(-1, 3) / 255.0
242
-
243
- # Filter out invalid points (tighter range for foreground)
244
- mask = (z.reshape(-1) > 0.2) & (z.reshape(-1) < 0.8)
245
- points = points[mask]
246
- colors = colors[mask]
247
-
248
- # Create Open3D point cloud
249
- pcd = o3d.geometry.PointCloud()
250
- pcd.points = o3d.utility.Vector3dVector(points)
251
- pcd.colors = o3d.utility.Vector3dVector(colors)
252
-
253
- # Estimate normals
254
- pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))
255
-
256
- # Poisson surface reconstruction
257
- mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
258
- pcd, depth=8 if detail_level == 'high' else 6
259
- )
260
-
261
- # Convert to trimesh
262
- vertices = np.asarray(mesh.vertices)
263
- faces = np.asarray(mesh.triangles)
264
- vertex_colors = np.asarray(mesh.vertex_colors)
265
-
266
- trimesh_mesh = trimesh.Trimesh(
267
- vertices=vertices,
268
- faces=faces,
269
- vertex_colors=vertex_colors
270
- )
271
-
272
- # Rotate mesh to correct orientation (180 degrees around X-axis)
273
- trimesh_mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]))
274
-
275
- return trimesh_mesh
276
 
277
  @app.route('/health', methods=['GET'])
278
  def health_check():
279
  return jsonify({
280
  "status": "healthy",
281
- "model": "Depth-Anything",
282
  "device": "cpu"
283
  }), 200
284
 
@@ -375,15 +298,7 @@ def convert_image_to_3d():
375
 
376
  try:
377
  def generate_3d():
378
- # Generate depth map
379
- with torch.no_grad():
380
- depth_output = pipeline(image)
381
-
382
- depth_map = depth_output["depth"]
383
-
384
- # Convert depth to mesh
385
- mesh = depth_to_point_cloud(depth_map, image, detail_level)
386
- return mesh
387
 
388
  mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
389
 
@@ -397,7 +312,7 @@ def convert_image_to_3d():
397
 
398
  processing_jobs[job_id]['progress'] = 80
399
 
400
- # Export as GLB or OBJ
401
  file_path = os.path.join(output_dir, f"model.{output_format}")
402
  mesh.export(file_path, file_type=output_format)
403
 
@@ -406,7 +321,7 @@ def convert_image_to_3d():
406
 
407
  processing_jobs[job_id]['status'] = 'completed'
408
  processing_jobs[job_id]['progress'] = 100
409
- print(f"Job {job_id} completed successfully")
410
 
411
  except Exception as e:
412
  error_details = traceback.format_exc()
@@ -527,7 +442,7 @@ def model_info(job_id):
527
  @app.route('/', methods=['GET'])
528
  def index():
529
  return jsonify({
530
- "message": "Image to 3D API (Depth-Anything)",
531
  "endpoints": [
532
  "/convert",
533
  "/progress/<job_id>",
@@ -537,12 +452,12 @@ def index():
537
  ],
538
  "parameters": {
539
  "output_format": "glb or obj",
540
- "detail_level": "low, medium, or high - controls point cloud density"
541
  },
542
- "description": "This API creates 3D models from 2D images using Depth-Anything depth estimation. Images should have transparent backgrounds for best results."
543
  }), 200
544
 
545
  if __name__ == '__main__':
546
  cleanup_old_jobs()
547
  port = int(os.environ.get('PORT', 7860))
548
- app.run(host='0.0.0.0', port=port)
 
 
1
  import os
2
  import torch
3
  import time
 
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
+ from diffusers import DiffusionPipeline
 
 
18
  import cv2
19
 
20
  # Force CPU usage
21
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
22
  torch.set_default_device("cpu")
 
23
  torch.cuda.is_available = lambda: False
24
  torch.cuda.device_count = lambda: 0
25
 
 
32
  CACHE_DIR = '/tmp/huggingface'
33
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
34
 
35
+ # Create directories
36
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
37
  os.makedirs(RESULTS_FOLDER, exist_ok=True)
38
  os.makedirs(CACHE_DIR, exist_ok=True)
39
 
40
+ # Set Hugging Face cache
41
  os.environ['HF_HOME'] = CACHE_DIR
42
  os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
43
  os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
 
45
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
46
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
47
 
48
+ # Job tracking
49
  processing_jobs = {}
50
 
51
+ # Global model
52
+ tripo_pipeline = None
53
  model_loaded = False
54
  model_loading = False
55
 
56
+ # Configuration
57
+ TIMEOUT_SECONDS = 300 # 5 minutes for TripoSG
58
+ MAX_DIMENSION = 256 # TripoSG works with smaller images
59
 
 
60
  class TimeoutError(Exception):
61
  pass
62
 
 
63
  def process_with_timeout(function, args, timeout):
64
  result = [None]
65
  error = [None]
 
75
  thread = threading.Thread(target=target)
76
  thread.daemon = True
77
  thread.start()
 
78
  thread.join(timeout)
79
 
80
  if not completed[0]:
 
91
  def allowed_file(filename):
92
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
93
 
94
+ # Image preprocessing
95
  def preprocess_image(image_path):
96
  try:
 
97
  with Image.open(image_path) as img:
98
+ # Convert to RGB
99
  if img.mode == 'RGBA':
100
+ img = img.convert('RGB')
101
+ # Resize to 256x256
102
+ img = img.resize((256, 256), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ # Basic cv2 cleanup
105
+ img_array = np.array(img)
106
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
107
+ _, mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
108
+ img_array = cv2.bitwise_and(img_array, img_array, mask=mask)
109
 
110
+ return Image.fromarray(img_array)
111
  except Exception as e:
112
  raise Exception(f"Error preprocessing image: {str(e)}")
113
 
114
  def load_model():
115
+ global tripo_pipeline, model_loaded, model_loading
116
 
117
  if model_loaded:
118
+ return tripo_pipeline
119
 
120
  if model_loading:
121
  while model_loading and not model_loaded:
122
  time.sleep(0.5)
123
+ return tripo_pipeline
124
 
125
  try:
126
  model_loading = True
127
+ print("Loading TripoSG model...")
128
 
129
+ model_name = "tripo3d/tripo-sg-3d"
130
 
131
+ # Download model
132
  max_retries = 3
133
  retry_delay = 5
134
  for attempt in range(max_retries):
 
141
  break
142
  except Exception as e:
143
  if attempt < max_retries - 1:
144
+ print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying...")
145
  time.sleep(retry_delay)
146
  retry_delay *= 2
147
  else:
148
  raise
149
 
150
+ # Load TripoSG pipeline
151
+ tripo_pipeline = DiffusionPipeline.from_pretrained(
152
+ model_name,
 
153
  cache_dir=CACHE_DIR,
 
154
  torch_dtype=torch.float32,
155
  )
156
+ tripo_pipeline.to("cpu")
157
 
158
  model_loaded = True
159
+ print("TripoSG loaded successfully on CPU")
160
+ return tripo_pipeline
161
 
162
  except Exception as e:
163
  print(f"Error loading model: {str(e)}")
 
166
  finally:
167
  model_loading = False
168
 
169
+ def generate_3d_model(image, detail_level):
170
+ try:
171
+ # Parameters
172
+ num_steps = {'low': 20, 'medium': 30, 'high': 40}
173
+ steps = num_steps[detail_level]
174
+
175
+ # Generate 3D model
176
+ with torch.no_grad():
177
+ result = tripo_pipeline(image, num_inference_steps=steps)
178
+
179
+ # Extract mesh
180
+ mesh = result.meshes[0]
181
+
182
+ # Convert to trimesh
183
+ vertices = np.array(mesh.vertices)
184
+ faces = np.array(mesh.faces)
185
+ vertex_colors = np.array(mesh.vertex_colors) if mesh.vertex_colors is not None else None
186
+
187
+ trimesh_mesh = trimesh.Trimesh(
188
+ vertices=vertices,
189
+ faces=faces,
190
+ vertex_colors=vertex_colors
191
+ )
192
+
193
+ # Rotate for Unity Z-up
194
+ trimesh_mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]))
195
+
196
+ return trimesh_mesh
197
+ except Exception as e:
198
+ raise Exception(f"Error generating 3D model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  @app.route('/health', methods=['GET'])
201
  def health_check():
202
  return jsonify({
203
  "status": "healthy",
204
+ "model": "TripoSG",
205
  "device": "cpu"
206
  }), 200
207
 
 
298
 
299
  try:
300
  def generate_3d():
301
+ return generate_3d_model(image, detail_level)
 
 
 
 
 
 
 
 
302
 
303
  mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
304
 
 
312
 
313
  processing_jobs[job_id]['progress'] = 80
314
 
315
+ # Export
316
  file_path = os.path.join(output_dir, f"model.{output_format}")
317
  mesh.export(file_path, file_type=output_format)
318
 
 
321
 
322
  processing_jobs[job_id]['status'] = 'completed'
323
  processing_jobs[job_id]['progress'] = 100
324
+ print(f"Job {job_id} completed")
325
 
326
  except Exception as e:
327
  error_details = traceback.format_exc()
 
442
  @app.route('/', methods=['GET'])
443
  def index():
444
  return jsonify({
445
+ "message": "Image to 3D API (TripoSG)",
446
  "endpoints": [
447
  "/convert",
448
  "/progress/<job_id>",
 
452
  ],
453
  "parameters": {
454
  "output_format": "glb or obj",
455
+ "detail_level": "low, medium, or high - controls inference steps"
456
  },
457
+ "description": "Creates 3D models from 2D images using TripoSG. Use transparent PNGs for best results."
458
  }), 200
459
 
460
  if __name__ == '__main__':
461
  cleanup_old_jobs()
462
  port = int(os.environ.get('PORT', 7860))
463
+ app.run(host='0.0.0.0', port=port)