syedMohib44 commited on
Commit
4d7448f
·
1 Parent(s): 00e5927
Files changed (1) hide show
  1. app.py +287 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import base64
7
+ import io
8
+ import logging
9
+ import logging.handlers
10
+ import os
11
+ import sys
12
+ import tempfile
13
+ import threading
14
+ import traceback
15
+ import uuid
16
+ from io import BytesIO
17
+
18
+ import torch
19
+ import trimesh
20
+ import uvicorn
21
+ from PIL import Image
22
+ from fastapi import FastAPI, Request, UploadFile
23
+ from fastapi.responses import JSONResponse, FileResponse
24
+
25
+ from hy3dgen.rembg import BackgroundRemover
26
+ from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FloaterRemover, DegenerateFaceRemover, FaceReducer
27
+ from hy3dgen.texgen import Hunyuan3DPaintPipeline
28
+ from hy3dgen.text2image import HunyuanDiTPipeline
29
+
30
+ LOGDIR = '.'
31
+
32
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
33
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
34
+
35
+ handler = None
36
+
37
+
38
+ def build_logger(logger_name, logger_filename):
39
+ global handler
40
+
41
+ formatter = logging.Formatter(
42
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
43
+ datefmt="%Y-%m-%d %H:%M:%S",
44
+ )
45
+
46
+ # Set the format of root handlers
47
+ if not logging.getLogger().handlers:
48
+ logging.basicConfig(level=logging.INFO)
49
+ logging.getLogger().handlers[0].setFormatter(formatter)
50
+
51
+ # Redirect stdout and stderr to loggers
52
+ stdout_logger = logging.getLogger("stdout")
53
+ stdout_logger.setLevel(logging.INFO)
54
+ sl = StreamToLogger(stdout_logger, logging.INFO)
55
+ sys.stdout = sl
56
+
57
+ stderr_logger = logging.getLogger("stderr")
58
+ stderr_logger.setLevel(logging.ERROR)
59
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
60
+ sys.stderr = sl
61
+
62
+ # Get logger
63
+ logger = logging.getLogger(logger_name)
64
+ logger.setLevel(logging.INFO)
65
+
66
+ # Add a file handler for all loggers
67
+ if handler is None:
68
+ os.makedirs(LOGDIR, exist_ok=True)
69
+ filename = os.path.join(LOGDIR, logger_filename)
70
+ handler = logging.handlers.TimedRotatingFileHandler(
71
+ filename, when='D', utc=True, encoding='UTF-8')
72
+ handler.setFormatter(formatter)
73
+
74
+ for name, item in logging.root.manager.loggerDict.items():
75
+ if isinstance(item, logging.Logger):
76
+ item.addHandler(handler)
77
+
78
+ return logger
79
+
80
+
81
+ class StreamToLogger(object):
82
+ """
83
+ Fake file-like stream object that redirects writes to a logger instance.
84
+ """
85
+
86
+ def __init__(self, logger, log_level=logging.INFO):
87
+ self.terminal = sys.stdout
88
+ self.logger = logger
89
+ self.log_level = log_level
90
+ self.linebuf = ''
91
+
92
+ def __getattr__(self, attr):
93
+ return getattr(self.terminal, attr)
94
+
95
+ def write(self, buf):
96
+ temp_linebuf = self.linebuf + buf
97
+ self.linebuf = ''
98
+ for line in temp_linebuf.splitlines(True):
99
+ # From the io.TextIOWrapper docs:
100
+ # On output, if newline is None, any '\n' characters written
101
+ # are translated to the system default line separator.
102
+ # By default sys.stdout.write() expects '\n' newlines and then
103
+ # translates them so this is still cross platform.
104
+ if line[-1] == '\n':
105
+ self.logger.log(self.log_level, line.rstrip())
106
+ else:
107
+ self.linebuf += line
108
+
109
+ def flush(self):
110
+ if self.linebuf != '':
111
+ self.logger.log(self.log_level, self.linebuf.rstrip())
112
+ self.linebuf = ''
113
+
114
+
115
+ def pretty_print_semaphore(semaphore):
116
+ if semaphore is None:
117
+ return "None"
118
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
119
+
120
+
121
+ SAVE_DIR = 'gradio_cache'
122
+ os.makedirs(SAVE_DIR, exist_ok=True)
123
+
124
+ worker_id = str(uuid.uuid4())[:6]
125
+ logger = build_logger("controller", f"{SAVE_DIR}/controller.log")
126
+
127
+
128
+ def load_image_from_base64(image):
129
+ return Image.open(BytesIO(base64.b64decode(image)))
130
+
131
+
132
+ def load_image_from_dir(image: UploadFile):
133
+ """Loads an image from a given file path."""
134
+ try:
135
+ with image.file as f: # Ensures file is properly closed after reading
136
+ image_bytes = f.read() # Read image bytes
137
+ image = Image.open(io.BytesIO(image_bytes)) # Convert to PIL image
138
+ return image
139
+ except Exception as e:
140
+ return {"error": f"Failed to read image: {str(e)}"}
141
+
142
+
143
+ class ModelWorker:
144
+ def __init__(self, model_path='tencent/Hunyuan3D-2', device='cuda'):
145
+ self.model_path = model_path
146
+ self.worker_id = worker_id
147
+ self.device = device
148
+ logger.info(f"Loading the model {model_path} on worker {worker_id} ...")
149
+
150
+ self.rembg = BackgroundRemover()
151
+ self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path, device=device)
152
+ # self.pipeline_t2i = HunyuanDiTPipeline('Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled',
153
+ # device=device)
154
+ self.pipeline_tex = Hunyuan3DPaintPipeline.from_pretrained(model_path)
155
+
156
+ def get_queue_length(self):
157
+ if model_semaphore is None:
158
+ return 0
159
+ else:
160
+ return args.limit_model_concurrency - model_semaphore._value + (len(
161
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
162
+
163
+ def get_status(self):
164
+ return {
165
+ "speed": 1,
166
+ "queue_length": self.get_queue_length(),
167
+ }
168
+
169
+ @torch.inference_mode()
170
+ def generate(self, uid, form):
171
+ params = dict()
172
+ image = form.get("image") # Returns UploadFile object
173
+ if image:
174
+ image = load_image_from_dir(image)
175
+
176
+ image = self.rembg(image)
177
+ params['image'] = image
178
+
179
+ if 'mesh' in params:
180
+ mesh = trimesh.load(BytesIO(base64.b64decode(params["mesh"])), file_type='glb')
181
+ else:
182
+ seed = params.get("seed", 1234)
183
+ params['generator'] = torch.Generator(self.device).manual_seed(seed)
184
+ params['octree_resolution'] = params.get("octree_resolution", 256)
185
+ params['num_inference_steps'] = params.get("num_inference_steps", 30)
186
+ params['guidance_scale'] = params.get('guidance_scale', 7.5)
187
+ params['mc_algo'] = 'mc'
188
+ mesh = self.pipeline(**params)[0]
189
+
190
+ if params.get('texture', False):
191
+ mesh = FloaterRemover()(mesh)
192
+ mesh = DegenerateFaceRemover()(mesh)
193
+ mesh = FaceReducer()(mesh, max_facenum=params.get('face_count', 40000))
194
+ mesh = self.pipeline_tex(mesh, image)
195
+
196
+ # with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as temp_file:
197
+ # print("Thsi is the pathh ====== %s" %temp_file.name)
198
+ # mesh.export(temp_file.name)
199
+ # mesh = trimesh.load(temp_file.name)
200
+ # save_path = os.path.join(SAVE_DIR, f'{str(uid)}.glb')
201
+ # mesh.export(save_path)
202
+
203
+ save_path = os.path.join(SAVE_DIR, f'{str(uid)}.glb')
204
+ print("Thsi is the pathh ====== %s" %save_path)
205
+ mesh.export(save_path)
206
+ torch.cuda.empty_cache()
207
+ return save_path, uid
208
+
209
+
210
+ app = FastAPI()
211
+
212
+
213
+ @app.post("/generate")
214
+ async def generate(request: Request):
215
+ logger.info("Worker generating...")
216
+ # params = await request.json()
217
+ form = await request.form()
218
+
219
+ # data = dict(params) # Convert form fields to a dictionary
220
+ # files = {key: params[key] for key in params if hasattr(params[key], "filename")} # Extract files
221
+
222
+ uid = uuid.uuid4()
223
+ try:
224
+ file_path, uid = worker.generate(uid, form)
225
+ return FileResponse(file_path)
226
+ except ValueError as e:
227
+ traceback.print_exc()
228
+ print("Caught ValueError:", e)
229
+ ret = {
230
+ "text": server_error_msg,
231
+ "error_code": 1,
232
+ }
233
+ return JSONResponse(ret, status_code=404)
234
+ except torch.cuda.CudaError as e:
235
+ print("Caught torch.cuda.CudaError:", e)
236
+ ret = {
237
+ "text": server_error_msg,
238
+ "error_code": 1,
239
+ }
240
+ return JSONResponse(ret, status_code=404)
241
+ except Exception as e:
242
+ print("Caught Unknown Error", e)
243
+ traceback.print_exc()
244
+ ret = {
245
+ "text": server_error_msg,
246
+ "error_code": 1,
247
+ }
248
+ return JSONResponse(ret, status_code=404)
249
+
250
+
251
+ @app.post("/send")
252
+ async def generate(request: Request):
253
+ logger.info("Worker send...")
254
+ params = await request.json()
255
+ uid = uuid.uuid4()
256
+ threading.Thread(target=worker.generate, args=(uid, params,)).start()
257
+ ret = {"uid": str(uid)}
258
+ return JSONResponse(ret, status_code=200)
259
+
260
+
261
+ @app.get("/status/{uid}")
262
+ async def status(uid: str):
263
+ save_file_path = os.path.join(SAVE_DIR, f'{uid}.glb')
264
+ print(save_file_path, os.path.exists(save_file_path))
265
+ if not os.path.exists(save_file_path):
266
+ response = {'status': 'processing'}
267
+ return JSONResponse(response, status_code=200)
268
+ else:
269
+ base64_str = base64.b64encode(open(save_file_path, 'rb').read()).decode()
270
+ response = {'status': 'completed', 'model_base64': base64_str}
271
+ return JSONResponse(response, status_code=200)
272
+
273
+
274
+ if __name__ == "__main__":
275
+ parser = argparse.ArgumentParser()
276
+ parser.add_argument("--host", type=str, default="0.0.0.0")
277
+ parser.add_argument("--port", type=str, default=8081)
278
+ parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2')
279
+ parser.add_argument("--device", type=str, default="cuda")
280
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
281
+ args = parser.parse_args()
282
+ logger.info(f"args: {args}")
283
+
284
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
285
+
286
+ worker = ModelWorker(model_path=args.model_path, device=args.device)
287
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")