mac9087 commited on
Commit
e4c93be
·
verified ·
1 Parent(s): 5a23d7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -43
app.py CHANGED
@@ -11,11 +11,11 @@ import io
11
  import zipfile
12
  import uuid
13
  import traceback
14
- from diffusers import ShapEImg2ImgPipeline
15
- from diffusers.utils import export_to_obj
16
  from huggingface_hub import snapshot_download
17
  from flask_cors import CORS
18
- import functools
 
19
 
20
  app = Flask(__name__)
21
  CORS(app) # Enable CORS for all routes
@@ -42,15 +42,15 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
42
  # Job tracking dictionary
43
  processing_jobs = {}
44
 
45
- # Global model variable
46
- pipe = None
 
47
  model_loaded = False
48
  model_loading = False
49
 
50
  # Configuration for processing
51
- TIMEOUT_SECONDS = 300 # 5 minutes max for processing
52
  MAX_DIMENSION = 512 # Max image dimension to process
53
- MAX_INFERENCE_STEPS = 64 # Maximum allowed inference steps to prevent the index error
54
 
55
  # TimeoutError for handling timeouts
56
  class TimeoutError(Exception):
@@ -89,7 +89,7 @@ def process_with_timeout(function, args, timeout):
89
  def allowed_file(filename):
90
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
91
 
92
- # Function to preprocess image - resize if needed
93
  def preprocess_image(image_path):
94
  with Image.open(image_path) as img:
95
  img = img.convert("RGB")
@@ -104,26 +104,26 @@ def preprocess_image(image_path):
104
  new_width = int(img.width * (MAX_DIMENSION / img.height))
105
  img = img.resize((new_width, new_height), Image.LANCZOS)
106
 
107
- # Convert to RGB and return
108
  return img
109
 
110
  def load_model():
111
- global pipe, model_loaded, model_loading
112
 
113
  if model_loaded:
114
- return pipe
115
 
116
  if model_loading:
117
  # Wait for model to load if it's already in progress
118
  while model_loading and not model_loaded:
119
  time.sleep(0.5)
120
- return pipe
121
 
122
  try:
123
  model_loading = True
124
  print("Starting model loading...")
125
 
126
- model_name = "openai/shap-e-img2img"
 
127
 
128
  # Download model with retry mechanism
129
  max_retries = 3
@@ -145,24 +145,26 @@ def load_model():
145
  else:
146
  raise
147
 
148
- # Initialize pipeline with lower precision to save memory
149
  device = "cuda" if torch.cuda.is_available() else "cpu"
150
  dtype = torch.float16 if device == "cuda" else torch.float32
151
 
152
- pipe = ShapEImg2ImgPipeline.from_pretrained(
 
153
  model_name,
154
  torch_dtype=dtype,
155
  cache_dir=CACHE_DIR,
 
156
  )
157
- pipe = pipe.to(device)
158
 
159
  # Optimize for inference
160
  if device == "cuda":
161
- pipe.enable_model_cpu_offload()
162
 
163
  model_loaded = True
164
  print(f"Model loaded successfully on {device}")
165
- return pipe
166
 
167
  except Exception as e:
168
  print(f"Error loading model: {str(e)}")
@@ -171,11 +173,89 @@ def load_model():
171
  finally:
172
  model_loading = False
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  @app.route('/health', methods=['GET'])
175
  def health_check():
176
  return jsonify({
177
  "status": "healthy",
178
- "model": "Shap-E Image to 3D",
179
  "device": "cuda" if torch.cuda.is_available() else "cpu"
180
  }), 200
181
 
@@ -234,7 +314,6 @@ def convert_image_to_3d():
234
  # Get optional parameters with defaults
235
  try:
236
  guidance_scale = float(request.form.get('guidance_scale', 3.0))
237
- num_inference_steps = min(int(request.form.get('num_inference_steps', 64)), MAX_INFERENCE_STEPS)
238
  output_format = request.form.get('output_format', 'obj').lower()
239
  except ValueError:
240
  return jsonify({"error": "Invalid parameter values"}), 400
@@ -243,9 +322,6 @@ def convert_image_to_3d():
243
  if guidance_scale < 1.0 or guidance_scale > 5.0:
244
  return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
245
 
246
- if num_inference_steps < 32 or num_inference_steps > MAX_INFERENCE_STEPS:
247
- num_inference_steps = min(num_inference_steps, MAX_INFERENCE_STEPS)
248
-
249
  # Validate output format
250
  if output_format not in ['obj', 'glb']:
251
  return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
@@ -277,14 +353,14 @@ def convert_image_to_3d():
277
  processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
278
 
279
  try:
280
- # Preprocess image (resize if needed)
281
  processing_jobs[job_id]['progress'] = 5
282
  image = preprocess_image(filepath)
283
  processing_jobs[job_id]['progress'] = 10
284
 
285
  # Load model
286
  try:
287
- pipe = load_model()
288
  processing_jobs[job_id]['progress'] = 30
289
  except Exception as e:
290
  processing_jobs[job_id]['status'] = 'error'
@@ -293,15 +369,18 @@ def convert_image_to_3d():
293
 
294
  # Process image with thread-safe timeout
295
  try:
296
- def generate_mesh():
297
- return pipe(
298
- image,
299
- guidance_scale=guidance_scale,
300
- num_inference_steps=num_inference_steps,
301
- output_type="mesh",
302
- ).images
 
 
 
303
 
304
- images, error = process_with_timeout(generate_mesh, [], TIMEOUT_SECONDS)
305
 
306
  if error:
307
  if isinstance(error, TimeoutError):
@@ -312,6 +391,10 @@ def convert_image_to_3d():
312
  raise error
313
 
314
  processing_jobs[job_id]['progress'] = 80
 
 
 
 
315
  except Exception as e:
316
  error_details = traceback.format_exc()
317
  processing_jobs[job_id]['status'] = 'error'
@@ -324,7 +407,7 @@ def convert_image_to_3d():
324
  try:
325
  if output_format == 'obj':
326
  obj_path = os.path.join(output_dir, "model.obj")
327
- export_to_obj(images[0], obj_path)
328
 
329
  # Create a zip file with OBJ and MTL
330
  zip_path = os.path.join(output_dir, "model.zip")
@@ -338,17 +421,9 @@ def convert_image_to_3d():
338
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
339
 
340
  elif output_format == 'glb':
341
- from trimesh import Trimesh
342
- mesh = images[0]
343
- vertices = mesh.verts
344
- faces = mesh.faces
345
-
346
- # Create a trimesh object
347
- trimesh_obj = Trimesh(vertices=vertices, faces=faces)
348
-
349
  # Export as GLB
350
  glb_path = os.path.join(output_dir, "model.glb")
351
- trimesh_obj.export(glb_path)
352
 
353
  processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
354
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
 
11
  import zipfile
12
  import uuid
13
  import traceback
14
+ from transformers import AutoImageProcessor, AutoModel
 
15
  from huggingface_hub import snapshot_download
16
  from flask_cors import CORS
17
+ import numpy as np
18
+ import trimesh
19
 
20
  app = Flask(__name__)
21
  CORS(app) # Enable CORS for all routes
 
42
  # Job tracking dictionary
43
  processing_jobs = {}
44
 
45
+ # Global model variables
46
+ image_processor = None
47
+ model = None
48
  model_loaded = False
49
  model_loading = False
50
 
51
  # Configuration for processing
52
+ TIMEOUT_SECONDS = 180 # 3 minutes max for processing
53
  MAX_DIMENSION = 512 # Max image dimension to process
 
54
 
55
  # TimeoutError for handling timeouts
56
  class TimeoutError(Exception):
 
89
  def allowed_file(filename):
90
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
91
 
92
+ # Function to preprocess image
93
  def preprocess_image(image_path):
94
  with Image.open(image_path) as img:
95
  img = img.convert("RGB")
 
104
  new_width = int(img.width * (MAX_DIMENSION / img.height))
105
  img = img.resize((new_width, new_height), Image.LANCZOS)
106
 
 
107
  return img
108
 
109
  def load_model():
110
+ global image_processor, model, model_loaded, model_loading
111
 
112
  if model_loaded:
113
+ return image_processor, model
114
 
115
  if model_loading:
116
  # Wait for model to load if it's already in progress
117
  while model_loading and not model_loaded:
118
  time.sleep(0.5)
119
+ return image_processor, model
120
 
121
  try:
122
  model_loading = True
123
  print("Starting model loading...")
124
 
125
+ # Using a lightweight model: Pictorial 3D Scene Representation
126
+ model_name = "damo-vilab/text-to-3d-texture-base" # Smaller model than ShapE-img2img
127
 
128
  # Download model with retry mechanism
129
  max_retries = 3
 
145
  else:
146
  raise
147
 
148
+ # Initialize model with lower precision to save memory
149
  device = "cuda" if torch.cuda.is_available() else "cpu"
150
  dtype = torch.float16 if device == "cuda" else torch.float32
151
 
152
+ image_processor = AutoImageProcessor.from_pretrained(model_name, cache_dir=CACHE_DIR)
153
+ model = AutoModel.from_pretrained(
154
  model_name,
155
  torch_dtype=dtype,
156
  cache_dir=CACHE_DIR,
157
+ low_cpu_mem_usage=True,
158
  )
159
+ model = model.to(device)
160
 
161
  # Optimize for inference
162
  if device == "cuda":
163
+ model = model.half() # Use half precision on GPU
164
 
165
  model_loaded = True
166
  print(f"Model loaded successfully on {device}")
167
+ return image_processor, model
168
 
169
  except Exception as e:
170
  print(f"Error loading model: {str(e)}")
 
173
  finally:
174
  model_loading = False
175
 
176
+ # Convert model output to 3D mesh
177
+ def create_mesh_from_output(output, resolution=64):
178
+ """Create a mesh from model output"""
179
+ # Extract features from model output and create mesh
180
+ # This is a simplified implementation - adapt based on your specific model
181
+ features = output.last_hidden_state.detach().cpu().numpy()[0]
182
+
183
+ # Create a simple cube mesh as placeholder - replace with actual mesh generation
184
+ vertices, faces = create_primitive_mesh(features, resolution)
185
+
186
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
187
+ return mesh
188
+
189
+ def create_primitive_mesh(features, resolution=64):
190
+ """Create a simple primitive mesh based on features"""
191
+ # Create a mesh using features as modifiers
192
+ # This is a simplified implementation - adapt based on your specific model's output
193
+
194
+ # Create a cube/sphere mesh as a placeholder
195
+ use_sphere = True # Change to False for cube
196
+
197
+ if use_sphere:
198
+ # Create a sphere
199
+ u = np.linspace(0, 2 * np.pi, resolution)
200
+ v = np.linspace(0, np.pi, resolution)
201
+
202
+ # Base radius and modifiers
203
+ base_radius = 1.0
204
+
205
+ # Use some features to modify the radius (just as an example)
206
+ feature_sum = np.sum(features[:10]) # Use first 10 features
207
+ radius_mod = 0.5 + (feature_sum % 1.0) # Simple modifier between 0.5 and 1.5
208
+
209
+ # Create vertices
210
+ x = base_radius * radius_mod * np.outer(np.cos(u), np.sin(v))
211
+ y = base_radius * radius_mod * np.outer(np.sin(u), np.sin(v))
212
+ z = base_radius * radius_mod * np.outer(np.ones_like(u), np.cos(v))
213
+
214
+ # Reshape to get list of vertices
215
+ vertices = np.vstack([x.flatten(), y.flatten(), z.flatten()]).T
216
+
217
+ # Create faces (triangles)
218
+ faces = []
219
+ for i in range(resolution-1):
220
+ for j in range(resolution-1):
221
+ p1 = i * resolution + j
222
+ p2 = i * resolution + (j + 1)
223
+ p3 = (i + 1) * resolution + j
224
+ p4 = (i + 1) * resolution + (j + 1)
225
+
226
+ faces.append([p1, p2, p4])
227
+ faces.append([p1, p4, p3])
228
+
229
+ faces = np.array(faces)
230
+ else:
231
+ # Create a cube
232
+ vertices = np.array([
233
+ [-1, -1, -1], [1, -1, -1], [1, 1, -1], [-1, 1, -1],
234
+ [-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]
235
+ ])
236
+
237
+ # Apply some feature-based modifications
238
+ feature_sum = np.sum(features[:10]) # Use first 10 features
239
+ scale_factor = 0.5 + (feature_sum % 1.0) # Simple modifier between 0.5 and 1.5
240
+ vertices *= scale_factor
241
+
242
+ # Faces (triangles)
243
+ faces = np.array([
244
+ [0, 1, 2], [0, 2, 3], # Bottom face
245
+ [4, 5, 6], [4, 6, 7], # Top face
246
+ [0, 1, 5], [0, 5, 4], # Front face
247
+ [2, 3, 7], [2, 7, 6], # Back face
248
+ [0, 3, 7], [0, 7, 4], # Left face
249
+ [1, 2, 6], [1, 6, 5] # Right face
250
+ ])
251
+
252
+ return vertices, faces
253
+
254
  @app.route('/health', methods=['GET'])
255
  def health_check():
256
  return jsonify({
257
  "status": "healthy",
258
+ "model": "Lightweight 3D Model Generator",
259
  "device": "cuda" if torch.cuda.is_available() else "cpu"
260
  }), 200
261
 
 
314
  # Get optional parameters with defaults
315
  try:
316
  guidance_scale = float(request.form.get('guidance_scale', 3.0))
 
317
  output_format = request.form.get('output_format', 'obj').lower()
318
  except ValueError:
319
  return jsonify({"error": "Invalid parameter values"}), 400
 
322
  if guidance_scale < 1.0 or guidance_scale > 5.0:
323
  return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
324
 
 
 
 
325
  # Validate output format
326
  if output_format not in ['obj', 'glb']:
327
  return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
 
353
  processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
354
 
355
  try:
356
+ # Preprocess image
357
  processing_jobs[job_id]['progress'] = 5
358
  image = preprocess_image(filepath)
359
  processing_jobs[job_id]['progress'] = 10
360
 
361
  # Load model
362
  try:
363
+ processor, model_instance = load_model()
364
  processing_jobs[job_id]['progress'] = 30
365
  except Exception as e:
366
  processing_jobs[job_id]['status'] = 'error'
 
369
 
370
  # Process image with thread-safe timeout
371
  try:
372
+ def generate_3d():
373
+ # Process the image
374
+ device = model_instance.device
375
+ inputs = processor(images=image, return_tensors="pt").to(device)
376
+
377
+ # Forward pass through model
378
+ with torch.no_grad():
379
+ outputs = model_instance(**inputs)
380
+
381
+ return outputs
382
 
383
+ outputs, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
384
 
385
  if error:
386
  if isinstance(error, TimeoutError):
 
391
  raise error
392
 
393
  processing_jobs[job_id]['progress'] = 80
394
+
395
+ # Create mesh from outputs
396
+ mesh = create_mesh_from_output(outputs)
397
+
398
  except Exception as e:
399
  error_details = traceback.format_exc()
400
  processing_jobs[job_id]['status'] = 'error'
 
407
  try:
408
  if output_format == 'obj':
409
  obj_path = os.path.join(output_dir, "model.obj")
410
+ mesh.export(obj_path, file_type='obj')
411
 
412
  # Create a zip file with OBJ and MTL
413
  zip_path = os.path.join(output_dir, "model.zip")
 
421
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
422
 
423
  elif output_format == 'glb':
 
 
 
 
 
 
 
 
424
  # Export as GLB
425
  glb_path = os.path.join(output_dir, "model.glb")
426
+ mesh.export(glb_path, file_type='glb')
427
 
428
  processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
429
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"