mac9087 commited on
Commit
6dbef31
·
verified ·
1 Parent(s): 2a62c40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -23
app.py CHANGED
@@ -14,11 +14,15 @@ from huggingface_hub import snapshot_download
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
- from lgm.pipeline import LGMPipeline
 
18
 
19
  # Force CPU usage
20
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
21
  torch.set_default_device("cpu")
 
 
 
22
 
23
  app = Flask(__name__)
24
  CORS(app)
@@ -46,12 +50,12 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
46
  processing_jobs = {}
47
 
48
  # Global model variables
49
- lgm_pipeline = None
50
  model_loaded = False
51
  model_loading = False
52
 
53
  # Configuration for processing
54
- TIMEOUT_SECONDS = 240 # 4 minutes max for LGM on CPU
55
  MAX_DIMENSION = 256
56
 
57
  # TimeoutError for handling timeouts
@@ -106,21 +110,21 @@ def preprocess_image(image_path):
106
  return img
107
 
108
  def load_model():
109
- global lgm_pipeline, model_loaded, model_loading
110
 
111
  if model_loaded:
112
- return lgm_pipeline
113
 
114
  if model_loading:
115
  while model_loading and not model_loaded:
116
  time.sleep(0.5)
117
- return lgm_pipeline
118
 
119
  try:
120
  model_loading = True
121
  print("Starting model loading...")
122
 
123
- model_name = "open-mmlab/LGM"
124
 
125
  # Download model with retry mechanism
126
  max_retries = 3
@@ -141,18 +145,19 @@ def load_model():
141
  else:
142
  raise
143
 
144
- # Load LGM pipeline
145
- lgm_pipeline = LGMPipeline.from_pretrained(
146
  model_name,
147
  use_safetensors=True,
148
  torch_dtype=torch.float16,
149
- cache_dir=CACHE_DIR,
150
- device_map="cpu"
151
  )
 
 
152
 
153
  model_loaded = True
154
  print("Model loaded successfully on CPU")
155
- return lgm_pipeline
156
 
157
  except Exception as e:
158
  print(f"Error loading model: {str(e)}")
@@ -165,7 +170,7 @@ def load_model():
165
  def health_check():
166
  return jsonify({
167
  "status": "healthy",
168
- "model": "LGM 3D Generator",
169
  "device": "cpu"
170
  }), 200
171
 
@@ -263,19 +268,21 @@ def convert_image_to_3d():
263
  try:
264
  def generate_3d():
265
  # Adjust settings based on detail level
266
- steps = {'low': 40, 'medium': 60, 'high': 80}
267
- resolution = {'low': 512, 'medium': 1024, 'high': 2048}
268
 
269
- # Convert PIL image to numpy
270
- img_array = np.array(image)
 
 
 
271
 
272
  # Generate mesh
273
  mesh = model(
274
  image=img_array,
275
- num_inference_steps=steps[detail_level],
276
- texture_resolution=resolution[detail_level],
277
- generator=torch.manual_seed(12345),
278
- output_type="trimesh"
279
  )
280
  return mesh
281
 
@@ -415,7 +422,7 @@ def model_info(job_id):
415
  @app.route('/', methods=['GET'])
416
  def index():
417
  return jsonify({
418
- "message": "Image to 3D API (LGM)",
419
  "endpoints": [
420
  "/convert",
421
  "/progress/<job_id>",
@@ -427,7 +434,7 @@ def index():
427
  "output_format": "glb",
428
  "detail_level": "low, medium, or high - controls mesh detail"
429
  },
430
- "description": "This API creates full 3D models from 2D images using LGM"
431
  }), 200
432
 
433
  if __name__ == '__main__':
 
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
+ from triposr import TripoSR
18
+ from rembg import remove
19
 
20
  # Force CPU usage
21
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
22
  torch.set_default_device("cpu")
23
+ # Patch PyTorch to disable CUDA initialization
24
+ torch.cuda.is_available = lambda: False
25
+ torch.cuda.device_count = lambda: 0
26
 
27
  app = Flask(__name__)
28
  CORS(app)
 
50
  processing_jobs = {}
51
 
52
  # Global model variables
53
+ triposr = None
54
  model_loaded = False
55
  model_loading = False
56
 
57
  # Configuration for processing
58
+ TIMEOUT_SECONDS = 180 # 3 minutes max for TripoSR on CPU
59
  MAX_DIMENSION = 256
60
 
61
  # TimeoutError for handling timeouts
 
110
  return img
111
 
112
  def load_model():
113
+ global triposr, model_loaded, model_loading
114
 
115
  if model_loaded:
116
+ return triposr
117
 
118
  if model_loading:
119
  while model_loading and not model_loaded:
120
  time.sleep(0.5)
121
+ return triposr
122
 
123
  try:
124
  model_loading = True
125
  print("Starting model loading...")
126
 
127
+ model_name = "VAST-AI-Research/TripoSR"
128
 
129
  # Download model with retry mechanism
130
  max_retries = 3
 
145
  else:
146
  raise
147
 
148
+ # Load TripoSR
149
+ triposr = TripoSR.from_pretrained(
150
  model_name,
151
  use_safetensors=True,
152
  torch_dtype=torch.float16,
153
+ cache_dir=CACHE_DIR
 
154
  )
155
+ # Explicitly move to CPU
156
+ triposr.to("cpu")
157
 
158
  model_loaded = True
159
  print("Model loaded successfully on CPU")
160
+ return triposr
161
 
162
  except Exception as e:
163
  print(f"Error loading model: {str(e)}")
 
170
  def health_check():
171
  return jsonify({
172
  "status": "healthy",
173
+ "model": "TripoSR 3D Generator",
174
  "device": "cpu"
175
  }), 200
176
 
 
268
  try:
269
  def generate_3d():
270
  # Adjust settings based on detail level
271
+ foreground_ratios = {'low': 0.6, 'medium': 0.8, 'high': 1.0}
272
+ num_samples = {'low': 5000, 'medium': 10000, 'high': 20000}
273
 
274
+ # Remove background
275
+ image_no_bg = remove(image)
276
+
277
+ # Convert to numpy
278
+ img_array = np.array(image_no_bg)
279
 
280
  # Generate mesh
281
  mesh = model(
282
  image=img_array,
283
+ foreground_ratio=foreground_ratios[detail_level],
284
+ num_samples=num_samples[detail_level],
285
+ seed=12345
 
286
  )
287
  return mesh
288
 
 
422
  @app.route('/', methods=['GET'])
423
  def index():
424
  return jsonify({
425
+ "message": "Image to 3D API (TripoSR)",
426
  "endpoints": [
427
  "/convert",
428
  "/progress/<job_id>",
 
434
  "output_format": "glb",
435
  "detail_level": "low, medium, or high - controls mesh detail"
436
  },
437
+ "description": "This API creates full 3D models from 2D images using TripoSR"
438
  }), 200
439
 
440
  if __name__ == '__main__':