rightnight / app.py
mac9087's picture
Update app.py
cb8f4d2 verified
raw
history blame
6.64 kB
import os
import torch
import time
import threading
import json
import gc
from flask import Flask, request, jsonify, send_file, Response, stream_with_context
from werkzeug.utils import secure_filename
from PIL import Image
import io
import zipfile
import uuid
import traceback
from huggingface_hub import snapshot_download
from flask_cors import CORS
import numpy as np
import trimesh
from transformers import pipeline
from diffusers import StableDiffusionZero123Pipeline
import imageio
from scipy.spatial.transform import Rotation
app = Flask(__name__)
CORS(app)
# Configuration
UPLOAD_FOLDER = '/tmp/uploads'
RESULTS_FOLDER = '/tmp/results'
CACHE_DIR = '/tmp/huggingface'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
VIEW_ANGLES = [(30, 0), (30, 90), (30, 180), (30, 270)] # (elevation, azimuth)
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULTS_FOLDER, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
# Environment variables for caching
os.environ['HF_HOME'] = CACHE_DIR
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
# Global models
view_generator = None
depth_estimator = None
model_loaded = False
model_loading = False
processing_jobs = {}
class TimeoutError(Exception):
pass
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def preprocess_image(image_path, size=256):
img = Image.open(image_path).convert("RGB")
img = img.resize((size, size), Image.LANCZOS)
return img
def load_models():
global view_generator, depth_estimator, model_loaded
if model_loaded:
return
try:
# Load view generator
view_generator = StableDiffusionZero123Pipeline.from_pretrained(
"stabilityai/stable-zero123-6dof",
torch_dtype=torch.float16,
cache_dir=CACHE_DIR
).to("cuda" if torch.cuda.is_available() else "cpu")
# Load depth estimator
depth_estimator = pipeline(
"depth-estimation",
model="Intel/dpt-hybrid-midas",
cache_dir=CACHE_DIR
)
model_loaded = True
print("Models loaded successfully")
except Exception as e:
print(f"Error loading models: {str(e)}")
raise
def generate_novel_views(image, num_views=4):
views = []
for elevation, azimuth in VIEW_ANGLES:
result = view_generator(
image,
num_inference_steps=50,
elevation=elevation,
azimuth=azimuth,
guidance_scale=3.0
).images[0]
views.append((result, (elevation, azimuth)))
return views
def depth_to_pointcloud(depth_map, pose, fov=60):
h, w = depth_map.shape
f = w / (2 * np.tan(np.radians(fov/2)))
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
x = (xx - w/2) * depth_map / f
y = (yy - h/2) * depth_map / f
z = depth_map
points = np.vstack((x.flatten(), y.flatten(), z.flatten())).T
# Apply pose transformation
rot = Rotation.from_euler('zyx', [pose[1], pose[0], 0], degrees=True)
points = rot.apply(points)
return points
def create_mesh_from_pointcloud(points, image):
pcd = trimesh.PointCloud(points)
scene = pcd.scene()
mesh = scene.delaunay_3d.triangulate_pcd(pcd)
mesh.visual.vertex_colors = image.resize((mesh.vertices.shape[0], 3))
return mesh
@app.route('/convert', methods=['POST'])
def convert_image_to_3d():
if 'image' not in request.files:
return jsonify({"error": "No image provided"}), 400
file = request.files['image']
if not allowed_file(file.filename):
return jsonify({"error": "Invalid file type"}), 400
job_id = str(uuid.uuid4())
output_dir = os.path.join(RESULTS_FOLDER, job_id)
os.makedirs(output_dir, exist_ok=True)
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}")
file.save(filepath)
processing_jobs[job_id] = {
'status': 'processing',
'progress': 0,
'result_url': None,
'error': None
}
def process_image():
try:
# Preprocess input image
img = preprocess_image(filepath)
processing_jobs[job_id]['progress'] = 20
# Generate novel views
views = generate_novel_views(img)
processing_jobs[job_id]['progress'] = 40
# Process each view
all_points = []
for view_img, pose in views:
# Estimate depth
depth_result = depth_estimator(view_img)
depth_map = np.array(depth_result["depth"])
# Convert to point cloud
points = depth_to_pointcloud(depth_map, pose)
all_points.append(points)
processing_jobs[job_id]['progress'] += 10
# Combine point clouds
combined_points = np.vstack(all_points)
processing_jobs[job_id]['progress'] = 80
# Create mesh
mesh = create_mesh_from_pointcloud(combined_points, img)
# Export
obj_path = os.path.join(output_dir, "model.obj")
mesh.export(obj_path)
processing_jobs[job_id]['status'] = 'completed'
processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
processing_jobs[job_id]['progress'] = 100
except Exception as e:
processing_jobs[job_id]['status'] = 'error'
processing_jobs[job_id]['error'] = str(e)
finally:
if os.path.exists(filepath):
os.remove(filepath)
gc.collect()
torch.cuda.empty_cache()
thread = threading.Thread(target=process_image)
thread.start()
return jsonify({"job_id": job_id}), 202
@app.route('/download/<job_id>')
def download_model(job_id):
if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
return jsonify({"error": "Job not complete"}), 404
obj_path = os.path.join(RESULTS_FOLDER, job_id, "model.obj")
return send_file(obj_path, as_attachment=True)
@app.route('/progress/<job_id>')
def get_progress(job_id):
job = processing_jobs.get(job_id, {})
return jsonify({
'status': job.get('status'),
'progress': job.get('progress'),
'error': job.get('error')
})
if __name__ == '__main__':
load_models()
app.run(host='0.0.0.0', port=7860)