import os import torch import time import threading import json import gc from flask import Flask, request, jsonify, send_file, Response, stream_with_context from werkzeug.utils import secure_filename from PIL import Image import io import zipfile import uuid import traceback from huggingface_hub import snapshot_download from flask_cors import CORS import numpy as np import trimesh import cv2 from transformers import AutoModel, AutoProcessor # For TripoSR from u2net import U2NET # For background removal; install from https://github.com/xuebinqin/U-2-Net import torchvision.transforms as T app = Flask(__name__) CORS(app) # Configure directories UPLOAD_FOLDER = '/tmp/uploads' RESULTS_FOLDER = '/tmp/results' CACHE_DIR = '/tmp/huggingface' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # Create directories os.makedirs(UPLOAD_FOLDER, exist_ok=True) os.makedirs(RESULTS_FOLDER, exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True) # Set Hugging Face cache os.environ['HF_HOME'] = CACHE_DIR os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers') os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets') app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max # Job tracking processing_jobs = {} # Global model variables u2net_model = None triposr_model = None triposr_processor = None model_loaded = False model_loading = False # Configuration TIMEOUT_SECONDS = 240 # 4 minutes max MAX_DIMENSION = 512 # Max image dimension class TimeoutError(Exception): pass def process_with_timeout(function, args, timeout): result = [None] error = [None] completed = [False] def target(): try: result[0] = function(*args) completed[0] = True except Exception as e: error[0] = e thread = threading.Thread(target=target) thread.daemon = True thread.start() thread.join(timeout) if not completed[0]: if thread.is_alive(): return None, TimeoutError(f"Processing timed out after {timeout} seconds") elif error[0]: return None, error[0] if error[0]: return None, error[0] return result[0], None def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def preprocess_image(image_path): with Image.open(image_path) as img: img = img.convert("RGB") # Resize if too large if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION: if img.width > img.height: new_width = MAX_DIMENSION new_height = int(img.height * (MAX_DIMENSION / img.width)) else: new_height = MAX_DIMENSION new_width = int(img.width * (MAX_DIMENSION / img.height)) img = img.resize((new_width, new_height), Image.LANCZOS) # Apply adaptive histogram equalization img_array = np.array(img) if len(img_array.shape) == 3 and img_array.shape[2] == 3: lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) cl = clahe.apply(l) enhanced_lab = cv2.merge((cl, a, b)) img_array = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB) img = Image.fromarray(img_array) return img def remove_background(image): global u2net_model if u2net_model is None: u2net_model = U2NET() u2net_model.load_state_dict(torch.load('u2net.pth', map_location='cpu')) u2net_model.eval() u2net_model.to('cpu') img_array = np.array(image) img_tensor = T.ToTensor()(image.resize((320, 320))).unsqueeze(0) with torch.no_grad(): d1, *_ = u2net_model(img_tensor) pred = d1[:, 0, :, :] pred = (pred - pred.min()) / (pred.max() - pred.min()) mask = (pred > 0.5).float().squeeze().numpy() mask_img = Image.fromarray((mask * 255).astype('uint8')).resize(image.size) mask_array = np.array(mask_img)[:, :, np.newaxis] / 255 result = img_array * mask_array + (1 - mask_array) * 255 # White background return Image.fromarray(result.astype('uint8')) def load_model(): global triposr_model, triposr_processor, model_loaded, model_loading if model_loaded: return triposr_model, triposr_processor if model_loading: while model_loading and not model_loaded: time.sleep(0.5) return triposr_model, triposr_processor try: model_loading = True print("Loading TripoSR model...") model_name = "stabilityai/TripoSR" max_retries = 3 retry_delay = 5 for attempt in range(max_retries): try: snapshot_download( repo_id=model_name, cache_dir=CACHE_DIR, resume_download=True, ) break except Exception as e: if attempt < max_retries - 1: print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying...") time.sleep(retry_delay) retry_delay *= 2 else: raise triposr_processor = AutoProcessor.from_pretrained(model_name, cache_dir=CACHE_DIR) triposr_model = AutoModel.from_pretrained(model_name, cache_dir=CACHE_DIR) triposr_model.to('cpu') model_loaded = True print("TripoSR model loaded successfully on CPU") return triposr_model, triposr_processor except Exception as e: print(f"Error loading model: {str(e)}") print(traceback.format_exc()) raise finally: model_loading = False def optimize_mesh(mesh, detail_level='medium'): # Simplify mesh based on detail level if detail_level == 'high': target_faces = 50000 elif detail_level == 'medium': target_faces = 30000 else: target_faces = 15000 if len(mesh.faces) > target_faces: mesh = mesh.simplify_quadric_decimation(target_faces) # Fix normals mesh.fix_normals() return mesh @app.route('/health', methods=['GET']) def health_check(): return jsonify({ "status": "healthy", "model": "TripoSR 3D Model Generator", "device": "cpu" }), 200 @app.route('/progress/', methods=['GET']) def progress(job_id): def generate(): if job_id not in processing_jobs: yield f"data: {json.dumps({'error': 'Job not found'})}\n\n" return job = processing_jobs[job_id] yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n" last_progress = job['progress'] check_count = 0 while job['status'] == 'processing': if job['progress'] != last_progress: yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n" last_progress = job['progress'] time.sleep(0.5) check_count += 1 if check_count > 60: if 'thread_alive' in job and not job['thread_alive'](): job['status'] = 'error' job['error'] = 'Processing thread died unexpectedly' break check_count = 0 if job['status'] == 'completed': yield f"data: {json.dumps({'status': 'completed', 'progress': 100, 'result_url': job['result_url'], 'preview_url': job['preview_url']})}\n\n" else: yield f"data: {json.dumps({'status': 'error', 'error': job['error']})}\n\n" return Response(stream_with_context(generate()), mimetype='text/event-stream') @app.route('/convert', methods=['POST']) def convert_image_to_3d(): if 'image' not in request.files: return jsonify({"error": "No image provided"}), 400 file = request.files['image'] if file.filename == '': return jsonify({"error": "No image selected"}), 400 if not allowed_file(file.filename): return jsonify({"error": f"File type not allowed: {', '.join(ALLOWED_EXTENSIONS)}"}), 400 try: output_format = request.form.get('output_format', 'glb').lower() detail_level = request.form.get('detail_level', 'medium').lower() texture_quality = request.form.get('texture_quality', 'medium').lower() except ValueError: return jsonify({"error": "Invalid parameter values"}), 400 if output_format not in ['obj', 'glb']: return jsonify({"error": "Unsupported output format: 'obj' or 'glb'"}), 400 job_id = str(uuid.uuid4()) output_dir = os.path.join(RESULTS_FOLDER, job_id) os.makedirs(output_dir, exist_ok=True) filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}") file.save(filepath) processing_jobs[job_id] = { 'status': 'processing', 'progress': 0, 'result_url': None, 'preview_url': None, 'error': None, 'output_format': output_format, 'created_at': time.time() } def process_image(): thread = threading.current_thread() processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive() try: # Preprocess image processing_jobs[job_id]['progress'] = 5 image = preprocess_image(filepath) processing_jobs[job_id]['progress'] = 10 # Remove background processing_jobs[job_id]['progress'] = 20 clean_image = remove_background(image) processing_jobs[job_id]['progress'] = 30 # Load TripoSR model try: model, processor = load_model() processing_jobs[job_id]['progress'] = 40 except Exception as e: processing_jobs[job_id]['status'] = 'error' processing_jobs[job_id]['error'] = f"Error loading model: {str(e)}" return # Generate 3D model try: def generate_3d(): inputs = processor(images=clean_image, return_tensors="pt").to('cpu') with torch.no_grad(): outputs = model(**inputs) mesh = outputs.mesh # TripoSR outputs a trimesh object return mesh mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS) if error: if isinstance(error, TimeoutError): processing_jobs[job_id]['status'] = 'error' processing_jobs[job_id]['error'] = f"Processing timed out after {TIMEOUT_SECONDS} seconds" return else: raise error processing_jobs[job_id]['progress'] = 70 # Optimize mesh mesh = optimize_mesh(mesh, detail_level) processing_jobs[job_id]['progress'] = 80 except Exception as e: error_details = traceback.format_exc() processing_jobs[job_id]['status'] = 'error' processing_jobs[job_id]['error'] = f"Error during processing: {str(e)}" print(f"Error processing job {job_id}: {str(e)}") print(error_details) return # Export model try: if output_format == 'obj': obj_path = os.path.join(output_dir, "model.obj") mesh.export( obj_path, file_type='obj', include_normals=True, include_texture=True ) zip_path = os.path.join(output_dir, "model.zip") with zipfile.ZipFile(zip_path, 'w') as zipf: zipf.write(obj_path, arcname="model.obj") mtl_path = os.path.join(output_dir, "model.mtl") if os.path.exists(mtl_path): zipf.write(mtl_path, arcname="model.mtl") texture_path = os.path.join(output_dir, "model.png") if os.path.exists(texture_path): zipf.write(texture_path, arcname="model.png") processing_jobs[job_id]['result_url'] = f"/download/{job_id}" processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}" elif output_format == 'glb': glb_path = os.path.join(output_dir, "model.glb") mesh.export(glb_path, file_type='glb') processing_jobs[job_id]['result_url'] = f"/download/{job_id}" processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}" processing_jobs[job_id]['status'] = 'completed' processing_jobs[job_id]['progress'] = 100 print(f"Job {job_id} completed successfully") except Exception as e: error_details = traceback.format_exc() processing_jobs[job_id]['status'] = 'error' processing_jobs[job_id]['error'] = f"Error exporting model: {str(e)}" print(f"Error exporting model for job {job_id}: {str(e)}") print(error_details) if os.path.exists(filepath): os.remove(filepath) gc.collect() except Exception as e: error_details = traceback.format_exc() processing_jobs[job_id]['status'] = 'error' processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}" print(f"Error processing job {job_id}: {str(e)}") print(error_details) if os.path.exists(filepath): os.remove(filepath) processing_thread = threading.Thread(target=process_image) processing_thread.daemon = True processing_thread.start() return jsonify({"job_id": job_id}), 202 @app.route('/download/', methods=['GET']) def download_model(job_id): if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed': return jsonify({"error": "Model not found or processing not complete"}), 404 output_dir = os.path.join(RESULTS_FOLDER, job_id) output_format = processing_jobs[job_id].get('output_format', 'glb') if output_format == 'obj': zip_path = os.path.join(output_dir, "model.zip") if os.path.exists(zip_path): return send_file(zip_path, as_attachment=True, download_name="model.zip") else: glb_path = os.path.join(output_dir, "model.glb") if os.path.exists(glb_path): return send_file(glb_path, as_attachment=True, download_name="model.glb") return jsonify({"error": "File not found"}), 404 @app.route('/preview/', methods=['GET']) def preview_model(job_id): if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed': return jsonify({"error": "Model not found or processing not complete"}), 404 output_dir = os.path.join(RESULTS_FOLDER, job_id) output_format = processing_jobs[job_id].get('output_format', 'glb') if output_format == 'obj': obj_path = os.path.join(output_dir, "model.obj") if os.path.exists(obj_path): return send_file(obj_path, mimetype='model/obj') else: glb_path = os.path.join(output_dir, "model.glb") if os.path.exists(glb_path): return send_file(glb_path, mimetype='model/gltf-binary') return jsonify({"error": "Model file not found"}), 404 def cleanup_old_jobs(): current_time = time.time() job_ids_to_remove = [] for job_id, job_data in processing_jobs.items(): if job_data['status'] == 'completed' and (current_time - job_data.get('created_at', 0)) > 3600: job_ids_to_remove.append(job_id) elif job_data['status'] == 'error' and (current_time - job_data.get('created_at', 0)) > 1800: job_ids_to_remove.append(job_id) for job_id in job_ids_to_remove: output_dir = os.path.join(RESULTS_FOLDER, job_id) try: import shutil if os.path.exists(output_dir): shutil.rmtree(output_dir) except Exception as e: print(f"Error cleaning up job {job_id}: {str(e)}") if job_id in processing_jobs: del processing_jobs[job_id] threading.Timer(300, cleanup_old_jobs).start() @app.route('/model-info/', methods=['GET']) def model_info(job_id): if job_id not in processing_jobs: return jsonify({"error": "Model not found"}), 404 job = processing_jobs[job_id] if job['status'] != 'completed': return jsonify({ "status": job['status'], "progress": job['progress'], "error": job.get('error') }), 200 output_dir = os.path.join(RESULTS_FOLDER, job_id) model_stats = {} if job['output_format'] == 'obj': obj_path = os.path.join(output_dir, "model.obj") zip_path = os.path.join(output_dir, "model.zip") if os.path.exists(obj_path): model_stats['obj_size'] = os.path.getsize(obj_path) if os.path.exists(zip_path): model_stats['package_size'] = os.path.getsize(zip_path) else: glb_path = os.path.join(output_dir, "model.glb") if os.path.exists(glb_path): model_stats['model_size'] = os.path.getsize(glb_path) return jsonify({ "status": job['status'], "model_format": job['output_format'], "download_url": job['result_url'], "preview_url": job['preview_url'], "model_stats": model_stats, "created_at": job.get('created_at'), "completed_at": job.get('completed_at') }), 200 @app.route('/', methods=['GET']) def index(): return jsonify({ "message": "TripoSR Image to 3D API", "endpoints": [ "/convert", "/progress/", "/download/", "/preview/", "/model-info/" ], "parameters": { "output_format": "obj or glb", "detail_level": "low, medium, or high - controls mesh density", "texture_quality": "low, medium, or high - controls texture quality" }, "description": "Creates full 3D models from 2D images with background removal" }), 200 if __name__ == '__main__': cleanup_old_jobs() port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port)