mac9087 commited on
Commit
e1232bb
·
verified ·
1 Parent(s): 837715e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -32
app.py CHANGED
@@ -14,8 +14,18 @@ from huggingface_hub import snapshot_download
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
- from triposr import TripoSR
18
  from rembg import remove
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Force CPU usage
21
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
@@ -50,13 +60,13 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
50
  processing_jobs = {}
51
 
52
  # Global model variables
53
- triposr = None
54
  model_loaded = False
55
  model_loading = False
56
 
57
  # Configuration for processing
58
- TIMEOUT_SECONDS = 180 # 3 minutes max for TripoSR on CPU
59
- MAX_DIMENSION = 256
60
 
61
  # TimeoutError for handling timeouts
62
  class TimeoutError(Exception):
@@ -95,36 +105,30 @@ def process_with_timeout(function, args, timeout):
95
  def allowed_file(filename):
96
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
97
 
98
- # Simplified image preprocessing
99
  def preprocess_image(image_path):
100
  with Image.open(image_path) as img:
101
  img = img.convert("RGB")
102
- if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
103
- if img.width > img.height:
104
- new_width = MAX_DIMENSION
105
- new_height = int(img.height * (MAX_DIMENSION / img.width))
106
- else:
107
- new_height = MAX_DIMENSION
108
- new_width = int(img.width * (MAX_DIMENSION / img.height))
109
- img = img.resize((new_width, new_height), Image.LANCZOS)
110
  return img
111
 
112
  def load_model():
113
- global triposr, model_loaded, model_loading
114
 
115
  if model_loaded:
116
- return triposr
117
 
118
  if model_loading:
119
  while model_loading and not model_loaded:
120
  time.sleep(0.5)
121
- return triposr
122
 
123
  try:
124
  model_loading = True
125
  print("Starting model loading...")
126
 
127
- model_name = "VAST-AI-Research/TripoSR"
128
 
129
  # Download model with retry mechanism
130
  max_retries = 3
@@ -145,19 +149,17 @@ def load_model():
145
  else:
146
  raise
147
 
148
- # Load TripoSR
149
- triposr = TripoSR.from_pretrained(
150
  model_name,
151
- use_safetensors=True,
152
  torch_dtype=torch.float16,
153
- cache_dir=CACHE_DIR
154
  )
155
- # Explicitly move to CPU
156
- triposr.to("cpu")
157
 
158
  model_loaded = True
159
  print("Model loaded successfully on CPU")
160
- return triposr
161
 
162
  except Exception as e:
163
  print(f"Error loading model: {str(e)}")
@@ -170,7 +172,7 @@ def load_model():
170
  def health_check():
171
  return jsonify({
172
  "status": "healthy",
173
- "model": "TripoSR 3D Generator",
174
  "device": "cpu"
175
  }), 200
176
 
@@ -268,8 +270,8 @@ def convert_image_to_3d():
268
  try:
269
  def generate_3d():
270
  # Adjust settings based on detail level
271
- foreground_ratios = {'low': 0.6, 'medium': 0.8, 'high': 1.0}
272
- num_samples = {'low': 5000, 'medium': 10000, 'high': 20000}
273
 
274
  # Remove background
275
  image_no_bg = remove(image)
@@ -280,8 +282,8 @@ def convert_image_to_3d():
280
  # Generate mesh
281
  mesh = model(
282
  image=img_array,
283
- foreground_ratio=foreground_ratios[detail_level],
284
- num_samples=num_samples[detail_level],
285
  seed=12345
286
  )
287
  return mesh
@@ -422,7 +424,7 @@ def model_info(job_id):
422
  @app.route('/', methods=['GET'])
423
  def index():
424
  return jsonify({
425
- "message": "Image to 3D API (TripoSR)",
426
  "endpoints": [
427
  "/convert",
428
  "/progress/<job_id>",
@@ -432,9 +434,9 @@ def index():
432
  ],
433
  "parameters": {
434
  "output_format": "glb",
435
- "detail_level": "low, medium, or high - controls mesh detail"
436
  },
437
- "description": "This API creates full 3D models from 2D images using TripoSR"
438
  }), 200
439
 
440
  if __name__ == '__main__':
 
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
 
17
  from rembg import remove
18
+ import sys
19
+
20
+ # Add SF3D code to sys.path
21
+ sys.path.append('/app/stable-fast-3d')
22
+
23
+ # Import SF3D components (adjust based on actual module structure)
24
+ try:
25
+ from sf3d.system import SF3D
26
+ except ImportError as e:
27
+ print(f"Failed to import SF3D: {e}")
28
+ raise
29
 
30
  # Force CPU usage
31
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
60
  processing_jobs = {}
61
 
62
  # Global model variables
63
+ sf3d_model = None
64
  model_loaded = False
65
  model_loading = False
66
 
67
  # Configuration for processing
68
+ TIMEOUT_SECONDS = 240 # 4 minutes max for SF3D on CPU
69
+ MAX_DIMENSION = 512 # SF3D expects 512x512
70
 
71
  # TimeoutError for handling timeouts
72
  class TimeoutError(Exception):
 
105
  def allowed_file(filename):
106
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
107
 
108
+ # Image preprocessing for SF3D (512x512)
109
  def preprocess_image(image_path):
110
  with Image.open(image_path) as img:
111
  img = img.convert("RGB")
112
+ # SF3D requires 512x512
113
+ img = img.resize((512, 512), Image.LANCZOS)
 
 
 
 
 
 
114
  return img
115
 
116
  def load_model():
117
+ global sf3d_model, model_loaded, model_loading
118
 
119
  if model_loaded:
120
+ return sf3d_model
121
 
122
  if model_loading:
123
  while model_loading and not model_loaded:
124
  time.sleep(0.5)
125
+ return sf3d_model
126
 
127
  try:
128
  model_loading = True
129
  print("Starting model loading...")
130
 
131
+ model_name = "stabilityai/stable-fast-3d"
132
 
133
  # Download model with retry mechanism
134
  max_retries = 3
 
149
  else:
150
  raise
151
 
152
+ # Load SF3D model
153
+ sf3d_model = SF3D.from_pretrained(
154
  model_name,
155
+ cache_dir=CACHE_DIR,
156
  torch_dtype=torch.float16,
157
+ device_map="cpu"
158
  )
 
 
159
 
160
  model_loaded = True
161
  print("Model loaded successfully on CPU")
162
+ return sf3d_model
163
 
164
  except Exception as e:
165
  print(f"Error loading model: {str(e)}")
 
172
  def health_check():
173
  return jsonify({
174
  "status": "healthy",
175
+ "model": "Stable Fast 3D (SF3D)",
176
  "device": "cpu"
177
  }), 200
178
 
 
270
  try:
271
  def generate_3d():
272
  # Adjust settings based on detail level
273
+ texture_sizes = {'low': 512, 'medium': 1024, 'high': 2048}
274
+ remesh_options = {'low': 0.5, 'medium': 1.0, 'high': 2.0}
275
 
276
  # Remove background
277
  image_no_bg = remove(image)
 
282
  # Generate mesh
283
  mesh = model(
284
  image=img_array,
285
+ texture_size=texture_sizes[detail_level],
286
+ remesh_option=remesh_options[detail_level],
287
  seed=12345
288
  )
289
  return mesh
 
424
  @app.route('/', methods=['GET'])
425
  def index():
426
  return jsonify({
427
+ "message": "Image to 3D API (Stable Fast 3D)",
428
  "endpoints": [
429
  "/convert",
430
  "/progress/<job_id>",
 
434
  ],
435
  "parameters": {
436
  "output_format": "glb",
437
+ "detail_level": "low, medium, or high - controls texture and mesh detail"
438
  },
439
+ "description": "This API creates full 3D models from 2D images using Stable Fast 3D"
440
  }), 200
441
 
442
  if __name__ == '__main__':