Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import time | |
import threading | |
import json | |
import gc | |
from flask import Flask, request, jsonify, send_file, Response, stream_with_context | |
from werkzeug.utils import secure_filename | |
from PIL import Image | |
import io | |
import uuid | |
import traceback | |
from huggingface_hub import snapshot_download | |
from flask_cors import CORS | |
import numpy as np | |
import trimesh | |
from tsr.system import TripoSR | |
from tsr.utils import remove_background, resize_foreground | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
torch.set_default_device("cpu") | |
torch.cuda.is_available = lambda: False | |
torch.cuda.device_count = lambda: 0 | |
app = Flask(__name__) | |
CORS(app) | |
UPLOAD_FOLDER = '/tmp/uploads' | |
RESULTS_FOLDER = '/tmp/results' | |
CACHE_DIR = '/tmp/huggingface' | |
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(RESULTS_FOLDER, exist_ok=True) | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
os.environ['HF_HOME'] = CACHE_DIR | |
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers') | |
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets') | |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 | |
processing_jobs = {} | |
triposr_model = None | |
model_loaded = False | |
model_loading = False | |
TIMEOUT_SECONDS = 300 | |
MAX_DIMENSION = 512 # TripoSR uses 512x512 inputs | |
class TimeoutError(Exception): | |
pass | |
def process_with_timeout(function, args, timeout): | |
result = [None] | |
error = [None] | |
completed = [False] | |
def target(): | |
try: | |
result[0] = function(*args) | |
completed[0] = True | |
except Exception as e: | |
error[0] = e | |
thread = threading.Thread(target=target) | |
thread.daemon = True | |
thread.start() | |
thread.join(timeout) | |
if not completed[0]: | |
if thread.is_alive(): | |
return None, TimeoutError(f"Processing timed out after {timeout} seconds") | |
elif error[0]: | |
return None, error[0] | |
if error[0]: | |
return None, error[0] | |
return result[0], None | |
def allowed_file(filename): | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
def preprocess_image(image_path): | |
try: | |
with Image.open(image_path) as img: | |
if img.mode == 'RGBA': | |
img = img.convert('RGB') | |
img = img.resize((512, 512), Image.LANCZOS) | |
img_array = np.array(img) / 255.0 | |
img_array = remove_background(img_array) | |
img_array = resize_foreground(img_array, 0.85) | |
img_array = np.clip(img_array, 0, 1) * 255 | |
return Image.fromarray(img_array.astype(np.uint8)) | |
except Exception as e: | |
raise Exception(f"Error preprocessing image: {str(e)}") | |
def load_model(): | |
global triposr_model, model_loaded, model_loading | |
if model_loaded: | |
return triposr_model | |
if model_loading: | |
while model_loading and not model_loaded: | |
time.sleep(0.5) | |
return triposr_model | |
try: | |
model_loading = True | |
print("Loading TripoSR...") | |
model_name = "tripo3d/triposr" | |
max_retries = 3 | |
retry_delay = 5 | |
for attempt in range(max_retries): | |
try: | |
snapshot_download( | |
repo_id=model_name, | |
cache_dir=CACHE_DIR, | |
resume_download=True, | |
) | |
break | |
except Exception as e: | |
if attempt < max_retries - 1: | |
print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying...") | |
time.sleep(retry_delay) | |
retry_delay *= 2 | |
else: | |
raise | |
triposr_model = TripoSR.from_pretrained( | |
model_name, | |
cache_dir=CACHE_DIR, | |
device="cpu", | |
) | |
model_loaded = True | |
print("TripoSR loaded successfully on CPU") | |
return triposr_model | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
print(traceback.format_exc()) | |
raise | |
finally: | |
model_loading = False | |
def generate_3d_model(image, detail_level): | |
try: | |
chunk_size = {'low': 4096, 'medium': 8192, 'high': 16384} | |
chunk = chunk_size[detail_level] | |
with torch.no_grad(): | |
scene_codes = triposr_model(image, device="cpu") | |
meshes = triposr_model.mesher(scene_codes, chunk_size=chunk) | |
mesh = meshes[0] | |
vertices = np.array(mesh.vertices) | |
faces = np.array(mesh.faces) | |
vertex_colors = np.array(mesh.vertex_colors) if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None | |
trimesh_mesh = trimesh.Trimesh( | |
vertices=vertices, | |
faces=faces, | |
vertex_colors=vertex_colors | |
) | |
trimesh_mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0])) | |
return trimesh_mesh | |
except Exception as e: | |
raise Exception(f"Error generating 3D model: {str(e)}") | |
def health_check(): | |
return jsonify({ | |
"status": "healthy", | |
"model": "TripoSR", | |
"device": "cpu" | |
}), 200 | |
def progress(job_id): | |
def generate(): | |
if job_id not in processing_jobs: | |
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n" | |
return | |
job = processing_jobs[job_id] | |
yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n" | |
last_progress = job['progress'] | |
check_count = 0 | |
while job['status'] == 'processing': | |
if job['progress'] != last_progress: | |
yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n" | |
last_progress = job['progress'] | |
time.sleep(0.5) | |
check_count += 1 | |
if check_count > 60: | |
if 'thread_alive' in job and not job['thread_alive'](): | |
job['status'] = 'error' | |
job['error'] = 'Processing thread died unexpectedly' | |
break | |
check_count = 0 | |
if job['status'] == 'completed': | |
yield f"data: {json.dumps({'status': 'completed', 'progress': 100, 'result_url': job['result_url'], 'preview_url': job['preview_url']})}\n\n" | |
else: | |
yield f"data: {json.dumps({'status': 'error', 'error': job['error']})}\n\n" | |
return Response(stream_with_context(generate()), mimetype='text/event-stream') | |
def convert_image_to_3d(): | |
if 'image' not in request.files: | |
return jsonify({"error": "No image provided"}), 400 | |
file = request.files['image'] | |
if file.filename == '': | |
return jsonify({"error": "No image selected"}), 400 | |
if not allowed_file(file.filename): | |
return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400 | |
try: | |
output_format = request.form.get('output_format', 'glb').lower() | |
detail_level = request.form.get('detail_level', 'medium').lower() | |
except ValueError: | |
return jsonify({"error": "Invalid parameter values"}), 400 | |
if output_format not in ['glb', 'obj']: | |
return jsonify({"error": "Supported formats: glb, obj"}), 400 | |
job_id = str(uuid.uuid4()) | |
output_dir = os.path.join(RESULTS_FOLDER, job_id) | |
os.makedirs(output_dir, exist_ok=True) | |
filename = secure_filename(file.filename) | |
filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}") | |
file.save(filepath) | |
processing_jobs[job_id] = { | |
'status': 'processing', | |
'progress': 0, | |
'result_url': None, | |
'preview_url': None, | |
'error': None, | |
'output_format': output_format, | |
'created_at': time.time() | |
} | |
def process_image(): | |
thread = threading.current_thread() | |
processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive() | |
try: | |
processing_jobs[job_id]['progress'] = 5 | |
image = preprocess_image(filepath) | |
processing_jobs[job_id]['progress'] = 10 | |
try: | |
model = load_model() | |
processing_jobs[job_id]['progress'] = 30 | |
except Exception as e: | |
processing_jobs[job_id]['status'] = 'error' | |
processing_jobs[job_id]['error'] = f"Error loading model: {str(e)}" | |
return | |
try: | |
def generate_3d(): | |
return generate_3d_model(image, detail_level) | |
mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS) | |
if error: | |
if isinstance(error, TimeoutError): | |
processing_jobs[job_id]['status'] = 'error' | |
processing_jobs[job_id]['error'] = f"Processing timed out after {TIMEOUT_SECONDS} seconds" | |
return | |
else: | |
raise error | |
processing_jobs[job_id]['progress'] = 80 | |
file_path = os.path.join(output_dir, f"model.{output_format}") | |
mesh.export(file_path, file_type=output_format) | |
processing_jobs[job_id]['result_url'] = f"/download/{job_id.ConcurrentHashMap}" | |
processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}" | |
processing_jobs[job_id]['status'] = 'completed' | |
processing_jobs[job_id]['progress'] = 100 | |
print(f"Job {job_id} completed") | |
except Exception as e: | |
error_details = traceback.format_exc() | |
processing_jobs[job_id]['status'] = 'error' | |
processing_jobs[job_id]['error'] = f"Error during processing: {str(e)}" | |
print(f"Error processing job {job_id}: {str(e)}") | |
print(error_details) | |
return | |
if os.path.exists(filepath): | |
os.remove(filepath) | |
gc.collect() | |
except Exception as e: | |
error_details = traceback.format_exc() | |
processing_jobs[job_id]['status'] = 'error' | |
processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}" | |
print(f"Error processing job {job_id}: {str(e)}") | |
print(error_details) | |
if os.path.exists(filepath): | |
os.remove(filepath) | |
processing_thread = threading.Thread(target=process_image) | |
processing_thread.daemon = True | |
processing_thread.start() | |
return jsonify({"job_id": job_id}), 202 | |
def download_model(job_id): | |
if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed': | |
return jsonify({"error": "Model not found or processing not complete"}), 404 | |
output_dir = os.path.join(RESULTS_FOLDER, job_id) | |
output_format = processing_jobs[job_id]['output_format'] | |
file_path = os.path.join(output_dir, f"model.{output_format}") | |
if os.path.exists(file_path): | |
return send_file(file_path, as_attachment=True, download_name=f"model.{output_format}") | |
return jsonify({"error": "File not found"}), 404 | |
def preview_model(job_id): | |
if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed': | |
return jsonify({"error": "Model not found or processing not complete"}), 404 | |
output_dir = os.path.join(RESULTS_FOLDER, job_id) | |
output_format = processing_jobs[job_id]['output_format'] | |
file_path = os.path.join(output_dir, f"model.{output_format}") | |
if os.path.exists(file_path): | |
if output_format == 'glb': | |
return send_file(file_path, mimetype='model/gltf-binary') | |
else: | |
return send_file(file_path, mimetype='text/plain') | |
return jsonify({"error": "File not found"}), 404 | |
def cleanup_old_jobs(): | |
current_time = time.time() | |
job_ids_to_remove = [] | |
for job_id, job_data in processing_jobs.items(): | |
if job_data['status'] == 'completed' and (current_time - job_data.get('created_at', 0)) > 3600: | |
job_ids_to_remove.append(job_id) | |
elif job_data['status'] == 'error' and (current_time - job_data.get('created_at', 0)) > 1800: | |
job_ids_to_remove.append(job_id) | |
for job_id in job_ids_to_remove: | |
output_dir = os.path.join(RESULTS_FOLDER, job_id) | |
try: | |
import shutil | |
if os.path.exists(output_dir): | |
shutil.rmtree(output_dir) | |
except Exception as e: | |
print(f"Error cleaning up job {job_id}: {str(e)}") | |
if job_id in processing_jobs: | |
del processing_jobs[job_id] | |
threading.Timer(300, cleanup_old_jobs).start() | |
def model_info(job_id): | |
if job_id not in processing_jobs: | |
return jsonify({"error": "Model not found"}), 404 | |
job = processing_jobs[job_id] | |
if job['status'] != 'completed': | |
return jsonify({ | |
"status": job['status'], | |
"progress": job['progress'], | |
"error": job.get('error') | |
}), 200 | |
output_dir = os.path.join(RESULTS_FOLDER, job_id) | |
output_format = job['output_format'] | |
model_stats = {} | |
file_path = os.path.join(output_dir, f"model.{output_format}") | |
if os.path.exists(file_path): | |
model_stats['model_size'] = os.path.getsize(file_path) | |
return jsonify({ | |
"status": job['status'], | |
"model_format": output_format, | |
"download_url": job['result_url'], | |
"preview_url": job['preview_url'], | |
"model_stats": model_stats, | |
"created_at": job.get('created_at'), | |
"completed_at": job.get('completed_at') | |
}), 200 | |
def index(): | |
return jsonify({ | |
"message": "Image to 3D API (TripoSR)", | |
"endpoints": [ | |
"/convert", | |
"/progress/<job_id>", | |
"/download/<job_id>", | |
"/preview/<job_id>", | |
"/model-info/<job_id>" | |
], | |
"parameters": { | |
"output_format": "glb or obj", | |
"detail_level": "low, medium, or high" | |
}, | |
"description": "Creates 3D models from 2D images using TripoSR." | |
}), 200 | |
if __name__ == '__main__': | |
cleanup_old_jobs() | |
port = int(os.environ.get('PORT', 7860)) | |
app.run(host='0.0.0.0', port=port) | |