mac9087 commited on
Commit
138cb5e
·
verified ·
1 Parent(s): d470117

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -18
app.py CHANGED
@@ -16,8 +16,8 @@ from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
  import cv2
19
- from transformers import AutoModel, AutoProcessor # For TripoSR
20
- from u2net import U2NET # For background removal; install from https://github.com/xuebinqin/U-2-Net
21
  import torchvision.transforms as T
22
 
23
  app = Flask(__name__)
@@ -47,8 +47,7 @@ processing_jobs = {}
47
 
48
  # Global model variables
49
  u2net_model = None
50
- triposr_model = None
51
- triposr_processor = None
52
  model_loaded = False
53
  model_loading = False
54
 
@@ -121,6 +120,8 @@ def preprocess_image(image_path):
121
  def remove_background(image):
122
  global u2net_model
123
  if u2net_model is None:
 
 
124
  u2net_model = U2NET()
125
  u2net_model.load_state_dict(torch.load('u2net.pth', map_location='cpu'))
126
  u2net_model.eval()
@@ -141,15 +142,15 @@ def remove_background(image):
141
  return Image.fromarray(result.astype('uint8'))
142
 
143
  def load_model():
144
- global triposr_model, triposr_processor, model_loaded, model_loading
145
 
146
  if model_loaded:
147
- return triposr_model, triposr_processor
148
 
149
  if model_loading:
150
  while model_loading and not model_loaded:
151
  time.sleep(0.5)
152
- return triposr_model, triposr_processor
153
 
154
  try:
155
  model_loading = True
@@ -175,13 +176,15 @@ def load_model():
175
  else:
176
  raise
177
 
178
- triposr_processor = AutoProcessor.from_pretrained(model_name, cache_dir=CACHE_DIR)
179
- triposr_model = AutoModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
180
- triposr_model.to('cpu')
 
 
181
 
182
  model_loaded = True
183
- print("TripoSR model loaded successfully on CPU")
184
- return triposr_model, triposr_processor
185
 
186
  except Exception as e:
187
  print(f"Error loading model: {str(e)}")
@@ -301,9 +304,9 @@ def convert_image_to_3d():
301
  clean_image = remove_background(image)
302
  processing_jobs[job_id]['progress'] = 30
303
 
304
- # Load TripoSR model
305
  try:
306
- model, processor = load_model()
307
  processing_jobs[job_id]['progress'] = 40
308
  except Exception as e:
309
  processing_jobs[job_id]['status'] = 'error'
@@ -313,10 +316,8 @@ def convert_image_to_3d():
313
  # Generate 3D model
314
  try:
315
  def generate_3d():
316
- inputs = processor(images=clean_image, return_tensors="pt").to('cpu')
317
- with torch.no_grad():
318
- outputs = model(**inputs)
319
- mesh = outputs.mesh # TripoSR outputs a trimesh object
320
  return mesh
321
 
322
  mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
 
16
  import numpy as np
17
  import trimesh
18
  import cv2
19
+ from tsr.models import TripoSR # Custom TripoSR model
20
+ from tsr.pipeline import TripoSRPipeline # Custom pipeline
21
  import torchvision.transforms as T
22
 
23
  app = Flask(__name__)
 
47
 
48
  # Global model variables
49
  u2net_model = None
50
+ triposr_pipeline = None
 
51
  model_loaded = False
52
  model_loading = False
53
 
 
120
  def remove_background(image):
121
  global u2net_model
122
  if u2net_model is None:
123
+ # Dynamically import U2NET to avoid circular import issues
124
+ from u2net import U2NET
125
  u2net_model = U2NET()
126
  u2net_model.load_state_dict(torch.load('u2net.pth', map_location='cpu'))
127
  u2net_model.eval()
 
142
  return Image.fromarray(result.astype('uint8'))
143
 
144
  def load_model():
145
+ global triposr_pipeline, model_loaded, model_loading
146
 
147
  if model_loaded:
148
+ return triposr_pipeline
149
 
150
  if model_loading:
151
  while model_loading and not model_loaded:
152
  time.sleep(0.5)
153
+ return triposr_pipeline
154
 
155
  try:
156
  model_loading = True
 
176
  else:
177
  raise
178
 
179
+ # Initialize TripoSR pipeline
180
+ triposr_pipeline = TripoSRPipeline(
181
+ model_path=os.path.join(CACHE_DIR, "stabilityai/TripoSR"),
182
+ device='cpu'
183
+ )
184
 
185
  model_loaded = True
186
+ print("TripoSR pipeline loaded successfully on CPU")
187
+ return triposr_pipeline
188
 
189
  except Exception as e:
190
  print(f"Error loading model: {str(e)}")
 
304
  clean_image = remove_background(image)
305
  processing_jobs[job_id]['progress'] = 30
306
 
307
+ # Load TripoSR pipeline
308
  try:
309
+ pipeline = load_model()
310
  processing_jobs[job_id]['progress'] = 40
311
  except Exception as e:
312
  processing_jobs[job_id]['status'] = 'error'
 
316
  # Generate 3D model
317
  try:
318
  def generate_3d():
319
+ # TripoSR pipeline expects a PIL image
320
+ mesh = pipeline(clean_image)
 
 
321
  return mesh
322
 
323
  mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)