mac9087 commited on
Commit
d6e1b09
·
verified ·
1 Parent(s): a59f056

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -20
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import torch
3
  import time
@@ -14,7 +15,7 @@ from huggingface_hub import snapshot_download
14
  from flask_cors import CORS
15
  import numpy as np
16
  import trimesh
17
- from diffusers import DiffusionPipeline
18
 
19
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
20
  torch.set_default_device("cpu")
@@ -41,12 +42,12 @@ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
41
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
42
 
43
  processing_jobs = {}
44
- zero123_pipeline = None
45
  model_loaded = False
46
  model_loading = False
47
 
48
  TIMEOUT_SECONDS = 300
49
- MAX_DIMENSION = 256
50
 
51
  class TimeoutError(Exception):
52
  pass
@@ -87,27 +88,27 @@ def preprocess_image(image_path):
87
  with Image.open(image_path) as img:
88
  if img.mode == 'RGBA':
89
  img = img.convert('RGB')
90
- img = img.resize((256, 256), Image.LANCZOS)
91
  return img
92
  except Exception as e:
93
  raise Exception(f"Error preprocessing image: {str(e)}")
94
 
95
  def load_model():
96
- global zero123_pipeline, model_loaded, model_loading
97
 
98
  if model_loaded:
99
- return zero123_pipeline
100
 
101
  if model_loading:
102
  while model_loading and not model_loaded:
103
  time.sleep(0.5)
104
- return zero123_pipeline
105
 
106
  try:
107
  model_loading = True
108
- print("Loading Zero123++...")
109
 
110
- model_name = "sudo-ai/zero123plus-v1.2"
111
 
112
  max_retries = 3
113
  retry_delay = 5
@@ -127,16 +128,17 @@ def load_model():
127
  else:
128
  raise
129
 
130
- zero123_pipeline = DiffusionPipeline.from_pretrained(
131
  model_name,
 
132
  cache_dir=CACHE_DIR,
133
  torch_dtype=torch.float32,
134
  )
135
- zero123_pipeline.to("cpu")
136
 
137
  model_loaded = True
138
- print("Zero123++ loaded successfully on CPU")
139
- return zero123_pipeline
140
 
141
  except Exception as e:
142
  print(f"Error loading model: {str(e)}")
@@ -147,13 +149,13 @@ def load_model():
147
 
148
  def generate_3d_model(image, detail_level):
149
  try:
150
- num_steps = {'low': 30, 'medium': 50, 'high': 75}
151
  steps = num_steps[detail_level]
152
 
153
  with torch.no_grad():
154
- result = zero123_pipeline(image, num_inference_steps=steps)
155
 
156
- mesh = result.meshes[0]
157
  vertices = np.array(mesh.vertices)
158
  faces = np.array(mesh.faces)
159
  vertex_colors = np.array(mesh.vertex_colors) if mesh.vertex_colors is not None else None
@@ -174,7 +176,7 @@ def generate_3d_model(image, detail_level):
174
  def health_check():
175
  return jsonify({
176
  "status": "healthy",
177
- "model": "Zero123++",
178
  "device": "cpu"
179
  }), 200
180
 
@@ -414,7 +416,7 @@ def model_info(job_id):
414
  @app.route('/', methods=['GET'])
415
  def index():
416
  return jsonify({
417
- "message": "Image to 3D API (Zero123++)",
418
  "endpoints": [
419
  "/convert",
420
  "/progress/<job_id>",
@@ -426,10 +428,10 @@ def index():
426
  "output_format": "glb or obj",
427
  "detail_level": "low, medium, or high"
428
  },
429
- "description": "Creates 3D models from 2D images using Zero123++."
430
  }), 200
431
 
432
  if __name__ == '__main__':
433
  cleanup_old_jobs()
434
  port = int(os.environ.get('PORT', 7860))
435
- app.run(host='0.0.0.0', port=port)
 
1
+
2
  import os
3
  import torch
4
  import time
 
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
+ from diffusers import Hunyuan3DDiTPipeline
19
 
20
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
21
  torch.set_default_device("cpu")
 
42
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
43
 
44
  processing_jobs = {}
45
+ hunyuan_pipeline = None
46
  model_loaded = False
47
  model_loading = False
48
 
49
  TIMEOUT_SECONDS = 300
50
+ MAX_DIMENSION = 512 # Hunyuan3D-1.0 uses 512x512 inputs
51
 
52
  class TimeoutError(Exception):
53
  pass
 
88
  with Image.open(image_path) as img:
89
  if img.mode == 'RGBA':
90
  img = img.convert('RGB')
91
+ img = img.resize((512, 512), Image.LANCZOS)
92
  return img
93
  except Exception as e:
94
  raise Exception(f"Error preprocessing image: {str(e)}")
95
 
96
  def load_model():
97
+ global hunyuan_pipeline, model_loaded, model_loading
98
 
99
  if model_loaded:
100
+ return hunyuan_pipeline
101
 
102
  if model_loading:
103
  while model_loading and not model_loaded:
104
  time.sleep(0.5)
105
+ return hunyuan_pipeline
106
 
107
  try:
108
  model_loading = True
109
+ print("Loading Hunyuan3D-1.0 Lite...")
110
 
111
+ model_name = "tencent/Hunyuan3D-1"
112
 
113
  max_retries = 3
114
  retry_delay = 5
 
128
  else:
129
  raise
130
 
131
+ hunyuan_pipeline = Hunyuan3DDiTPipeline.from_pretrained(
132
  model_name,
133
+ subfolder="lite",
134
  cache_dir=CACHE_DIR,
135
  torch_dtype=torch.float32,
136
  )
137
+ hunyuan_pipeline.to("cpu")
138
 
139
  model_loaded = True
140
+ print("Hunyuan3D-1.0 Lite loaded successfully on CPU")
141
+ return hunyuan_pipeline
142
 
143
  except Exception as e:
144
  print(f"Error loading model: {str(e)}")
 
149
 
150
  def generate_3d_model(image, detail_level):
151
  try:
152
+ num_steps = {'low': 20, 'medium': 30, 'high': 50}
153
  steps = num_steps[detail_level]
154
 
155
  with torch.no_grad():
156
+ result = hunyuan_pipeline(image, num_inference_steps=steps, max_faces_num=20000)
157
 
158
+ mesh = result[0] # Hunyuan3D returns a trimesh object
159
  vertices = np.array(mesh.vertices)
160
  faces = np.array(mesh.faces)
161
  vertex_colors = np.array(mesh.vertex_colors) if mesh.vertex_colors is not None else None
 
176
  def health_check():
177
  return jsonify({
178
  "status": "healthy",
179
+ "model": "Hunyuan3D-1.0 Lite",
180
  "device": "cpu"
181
  }), 200
182
 
 
416
  @app.route('/', methods=['GET'])
417
  def index():
418
  return jsonify({
419
+ "message": "Image to 3D API (Hunyuan3D-1.0 Lite)",
420
  "endpoints": [
421
  "/convert",
422
  "/progress/<job_id>",
 
428
  "output_format": "glb or obj",
429
  "detail_level": "low, medium, or high"
430
  },
431
+ "description": "Creates 3D models from 2D images using Hunyuan3D-1.0 Lite."
432
  }), 200
433
 
434
  if __name__ == '__main__':
435
  cleanup_old_jobs()
436
  port = int(os.environ.get('PORT', 7860))
437
+ app.run(host='0.0.0.0', port=port)