import os import torch import time import threading import json import gc from flask import Flask, request, jsonify, send_file, Response, stream_with_context from werkzeug.utils import secure_filename from PIL import Image import io import zipfile import uuid import traceback from diffusers import ShapEImg2ImgPipeline from diffusers.utils import export_to_obj from huggingface_hub import snapshot_download from flask_cors import CORS import signal import functools app = Flask(__name__) CORS(app) # Enable CORS for all routes # Configure directories UPLOAD_FOLDER = '/tmp/uploads' RESULTS_FOLDER = '/tmp/results' CACHE_DIR = '/tmp/huggingface' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # Create necessary directories os.makedirs(UPLOAD_FOLDER, exist_ok=True) os.makedirs(RESULTS_FOLDER, exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True) # Set Hugging Face cache environment variables os.environ['HF_HOME'] = CACHE_DIR os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers') os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets') app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max # Job tracking dictionary processing_jobs = {} # Global model variable pipe = None model_loaded = False model_loading = False # Configuration for processing TIMEOUT_SECONDS = 300 # 5 minutes max for processing MAX_DIMENSION = 512 # Max image dimension to process # Timeout handler for long-running processes class TimeoutError(Exception): pass def timeout_handler(signum, frame): raise TimeoutError("Processing timed out") def with_timeout(timeout): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): # Set the timeout handler signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(timeout) try: result = func(*args, **kwargs) finally: # Disable the alarm signal.alarm(0) return result return wrapper return decorator def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS # Function to preprocess image - resize if needed def preprocess_image(image_path): with Image.open(image_path) as img: img = img.convert("RGB") # Resize if the image is too large if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION: # Calculate new dimensions while preserving aspect ratio if img.width > img.height: new_width = MAX_DIMENSION new_height = int(img.height * (MAX_DIMENSION / img.width)) else: new_height = MAX_DIMENSION new_width = int(img.width * (MAX_DIMENSION / img.height)) img = img.resize((new_width, new_height), Image.LANCZOS) # Convert to RGB and return return img def load_model(): global pipe, model_loaded, model_loading if model_loaded: return pipe if model_loading: # Wait for model to load if it's already in progress while model_loading and not model_loaded: time.sleep(0.5) return pipe try: model_loading = True print("Starting model loading...") model_name = "openai/shap-e-img2img" # Download model with retry mechanism max_retries = 3 retry_delay = 5 for attempt in range(max_retries): try: snapshot_download( repo_id=model_name, cache_dir=CACHE_DIR, resume_download=True, ) break except Exception as e: if attempt < max_retries - 1: print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying in {retry_delay} seconds...") time.sleep(retry_delay) retry_delay *= 2 else: raise # Initialize pipeline with lower precision to save memory device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 pipe = ShapEImg2ImgPipeline.from_pretrained( model_name, torch_dtype=dtype, cache_dir=CACHE_DIR, ) pipe = pipe.to(device) # Optimize for inference if device == "cuda": pipe.enable_model_cpu_offload() model_loaded = True print(f"Model loaded successfully on {device}") return pipe except Exception as e: print(f"Error loading model: {str(e)}") print(traceback.format_exc()) raise finally: model_loading = False @app.route('/health', methods=['GET']) def health_check(): return jsonify({ "status": "healthy", "model": "Shap-E Image to 3D", "device": "cuda" if torch.cuda.is_available() else "cpu" }), 200 @app.route('/progress/', methods=['GET']) def progress(job_id): def generate(): if job_id not in processing_jobs: yield f"data: {json.dumps({'error': 'Job not found'})}\n\n" return job = processing_jobs[job_id] # Send initial progress yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n" # Wait for job to complete or update last_progress = job['progress'] check_count = 0 while job['status'] == 'processing': if job['progress'] != last_progress: yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n" last_progress = job['progress'] time.sleep(0.5) check_count += 1 # If client hasn't received updates for a while, check if job is still running if check_count > 60: # 30 seconds with no updates if 'thread_alive' in job and not job['thread_alive'](): job['status'] = 'error' job['error'] = 'Processing thread died unexpectedly' break check_count = 0 # Send final status if job['status'] == 'completed': yield f"data: {json.dumps({'status': 'completed', 'progress': 100, 'result_url': job['result_url'], 'preview_url': job['preview_url']})}\n\n" else: yield f"data: {json.dumps({'status': 'error', 'error': job['error']})}\n\n" return Response(stream_with_context(generate()), mimetype='text/event-stream') @app.route('/convert', methods=['POST']) def convert_image_to_3d(): # Check if image is in the request if 'image' not in request.files: return jsonify({"error": "No image provided"}), 400 file = request.files['image'] if file.filename == '': return jsonify({"error": "No image selected"}), 400 if not allowed_file(file.filename): return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400 # Get optional parameters with defaults try: guidance_scale = float(request.form.get('guidance_scale', 3.0)) num_inference_steps = int(request.form.get('num_inference_steps', 64)) output_format = request.form.get('output_format', 'obj').lower() except ValueError: return jsonify({"error": "Invalid parameter values"}), 400 # Validate parameters if guidance_scale < 1.0 or guidance_scale > 5.0: return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400 if num_inference_steps < 32 or num_inference_steps > 128: return jsonify({"error": "Number of inference steps must be between 32 and 128"}), 400 # Validate output format if output_format not in ['obj', 'glb']: return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400 # Create a job ID job_id = str(uuid.uuid4()) output_dir = os.path.join(RESULTS_FOLDER, job_id) os.makedirs(output_dir, exist_ok=True) # Save the uploaded file filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}") file.save(filepath) # Initialize job tracking processing_jobs[job_id] = { 'status': 'processing', 'progress': 0, 'result_url': None, 'preview_url': None, 'error': None, 'output_format': output_format, 'created_at': time.time() } # Process function with timeout @with_timeout(TIMEOUT_SECONDS) def process_with_timeout(image, steps, scale, format): # Load model pipe = load_model() processing_jobs[job_id]['progress'] = 30 # Generate 3D model return pipe( image, guidance_scale=scale, num_inference_steps=steps, output_type="mesh", ).images # Start processing in a separate thread def process_image(): thread = threading.current_thread() processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive() try: # Preprocess image (resize if needed) processing_jobs[job_id]['progress'] = 5 image = preprocess_image(filepath) processing_jobs[job_id]['progress'] = 10 # Process image with timeout try: images = process_with_timeout(image, num_inference_steps, guidance_scale, output_format) processing_jobs[job_id]['progress'] = 80 except TimeoutError: processing_jobs[job_id]['status'] = 'error' processing_jobs[job_id]['error'] = f"Processing timed out after {TIMEOUT_SECONDS} seconds" return # Export based on requested format if output_format == 'obj': obj_path = os.path.join(output_dir, "model.obj") export_to_obj(images[0], obj_path) # Create a zip file with OBJ and MTL zip_path = os.path.join(output_dir, "model.zip") with zipfile.ZipFile(zip_path, 'w') as zipf: zipf.write(obj_path, arcname="model.obj") mtl_path = os.path.join(output_dir, "model.mtl") if os.path.exists(mtl_path): zipf.write(mtl_path, arcname="model.mtl") processing_jobs[job_id]['result_url'] = f"/download/{job_id}" processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}" elif output_format == 'glb': from trimesh import Trimesh mesh = images[0] vertices = mesh.verts faces = mesh.faces # Create a trimesh object trimesh_obj = Trimesh(vertices=vertices, faces=faces) # Export as GLB glb_path = os.path.join(output_dir, "model.glb") trimesh_obj.export(glb_path) processing_jobs[job_id]['result_url'] = f"/download/{job_id}" processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}" # Update job status processing_jobs[job_id]['status'] = 'completed' processing_jobs[job_id]['progress'] = 100 # Clean up temporary file if os.path.exists(filepath): os.remove(filepath) # Force garbage collection to free memory gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: # Handle errors error_details = traceback.format_exc() processing_jobs[job_id]['status'] = 'error' processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}" print(f"Error processing job {job_id}: {str(e)}") print(error_details) # Clean up on error if os.path.exists(filepath): os.remove(filepath) # Start processing thread processing_thread = threading.Thread(target=process_image) processing_thread.daemon = True processing_thread.start() # Return job ID immediately return jsonify({"job_id": job_id}), 202 # 202 Accepted @app.route('/download/', methods=['GET']) def download_model(job_id): if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed': return jsonify({"error": "Model not found or processing not complete"}), 404 # Get the output directory for this job output_dir = os.path.join(RESULTS_FOLDER, job_id) # Determine file format from the job data output_format = processing_jobs[job_id].get('output_format', 'obj') if output_format == 'obj': zip_path = os.path.join(output_dir, "model.zip") if os.path.exists(zip_path): return send_file(zip_path, as_attachment=True, download_name="model.zip") else: # glb glb_path = os.path.join(output_dir, "model.glb") if os.path.exists(glb_path): return send_file(glb_path, as_attachment=True, download_name="model.glb") return jsonify({"error": "File not found"}), 404 @app.route('/preview/', methods=['GET']) def preview_model(job_id): if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed': return jsonify({"error": "Model not found or processing not complete"}), 404 # Get the output directory for this job output_dir = os.path.join(RESULTS_FOLDER, job_id) output_format = processing_jobs[job_id].get('output_format', 'obj') if output_format == 'obj': obj_path = os.path.join(output_dir, "model.obj") if os.path.exists(obj_path): return send_file(obj_path, mimetype='model/obj') else: # glb glb_path = os.path.join(output_dir, "model.glb") if os.path.exists(glb_path): return send_file(glb_path, mimetype='model/gltf-binary') return jsonify({"error": "Model file not found"}), 404 # Cleanup old jobs periodically def cleanup_old_jobs(): current_time = time.time() job_ids_to_remove = [] for job_id, job_data in processing_jobs.items(): # Remove completed jobs after 1 hour if job_data['status'] == 'completed' and (current_time - job_data.get('created_at', 0)) > 3600: job_ids_to_remove.append(job_id) # Remove error jobs after 30 minutes elif job_data['status'] == 'error' and (current_time - job_data.get('created_at', 0)) > 1800: job_ids_to_remove.append(job_id) # Remove the jobs for job_id in job_ids_to_remove: output_dir = os.path.join(RESULTS_FOLDER, job_id) try: import shutil if os.path.exists(output_dir): shutil.rmtree(output_dir) except Exception as e: print(f"Error cleaning up job {job_id}: {str(e)}") # Remove from tracking dictionary if job_id in processing_jobs: del processing_jobs[job_id] # Schedule the next cleanup threading.Timer(300, cleanup_old_jobs).start() # Run every 5 minutes @app.route('/', methods=['GET']) def index(): return jsonify({ "message": "Image to 3D API is running", "endpoints": ["/convert", "/progress/", "/download/", "/preview/"] }), 200 if __name__ == '__main__': # Start the cleanup thread cleanup_old_jobs() # Use port 7860 which is standard for Hugging Face Spaces port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port)