Thomas Male
commited on
Update handler.py
Browse files- handler.py +38 -15
handler.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from typing import Dict, List, Any
|
|
|
2 |
import torch
|
3 |
from torch import autocast
|
4 |
from tqdm.auto import tqdm
|
@@ -24,13 +25,18 @@ class EndpointHandler():
|
|
24 |
# load the optimized model
|
25 |
print('creating base model...')
|
26 |
|
|
|
27 |
self.base_name = 'base40M-textvec'
|
28 |
-
#self.base_name = 'base40M'
|
29 |
-
|
30 |
self.base_model = model_from_config(MODEL_CONFIGS[self.base_name], device)
|
31 |
self.base_model.eval()
|
32 |
self.base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[self.base_name])
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
print('creating upsample model...')
|
35 |
self.upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
|
36 |
self.upsampler_model.eval()
|
@@ -38,6 +44,7 @@ class EndpointHandler():
|
|
38 |
|
39 |
print('downloading base checkpoint...')
|
40 |
self.base_model.load_state_dict(load_checkpoint(self.base_name, device))
|
|
|
41 |
|
42 |
print('downloading upsampler checkpoint...')
|
43 |
self.upsampler_model.load_state_dict(load_checkpoint('upsample', device))
|
@@ -58,27 +65,43 @@ class EndpointHandler():
|
|
58 |
print('image data found')
|
59 |
else:
|
60 |
print('no image data found')
|
|
|
61 |
|
62 |
inputs = data.pop("inputs", data)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# run inference pipeline
|
78 |
with autocast(device.type):
|
79 |
samples = None
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
#image = self.pipe(inputs, guidance_scale=7.5)["sample"][0]
|
83 |
|
84 |
pc = sampler.output_to_point_clouds(samples)[0]
|
|
|
1 |
from typing import Dict, List, Any
|
2 |
+
from PIL import Image
|
3 |
import torch
|
4 |
from torch import autocast
|
5 |
from tqdm.auto import tqdm
|
|
|
25 |
# load the optimized model
|
26 |
print('creating base model...')
|
27 |
|
28 |
+
print('creating base model...')
|
29 |
self.base_name = 'base40M-textvec'
|
|
|
|
|
30 |
self.base_model = model_from_config(MODEL_CONFIGS[self.base_name], device)
|
31 |
self.base_model.eval()
|
32 |
self.base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[self.base_name])
|
33 |
|
34 |
+
print('creating image model...')
|
35 |
+
self.base_image_name = 'base40M'
|
36 |
+
self.base_image_model = model_from_config(MODEL_CONFIGS[self.base_image_name], device)
|
37 |
+
self.base_image_model.eval()
|
38 |
+
self.base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[self.base_image_name])
|
39 |
+
|
40 |
print('creating upsample model...')
|
41 |
self.upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
|
42 |
self.upsampler_model.eval()
|
|
|
44 |
|
45 |
print('downloading base checkpoint...')
|
46 |
self.base_model.load_state_dict(load_checkpoint(self.base_name, device))
|
47 |
+
self.base_image_model.load_state_dict(load_checkpoint(self.base_image_name, device))
|
48 |
|
49 |
print('downloading upsampler checkpoint...')
|
50 |
self.upsampler_model.load_state_dict(load_checkpoint('upsample', device))
|
|
|
65 |
print('image data found')
|
66 |
else:
|
67 |
print('no image data found')
|
68 |
+
|
69 |
|
70 |
inputs = data.pop("inputs", data)
|
71 |
|
72 |
+
if use_image:
|
73 |
+
sampler = PointCloudSampler(
|
74 |
+
device=device,
|
75 |
+
models=[base_model, upsampler_model],
|
76 |
+
diffusions=[base_diffusion, upsampler_diffusion],
|
77 |
+
num_points=[1024, 4096 - 1024],
|
78 |
+
aux_channels=['R', 'G', 'B'],
|
79 |
+
guidance_scale=[3.0, 3.0],
|
80 |
+
)
|
81 |
|
82 |
+
# Load an image to condition on.
|
83 |
+
img = Image.open('example_data/cube_stack.jpg')
|
84 |
+
else:
|
85 |
+
sampler = PointCloudSampler(
|
86 |
+
device=device,
|
87 |
+
models=[self.base_model,self.upsampler_model],
|
88 |
+
diffusions=[self.base_diffusion, self.upsampler_diffusion],
|
89 |
+
num_points=[1024, 4096 - 1024],
|
90 |
+
aux_channels=['R', 'G', 'B'],
|
91 |
+
guidance_scale=[3.0, 0.0],
|
92 |
+
model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all
|
93 |
+
)
|
94 |
|
95 |
# run inference pipeline
|
96 |
with autocast(device.type):
|
97 |
samples = None
|
98 |
+
if use_image:
|
99 |
+
for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):
|
100 |
+
samples = x
|
101 |
+
else:
|
102 |
+
for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inputs]))):
|
103 |
+
samples = x
|
104 |
+
|
105 |
#image = self.pipe(inputs, guidance_scale=7.5)["sample"][0]
|
106 |
|
107 |
pc = sampler.output_to_point_clouds(samples)[0]
|