mac9087 commited on
Commit
4954710
·
verified ·
1 Parent(s): 8d8ab2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -30
app.py CHANGED
@@ -15,7 +15,8 @@ from huggingface_hub import snapshot_download
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
- from diffusers import StableFast3DPipeline
 
19
 
20
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
21
  torch.set_default_device("cpu")
@@ -42,12 +43,12 @@ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
42
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
43
 
44
  processing_jobs = {}
45
- sf3d_pipeline = None
46
  model_loaded = False
47
  model_loading = False
48
 
49
  TIMEOUT_SECONDS = 300
50
- MAX_DIMENSION = 512 # Stable-Fast-3D uses 512x512 inputs
51
 
52
  class TimeoutError(Exception):
53
  pass
@@ -89,26 +90,30 @@ def preprocess_image(image_path):
89
  if img.mode == 'RGBA':
90
  img = img.convert('RGB')
91
  img = img.resize((512, 512), Image.LANCZOS)
92
- return img
 
 
 
 
93
  except Exception as e:
94
  raise Exception(f"Error preprocessing image: {str(e)}")
95
 
96
  def load_model():
97
- global sf3d_pipeline, model_loaded, model_loading
98
 
99
  if model_loaded:
100
- return sf3d_pipeline
101
 
102
  if model_loading:
103
  while model_loading and not model_loaded:
104
  time.sleep(0.5)
105
- return sf3d_pipeline
106
 
107
  try:
108
  model_loading = True
109
- print("Loading Stable-Fast-3D...")
110
 
111
- model_name = "stabilityai/stable-fast-3d"
112
 
113
  max_retries = 3
114
  retry_delay = 5
@@ -128,16 +133,15 @@ def load_model():
128
  else:
129
  raise
130
 
131
- sf3d_pipeline = StableFast3DPipeline.from_pretrained(
132
  model_name,
133
  cache_dir=CACHE_DIR,
134
- torch_dtype=torch.float32,
135
  )
136
- sf3d_pipeline.to("cpu")
137
 
138
  model_loaded = True
139
- print("Stable-Fast-3D loaded successfully on CPU")
140
- return sf3d_pipeline
141
 
142
  except Exception as e:
143
  print(f"Error loading model: {str(e)}")
@@ -148,21 +152,17 @@ def load_model():
148
 
149
  def generate_3d_model(image, detail_level):
150
  try:
151
- num_steps = {'low': 20, 'medium': 30, 'high': 50}
152
- steps = num_steps[detail_level]
153
 
154
  with torch.no_grad():
155
- result = sf3d_pipeline(
156
- image,
157
- num_inference_steps=steps,
158
- normal_num_inference_steps=steps // 2,
159
- guidance_scale=7.0,
160
- )
161
 
162
- mesh = result.trimesh_meshes[0]
163
  vertices = np.array(mesh.vertices)
164
  faces = np.array(mesh.faces)
165
- vertex_colors = np.array(mesh.visual.vertex_colors) if mesh.visual.vertex_colors is not None else None
166
 
167
  trimesh_mesh = trimesh.Trimesh(
168
  vertices=vertices,
@@ -180,7 +180,7 @@ def generate_3d_model(image, detail_level):
180
  def health_check():
181
  return jsonify({
182
  "status": "healthy",
183
- "model": "Stable-Fast-3D",
184
  "device": "cpu"
185
  }), 200
186
 
@@ -268,7 +268,7 @@ def convert_image_to_3d():
268
  processing_jobs[job_id]['progress'] = 10
269
 
270
  try:
271
- pipeline = load_model()
272
  processing_jobs[job_id]['progress'] = 30
273
  except Exception as e:
274
  processing_jobs[job_id]['status'] = 'error'
@@ -294,7 +294,7 @@ def convert_image_to_3d():
294
  file_path = os.path.join(output_dir, f"model.{output_format}")
295
  mesh.export(file_path, file_type=output_format)
296
 
297
- processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
298
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
299
 
300
  processing_jobs[job_id]['status'] = 'completed'
@@ -359,7 +359,7 @@ def preview_model(job_id):
359
  else:
360
  return send_file(file_path, mimetype='text/plain')
361
 
362
- return jsonify({"error": "Model not found"}), 404
363
 
364
  def cleanup_old_jobs():
365
  current_time = time.time()
@@ -420,7 +420,7 @@ def model_info(job_id):
420
  @app.route('/', methods=['GET'])
421
  def index():
422
  return jsonify({
423
- "message": "Image to 3D API (Stable-Fast-3D)",
424
  "endpoints": [
425
  "/convert",
426
  "/progress/<job_id>",
@@ -432,7 +432,7 @@ def index():
432
  "output_format": "glb or obj",
433
  "detail_level": "low, medium, or high"
434
  },
435
- "description": "Creates 3D models from 2D images using Stable-Fast-3D."
436
  }), 200
437
 
438
  if __name__ == '__main__':
 
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
+ from tsr.system import TripoSR
19
+ from tsr.utils import remove_background, resize_foreground
20
 
21
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
22
  torch.set_default_device("cpu")
 
43
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
44
 
45
  processing_jobs = {}
46
+ triposr_model = None
47
  model_loaded = False
48
  model_loading = False
49
 
50
  TIMEOUT_SECONDS = 300
51
+ MAX_DIMENSION = 512 # TripoSR uses 512x512 inputs
52
 
53
  class TimeoutError(Exception):
54
  pass
 
90
  if img.mode == 'RGBA':
91
  img = img.convert('RGB')
92
  img = img.resize((512, 512), Image.LANCZOS)
93
+ img_array = np.array(img) / 255.0
94
+ img_array = remove_background(img_array)
95
+ img_array = resize_foreground(img_array, 0.85)
96
+ img_array = np.clip(img_array, 0, 1) * 255
97
+ return Image.fromarray(img_array.astype(np.uint8))
98
  except Exception as e:
99
  raise Exception(f"Error preprocessing image: {str(e)}")
100
 
101
  def load_model():
102
+ global triposr_model, model_loaded, model_loading
103
 
104
  if model_loaded:
105
+ return triposr_model
106
 
107
  if model_loading:
108
  while model_loading and not model_loaded:
109
  time.sleep(0.5)
110
+ return triposr_model
111
 
112
  try:
113
  model_loading = True
114
+ print("Loading TripoSR...")
115
 
116
+ model_name = "tripo3d/triposr"
117
 
118
  max_retries = 3
119
  retry_delay = 5
 
133
  else:
134
  raise
135
 
136
+ triposr_model = TripoSR.from_pretrained(
137
  model_name,
138
  cache_dir=CACHE_DIR,
139
+ device="cpu",
140
  )
 
141
 
142
  model_loaded = True
143
+ print("TripoSR loaded successfully on CPU")
144
+ return triposr_model
145
 
146
  except Exception as e:
147
  print(f"Error loading model: {str(e)}")
 
152
 
153
  def generate_3d_model(image, detail_level):
154
  try:
155
+ chunk_size = {'low': 4096, 'medium': 8192, 'high': 16384}
156
+ chunk = chunk_size[detail_level]
157
 
158
  with torch.no_grad():
159
+ scene_codes = triposr_model(image, device="cpu")
160
+ meshes = triposr_model.mesher(scene_codes, chunk_size=chunk)
 
 
 
 
161
 
162
+ mesh = meshes[0]
163
  vertices = np.array(mesh.vertices)
164
  faces = np.array(mesh.faces)
165
+ vertex_colors = np.array(mesh.vertex_colors) if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None
166
 
167
  trimesh_mesh = trimesh.Trimesh(
168
  vertices=vertices,
 
180
  def health_check():
181
  return jsonify({
182
  "status": "healthy",
183
+ "model": "TripoSR",
184
  "device": "cpu"
185
  }), 200
186
 
 
268
  processing_jobs[job_id]['progress'] = 10
269
 
270
  try:
271
+ model = load_model()
272
  processing_jobs[job_id]['progress'] = 30
273
  except Exception as e:
274
  processing_jobs[job_id]['status'] = 'error'
 
294
  file_path = os.path.join(output_dir, f"model.{output_format}")
295
  mesh.export(file_path, file_type=output_format)
296
 
297
+ processing_jobs[job_id]['result_url'] = f"/download/{job_id.ConcurrentHashMap}"
298
  processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
299
 
300
  processing_jobs[job_id]['status'] = 'completed'
 
359
  else:
360
  return send_file(file_path, mimetype='text/plain')
361
 
362
+ return jsonify({"error": "File not found"}), 404
363
 
364
  def cleanup_old_jobs():
365
  current_time = time.time()
 
420
  @app.route('/', methods=['GET'])
421
  def index():
422
  return jsonify({
423
+ "message": "Image to 3D API (TripoSR)",
424
  "endpoints": [
425
  "/convert",
426
  "/progress/<job_id>",
 
432
  "output_format": "glb or obj",
433
  "detail_level": "low, medium, or high"
434
  },
435
+ "description": "Creates 3D models from 2D images using TripoSR."
436
  }), 200
437
 
438
  if __name__ == '__main__':