mac9087 commited on
Commit
dffcbc8
·
verified ·
1 Parent(s): 1e66a9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -44
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import torch
3
  import time
@@ -14,12 +15,9 @@ from huggingface_hub import snapshot_download
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
- from trellis.pipelines import TrellisImageTo3DPipeline
18
 
19
- # Force CPU usage
20
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
21
- os.environ["ATTN_BACKEND"] = "native" # Disable xformers/flash-attn
22
- os.environ["SPCONV_ALGO"] = "native" # Optimize for CPU
23
  torch.set_default_device("cpu")
24
  torch.cuda.is_available = lambda: False
25
  torch.cuda.device_count = lambda: 0
@@ -27,18 +25,15 @@ torch.cuda.device_count = lambda: 0
27
  app = Flask(__name__)
28
  CORS(app)
29
 
30
- # Configure directories
31
  UPLOAD_FOLDER = '/tmp/uploads'
32
  RESULTS_FOLDER = '/tmp/results'
33
  CACHE_DIR = '/tmp/huggingface'
34
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
35
 
36
- # Create directories
37
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
38
  os.makedirs(RESULTS_FOLDER, exist_ok=True)
39
  os.makedirs(CACHE_DIR, exist_ok=True)
40
 
41
- # Set Hugging Face cache
42
  os.environ['HF_HOME'] = CACHE_DIR
43
  os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
44
  os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
@@ -46,17 +41,13 @@ os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
46
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
47
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
48
 
49
- # Job tracking
50
  processing_jobs = {}
51
-
52
- # Global model
53
- trellis_pipeline = None
54
  model_loaded = False
55
  model_loading = False
56
 
57
- # Configuration
58
- TIMEOUT_SECONDS = 360 # 6 minutes for TRELLIS
59
- MAX_DIMENSION = 256 # TRELLIS works with smaller images
60
 
61
  class TimeoutError(Exception):
62
  pass
@@ -92,37 +83,33 @@ def process_with_timeout(function, args, timeout):
92
  def allowed_file(filename):
93
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
94
 
95
- # Image preprocessing
96
  def preprocess_image(image_path):
97
  try:
98
  with Image.open(image_path) as img:
99
- # Convert to RGB
100
  if img.mode == 'RGBA':
101
  img = img.convert('RGB')
102
- # Resize to 256x256
103
  img = img.resize((256, 256), Image.LANCZOS)
104
  return img
105
  except Exception as e:
106
  raise Exception(f"Error preprocessing image: {str(e)}")
107
 
108
  def load_model():
109
- global trellis_pipeline, model_loaded, model_loading
110
 
111
  if model_loaded:
112
- return trellis_pipeline
113
 
114
  if model_loading:
115
  while model_loading and not model_loaded:
116
  time.sleep(0.5)
117
- return trellis_pipeline
118
 
119
  try:
120
  model_loading = True
121
- print("Loading TRELLIS-image-large...")
122
 
123
- model_name = "JeffreyXiang/TRELLIS-image-large"
124
 
125
- # Download model
126
  max_retries = 3
127
  retry_delay = 5
128
  for attempt in range(max_retries):
@@ -141,17 +128,16 @@ def load_model():
141
  else:
142
  raise
143
 
144
- # Load TRELLIS pipeline
145
- trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained(
146
  model_name,
147
  cache_dir=CACHE_DIR,
148
  torch_dtype=torch.float32,
149
  )
150
- trellis_pipeline.to("cpu")
151
 
152
  model_loaded = True
153
- print("TRELLIS loaded successfully on CPU")
154
- return trellis_pipeline
155
 
156
  except Exception as e:
157
  print(f"Error loading model: {str(e)}")
@@ -162,18 +148,16 @@ def load_model():
162
 
163
  def generate_3d_model(image, detail_level):
164
  try:
165
- # Parameters
166
- num_steps = {'low': 50, 'medium': 75, 'high': 100}
167
  steps = num_steps[detail_level]
168
 
169
- # Generate 3D model
170
  with torch.no_grad():
171
- result = trellis_pipeline(image, num_inference_steps=steps, output_type="mesh")
172
 
173
- # Extract mesh
174
- vertices = np.array(result.vertices)
175
- faces = np.array(result.faces)
176
- vertex_colors = np.array(result.vertex_colors) if result.vertex_colors is not None else None
177
 
178
  trimesh_mesh = trimesh.Trimesh(
179
  vertices=vertices,
@@ -181,7 +165,6 @@ def generate_3d_model(image, detail_level):
181
  vertex_colors=vertex_colors
182
  )
183
 
184
- # Rotate for Unity Z-up
185
  trimesh_mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]))
186
 
187
  return trimesh_mesh
@@ -192,7 +175,7 @@ def generate_3d_model(image, detail_level):
192
  def health_check():
193
  return jsonify({
194
  "status": "healthy",
195
- "model": "TRELLIS-image-large",
196
  "device": "cpu"
197
  }), 200
198
 
@@ -303,7 +286,6 @@ def convert_image_to_3d():
303
 
304
  processing_jobs[job_id]['progress'] = 80
305
 
306
- # Export
307
  file_path = os.path.join(output_dir, f"model.{output_format}")
308
  mesh.export(file_path, file_type=output_format)
309
 
@@ -369,7 +351,7 @@ def preview_model(job_id):
369
  if os.path.exists(file_path):
370
  if output_format == 'glb':
371
  return send_file(file_path, mimetype='model/gltf-binary')
372
- else: # OBJ
373
  return send_file(file_path, mimetype='text/plain')
374
 
375
  return jsonify({"error": "Model file not found"}), 404
@@ -433,7 +415,7 @@ def model_info(job_id):
433
  @app.route('/', methods=['GET'])
434
  def index():
435
  return jsonify({
436
- "message": "Image to 3D API (TRELLIS-image-large)",
437
  "endpoints": [
438
  "/convert",
439
  "/progress/<job_id>",
@@ -443,12 +425,13 @@ def index():
443
  ],
444
  "parameters": {
445
  "output_format": "glb or obj",
446
- "detail_level": "low, medium, or high - controls inference steps"
447
  },
448
- "description": "Creates 3D models from 2D images using TRELLIS-image-large. Use transparent PNGs for best results."
449
  }), 200
450
 
451
  if __name__ == '__main__':
452
  cleanup_old_jobs()
453
  port = int(os.environ.get('PORT', 7860))
454
- app.run(host='0.0.0.0', port=port)
 
 
1
+ ```python
2
  import os
3
  import torch
4
  import time
 
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
+ from diffusers import DiffusionPipeline
19
 
 
20
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
 
21
  torch.set_default_device("cpu")
22
  torch.cuda.is_available = lambda: False
23
  torch.cuda.device_count = lambda: 0
 
25
  app = Flask(__name__)
26
  CORS(app)
27
 
 
28
  UPLOAD_FOLDER = '/tmp/uploads'
29
  RESULTS_FOLDER = '/tmp/results'
30
  CACHE_DIR = '/tmp/huggingface'
31
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
32
 
 
33
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
34
  os.makedirs(RESULTS_FOLDER, exist_ok=True)
35
  os.makedirs(CACHE_DIR, exist_ok=True)
36
 
 
37
  os.environ['HF_HOME'] = CACHE_DIR
38
  os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
39
  os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, 'datasets')
 
41
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
42
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
43
 
 
44
  processing_jobs = {}
45
+ zero123_pipeline = None
 
 
46
  model_loaded = False
47
  model_loading = False
48
 
49
+ TIMEOUT_SECONDS = 300
50
+ MAX_DIMENSION = 256
 
51
 
52
  class TimeoutError(Exception):
53
  pass
 
83
  def allowed_file(filename):
84
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
85
 
 
86
  def preprocess_image(image_path):
87
  try:
88
  with Image.open(image_path) as img:
 
89
  if img.mode == 'RGBA':
90
  img = img.convert('RGB')
 
91
  img = img.resize((256, 256), Image.LANCZOS)
92
  return img
93
  except Exception as e:
94
  raise Exception(f"Error preprocessing image: {str(e)}")
95
 
96
  def load_model():
97
+ global zero123_pipeline, model_loaded, model_loading
98
 
99
  if model_loaded:
100
+ return zero123_pipeline
101
 
102
  if model_loading:
103
  while model_loading and not model_loaded:
104
  time.sleep(0.5)
105
+ return zero123_pipeline
106
 
107
  try:
108
  model_loading = True
109
+ print("Loading Zero123++...")
110
 
111
+ model_name = "sudo-ai/zero123plus-v1.2"
112
 
 
113
  max_retries = 3
114
  retry_delay = 5
115
  for attempt in range(max_retries):
 
128
  else:
129
  raise
130
 
131
+ zero123_pipeline = DiffusionPipeline.from_pretrained(
 
132
  model_name,
133
  cache_dir=CACHE_DIR,
134
  torch_dtype=torch.float32,
135
  )
136
+ zero123_pipeline.to("cpu")
137
 
138
  model_loaded = True
139
+ print("Zero123++ loaded successfully on CPU")
140
+ return zero123_pipeline
141
 
142
  except Exception as e:
143
  print(f"Error loading model: {str(e)}")
 
148
 
149
  def generate_3d_model(image, detail_level):
150
  try:
151
+ num_steps = {'low': 30, 'medium': 50, 'high': 75}
 
152
  steps = num_steps[detail_level]
153
 
 
154
  with torch.no_grad():
155
+ result = zero123_pipeline(image, num_inference_steps=steps)
156
 
157
+ mesh = result.meshes[0]
158
+ vertices = np.array(mesh.vertices)
159
+ faces = np.array(mesh.faces)
160
+ vertex_colors = np.array(mesh.vertex_colors) if mesh.vertex_colors is not None else None
161
 
162
  trimesh_mesh = trimesh.Trimesh(
163
  vertices=vertices,
 
165
  vertex_colors=vertex_colors
166
  )
167
 
 
168
  trimesh_mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]))
169
 
170
  return trimesh_mesh
 
175
  def health_check():
176
  return jsonify({
177
  "status": "healthy",
178
+ "model": "Zero123++",
179
  "device": "cpu"
180
  }), 200
181
 
 
286
 
287
  processing_jobs[job_id]['progress'] = 80
288
 
 
289
  file_path = os.path.join(output_dir, f"model.{output_format}")
290
  mesh.export(file_path, file_type=output_format)
291
 
 
351
  if os.path.exists(file_path):
352
  if output_format == 'glb':
353
  return send_file(file_path, mimetype='model/gltf-binary')
354
+ else:
355
  return send_file(file_path, mimetype='text/plain')
356
 
357
  return jsonify({"error": "Model file not found"}), 404
 
415
  @app.route('/', methods=['GET'])
416
  def index():
417
  return jsonify({
418
+ "message": "Image to 3D API (Zero123++)",
419
  "endpoints": [
420
  "/convert",
421
  "/progress/<job_id>",
 
425
  ],
426
  "parameters": {
427
  "output_format": "glb or obj",
428
+ "detail_level": "low, medium, or high"
429
  },
430
+ "description": "Creates 3D models from 2D images using Zero123++."
431
  }), 200
432
 
433
  if __name__ == '__main__':
434
  cleanup_old_jobs()
435
  port = int(os.environ.get('PORT', 7860))
436
+ app.run(host='0.0.0.0', port=port)
437
+ ```