rightnight / app.py
mac9087's picture
Update app.py
e4c93be verified
raw
history blame
20.8 kB
import os
import torch
import time
import threading
import json
import gc
from flask import Flask, request, jsonify, send_file, Response, stream_with_context
from werkzeug.utils import secure_filename
from PIL import Image
import io
import zipfile
import uuid
import traceback
from transformers import AutoImageProcessor, AutoModel
from huggingface_hub import snapshot_download
from flask_cors import CORS
import numpy as np
import trimesh
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Configure directories
UPLOAD_FOLDER = '/tmp/uploads'
RESULTS_FOLDER = '/tmp/results'
CACHE_DIR = '/tmp/huggingface'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
# Create necessary directories
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULTS_FOLDER, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
# Set Hugging Face cache environment variables
os.environ['HF_HOME'] = CACHE_DIR
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
# Job tracking dictionary
processing_jobs = {}
# Global model variables
image_processor = None
model = None
model_loaded = False
model_loading = False
# Configuration for processing
TIMEOUT_SECONDS = 180 # 3 minutes max for processing
MAX_DIMENSION = 512 # Max image dimension to process
# TimeoutError for handling timeouts
class TimeoutError(Exception):
pass
# Thread-safe timeout implementation
def process_with_timeout(function, args, timeout):
result = [None]
error = [None]
completed = [False]
def target():
try:
result[0] = function(*args)
completed[0] = True
except Exception as e:
error[0] = e
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
thread.join(timeout)
if not completed[0]:
if thread.is_alive():
return None, TimeoutError(f"Processing timed out after {timeout} seconds")
elif error[0]:
return None, error[0]
if error[0]:
return None, error[0]
return result[0], None
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
# Function to preprocess image
def preprocess_image(image_path):
with Image.open(image_path) as img:
img = img.convert("RGB")
# Resize if the image is too large
if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
# Calculate new dimensions while preserving aspect ratio
if img.width > img.height:
new_width = MAX_DIMENSION
new_height = int(img.height * (MAX_DIMENSION / img.width))
else:
new_height = MAX_DIMENSION
new_width = int(img.width * (MAX_DIMENSION / img.height))
img = img.resize((new_width, new_height), Image.LANCZOS)
return img
def load_model():
global image_processor, model, model_loaded, model_loading
if model_loaded:
return image_processor, model
if model_loading:
# Wait for model to load if it's already in progress
while model_loading and not model_loaded:
time.sleep(0.5)
return image_processor, model
try:
model_loading = True
print("Starting model loading...")
# Using a lightweight model: Pictorial 3D Scene Representation
model_name = "damo-vilab/text-to-3d-texture-base" # Smaller model than ShapE-img2img
# Download model with retry mechanism
max_retries = 3
retry_delay = 5
for attempt in range(max_retries):
try:
snapshot_download(
repo_id=model_name,
cache_dir=CACHE_DIR,
resume_download=True,
)
break
except Exception as e:
if attempt < max_retries - 1:
print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
retry_delay *= 2
else:
raise
# Initialize model with lower precision to save memory
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
image_processor = AutoImageProcessor.from_pretrained(model_name, cache_dir=CACHE_DIR)
model = AutoModel.from_pretrained(
model_name,
torch_dtype=dtype,
cache_dir=CACHE_DIR,
low_cpu_mem_usage=True,
)
model = model.to(device)
# Optimize for inference
if device == "cuda":
model = model.half() # Use half precision on GPU
model_loaded = True
print(f"Model loaded successfully on {device}")
return image_processor, model
except Exception as e:
print(f"Error loading model: {str(e)}")
print(traceback.format_exc())
raise
finally:
model_loading = False
# Convert model output to 3D mesh
def create_mesh_from_output(output, resolution=64):
"""Create a mesh from model output"""
# Extract features from model output and create mesh
# This is a simplified implementation - adapt based on your specific model
features = output.last_hidden_state.detach().cpu().numpy()[0]
# Create a simple cube mesh as placeholder - replace with actual mesh generation
vertices, faces = create_primitive_mesh(features, resolution)
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
return mesh
def create_primitive_mesh(features, resolution=64):
"""Create a simple primitive mesh based on features"""
# Create a mesh using features as modifiers
# This is a simplified implementation - adapt based on your specific model's output
# Create a cube/sphere mesh as a placeholder
use_sphere = True # Change to False for cube
if use_sphere:
# Create a sphere
u = np.linspace(0, 2 * np.pi, resolution)
v = np.linspace(0, np.pi, resolution)
# Base radius and modifiers
base_radius = 1.0
# Use some features to modify the radius (just as an example)
feature_sum = np.sum(features[:10]) # Use first 10 features
radius_mod = 0.5 + (feature_sum % 1.0) # Simple modifier between 0.5 and 1.5
# Create vertices
x = base_radius * radius_mod * np.outer(np.cos(u), np.sin(v))
y = base_radius * radius_mod * np.outer(np.sin(u), np.sin(v))
z = base_radius * radius_mod * np.outer(np.ones_like(u), np.cos(v))
# Reshape to get list of vertices
vertices = np.vstack([x.flatten(), y.flatten(), z.flatten()]).T
# Create faces (triangles)
faces = []
for i in range(resolution-1):
for j in range(resolution-1):
p1 = i * resolution + j
p2 = i * resolution + (j + 1)
p3 = (i + 1) * resolution + j
p4 = (i + 1) * resolution + (j + 1)
faces.append([p1, p2, p4])
faces.append([p1, p4, p3])
faces = np.array(faces)
else:
# Create a cube
vertices = np.array([
[-1, -1, -1], [1, -1, -1], [1, 1, -1], [-1, 1, -1],
[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]
])
# Apply some feature-based modifications
feature_sum = np.sum(features[:10]) # Use first 10 features
scale_factor = 0.5 + (feature_sum % 1.0) # Simple modifier between 0.5 and 1.5
vertices *= scale_factor
# Faces (triangles)
faces = np.array([
[0, 1, 2], [0, 2, 3], # Bottom face
[4, 5, 6], [4, 6, 7], # Top face
[0, 1, 5], [0, 5, 4], # Front face
[2, 3, 7], [2, 7, 6], # Back face
[0, 3, 7], [0, 7, 4], # Left face
[1, 2, 6], [1, 6, 5] # Right face
])
return vertices, faces
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
"status": "healthy",
"model": "Lightweight 3D Model Generator",
"device": "cuda" if torch.cuda.is_available() else "cpu"
}), 200
@app.route('/progress/<job_id>', methods=['GET'])
def progress(job_id):
def generate():
if job_id not in processing_jobs:
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
return
job = processing_jobs[job_id]
# Send initial progress
yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
# Wait for job to complete or update
last_progress = job['progress']
check_count = 0
while job['status'] == 'processing':
if job['progress'] != last_progress:
yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
last_progress = job['progress']
time.sleep(0.5)
check_count += 1
# If client hasn't received updates for a while, check if job is still running
if check_count > 60: # 30 seconds with no updates
if 'thread_alive' in job and not job['thread_alive']():
job['status'] = 'error'
job['error'] = 'Processing thread died unexpectedly'
break
check_count = 0
# Send final status
if job['status'] == 'completed':
yield f"data: {json.dumps({'status': 'completed', 'progress': 100, 'result_url': job['result_url'], 'preview_url': job['preview_url']})}\n\n"
else:
yield f"data: {json.dumps({'status': 'error', 'error': job['error']})}\n\n"
return Response(stream_with_context(generate()), mimetype='text/event-stream')
@app.route('/convert', methods=['POST'])
def convert_image_to_3d():
# Check if image is in the request
if 'image' not in request.files:
return jsonify({"error": "No image provided"}), 400
file = request.files['image']
if file.filename == '':
return jsonify({"error": "No image selected"}), 400
if not allowed_file(file.filename):
return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
# Get optional parameters with defaults
try:
guidance_scale = float(request.form.get('guidance_scale', 3.0))
output_format = request.form.get('output_format', 'obj').lower()
except ValueError:
return jsonify({"error": "Invalid parameter values"}), 400
# Validate parameters
if guidance_scale < 1.0 or guidance_scale > 5.0:
return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
# Validate output format
if output_format not in ['obj', 'glb']:
return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
# Create a job ID
job_id = str(uuid.uuid4())
output_dir = os.path.join(RESULTS_FOLDER, job_id)
os.makedirs(output_dir, exist_ok=True)
# Save the uploaded file
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}")
file.save(filepath)
# Initialize job tracking
processing_jobs[job_id] = {
'status': 'processing',
'progress': 0,
'result_url': None,
'preview_url': None,
'error': None,
'output_format': output_format,
'created_at': time.time()
}
# Start processing in a separate thread
def process_image():
thread = threading.current_thread()
processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
try:
# Preprocess image
processing_jobs[job_id]['progress'] = 5
image = preprocess_image(filepath)
processing_jobs[job_id]['progress'] = 10
# Load model
try:
processor, model_instance = load_model()
processing_jobs[job_id]['progress'] = 30
except Exception as e:
processing_jobs[job_id]['status'] = 'error'
processing_jobs[job_id]['error'] = f"Error loading model: {str(e)}"
return
# Process image with thread-safe timeout
try:
def generate_3d():
# Process the image
device = model_instance.device
inputs = processor(images=image, return_tensors="pt").to(device)
# Forward pass through model
with torch.no_grad():
outputs = model_instance(**inputs)
return outputs
outputs, error = process_with_timeout(generate_3d, [], TIMEOUT_SECONDS)
if error:
if isinstance(error, TimeoutError):
processing_jobs[job_id]['status'] = 'error'
processing_jobs[job_id]['error'] = f"Processing timed out after {TIMEOUT_SECONDS} seconds"
return
else:
raise error
processing_jobs[job_id]['progress'] = 80
# Create mesh from outputs
mesh = create_mesh_from_output(outputs)
except Exception as e:
error_details = traceback.format_exc()
processing_jobs[job_id]['status'] = 'error'
processing_jobs[job_id]['error'] = f"Error during processing: {str(e)}"
print(f"Error processing job {job_id}: {str(e)}")
print(error_details)
return
# Export based on requested format
try:
if output_format == 'obj':
obj_path = os.path.join(output_dir, "model.obj")
mesh.export(obj_path, file_type='obj')
# Create a zip file with OBJ and MTL
zip_path = os.path.join(output_dir, "model.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
zipf.write(obj_path, arcname="model.obj")
mtl_path = os.path.join(output_dir, "model.mtl")
if os.path.exists(mtl_path):
zipf.write(mtl_path, arcname="model.mtl")
processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
elif output_format == 'glb':
# Export as GLB
glb_path = os.path.join(output_dir, "model.glb")
mesh.export(glb_path, file_type='glb')
processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
processing_jobs[job_id]['preview_url'] = f"/preview/{job_id}"
# Update job status
processing_jobs[job_id]['status'] = 'completed'
processing_jobs[job_id]['progress'] = 100
print(f"Job {job_id} completed successfully")
except Exception as e:
error_details = traceback.format_exc()
processing_jobs[job_id]['status'] = 'error'
processing_jobs[job_id]['error'] = f"Error exporting model: {str(e)}"
print(f"Error exporting model for job {job_id}: {str(e)}")
print(error_details)
# Clean up temporary file
if os.path.exists(filepath):
os.remove(filepath)
# Force garbage collection to free memory
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
# Handle errors
error_details = traceback.format_exc()
processing_jobs[job_id]['status'] = 'error'
processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}"
print(f"Error processing job {job_id}: {str(e)}")
print(error_details)
# Clean up on error
if os.path.exists(filepath):
os.remove(filepath)
# Start processing thread
processing_thread = threading.Thread(target=process_image)
processing_thread.daemon = True
processing_thread.start()
# Return job ID immediately
return jsonify({"job_id": job_id}), 202 # 202 Accepted
@app.route('/download/<job_id>', methods=['GET'])
def download_model(job_id):
if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
return jsonify({"error": "Model not found or processing not complete"}), 404
# Get the output directory for this job
output_dir = os.path.join(RESULTS_FOLDER, job_id)
# Determine file format from the job data
output_format = processing_jobs[job_id].get('output_format', 'obj')
if output_format == 'obj':
zip_path = os.path.join(output_dir, "model.zip")
if os.path.exists(zip_path):
return send_file(zip_path, as_attachment=True, download_name="model.zip")
else: # glb
glb_path = os.path.join(output_dir, "model.glb")
if os.path.exists(glb_path):
return send_file(glb_path, as_attachment=True, download_name="model.glb")
return jsonify({"error": "File not found"}), 404
@app.route('/preview/<job_id>', methods=['GET'])
def preview_model(job_id):
if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
return jsonify({"error": "Model not found or processing not complete"}), 404
# Get the output directory for this job
output_dir = os.path.join(RESULTS_FOLDER, job_id)
output_format = processing_jobs[job_id].get('output_format', 'obj')
if output_format == 'obj':
obj_path = os.path.join(output_dir, "model.obj")
if os.path.exists(obj_path):
return send_file(obj_path, mimetype='model/obj')
else: # glb
glb_path = os.path.join(output_dir, "model.glb")
if os.path.exists(glb_path):
return send_file(glb_path, mimetype='model/gltf-binary')
return jsonify({"error": "Model file not found"}), 404
# Cleanup old jobs periodically
def cleanup_old_jobs():
current_time = time.time()
job_ids_to_remove = []
for job_id, job_data in processing_jobs.items():
# Remove completed jobs after 1 hour
if job_data['status'] == 'completed' and (current_time - job_data.get('created_at', 0)) > 3600:
job_ids_to_remove.append(job_id)
# Remove error jobs after 30 minutes
elif job_data['status'] == 'error' and (current_time - job_data.get('created_at', 0)) > 1800:
job_ids_to_remove.append(job_id)
# Remove the jobs
for job_id in job_ids_to_remove:
output_dir = os.path.join(RESULTS_FOLDER, job_id)
try:
import shutil
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
except Exception as e:
print(f"Error cleaning up job {job_id}: {str(e)}")
# Remove from tracking dictionary
if job_id in processing_jobs:
del processing_jobs[job_id]
# Schedule the next cleanup
threading.Timer(300, cleanup_old_jobs).start() # Run every 5 minutes
@app.route('/', methods=['GET'])
def index():
return jsonify({
"message": "Image to 3D API is running",
"endpoints": ["/convert", "/progress/<job_id>", "/download/<job_id>", "/preview/<job_id>"]
}), 200
if __name__ == '__main__':
# Start the cleanup thread
cleanup_old_jobs()
# Use port 7860 which is standard for Hugging Face Spaces
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)