Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,11 +11,11 @@ import io
|
|
11 |
import zipfile
|
12 |
import uuid
|
13 |
import traceback
|
14 |
-
from
|
15 |
-
from diffusers.utils import export_to_obj
|
16 |
from huggingface_hub import snapshot_download
|
17 |
from flask_cors import CORS
|
18 |
-
import
|
|
|
19 |
|
20 |
app = Flask(__name__)
|
21 |
CORS(app) # Enable CORS for all routes
|
@@ -42,15 +42,15 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
|
|
42 |
# Job tracking dictionary
|
43 |
processing_jobs = {}
|
44 |
|
45 |
-
# Global model
|
46 |
-
|
|
|
47 |
model_loaded = False
|
48 |
model_loading = False
|
49 |
|
50 |
# Configuration for processing
|
51 |
-
TIMEOUT_SECONDS =
|
52 |
MAX_DIMENSION = 512 # Max image dimension to process
|
53 |
-
MAX_INFERENCE_STEPS = 64 # Maximum allowed inference steps to prevent the index error
|
54 |
|
55 |
# TimeoutError for handling timeouts
|
56 |
class TimeoutError(Exception):
|
@@ -89,7 +89,7 @@ def process_with_timeout(function, args, timeout):
|
|
89 |
def allowed_file(filename):
|
90 |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
91 |
|
92 |
-
# Function to preprocess image
|
93 |
def preprocess_image(image_path):
|
94 |
with Image.open(image_path) as img:
|
95 |
img = img.convert("RGB")
|
@@ -104,26 +104,26 @@ def preprocess_image(image_path):
|
|
104 |
new_width = int(img.width * (MAX_DIMENSION / img.height))
|
105 |
img = img.resize((new_width, new_height), Image.LANCZOS)
|
106 |
|
107 |
-
# Convert to RGB and return
|
108 |
return img
|
109 |
|
110 |
def load_model():
|
111 |
-
global
|
112 |
|
113 |
if model_loaded:
|
114 |
-
return
|
115 |
|
116 |
if model_loading:
|
117 |
# Wait for model to load if it's already in progress
|
118 |
while model_loading and not model_loaded:
|
119 |
time.sleep(0.5)
|
120 |
-
return
|
121 |
|
122 |
try:
|
123 |
model_loading = True
|
124 |
print("Starting model loading...")
|
125 |
|
126 |
-
|
|
|
127 |
|
128 |
# Download model with retry mechanism
|
129 |
max_retries = 3
|
@@ -145,24 +145,26 @@ def load_model():
|
|
145 |
else:
|
146 |
raise
|
147 |
|
148 |
-
# Initialize
|
149 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
150 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
151 |
|
152 |
-
|
|
|
153 |
model_name,
|
154 |
torch_dtype=dtype,
|
155 |
cache_dir=CACHE_DIR,
|
|
|
156 |
)
|
157 |
-
|
158 |
|
159 |
# Optimize for inference
|
160 |
if device == "cuda":
|
161 |
-
|
162 |
|
163 |
model_loaded = True
|
164 |
print(f"Model loaded successfully on {device}")
|
165 |
-
return
|
166 |
|
167 |
except Exception as e:
|
168 |
print(f"Error loading model: {str(e)}")
|
@@ -171,11 +173,89 @@ def load_model():
|
|
171 |
finally:
|
172 |
model_loading = False
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
@app.route('/health', methods=['GET'])
|
175 |
def health_check():
|
176 |
return jsonify({
|
177 |
"status": "healthy",
|
178 |
-
"model": "
|
179 |
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
180 |
}), 200
|
181 |
|
@@ -234,7 +314,6 @@ def convert_image_to_3d():
|
|
234 |
# Get optional parameters with defaults
|
235 |
try:
|
236 |
guidance_scale = float(request.form.get('guidance_scale', 3.0))
|
237 |
-
num_inference_steps = min(int(request.form.get('num_inference_steps', 64)), MAX_INFERENCE_STEPS)
|
238 |
output_format = request.form.get('output_format', 'obj').lower()
|
239 |
except ValueError:
|
240 |
return jsonify({"error": "Invalid parameter values"}), 400
|
@@ -243,9 +322,6 @@ def convert_image_to_3d():
|
|
243 |
if guidance_scale < 1.0 or guidance_scale > 5.0:
|
244 |
return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
|
245 |
|
246 |
-
if num_inference_steps < 32 or num_inference_steps > MAX_INFERENCE_STEPS:
|
247 |
-
num_inference_steps = min(num_inference_steps, MAX_INFERENCE_STEPS)
|
248 |
-
|
249 |
# Validate output format
|
250 |
if output_format not in ['obj', 'glb']:
|
251 |
return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
|
@@ -277,14 +353,14 @@ def convert_image_to_3d():
|
|
277 |
processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
|
278 |
|
279 |
try:
|
280 |
-
# Preprocess image
|
281 |
processing_jobs[job_id]['progress'] = 5
|
282 |
image = preprocess_image(filepath)
|
283 |
processing_jobs[job_id]['progress'] = 10
|
284 |
|
285 |
# Load model
|
286 |
try:
|
287 |
-
|
288 |
processing_jobs[job_id]['progress'] = 30
|
289 |
except Exception as e:
|
290 |
processing_jobs[job_id]['status'] = 'error'
|
@@ -293,15 +369,18 @@ def convert_image_to_3d():
|
|
293 |
|
294 |
# Process image with thread-safe timeout
|
295 |
try:
|
296 |
-
def
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
)
|
|
|
|
|
|
|
303 |
|
304 |
-
|
305 |
|
306 |
if error:
|
307 |
if isinstance(error, TimeoutError):
|
@@ -312,6 +391,10 @@ def convert_image_to_3d():
|
|
312 |
raise error
|
313 |
|
314 |
processing_jobs[job_id]['progress'] = 80
|
|
|
|
|
|
|
|
|
315 |
except Exception as e:
|
316 |
error_details = traceback.format_exc()
|
317 |
processing_jobs[job_id]['status'] = 'error'
|
@@ -324,7 +407,7 @@ def convert_image_to_3d():
|
|
324 |
try:
|
325 |
if output_format == 'obj':
|
326 |
obj_path = os.path.join(output_dir, "model.obj")
|
327 |
-
|
328 |
|
329 |
# Create a zip file with OBJ and MTL
|
330 |
zip_path = os.path.join(output_dir, "model.zip")
|
@@ -338,17 +421,9 @@ def convert_image_to_3d():
|
|
338 |
processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
|
339 |
|
340 |
elif output_format == 'glb':
|
341 |
-
from trimesh import Trimesh
|
342 |
-
mesh = images[0]
|
343 |
-
vertices = mesh.verts
|
344 |
-
faces = mesh.faces
|
345 |
-
|
346 |
-
# Create a trimesh object
|
347 |
-
trimesh_obj = Trimesh(vertices=vertices, faces=faces)
|
348 |
-
|
349 |
# Export as GLB
|
350 |
glb_path = os.path.join(output_dir, "model.glb")
|
351 |
-
|
352 |
|
353 |
processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
|
354 |
processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
|
|
|
11 |
import zipfile
|
12 |
import uuid
|
13 |
import traceback
|
14 |
+
from transformers import AutoImageProcessor, AutoModel
|
|
|
15 |
from huggingface_hub import snapshot_download
|
16 |
from flask_cors import CORS
|
17 |
+
import numpy as np
|
18 |
+
import trimesh
|
19 |
|
20 |
app = Flask(__name__)
|
21 |
CORS(app) # Enable CORS for all routes
|
|
|
42 |
# Job tracking dictionary
|
43 |
processing_jobs = {}
|
44 |
|
45 |
+
# Global model variables
|
46 |
+
image_processor = None
|
47 |
+
model = None
|
48 |
model_loaded = False
|
49 |
model_loading = False
|
50 |
|
51 |
# Configuration for processing
|
52 |
+
TIMEOUT_SECONDS = 180 # 3 minutes max for processing
|
53 |
MAX_DIMENSION = 512 # Max image dimension to process
|
|
|
54 |
|
55 |
# TimeoutError for handling timeouts
|
56 |
class TimeoutError(Exception):
|
|
|
89 |
def allowed_file(filename):
|
90 |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
91 |
|
92 |
+
# Function to preprocess image
|
93 |
def preprocess_image(image_path):
|
94 |
with Image.open(image_path) as img:
|
95 |
img = img.convert("RGB")
|
|
|
104 |
new_width = int(img.width * (MAX_DIMENSION / img.height))
|
105 |
img = img.resize((new_width, new_height), Image.LANCZOS)
|
106 |
|
|
|
107 |
return img
|
108 |
|
109 |
def load_model():
|
110 |
+
global image_processor, model, model_loaded, model_loading
|
111 |
|
112 |
if model_loaded:
|
113 |
+
return image_processor, model
|
114 |
|
115 |
if model_loading:
|
116 |
# Wait for model to load if it's already in progress
|
117 |
while model_loading and not model_loaded:
|
118 |
time.sleep(0.5)
|
119 |
+
return image_processor, model
|
120 |
|
121 |
try:
|
122 |
model_loading = True
|
123 |
print("Starting model loading...")
|
124 |
|
125 |
+
# Using a lightweight model: Pictorial 3D Scene Representation
|
126 |
+
model_name = "damo-vilab/text-to-3d-texture-base" # Smaller model than ShapE-img2img
|
127 |
|
128 |
# Download model with retry mechanism
|
129 |
max_retries = 3
|
|
|
145 |
else:
|
146 |
raise
|
147 |
|
148 |
+
# Initialize model with lower precision to save memory
|
149 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
150 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
151 |
|
152 |
+
image_processor = AutoImageProcessor.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
153 |
+
model = AutoModel.from_pretrained(
|
154 |
model_name,
|
155 |
torch_dtype=dtype,
|
156 |
cache_dir=CACHE_DIR,
|
157 |
+
low_cpu_mem_usage=True,
|
158 |
)
|
159 |
+
model = model.to(device)
|
160 |
|
161 |
# Optimize for inference
|
162 |
if device == "cuda":
|
163 |
+
model = model.half() # Use half precision on GPU
|
164 |
|
165 |
model_loaded = True
|
166 |
print(f"Model loaded successfully on {device}")
|
167 |
+
return image_processor, model
|
168 |
|
169 |
except Exception as e:
|
170 |
print(f"Error loading model: {str(e)}")
|
|
|
173 |
finally:
|
174 |
model_loading = False
|
175 |
|
176 |
+
# Convert model output to 3D mesh
|
177 |
+
def create_mesh_from_output(output, resolution=64):
|
178 |
+
"""Create a mesh from model output"""
|
179 |
+
# Extract features from model output and create mesh
|
180 |
+
# This is a simplified implementation - adapt based on your specific model
|
181 |
+
features = output.last_hidden_state.detach().cpu().numpy()[0]
|
182 |
+
|
183 |
+
# Create a simple cube mesh as placeholder - replace with actual mesh generation
|
184 |
+
vertices, faces = create_primitive_mesh(features, resolution)
|
185 |
+
|
186 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
|
187 |
+
return mesh
|
188 |
+
|
189 |
+
def create_primitive_mesh(features, resolution=64):
|
190 |
+
"""Create a simple primitive mesh based on features"""
|
191 |
+
# Create a mesh using features as modifiers
|
192 |
+
# This is a simplified implementation - adapt based on your specific model's output
|
193 |
+
|
194 |
+
# Create a cube/sphere mesh as a placeholder
|
195 |
+
use_sphere = True # Change to False for cube
|
196 |
+
|
197 |
+
if use_sphere:
|
198 |
+
# Create a sphere
|
199 |
+
u = np.linspace(0, 2 * np.pi, resolution)
|
200 |
+
v = np.linspace(0, np.pi, resolution)
|
201 |
+
|
202 |
+
# Base radius and modifiers
|
203 |
+
base_radius = 1.0
|
204 |
+
|
205 |
+
# Use some features to modify the radius (just as an example)
|
206 |
+
feature_sum = np.sum(features[:10]) # Use first 10 features
|
207 |
+
radius_mod = 0.5 + (feature_sum % 1.0) # Simple modifier between 0.5 and 1.5
|
208 |
+
|
209 |
+
# Create vertices
|
210 |
+
x = base_radius * radius_mod * np.outer(np.cos(u), np.sin(v))
|
211 |
+
y = base_radius * radius_mod * np.outer(np.sin(u), np.sin(v))
|
212 |
+
z = base_radius * radius_mod * np.outer(np.ones_like(u), np.cos(v))
|
213 |
+
|
214 |
+
# Reshape to get list of vertices
|
215 |
+
vertices = np.vstack([x.flatten(), y.flatten(), z.flatten()]).T
|
216 |
+
|
217 |
+
# Create faces (triangles)
|
218 |
+
faces = []
|
219 |
+
for i in range(resolution-1):
|
220 |
+
for j in range(resolution-1):
|
221 |
+
p1 = i * resolution + j
|
222 |
+
p2 = i * resolution + (j + 1)
|
223 |
+
p3 = (i + 1) * resolution + j
|
224 |
+
p4 = (i + 1) * resolution + (j + 1)
|
225 |
+
|
226 |
+
faces.append([p1, p2, p4])
|
227 |
+
faces.append([p1, p4, p3])
|
228 |
+
|
229 |
+
faces = np.array(faces)
|
230 |
+
else:
|
231 |
+
# Create a cube
|
232 |
+
vertices = np.array([
|
233 |
+
[-1, -1, -1], [1, -1, -1], [1, 1, -1], [-1, 1, -1],
|
234 |
+
[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]
|
235 |
+
])
|
236 |
+
|
237 |
+
# Apply some feature-based modifications
|
238 |
+
feature_sum = np.sum(features[:10]) # Use first 10 features
|
239 |
+
scale_factor = 0.5 + (feature_sum % 1.0) # Simple modifier between 0.5 and 1.5
|
240 |
+
vertices *= scale_factor
|
241 |
+
|
242 |
+
# Faces (triangles)
|
243 |
+
faces = np.array([
|
244 |
+
[0, 1, 2], [0, 2, 3], # Bottom face
|
245 |
+
[4, 5, 6], [4, 6, 7], # Top face
|
246 |
+
[0, 1, 5], [0, 5, 4], # Front face
|
247 |
+
[2, 3, 7], [2, 7, 6], # Back face
|
248 |
+
[0, 3, 7], [0, 7, 4], # Left face
|
249 |
+
[1, 2, 6], [1, 6, 5] # Right face
|
250 |
+
])
|
251 |
+
|
252 |
+
return vertices, faces
|
253 |
+
|
254 |
@app.route('/health', methods=['GET'])
|
255 |
def health_check():
|
256 |
return jsonify({
|
257 |
"status": "healthy",
|
258 |
+
"model": "Lightweight 3D Model Generator",
|
259 |
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
260 |
}), 200
|
261 |
|
|
|
314 |
# Get optional parameters with defaults
|
315 |
try:
|
316 |
guidance_scale = float(request.form.get('guidance_scale', 3.0))
|
|
|
317 |
output_format = request.form.get('output_format', 'obj').lower()
|
318 |
except ValueError:
|
319 |
return jsonify({"error": "Invalid parameter values"}), 400
|
|
|
322 |
if guidance_scale < 1.0 or guidance_scale > 5.0:
|
323 |
return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
|
324 |
|
|
|
|
|
|
|
325 |
# Validate output format
|
326 |
if output_format not in ['obj', 'glb']:
|
327 |
return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
|
|
|
353 |
processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
|
354 |
|
355 |
try:
|
356 |
+
# Preprocess image
|
357 |
processing_jobs[job_id]['progress'] = 5
|
358 |
image = preprocess_image(filepath)
|
359 |
processing_jobs[job_id]['progress'] = 10
|
360 |
|
361 |
# Load model
|
362 |
try:
|
363 |
+
processor, model_instance = load_model()
|
364 |
processing_jobs[job_id]['progress'] = 30
|
365 |
except Exception as e:
|
366 |
processing_jobs[job_id]['status'] = 'error'
|
|
|
369 |
|
370 |
# Process image with thread-safe timeout
|
371 |
try:
|
372 |
+
def generate_3d():
|
373 |
+
# Process the image
|
374 |
+
device = model_instance.device
|
375 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
376 |
+
|
377 |
+
# Forward pass through model
|
378 |
+
with torch.no_grad():
|
379 |
+
outputs = model_instance(**inputs)
|
380 |
+
|
381 |
+
return outputs
|
382 |
|
383 |
+
outputs, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
|
384 |
|
385 |
if error:
|
386 |
if isinstance(error, TimeoutError):
|
|
|
391 |
raise error
|
392 |
|
393 |
processing_jobs[job_id]['progress'] = 80
|
394 |
+
|
395 |
+
# Create mesh from outputs
|
396 |
+
mesh = create_mesh_from_output(outputs)
|
397 |
+
|
398 |
except Exception as e:
|
399 |
error_details = traceback.format_exc()
|
400 |
processing_jobs[job_id]['status'] = 'error'
|
|
|
407 |
try:
|
408 |
if output_format == 'obj':
|
409 |
obj_path = os.path.join(output_dir, "model.obj")
|
410 |
+
mesh.export(obj_path, file_type='obj')
|
411 |
|
412 |
# Create a zip file with OBJ and MTL
|
413 |
zip_path = os.path.join(output_dir, "model.zip")
|
|
|
421 |
processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
|
422 |
|
423 |
elif output_format == 'glb':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
# Export as GLB
|
425 |
glb_path = os.path.join(output_dir, "model.glb")
|
426 |
+
mesh.export(glb_path, file_type='glb')
|
427 |
|
428 |
processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
|
429 |
processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
|