rightnight / app.py
mac9087's picture
Update app.py
e5edf92 verified
raw
history blame
7.91 kB
import os
import torch
from flask import Flask, request, jsonify, send_file
from werkzeug.utils import secure_filename
from PIL import Image
import io
import zipfile
import uuid
import traceback
from diffusers import ShapEImg2ImgPipeline
from diffusers.utils import export_to_obj
app = Flask(__name__)
# Configure directories - use /tmp for Hugging Face Spaces which is writable
UPLOAD_FOLDER = '/tmp/uploads'
RESULTS_FOLDER = '/tmp/results'
CACHE_DIR = '/tmp/huggingface'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
# Create necessary directories
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULTS_FOLDER, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
# Set Hugging Face cache environment variables
os.environ['HF_HOME'] = CACHE_DIR
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
# Lazy loading for the model - only load when needed
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = None
def load_model():
global pipe
if pipe is None:
pipe = ShapEImg2ImgPipeline.from_pretrained(
"openai/shap-e-img2img",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
cache_dir=CACHE_DIR # Explicitly set cache directory
)
pipe = pipe.to(device)
return pipe
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({"status": "healthy", "model": "Shap-E Image to 3D"}), 200
@app.route('/convert', methods=['POST'])
def convert_image_to_3d():
# Check if image is in the request
if 'image' not in request.files:
return jsonify({"error": "No image provided"}), 400
file = request.files['image']
if file.filename == '':
return jsonify({"error": "No image selected"}), 400
if not allowed_file(file.filename):
return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
# Get optional parameters
guidance_scale = float(request.form.get('guidance_scale', 3.0))
num_inference_steps = int(request.form.get('num_inference_steps', 64))
output_format = request.form.get('output_format', 'obj').lower()
# Validate output format
if output_format not in ['obj', 'glb']:
return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
try:
# Process image
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# Open image
image = Image.open(filepath).convert("RGB")
# Load model (lazy loading)
pipe = load_model()
# Generate 3D model
images = pipe(
image,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
output_type="mesh",
).images
# Create unique output directory
output_id = str(uuid.uuid4())
output_dir = os.path.join(RESULTS_FOLDER, output_id)
os.makedirs(output_dir, exist_ok=True)
# Export to requested format
if output_format == 'obj':
obj_path = os.path.join(output_dir, "model.obj")
export_to_obj(images[0], obj_path)
# Create a zip file with OBJ and MTL
zip_path = os.path.join(output_dir, "model.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
zipf.write(obj_path, arcname="model.obj")
mtl_path = os.path.join(output_dir, "model.mtl")
if os.path.exists(mtl_path):
zipf.write(mtl_path, arcname="model.mtl")
return send_file(zip_path, as_attachment=True, download_name="model.zip")
elif output_format == 'glb':
# For GLB format, we need to convert the mesh
from trimesh import Trimesh
mesh = images[0]
vertices = mesh.verts
faces = mesh.faces
# Create a trimesh object
trimesh_obj = Trimesh(vertices=vertices, faces=faces)
# Export as GLB
glb_path = os.path.join(output_dir, "model.glb")
trimesh_obj.export(glb_path)
return send_file(glb_path, as_attachment=True, download_name="model.glb")
except Exception as e:
# Enhanced error reporting with traceback
error_details = traceback.format_exc()
return jsonify({"error": str(e), "details": error_details}), 500
@app.route('/', methods=['GET'])
def index():
return """
<html>
<head>
<title>Image to 3D Model Converter</title>
<style>
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
h1 { color: #333; }
form { margin: 20px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; }
label { display: block; margin: 10px 0 5px; }
input, select { margin-bottom: 10px; padding: 8px; width: 100%; }
button { background: #4CAF50; color: white; border: none; padding: 10px 15px; cursor: pointer; }
.api-info { background: #f5f5f5; padding: 15px; border-radius: 5px; }
pre { background: #eee; padding: 10px; overflow-x: auto; }
</style>
</head>
<body>
<h1>Image to 3D Model Converter</h1>
<form action="/convert" method="post" enctype="multipart/form-data">
<label for="image">Upload Image:</label>
<input type="file" id="image" name="image" accept=".png,.jpg,.jpeg" required>
<label for="guidance_scale">Guidance Scale (1.0-5.0):</label>
<input type="number" id="guidance_scale" name="guidance_scale" min="1.0" max="5.0" step="0.1" value="3.0">
<label for="num_inference_steps">Inference Steps (32-128):</label>
<input type="number" id="num_inference_steps" name="num_inference_steps" min="32" max="128" value="64">
<label for="output_format">Output Format:</label>
<select id="output_format" name="output_format">
<option value="obj">OBJ (for Unity)</option>
<option value="glb">GLB (for Three.js/Unreal)</option>
</select>
<button type="submit">Convert to 3D</button>
</form>
<div class="api-info">
<h2>API Documentation</h2>
<p>Endpoint: <code>/convert</code> (POST)</p>
<p>Parameters:</p>
<ul>
<li><code>image</code>: Image file (required)</li>
<li><code>guidance_scale</code>: Float between 1.0-5.0 (default: 3.0)</li>
<li><code>num_inference_steps</code>: Integer between 32-128 (default: 64)</li>
<li><code>output_format</code>: "obj" or "glb" (default: "obj")</li>
</ul>
<p>Example curl request:</p>
<pre>curl -X POST -F "image=@your_image.jpg" -F "output_format=obj" http://localhost:7860/convert -o model.zip</pre>
</div>
</body>
</html>
"""
if __name__ == '__main__':
# Use port 7860 which is standard for Hugging Face Spaces
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)