Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -16,8 +16,8 @@ from flask_cors import CORS
|
|
16 |
import numpy as np
|
17 |
import trimesh
|
18 |
import cv2
|
19 |
-
from
|
20 |
-
from
|
21 |
import torchvision.transforms as T
|
22 |
|
23 |
app = Flask(__name__)
|
@@ -47,8 +47,7 @@ processing_jobs = {}
|
|
47 |
|
48 |
# Global model variables
|
49 |
u2net_model = None
|
50 |
-
|
51 |
-
triposr_processor = None
|
52 |
model_loaded = False
|
53 |
model_loading = False
|
54 |
|
@@ -121,6 +120,8 @@ def preprocess_image(image_path):
|
|
121 |
def remove_background(image):
|
122 |
global u2net_model
|
123 |
if u2net_model is None:
|
|
|
|
|
124 |
u2net_model = U2NET()
|
125 |
u2net_model.load_state_dict(torch.load('u2net.pth', map_location='cpu'))
|
126 |
u2net_model.eval()
|
@@ -141,15 +142,15 @@ def remove_background(image):
|
|
141 |
return Image.fromarray(result.astype('uint8'))
|
142 |
|
143 |
def load_model():
|
144 |
-
global
|
145 |
|
146 |
if model_loaded:
|
147 |
-
return
|
148 |
|
149 |
if model_loading:
|
150 |
while model_loading and not model_loaded:
|
151 |
time.sleep(0.5)
|
152 |
-
return
|
153 |
|
154 |
try:
|
155 |
model_loading = True
|
@@ -175,13 +176,15 @@ def load_model():
|
|
175 |
else:
|
176 |
raise
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
181 |
|
182 |
model_loaded = True
|
183 |
-
print("TripoSR
|
184 |
-
return
|
185 |
|
186 |
except Exception as e:
|
187 |
print(f"Error loading model: {str(e)}")
|
@@ -301,9 +304,9 @@ def convert_image_to_3d():
|
|
301 |
clean_image = remove_background(image)
|
302 |
processing_jobs[job_id]['progress'] = 30
|
303 |
|
304 |
-
# Load TripoSR
|
305 |
try:
|
306 |
-
|
307 |
processing_jobs[job_id]['progress'] = 40
|
308 |
except Exception as e:
|
309 |
processing_jobs[job_id]['status'] = 'error'
|
@@ -313,10 +316,8 @@ def convert_image_to_3d():
|
|
313 |
# Generate 3D model
|
314 |
try:
|
315 |
def generate_3d():
|
316 |
-
|
317 |
-
|
318 |
-
outputs = model(**inputs)
|
319 |
-
mesh = outputs.mesh # TripoSR outputs a trimesh object
|
320 |
return mesh
|
321 |
|
322 |
mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
|
|
|
16 |
import numpy as np
|
17 |
import trimesh
|
18 |
import cv2
|
19 |
+
from tsr.models import TripoSR # Custom TripoSR model
|
20 |
+
from tsr.pipeline import TripoSRPipeline # Custom pipeline
|
21 |
import torchvision.transforms as T
|
22 |
|
23 |
app = Flask(__name__)
|
|
|
47 |
|
48 |
# Global model variables
|
49 |
u2net_model = None
|
50 |
+
triposr_pipeline = None
|
|
|
51 |
model_loaded = False
|
52 |
model_loading = False
|
53 |
|
|
|
120 |
def remove_background(image):
|
121 |
global u2net_model
|
122 |
if u2net_model is None:
|
123 |
+
# Dynamically import U2NET to avoid circular import issues
|
124 |
+
from u2net import U2NET
|
125 |
u2net_model = U2NET()
|
126 |
u2net_model.load_state_dict(torch.load('u2net.pth', map_location='cpu'))
|
127 |
u2net_model.eval()
|
|
|
142 |
return Image.fromarray(result.astype('uint8'))
|
143 |
|
144 |
def load_model():
|
145 |
+
global triposr_pipeline, model_loaded, model_loading
|
146 |
|
147 |
if model_loaded:
|
148 |
+
return triposr_pipeline
|
149 |
|
150 |
if model_loading:
|
151 |
while model_loading and not model_loaded:
|
152 |
time.sleep(0.5)
|
153 |
+
return triposr_pipeline
|
154 |
|
155 |
try:
|
156 |
model_loading = True
|
|
|
176 |
else:
|
177 |
raise
|
178 |
|
179 |
+
# Initialize TripoSR pipeline
|
180 |
+
triposr_pipeline = TripoSRPipeline(
|
181 |
+
model_path=os.path.join(CACHE_DIR, "stabilityai/TripoSR"),
|
182 |
+
device='cpu'
|
183 |
+
)
|
184 |
|
185 |
model_loaded = True
|
186 |
+
print("TripoSR pipeline loaded successfully on CPU")
|
187 |
+
return triposr_pipeline
|
188 |
|
189 |
except Exception as e:
|
190 |
print(f"Error loading model: {str(e)}")
|
|
|
304 |
clean_image = remove_background(image)
|
305 |
processing_jobs[job_id]['progress'] = 30
|
306 |
|
307 |
+
# Load TripoSR pipeline
|
308 |
try:
|
309 |
+
pipeline = load_model()
|
310 |
processing_jobs[job_id]['progress'] = 40
|
311 |
except Exception as e:
|
312 |
processing_jobs[job_id]['status'] = 'error'
|
|
|
316 |
# Generate 3D model
|
317 |
try:
|
318 |
def generate_3d():
|
319 |
+
# TripoSR pipeline expects a PIL image
|
320 |
+
mesh = pipeline(clean_image)
|
|
|
|
|
321 |
return mesh
|
322 |
|
323 |
mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
|