rightnight / app.py
mac9087's picture
Update app.py
de67259 verified
raw
history blame
17.6 kB
import os
import torch
import time
import threading
import json
import gc
from flask import Flask, request, jsonify, send_file, Response, stream_with_context
from werkzeug.utils import secure_filename
from PIL import Image
import io
import uuid
import traceback
from huggingface_hub import snapshot_download
from flask_cors import CORS
import numpy as np
import trimesh
from transformers import pipeline
from scipy.ndimage import gaussian_filter
import open3d as o3d
from rembg import remove
import cv2
# Force CPU usage
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.set_default_device("cpu")
# Patch PyTorch to disable CUDA initialization
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0
app = Flask(__name__)
CORS(app)
# Configure directories
UPLOAD_FOLDER = '/tmp/uploads'
RESULTS_FOLDER = '/tmp/results'
CACHE_DIR = '/tmp/huggingface'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
# Create necessary directories
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULTS_FOLDER, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
# Set Hugging Face cache environment variables
os.environ['HF_HOME'] = CACHE_DIR
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
# Job tracking dictionary
processing_jobs = {}
# Global model variables
depth_pipeline = None
model_loaded = False
model_loading = False
# Configuration for processing
TIMEOUT_SECONDS = 240 # 4 minutes max for Depth-Anything on CPU
MAX_DIMENSION = 512 # Depth-Anything expects 512x512
# TimeoutError for handling timeouts
class TimeoutError(Exception):
pass
# Thread-safe timeout implementation
def process_with_timeout(function, args, timeout):
result = [None]
error = [None]
completed = [False]
def target():
try:
result[0] = function(*args)
completed[0] = True
except Exception as e:
error[0] = e
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
thread.join(timeout)
if not completed[0]:
if thread.is_alive():
return None, TimeoutError(f"Processing timed out after {timeout} seconds")
elif error[0]:
return None, error[0]
if error[0]:
return None, error[0]
return result[0], None
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
# Image preprocessing: Remove background and resize
def preprocess_image(image_path):
try:
# Load image
with Image.open(image_path) as img:
# Remove background using rembg
img_no_bg = remove(img)
# Convert to RGB if it has an alpha channel
if img_no_bg.mode == 'RGBA':
img_no_bg = img_no_bg.convert('RGB')
# Resize to 512x512
img_no_bg = img_no_bg.resize((512, 512), Image.LANCZOS)
# Optional: Use cv2 for additional segmentation (e.g., refine mask)
img_array = np.array(img_no_bg)
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
_, mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
img_array = cv2.bitwise_and(img_array, img_array, mask=mask)
return Image.fromarray(img_array)
except Exception as e:
raise Exception(f"Error preprocessing image: {str(e)}")
def load_model():
global depth_pipeline, model_loaded, model_loading
if model_loaded:
return depth_pipeline
if model_loading:
while model_loading and not model_loaded:
time.sleep(0.5)
return depth_pipeline
try:
model_loading = True
print("Starting model loading...")
model_name = "LiheYoung/depth-anything-small-hf"
# Download model with retry mechanism
max_retries = 3
retry_delay = 5
for attempt in range(max_retries):
try:
snapshot_download(
repo_id=model_name,
cache_dir=CACHE_DIR,
resume_download=True,
)
break
except Exception as e:
if attempt < max_retries - 1:
print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
retry_delay *= 2
else:
raise
# Load Depth-Anything pipeline
depth_pipeline = pipeline(
"depth-estimation",
model=model_name,
cache_dir=CACHE_DIR,
device=-1, # Force CPU
torch_dtype=torch.float32,
)
model_loaded = True
print("Model loaded successfully on CPU")
return depth_pipeline
except Exception as e:
print(f"Error loading model: {str(e)}")
print(traceback.format_exc())
raise
finally:
model_loading = False
def depth_to_point_cloud(depth_map, image, detail_level):
# Parameters based on detail level
downsample_factors = {'low': 4, 'medium': 2, 'high': 1}
downsample = downsample_factors[detail_level]
# Convert image and depth to numpy
img_array = np.array(image)
depth_array = np.array(depth_map)
# Downsample for performance
if downsample > 1:
depth_array = depth_array[::downsample, ::downsample]
img_array = img_array[::downsample, ::downsample]
# Normalize depth
depth_array = gaussian_filter(depth_array, sigma=1)
depth_array = (depth_array - depth_array.min()) / (depth_array.max() - depth_array.min() + 1e-8)
# Create point cloud
h, w = depth_array.shape
x, y = np.meshgrid(np.arange(w), np.arange(h))
# Camera intrinsics (assumed focal length)
fx = fy = w * 0.5
cx, cy = w / 2, h / 2
# Convert to 3D coordinates (Z-up for Unity)
z = depth_array
x = (x - cx) * z / fx
y = -(y - cy) * z / fy # Flip y-axis to correct orientation
points = np.stack([x, y, z], axis=-1).reshape(-1, 3)
colors = img_array.reshape(-1, 3) / 255.0
# Filter out invalid points (tighter range for foreground)
mask = (z.reshape(-1) > 0.2) & (z.reshape(-1) < 0.8)
points = points[mask]
colors = colors[mask]
# Create Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(colors)
# Estimate normals
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))
# Poisson surface reconstruction
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
pcd, depth=8 if detail_level == 'high' else 6
)
# Convert to trimesh
vertices = np.asarray(mesh.vertices)
faces = np.asarray(mesh.triangles)
vertex_colors = np.asarray(mesh.vertex_colors)
trimesh_mesh = trimesh.Trimesh(
vertices=vertices,
faces=faces,
vertex_colors=vertex_colors
)
# Rotate mesh to correct orientation (180 degrees around X-axis)
trimesh_mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]))
return trimesh_mesh
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
"status": "healthy",
"model": "Depth-Anything",
"device": "cpu"
}), 200
@app.route('/progress/<job_id>', methods=['GET'])
def progress(job_id):
def generate():
if job_id not in processing_jobs:
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
return
job = processing_jobs[job_id]
yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
last_progress = job['progress']
check_count = 0
while job['status'] == 'processing':
if job['progress'] != last_progress:
yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
last_progress = job['progress']
time.sleep(0.5)
check_count += 1
if check_count > 60:
if 'thread_alive' in job and not job['thread_alive']():
job['status'] = 'error'
job['error'] = 'Processing thread died unexpectedly'
break
check_count = 0
if job['status'] == 'completed':
yield f"data: {json.dumps({'status': 'completed', 'progress': 100, 'result_url': job['result_url'], 'preview_url': job['preview_url']})}\n\n"
else:
yield f"data: {json.dumps({'status': 'error', 'error': job['error']})}\n\n"
return Response(stream_with_context(generate()), mimetype='text/event-stream')
@app.route('/convert', methods=['POST'])
def convert_image_to_3d():
if 'image' not in request.files:
return jsonify({"error": "No image provided"}), 400
file = request.files['image']
if file.filename == '':
return jsonify({"error": "No image selected"}), 400
if not allowed_file(file.filename):
return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
try:
output_format = request.form.get('output_format', 'glb').lower()
detail_level = request.form.get('detail_level', 'medium').lower()
except ValueError:
return jsonify({"error": "Invalid parameter values"}), 400
if output_format not in ['glb']:
return jsonify({"error": "Only GLB format is supported"}), 400
job_id = str(uuid.uuid4())
output_dir = os.path.join(RESULTS_FOLDER, job_id)
os.makedirs(output_dir, exist_ok=True)
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}")
file.save(filepath)
processing_jobs[job_id] =*.
{
'status': 'processing',
'progress': 0,
'result_url': None,
'preview_url': None,
'error': None,
'output_format': output_format,
'created_at': time.time()
}
def process_image():
thread = threading.current_thread()
processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
try:
processing_jobs[job_id]['progress'] = 5
image = preprocess_image(filepath)
processing_jobs[job_id]['progress'] = 10
try:
pipeline = load_model()
processing_jobs[job_id]['progress'] = 30
except Exception as e:
processing_jobs[job_id]['status'] = 'error'
processing_jobs[job_id]['error'] = f"Error loading model: {str(e)}"
return
try:
def generate_3d():
# Generate depth map
with torch.no_grad():
depth_output = pipeline(image)
depth_map = depth_output["depth"]
# Convert depth to mesh
mesh = depth_to_point_cloud(depth_map, image, detail_level)
return mesh
mesh, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
if error:
if isinstance(error, TimeoutError):
processing_jobs[job_id]['status'] = 'error'
processing_jobs[job_id]['error'] = f"Processing timed out after {TIMEOUT_SECONDS} seconds"
return
else:
raise error
processing_jobs[job_id]['progress'] = 80
# Export as GLB
glb_path = os.path.join(output_dir, "model.glb")
mesh.export(glb_path, file_type='glb')
processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
processing_jobs[job_id]['status'] = 'completed'
processing_jobs[job_id]['progress'] = 100
print(f"Job {job_id} completed successfully")
except Exception as e:
error_details = traceback.format_exc()
processing_jobs[job_id]['status'] = 'error'
processing_jobs[job_id]['error'] = f"Error during processing: {str(e)}"
print(f"Error processing job {job_id}: {str(e)}")
print(error_details)
return
if os.path.exists(filepath):
os.remove(filepath)
gc.collect()
except Exception as e:
error_details = traceback.format_exc()
processing_jobs[job_id]['status'] = 'error'
processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}"
print(f"Error processing job {job_id}: {str(e)}")
print(error_details)
if os.path.exists(filepath):
os.remove(filepath)
processing_thread = threading.Thread(target=process_image)
processing_thread.daemon = True
processing_thread.start()
return jsonify({"job_id": job_id}), 202
@app.route('/download/<job_id>', methods=['GET'])
def download_model(job_id):
if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
return jsonify({"error": "Model not found or processing not complete"}), 404
output_dir = os.path.join(RESULTS_FOLDER, job_id)
glb_path = os.path.join(output_dir, "model.glb")
if os.path.exists(glb_path):
return send_file(glb_path, as_attachment=True, download_name="model.glb")
return jsonify({"error": "File not found"}), 404
@app.route('/preview/<job_id>', methods=['GET'])
def preview_model(job_id):
if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
return jsonify({"error": "Model not found or processing not complete"}), 404
output_dir = os.path.join(RESULTS_FOLDER, job_id)
glb_path = os.path.join(output_dir, "model.glb")
if os.path.exists(glb_path):
return send_file(glb_path, mimetype='model/gltf-binary')
return jsonify({"error": "Model file not found"}), 404
def cleanup_old_jobs():
current_time = time.time()
job_ids_to_remove = []
for job_id, job_data in processing_jobs.items():
if job_data['status'] == 'completed' and (current_time - job_data.get('created_at', 0)) > 3600:
job_ids_to_remove.append(job_id)
elif job_data['status'] == 'error' and (current_time - job_data.get('created_at', 0)) > 1800:
job_ids_to_remove.append(job_id)
for job_id in job_ids_to_remove:
output_dir = os.path.join(RESULTS_FOLDER, job_id)
try:
import shutil
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
except Exception as e:
print(f"Error cleaning up job {job_id}: {str(e)}")
if job_id in processing_jobs:
del processing_jobs[job_id]
threading.Timer(300, cleanup_old_jobs).start()
@app.route('/model-info/<job_id>', methods=['GET'])
def model_info(job_id):
if job_id not in processing_jobs:
return jsonify({"error": "Model not found"}), 404
job = processing_jobs[job_id]
if job['status'] != 'completed':
return jsonify({
"status": job['status'],
"progress": job['progress'],
"error": job.get('error')
}), 200
output_dir = os.path.join(RESULTS_FOLDER, job_id)
model_stats = {}
glb_path = os.path.join(output_dir, "model.glb")
if os.path.exists(glb_path):
model_stats['model_size'] = os.path.getsize(glb_path)
return jsonify({
"status": job['status'],
"model_format": job['output_format'],
"download_url": job['result_url'],
"preview_url": job['preview_url'],
"model_stats": model_stats,
"created_at": job.get('created_at'),
"completed_at": job.get('completed_at')
}), 200
@app.route('/', methods=['GET'])
def index():
return jsonify({
"message": "Image to 3D API (Depth-Anything)",
"endpoints": [
"/convert",
"/progress/<job_id>",
"/download/<job_id>",
"/preview/<job_id>",
"/model-info/<job_id>"
],
"parameters": {
"output_format": "glb",
"detail_level": "low, medium, or high - controls point cloud density"
},
"description": "This API creates 3D models from 2D images using Depth-Anything depth estimation. Images should have transparent backgrounds for best results."
}), 200
if __name__ == '__main__':
cleanup_old_jobs()
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)