Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import time | |
import threading | |
import json | |
import gc | |
from flask import Flask, request, jsonify, send_file, Response, stream_with_context | |
from werkzeug.utils import secure_filename | |
from PIL import Image | |
import io | |
import zipfile | |
import uuid | |
import traceback | |
from huggingface_hub import snapshot_download | |
from flask_cors import CORS | |
import numpy as np | |
import trimesh | |
from transformers import pipeline | |
from diffusers import StableDiffusionZero123Pipeline | |
import imageio | |
from scipy.spatial.transform import Rotation | |
app = Flask(__name__) | |
CORS(app) | |
# Configuration | |
UPLOAD_FOLDER = '/tmp/uploads' | |
RESULTS_FOLDER = '/tmp/results' | |
CACHE_DIR = '/tmp/huggingface' | |
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
VIEW_ANGLES = [(30, 0), (30, 90), (30, 180), (30, 270)] # (elevation, azimuth) | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(RESULTS_FOLDER, exist_ok=True) | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Environment variables for caching | |
os.environ['HF_HOME'] = CACHE_DIR | |
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers') | |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 | |
# Global models | |
view_generator = None | |
depth_estimator = None | |
model_loaded = False | |
model_loading = False | |
processing_jobs = {} | |
class TimeoutError(Exception): | |
pass | |
def allowed_file(filename): | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
def preprocess_image(image_path, size=256): | |
img = Image.open(image_path).convert("RGB") | |
img = img.resize((size, size), Image.LANCZOS) | |
return img | |
def load_models(): | |
global view_generator, depth_estimator, model_loaded | |
if model_loaded: | |
return | |
try: | |
# Load view generator | |
view_generator = StableDiffusionZero123Pipeline.from_pretrained( | |
"stabilityai/stable-zero123-6dof", | |
torch_dtype=torch.float16, | |
cache_dir=CACHE_DIR | |
).to("cuda" if torch.cuda.is_available() else "cpu") | |
# Load depth estimator | |
depth_estimator = pipeline( | |
"depth-estimation", | |
model="Intel/dpt-hybrid-midas", | |
cache_dir=CACHE_DIR | |
) | |
model_loaded = True | |
print("Models loaded successfully") | |
except Exception as e: | |
print(f"Error loading models: {str(e)}") | |
raise | |
def generate_novel_views(image, num_views=4): | |
views = [] | |
for elevation, azimuth in VIEW_ANGLES: | |
result = view_generator( | |
image, | |
num_inference_steps=50, | |
elevation=elevation, | |
azimuth=azimuth, | |
guidance_scale=3.0 | |
).images[0] | |
views.append((result, (elevation, azimuth))) | |
return views | |
def depth_to_pointcloud(depth_map, pose, fov=60): | |
h, w = depth_map.shape | |
f = w / (2 * np.tan(np.radians(fov/2))) | |
xx, yy = np.meshgrid(np.arange(w), np.arange(h)) | |
x = (xx - w/2) * depth_map / f | |
y = (yy - h/2) * depth_map / f | |
z = depth_map | |
points = np.vstack((x.flatten(), y.flatten(), z.flatten())).T | |
# Apply pose transformation | |
rot = Rotation.from_euler('zyx', [pose[1], pose[0], 0], degrees=True) | |
points = rot.apply(points) | |
return points | |
def create_mesh_from_pointcloud(points, image): | |
pcd = trimesh.PointCloud(points) | |
scene = pcd.scene() | |
mesh = scene.delaunay_3d.triangulate_pcd(pcd) | |
mesh.visual.vertex_colors = image.resize((mesh.vertices.shape[0], 3)) | |
return mesh | |
def convert_image_to_3d(): | |
if 'image' not in request.files: | |
return jsonify({"error": "No image provided"}), 400 | |
file = request.files['image'] | |
if not allowed_file(file.filename): | |
return jsonify({"error": "Invalid file type"}), 400 | |
job_id = str(uuid.uuid4()) | |
output_dir = os.path.join(RESULTS_FOLDER, job_id) | |
os.makedirs(output_dir, exist_ok=True) | |
filename = secure_filename(file.filename) | |
filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}") | |
file.save(filepath) | |
processing_jobs[job_id] = { | |
'status': 'processing', | |
'progress': 0, | |
'result_url': None, | |
'error': None | |
} | |
def process_image(): | |
try: | |
# Preprocess input image | |
img = preprocess_image(filepath) | |
processing_jobs[job_id]['progress'] = 20 | |
# Generate novel views | |
views = generate_novel_views(img) | |
processing_jobs[job_id]['progress'] = 40 | |
# Process each view | |
all_points = [] | |
for view_img, pose in views: | |
# Estimate depth | |
depth_result = depth_estimator(view_img) | |
depth_map = np.array(depth_result["depth"]) | |
# Convert to point cloud | |
points = depth_to_pointcloud(depth_map, pose) | |
all_points.append(points) | |
processing_jobs[job_id]['progress'] += 10 | |
# Combine point clouds | |
combined_points = np.vstack(all_points) | |
processing_jobs[job_id]['progress'] = 80 | |
# Create mesh | |
mesh = create_mesh_from_pointcloud(combined_points, img) | |
# Export | |
obj_path = os.path.join(output_dir, "model.obj") | |
mesh.export(obj_path) | |
processing_jobs[job_id]['status'] = 'completed' | |
processing_jobs[job_id]['result_url'] = f"/download/{job_id}" | |
processing_jobs[job_id]['progress'] = 100 | |
except Exception as e: | |
processing_jobs[job_id]['status'] = 'error' | |
processing_jobs[job_id]['error'] = str(e) | |
finally: | |
if os.path.exists(filepath): | |
os.remove(filepath) | |
gc.collect() | |
torch.cuda.empty_cache() | |
thread = threading.Thread(target=process_image) | |
thread.start() | |
return jsonify({"job_id": job_id}), 202 | |
def download_model(job_id): | |
if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed': | |
return jsonify({"error": "Job not complete"}), 404 | |
obj_path = os.path.join(RESULTS_FOLDER, job_id, "model.obj") | |
return send_file(obj_path, as_attachment=True) | |
def get_progress(job_id): | |
job = processing_jobs.get(job_id, {}) | |
return jsonify({ | |
'status': job.get('status'), | |
'progress': job.get('progress'), | |
'error': job.get('error') | |
}) | |
if __name__ == '__main__': | |
load_models() | |
app.run(host='0.0.0.0', port=7860) |