Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import os
|
3 |
import torch
|
4 |
import time
|
@@ -15,15 +14,12 @@ from huggingface_hub import snapshot_download
|
|
15 |
from flask_cors import CORS
|
16 |
import numpy as np
|
17 |
import trimesh
|
18 |
-
from
|
19 |
-
from scipy.ndimage import gaussian_filter
|
20 |
-
import open3d as o3d
|
21 |
import cv2
|
22 |
|
23 |
# Force CPU usage
|
24 |
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
25 |
torch.set_default_device("cpu")
|
26 |
-
# Patch PyTorch to disable CUDA initialization
|
27 |
torch.cuda.is_available = lambda: False
|
28 |
torch.cuda.device_count = lambda: 0
|
29 |
|
@@ -36,12 +32,12 @@ RESULTS_FOLDER = '/tmp/results'
|
|
36 |
CACHE_DIR = '/tmp/huggingface'
|
37 |
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
|
38 |
|
39 |
-
# Create
|
40 |
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
41 |
os.makedirs(RESULTS_FOLDER, exist_ok=True)
|
42 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
43 |
|
44 |
-
# Set Hugging Face cache
|
45 |
os.environ['HF_HOME'] = CACHE_DIR
|
46 |
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
|
47 |
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
|
@@ -49,23 +45,21 @@ os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
|
|
49 |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
50 |
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
|
51 |
|
52 |
-
# Job tracking
|
53 |
processing_jobs = {}
|
54 |
|
55 |
-
# Global model
|
56 |
-
|
57 |
model_loaded = False
|
58 |
model_loading = False
|
59 |
|
60 |
-
# Configuration
|
61 |
-
TIMEOUT_SECONDS =
|
62 |
-
MAX_DIMENSION =
|
63 |
|
64 |
-
# TimeoutError for handling timeouts
|
65 |
class TimeoutError(Exception):
|
66 |
pass
|
67 |
|
68 |
-
# Thread-safe timeout implementation
|
69 |
def process_with_timeout(function, args, timeout):
|
70 |
result = [None]
|
71 |
error = [None]
|
@@ -81,7 +75,6 @@ def process_with_timeout(function, args, timeout):
|
|
81 |
thread = threading.Thread(target=target)
|
82 |
thread.daemon = True
|
83 |
thread.start()
|
84 |
-
|
85 |
thread.join(timeout)
|
86 |
|
87 |
if not completed[0]:
|
@@ -98,76 +91,44 @@ def process_with_timeout(function, args, timeout):
|
|
98 |
def allowed_file(filename):
|
99 |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
100 |
|
101 |
-
# Image preprocessing
|
102 |
def preprocess_image(image_path):
|
103 |
try:
|
104 |
-
# Load image
|
105 |
with Image.open(image_path) as img:
|
106 |
-
# Convert to RGB
|
107 |
if img.mode == 'RGBA':
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
img_rgb = img_array[:, :, :3]
|
112 |
-
else:
|
113 |
-
img_rgb = np.array(img.convert('RGB'))
|
114 |
-
alpha = None
|
115 |
-
|
116 |
-
# Resize to 512x512
|
117 |
-
img_rgb = cv2.resize(img_rgb, (512, 512), interpolation=cv2.INTER_LANCZOS4)
|
118 |
-
|
119 |
-
# Convert to grayscale
|
120 |
-
gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
|
121 |
-
|
122 |
-
# Adaptive thresholding for initial mask
|
123 |
-
thresh = cv2.adaptiveThreshold(
|
124 |
-
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2
|
125 |
-
)
|
126 |
-
|
127 |
-
# If alpha channel exists, combine with threshold
|
128 |
-
if alpha is not None:
|
129 |
-
alpha_resized = cv2.resize(alpha, (512, 512), interpolation=cv2.INTER_LANCZOS4)
|
130 |
-
thresh = cv2.bitwise_and(thresh, alpha_resized)
|
131 |
-
|
132 |
-
# Refine with GrabCut
|
133 |
-
mask = np.zeros((512, 512), np.uint8)
|
134 |
-
mask[thresh == 255] = cv2.GC_PR_FGD # Probable foreground
|
135 |
-
mask[thresh == 0] = cv2.GC_PR_BGD # Probable background
|
136 |
-
|
137 |
-
bgdModel = np.zeros((1, 65), np.float64)
|
138 |
-
fgdModel = np.zeros((1, 65), np.float64)
|
139 |
-
|
140 |
-
rect = (10, 10, 492, 492) # ROI for GrabCut
|
141 |
-
cv2.grabCut(img_rgb, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_MASK)
|
142 |
-
|
143 |
-
# Create final mask (foreground = 1, background = 0)
|
144 |
-
mask2 = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype('uint8')
|
145 |
|
146 |
-
#
|
147 |
-
|
|
|
|
|
|
|
148 |
|
149 |
-
return Image.fromarray(
|
150 |
except Exception as e:
|
151 |
raise Exception(f"Error preprocessing image: {str(e)}")
|
152 |
|
153 |
def load_model():
|
154 |
-
global
|
155 |
|
156 |
if model_loaded:
|
157 |
-
return
|
158 |
|
159 |
if model_loading:
|
160 |
while model_loading and not model_loaded:
|
161 |
time.sleep(0.5)
|
162 |
-
return
|
163 |
|
164 |
try:
|
165 |
model_loading = True
|
166 |
-
print("
|
167 |
|
168 |
-
model_name = "
|
169 |
|
170 |
-
# Download model
|
171 |
max_retries = 3
|
172 |
retry_delay = 5
|
173 |
for attempt in range(max_retries):
|
@@ -180,24 +141,23 @@ def load_model():
|
|
180 |
break
|
181 |
except Exception as e:
|
182 |
if attempt < max_retries - 1:
|
183 |
-
print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying
|
184 |
time.sleep(retry_delay)
|
185 |
retry_delay *= 2
|
186 |
else:
|
187 |
raise
|
188 |
|
189 |
-
# Load
|
190 |
-
|
191 |
-
|
192 |
-
model=model_name,
|
193 |
cache_dir=CACHE_DIR,
|
194 |
-
device=-1, # Force CPU
|
195 |
torch_dtype=torch.float32,
|
196 |
)
|
|
|
197 |
|
198 |
model_loaded = True
|
199 |
-
print("
|
200 |
-
return
|
201 |
|
202 |
except Exception as e:
|
203 |
print(f"Error loading model: {str(e)}")
|
@@ -206,79 +166,42 @@ def load_model():
|
|
206 |
finally:
|
207 |
model_loading = False
|
208 |
|
209 |
-
def
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
points = np.stack([x, y, z], axis=-1).reshape(-1, 3)
|
241 |
-
colors = img_array.reshape(-1, 3) / 255.0
|
242 |
-
|
243 |
-
# Filter out invalid points (tighter range for foreground)
|
244 |
-
mask = (z.reshape(-1) > 0.2) & (z.reshape(-1) < 0.8)
|
245 |
-
points = points[mask]
|
246 |
-
colors = colors[mask]
|
247 |
-
|
248 |
-
# Create Open3D point cloud
|
249 |
-
pcd = o3d.geometry.PointCloud()
|
250 |
-
pcd.points = o3d.utility.Vector3dVector(points)
|
251 |
-
pcd.colors = o3d.utility.Vector3dVector(colors)
|
252 |
-
|
253 |
-
# Estimate normals
|
254 |
-
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))
|
255 |
-
|
256 |
-
# Poisson surface reconstruction
|
257 |
-
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
|
258 |
-
pcd, depth=8 if detail_level == 'high' else 6
|
259 |
-
)
|
260 |
-
|
261 |
-
# Convert to trimesh
|
262 |
-
vertices = np.asarray(mesh.vertices)
|
263 |
-
faces = np.asarray(mesh.triangles)
|
264 |
-
vertex_colors = np.asarray(mesh.vertex_colors)
|
265 |
-
|
266 |
-
trimesh_mesh = trimesh.Trimesh(
|
267 |
-
vertices=vertices,
|
268 |
-
faces=faces,
|
269 |
-
vertex_colors=vertex_colors
|
270 |
-
)
|
271 |
-
|
272 |
-
# Rotate mesh to correct orientation (180 degrees around X-axis)
|
273 |
-
trimesh_mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]))
|
274 |
-
|
275 |
-
return trimesh_mesh
|
276 |
|
277 |
@app.route('/health', methods=['GET'])
|
278 |
def health_check():
|
279 |
return jsonify({
|
280 |
"status": "healthy",
|
281 |
-
"model": "
|
282 |
"device": "cpu"
|
283 |
}), 200
|
284 |
|
@@ -375,15 +298,7 @@ def convert_image_to_3d():
|
|
375 |
|
376 |
try:
|
377 |
def generate_3d():
|
378 |
-
|
379 |
-
with torch.no_grad():
|
380 |
-
depth_output = pipeline(image)
|
381 |
-
|
382 |
-
depth_map = depth_output["depth"]
|
383 |
-
|
384 |
-
# Convert depth to mesh
|
385 |
-
mesh = depth_to_point_cloud(depth_map, image, detail_level)
|
386 |
-
return mesh
|
387 |
|
388 |
mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
|
389 |
|
@@ -397,7 +312,7 @@ def convert_image_to_3d():
|
|
397 |
|
398 |
processing_jobs[job_id]['progress'] = 80
|
399 |
|
400 |
-
# Export
|
401 |
file_path = os.path.join(output_dir, f"model.{output_format}")
|
402 |
mesh.export(file_path, file_type=output_format)
|
403 |
|
@@ -406,7 +321,7 @@ def convert_image_to_3d():
|
|
406 |
|
407 |
processing_jobs[job_id]['status'] = 'completed'
|
408 |
processing_jobs[job_id]['progress'] = 100
|
409 |
-
print(f"Job {job_id} completed
|
410 |
|
411 |
except Exception as e:
|
412 |
error_details = traceback.format_exc()
|
@@ -527,7 +442,7 @@ def model_info(job_id):
|
|
527 |
@app.route('/', methods=['GET'])
|
528 |
def index():
|
529 |
return jsonify({
|
530 |
-
"message": "Image to 3D API (
|
531 |
"endpoints": [
|
532 |
"/convert",
|
533 |
"/progress/<job_id>",
|
@@ -537,12 +452,12 @@ def index():
|
|
537 |
],
|
538 |
"parameters": {
|
539 |
"output_format": "glb or obj",
|
540 |
-
"detail_level": "low, medium, or high - controls
|
541 |
},
|
542 |
-
"description": "
|
543 |
}), 200
|
544 |
|
545 |
if __name__ == '__main__':
|
546 |
cleanup_old_jobs()
|
547 |
port = int(os.environ.get('PORT', 7860))
|
548 |
-
app.run(host='0.0.0.0', port=port)
|
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import time
|
|
|
14 |
from flask_cors import CORS
|
15 |
import numpy as np
|
16 |
import trimesh
|
17 |
+
from diffusers import DiffusionPipeline
|
|
|
|
|
18 |
import cv2
|
19 |
|
20 |
# Force CPU usage
|
21 |
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
22 |
torch.set_default_device("cpu")
|
|
|
23 |
torch.cuda.is_available = lambda: False
|
24 |
torch.cuda.device_count = lambda: 0
|
25 |
|
|
|
32 |
CACHE_DIR = '/tmp/huggingface'
|
33 |
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
|
34 |
|
35 |
+
# Create directories
|
36 |
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
37 |
os.makedirs(RESULTS_FOLDER, exist_ok=True)
|
38 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
39 |
|
40 |
+
# Set Hugging Face cache
|
41 |
os.environ['HF_HOME'] = CACHE_DIR
|
42 |
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
|
43 |
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
|
|
|
45 |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
46 |
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
|
47 |
|
48 |
+
# Job tracking
|
49 |
processing_jobs = {}
|
50 |
|
51 |
+
# Global model
|
52 |
+
tripo_pipeline = None
|
53 |
model_loaded = False
|
54 |
model_loading = False
|
55 |
|
56 |
+
# Configuration
|
57 |
+
TIMEOUT_SECONDS = 300 # 5 minutes for TripoSG
|
58 |
+
MAX_DIMENSION = 256 # TripoSG works with smaller images
|
59 |
|
|
|
60 |
class TimeoutError(Exception):
|
61 |
pass
|
62 |
|
|
|
63 |
def process_with_timeout(function, args, timeout):
|
64 |
result = [None]
|
65 |
error = [None]
|
|
|
75 |
thread = threading.Thread(target=target)
|
76 |
thread.daemon = True
|
77 |
thread.start()
|
|
|
78 |
thread.join(timeout)
|
79 |
|
80 |
if not completed[0]:
|
|
|
91 |
def allowed_file(filename):
|
92 |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
93 |
|
94 |
+
# Image preprocessing
|
95 |
def preprocess_image(image_path):
|
96 |
try:
|
|
|
97 |
with Image.open(image_path) as img:
|
98 |
+
# Convert to RGB
|
99 |
if img.mode == 'RGBA':
|
100 |
+
img = img.convert('RGB')
|
101 |
+
# Resize to 256x256
|
102 |
+
img = img.resize((256, 256), Image.LANCZOS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
+
# Basic cv2 cleanup
|
105 |
+
img_array = np.array(img)
|
106 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
107 |
+
_, mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
|
108 |
+
img_array = cv2.bitwise_and(img_array, img_array, mask=mask)
|
109 |
|
110 |
+
return Image.fromarray(img_array)
|
111 |
except Exception as e:
|
112 |
raise Exception(f"Error preprocessing image: {str(e)}")
|
113 |
|
114 |
def load_model():
|
115 |
+
global tripo_pipeline, model_loaded, model_loading
|
116 |
|
117 |
if model_loaded:
|
118 |
+
return tripo_pipeline
|
119 |
|
120 |
if model_loading:
|
121 |
while model_loading and not model_loaded:
|
122 |
time.sleep(0.5)
|
123 |
+
return tripo_pipeline
|
124 |
|
125 |
try:
|
126 |
model_loading = True
|
127 |
+
print("Loading TripoSG model...")
|
128 |
|
129 |
+
model_name = "tripo3d/tripo-sg-3d"
|
130 |
|
131 |
+
# Download model
|
132 |
max_retries = 3
|
133 |
retry_delay = 5
|
134 |
for attempt in range(max_retries):
|
|
|
141 |
break
|
142 |
except Exception as e:
|
143 |
if attempt < max_retries - 1:
|
144 |
+
print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying...")
|
145 |
time.sleep(retry_delay)
|
146 |
retry_delay *= 2
|
147 |
else:
|
148 |
raise
|
149 |
|
150 |
+
# Load TripoSG pipeline
|
151 |
+
tripo_pipeline = DiffusionPipeline.from_pretrained(
|
152 |
+
model_name,
|
|
|
153 |
cache_dir=CACHE_DIR,
|
|
|
154 |
torch_dtype=torch.float32,
|
155 |
)
|
156 |
+
tripo_pipeline.to("cpu")
|
157 |
|
158 |
model_loaded = True
|
159 |
+
print("TripoSG loaded successfully on CPU")
|
160 |
+
return tripo_pipeline
|
161 |
|
162 |
except Exception as e:
|
163 |
print(f"Error loading model: {str(e)}")
|
|
|
166 |
finally:
|
167 |
model_loading = False
|
168 |
|
169 |
+
def generate_3d_model(image, detail_level):
|
170 |
+
try:
|
171 |
+
# Parameters
|
172 |
+
num_steps = {'low': 20, 'medium': 30, 'high': 40}
|
173 |
+
steps = num_steps[detail_level]
|
174 |
+
|
175 |
+
# Generate 3D model
|
176 |
+
with torch.no_grad():
|
177 |
+
result = tripo_pipeline(image, num_inference_steps=steps)
|
178 |
+
|
179 |
+
# Extract mesh
|
180 |
+
mesh = result.meshes[0]
|
181 |
+
|
182 |
+
# Convert to trimesh
|
183 |
+
vertices = np.array(mesh.vertices)
|
184 |
+
faces = np.array(mesh.faces)
|
185 |
+
vertex_colors = np.array(mesh.vertex_colors) if mesh.vertex_colors is not None else None
|
186 |
+
|
187 |
+
trimesh_mesh = trimesh.Trimesh(
|
188 |
+
vertices=vertices,
|
189 |
+
faces=faces,
|
190 |
+
vertex_colors=vertex_colors
|
191 |
+
)
|
192 |
+
|
193 |
+
# Rotate for Unity Z-up
|
194 |
+
trimesh_mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]))
|
195 |
+
|
196 |
+
return trimesh_mesh
|
197 |
+
except Exception as e:
|
198 |
+
raise Exception(f"Error generating 3D model: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
@app.route('/health', methods=['GET'])
|
201 |
def health_check():
|
202 |
return jsonify({
|
203 |
"status": "healthy",
|
204 |
+
"model": "TripoSG",
|
205 |
"device": "cpu"
|
206 |
}), 200
|
207 |
|
|
|
298 |
|
299 |
try:
|
300 |
def generate_3d():
|
301 |
+
return generate_3d_model(image, detail_level)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
303 |
mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
|
304 |
|
|
|
312 |
|
313 |
processing_jobs[job_id]['progress'] = 80
|
314 |
|
315 |
+
# Export
|
316 |
file_path = os.path.join(output_dir, f"model.{output_format}")
|
317 |
mesh.export(file_path, file_type=output_format)
|
318 |
|
|
|
321 |
|
322 |
processing_jobs[job_id]['status'] = 'completed'
|
323 |
processing_jobs[job_id]['progress'] = 100
|
324 |
+
print(f"Job {job_id} completed")
|
325 |
|
326 |
except Exception as e:
|
327 |
error_details = traceback.format_exc()
|
|
|
442 |
@app.route('/', methods=['GET'])
|
443 |
def index():
|
444 |
return jsonify({
|
445 |
+
"message": "Image to 3D API (TripoSG)",
|
446 |
"endpoints": [
|
447 |
"/convert",
|
448 |
"/progress/<job_id>",
|
|
|
452 |
],
|
453 |
"parameters": {
|
454 |
"output_format": "glb or obj",
|
455 |
+
"detail_level": "low, medium, or high - controls inference steps"
|
456 |
},
|
457 |
+
"description": "Creates 3D models from 2D images using TripoSG. Use transparent PNGs for best results."
|
458 |
}), 200
|
459 |
|
460 |
if __name__ == '__main__':
|
461 |
cleanup_old_jobs()
|
462 |
port = int(os.environ.get('PORT', 7860))
|
463 |
+
app.run(host='0.0.0.0', port=port)
|