English
Thomas Male commited on
Commit
f18894a
·
verified ·
1 Parent(s): 3893e21

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -15
handler.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import Dict, List, Any
 
2
  import torch
3
  from torch import autocast
4
  from tqdm.auto import tqdm
@@ -24,13 +25,18 @@ class EndpointHandler():
24
  # load the optimized model
25
  print('creating base model...')
26
 
 
27
  self.base_name = 'base40M-textvec'
28
- #self.base_name = 'base40M'
29
-
30
  self.base_model = model_from_config(MODEL_CONFIGS[self.base_name], device)
31
  self.base_model.eval()
32
  self.base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[self.base_name])
33
 
 
 
 
 
 
 
34
  print('creating upsample model...')
35
  self.upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
36
  self.upsampler_model.eval()
@@ -38,6 +44,7 @@ class EndpointHandler():
38
 
39
  print('downloading base checkpoint...')
40
  self.base_model.load_state_dict(load_checkpoint(self.base_name, device))
 
41
 
42
  print('downloading upsampler checkpoint...')
43
  self.upsampler_model.load_state_dict(load_checkpoint('upsample', device))
@@ -58,27 +65,43 @@ class EndpointHandler():
58
  print('image data found')
59
  else:
60
  print('no image data found')
 
61
 
62
  inputs = data.pop("inputs", data)
63
 
64
- sampler = PointCloudSampler(
65
- device=device,
66
- models=[self.base_model,self.upsampler_model],
67
- diffusions=[self.base_diffusion, self.upsampler_diffusion],
68
- num_points=[1024, 4096 - 1024],
69
- aux_channels=['R', 'G', 'B'],
70
- guidance_scale=[3.0, 0.0],
71
- model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all
72
- )
73
 
74
- # Set a test prompt to condition on.
75
- # prompt = 'A bluebird mid-flight'
 
 
 
 
 
 
 
 
 
 
76
 
77
  # run inference pipeline
78
  with autocast(device.type):
79
  samples = None
80
- for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inputs]))):
81
- samples = x
 
 
 
 
 
82
  #image = self.pipe(inputs, guidance_scale=7.5)["sample"][0]
83
 
84
  pc = sampler.output_to_point_clouds(samples)[0]
 
1
  from typing import Dict, List, Any
2
+ from PIL import Image
3
  import torch
4
  from torch import autocast
5
  from tqdm.auto import tqdm
 
25
  # load the optimized model
26
  print('creating base model...')
27
 
28
+ print('creating base model...')
29
  self.base_name = 'base40M-textvec'
 
 
30
  self.base_model = model_from_config(MODEL_CONFIGS[self.base_name], device)
31
  self.base_model.eval()
32
  self.base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[self.base_name])
33
 
34
+ print('creating image model...')
35
+ self.base_image_name = 'base40M'
36
+ self.base_image_model = model_from_config(MODEL_CONFIGS[self.base_image_name], device)
37
+ self.base_image_model.eval()
38
+ self.base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[self.base_image_name])
39
+
40
  print('creating upsample model...')
41
  self.upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
42
  self.upsampler_model.eval()
 
44
 
45
  print('downloading base checkpoint...')
46
  self.base_model.load_state_dict(load_checkpoint(self.base_name, device))
47
+ self.base_image_model.load_state_dict(load_checkpoint(self.base_image_name, device))
48
 
49
  print('downloading upsampler checkpoint...')
50
  self.upsampler_model.load_state_dict(load_checkpoint('upsample', device))
 
65
  print('image data found')
66
  else:
67
  print('no image data found')
68
+
69
 
70
  inputs = data.pop("inputs", data)
71
 
72
+ if use_image:
73
+ sampler = PointCloudSampler(
74
+ device=device,
75
+ models=[base_model, upsampler_model],
76
+ diffusions=[base_diffusion, upsampler_diffusion],
77
+ num_points=[1024, 4096 - 1024],
78
+ aux_channels=['R', 'G', 'B'],
79
+ guidance_scale=[3.0, 3.0],
80
+ )
81
 
82
+ # Load an image to condition on.
83
+ img = Image.open('example_data/cube_stack.jpg')
84
+ else:
85
+ sampler = PointCloudSampler(
86
+ device=device,
87
+ models=[self.base_model,self.upsampler_model],
88
+ diffusions=[self.base_diffusion, self.upsampler_diffusion],
89
+ num_points=[1024, 4096 - 1024],
90
+ aux_channels=['R', 'G', 'B'],
91
+ guidance_scale=[3.0, 0.0],
92
+ model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all
93
+ )
94
 
95
  # run inference pipeline
96
  with autocast(device.type):
97
  samples = None
98
+ if use_image:
99
+ for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):
100
+ samples = x
101
+ else:
102
+ for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inputs]))):
103
+ samples = x
104
+
105
  #image = self.pipe(inputs, guidance_scale=7.5)["sample"][0]
106
 
107
  pc = sampler.output_to_point_clouds(samples)[0]