mac9087 commited on
Commit
fa24ab7
·
verified ·
1 Parent(s): f8775ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -16,8 +16,7 @@ from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
  import cv2
19
- from tsr.models import TripoSR # Custom TripoSR model
20
- from tsr.pipeline import TripoSRPipeline # Custom pipeline
21
  import torchvision.transforms as T
22
 
23
  app = Flask(__name__)
@@ -47,7 +46,7 @@ processing_jobs = {}
47
 
48
  # Global model variables
49
  u2net_model = None
50
- triposr_pipeline = None
51
  model_loaded = False
52
  model_loading = False
53
 
@@ -142,15 +141,15 @@ def remove_background(image):
142
  return Image.fromarray(result.astype('uint8'))
143
 
144
  def load_model():
145
- global triposr_pipeline, model_loaded, model_loading
146
 
147
  if model_loaded:
148
- return triposr_pipeline
149
 
150
  if model_loading:
151
  while model_loading and not model_loaded:
152
  time.sleep(0.5)
153
- return triposr_pipeline
154
 
155
  try:
156
  model_loading = True
@@ -176,15 +175,17 @@ def load_model():
176
  else:
177
  raise
178
 
179
- # Initialize TripoSR pipeline
180
- triposr_pipeline = TripoSRPipeline(
181
- model_path=os.path.join(CACHE_DIR, "stabilityai/TripoSR"),
182
- device='cpu'
 
 
183
  )
184
 
185
  model_loaded = True
186
- print("TripoSR pipeline loaded successfully on CPU")
187
- return triposr_pipeline
188
 
189
  except Exception as e:
190
  print(f"Error loading model: {str(e)}")
@@ -304,9 +305,9 @@ def convert_image_to_3d():
304
  clean_image = remove_background(image)
305
  processing_jobs[job_id]['progress'] = 30
306
 
307
- # Load TripoSR pipeline
308
  try:
309
- pipeline = load_model()
310
  processing_jobs[job_id]['progress'] = 40
311
  except Exception as e:
312
  processing_jobs[job_id]['status'] = 'error'
@@ -316,8 +317,8 @@ def convert_image_to_3d():
316
  # Generate 3D model
317
  try:
318
  def generate_3d():
319
- # TripoSR pipeline expects a PIL image
320
- mesh = pipeline(clean_image)
321
  return mesh
322
 
323
  mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
 
16
  import numpy as np
17
  import trimesh
18
  import cv2
19
+ from tsr.system import TSR # Updated import
 
20
  import torchvision.transforms as T
21
 
22
  app = Flask(__name__)
 
46
 
47
  # Global model variables
48
  u2net_model = None
49
+ triposr_model = None
50
  model_loaded = False
51
  model_loading = False
52
 
 
141
  return Image.fromarray(result.astype('uint8'))
142
 
143
  def load_model():
144
+ global triposr_model, model_loaded, model_loading
145
 
146
  if model_loaded:
147
+ return triposr_model
148
 
149
  if model_loading:
150
  while model_loading and not model_loaded:
151
  time.sleep(0.5)
152
+ return triposr_model
153
 
154
  try:
155
  model_loading = True
 
175
  else:
176
  raise
177
 
178
+ # Initialize TSR model
179
+ triposr_model = TSR.from_pretrained(
180
+ model_name,
181
+ torch_dtype=torch.float32,
182
+ device="cpu",
183
+ cache_dir=CACHE_DIR
184
  )
185
 
186
  model_loaded = True
187
+ print("TripoSR model loaded successfully on CPU")
188
+ return triposr_model
189
 
190
  except Exception as e:
191
  print(f"Error loading model: {str(e)}")
 
305
  clean_image = remove_background(image)
306
  processing_jobs[job_id]['progress'] = 30
307
 
308
+ # Load TripoSR model
309
  try:
310
+ model = load_model()
311
  processing_jobs[job_id]['progress'] = 40
312
  except Exception as e:
313
  processing_jobs[job_id]['status'] = 'error'
 
317
  # Generate 3D model
318
  try:
319
  def generate_3d():
320
+ # TSR expects a PIL image
321
+ mesh = model(clean_image)
322
  return mesh
323
 
324
  mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)