cocktailpeanut commited on
Commit
4752b73
ยท
1 Parent(s): 9ad025b
Files changed (1) hide show
  1. app.py +84 -26
app.py CHANGED
@@ -71,40 +71,42 @@ Best results come from clean, well-lit images with clear subject isolation. Try
71
  from image_process import prepare_image
72
  from briarmbg import BriaRMBG
73
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
74
- rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
75
- rmbg_net.eval()
76
  from triposg.pipelines.pipeline_triposg import TripoSGPipeline
77
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
78
- triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
79
 
80
  # mv adapter
81
  NUM_VIEWS = 6
82
  from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
83
  from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
84
  from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
85
- mv_adapter_pipe = prepare_pipeline(
86
- base_model="stabilityai/stable-diffusion-xl-base-1.0",
87
- vae_model="madebyollin/sdxl-vae-fp16-fix",
88
- unet_model=None,
89
- lora_model=None,
90
- adapter_path="huanngzh/mv-adapter",
91
- scheduler=None,
92
- num_views=NUM_VIEWS,
93
- device=DEVICE,
94
- dtype=torch.float16,
95
- )
96
- birefnet = AutoModelForImageSegmentation.from_pretrained(
97
- "ZhengPeng7/BiRefNet", trust_remote_code=True
98
- )
99
- birefnet.to(DEVICE)
100
- transform_image = transforms.Compose(
101
- [
102
- transforms.Resize((1024, 1024)),
103
- transforms.ToTensor(),
104
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
105
- ]
106
- )
107
- remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
 
 
108
 
109
  if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
110
  hf_hub_download("dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints")
@@ -140,6 +142,8 @@ def run_full(image: str, req: gr.Request):
140
 
141
  image_seg = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
142
 
 
 
143
  outputs = triposg_pipe(
144
  image=image_seg,
145
  generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
@@ -199,6 +203,19 @@ def run_full(image: str, req: gr.Request):
199
  .to(DEVICE)
200
  )
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  image = Image.open(image)
203
  image = remove_bg_fn(image)
204
  image = preprocess_image(image, height, width)
@@ -207,6 +224,18 @@ def run_full(image: str, req: gr.Request):
207
  if seed != -1 and isinstance(seed, int):
208
  pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
209
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  images = mv_adapter_pipe(
211
  "high quality",
212
  height=height,
@@ -256,6 +285,9 @@ def run_full(image: str, req: gr.Request):
256
  @spaces.GPU()
257
  @torch.no_grad()
258
  def run_segmentation(image: str):
 
 
 
259
  image = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
260
  return image
261
 
@@ -270,6 +302,7 @@ def image_to_3d(
270
  target_face_num: int,
271
  req: gr.Request
272
  ):
 
273
  outputs = triposg_pipe(
274
  image=image,
275
  generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
@@ -333,6 +366,19 @@ def run_texture(image: Image, mesh_path: str, seed: int, req: gr.Request):
333
  .to(DEVICE)
334
  )
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  image = Image.open(image)
337
  image = remove_bg_fn(image)
338
  image = preprocess_image(image, height, width)
@@ -341,6 +387,18 @@ def run_texture(image: Image, mesh_path: str, seed: int, req: gr.Request):
341
  if seed != -1 and isinstance(seed, int):
342
  pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
343
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  images = mv_adapter_pipe(
345
  "high quality",
346
  height=height,
 
71
  from image_process import prepare_image
72
  from briarmbg import BriaRMBG
73
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
74
+ #rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
75
+ #rmbg_net.eval()
76
  from triposg.pipelines.pipeline_triposg import TripoSGPipeline
77
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
78
+ #triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
79
 
80
  # mv adapter
81
  NUM_VIEWS = 6
82
  from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
83
  from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
84
  from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
85
+ #mv_adapter_pipe = prepare_pipeline(
86
+ # base_model="stabilityai/stable-diffusion-xl-base-1.0",
87
+ # vae_model="madebyollin/sdxl-vae-fp16-fix",
88
+ # unet_model=None,
89
+ # lora_model=None,
90
+ # adapter_path="huanngzh/mv-adapter",
91
+ # scheduler=None,
92
+ # num_views=NUM_VIEWS,
93
+ # device=DEVICE,
94
+ # dtype=torch.float16,
95
+ #)
96
+
97
+
98
+ #birefnet = AutoModelForImageSegmentation.from_pretrained(
99
+ # "ZhengPeng7/BiRefNet", trust_remote_code=True
100
+ # )
101
+ #birefnet.to(DEVICE)
102
+ #transform_image = transforms.Compose(
103
+ # [
104
+ # transforms.Resize((1024, 1024)),
105
+ # transforms.ToTensor(),
106
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
107
+ # ]
108
+ #)
109
+ #remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
110
 
111
  if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
112
  hf_hub_download("dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints")
 
142
 
143
  image_seg = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
144
 
145
+ triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
146
+
147
  outputs = triposg_pipe(
148
  image=image_seg,
149
  generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
 
203
  .to(DEVICE)
204
  )
205
 
206
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
207
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
208
+ )
209
+ birefnet.to(DEVICE)
210
+ transform_image = transforms.Compose(
211
+ [
212
+ transforms.Resize((1024, 1024)),
213
+ transforms.ToTensor(),
214
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
215
+ ]
216
+ )
217
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
218
+
219
  image = Image.open(image)
220
  image = remove_bg_fn(image)
221
  image = preprocess_image(image, height, width)
 
224
  if seed != -1 and isinstance(seed, int):
225
  pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
226
 
227
+ mv_adapter_pipe = prepare_pipeline(
228
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
229
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
230
+ unet_model=None,
231
+ lora_model=None,
232
+ adapter_path="huanngzh/mv-adapter",
233
+ scheduler=None,
234
+ num_views=NUM_VIEWS,
235
+ device=DEVICE,
236
+ dtype=torch.float16,
237
+ )
238
+
239
  images = mv_adapter_pipe(
240
  "high quality",
241
  height=height,
 
285
  @spaces.GPU()
286
  @torch.no_grad()
287
  def run_segmentation(image: str):
288
+ snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
289
+ rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
290
+ rmbg_net.eval()
291
  image = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
292
  return image
293
 
 
302
  target_face_num: int,
303
  req: gr.Request
304
  ):
305
+ triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
306
  outputs = triposg_pipe(
307
  image=image,
308
  generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
 
366
  .to(DEVICE)
367
  )
368
 
369
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
370
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
371
+ )
372
+ birefnet.to(DEVICE)
373
+ transform_image = transforms.Compose(
374
+ [
375
+ transforms.Resize((1024, 1024)),
376
+ transforms.ToTensor(),
377
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
378
+ ]
379
+ )
380
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
381
+
382
  image = Image.open(image)
383
  image = remove_bg_fn(image)
384
  image = preprocess_image(image, height, width)
 
387
  if seed != -1 and isinstance(seed, int):
388
  pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
389
 
390
+ mv_adapter_pipe = prepare_pipeline(
391
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
392
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
393
+ unet_model=None,
394
+ lora_model=None,
395
+ adapter_path="huanngzh/mv-adapter",
396
+ scheduler=None,
397
+ num_views=NUM_VIEWS,
398
+ device=DEVICE,
399
+ dtype=torch.float16,
400
+ )
401
+
402
  images = mv_adapter_pipe(
403
  "high quality",
404
  height=height,