aleafy commited on
Commit
c8013a6
·
1 Parent(s): 307df09
Files changed (2) hide show
  1. app.py +2 -1
  2. pl_trainer/inference/inference.py +11 -6
app.py CHANGED
@@ -243,7 +243,8 @@ def dummy_process(input_fg, input_bg):
243
  TEXT_CFG = 7.5
244
  text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768)
245
  text_uncond = diffusion_model.encode_text([''])
246
- # to float16 todo
 
247
  init_latent, text_cond, text_uncond, cond_tensor = (
248
  init_latent.to(dtype=torch.float16),
249
  text_cond.to(dtype=torch.float16),
 
243
  TEXT_CFG = 7.5
244
  text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768)
245
  text_uncond = diffusion_model.encode_text([''])
246
+ # to float16
247
+ print('------------to float 16----------------')
248
  init_latent, text_cond, text_uncond, cond_tensor = (
249
  init_latent.to(dtype=torch.float16),
250
  text_cond.to(dtype=torch.float16),
pl_trainer/inference/inference.py CHANGED
@@ -176,8 +176,8 @@ class InferenceIP2PVideo(Inference):
176
  text x x v
177
  img x v v
178
  '''
179
- all_latent = []
180
- all_pred = [] # x0_hat
181
  for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
182
  t = int(t)
183
  latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
@@ -208,13 +208,18 @@ class InferenceIP2PVideo(Inference):
208
  pred_samples = self.scheduler.step(noise_pred, t, latent)
209
  latent = pred_samples.prev_sample
210
  pred = pred_samples.pred_original_sample
211
- all_latent.append(latent.detach())
212
- all_pred.append(pred.detach())
 
 
 
 
 
213
 
214
  return {
215
  'latent': latent,
216
- 'all_latent': all_latent,
217
- 'all_pred': all_pred
218
  }
219
 
220
  @torch.no_grad()
 
176
  text x x v
177
  img x v v
178
  '''
179
+ # all_latent = []
180
+ # all_pred = [] # x0_hat
181
  for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
182
  t = int(t)
183
  latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
 
208
  pred_samples = self.scheduler.step(noise_pred, t, latent)
209
  latent = pred_samples.prev_sample
210
  pred = pred_samples.pred_original_sample
211
+
212
+ del noise_pred, noise_pred1, noise_pred2, noise_pred3, pred_samples
213
+ del latent_input, context_input
214
+ torch.cuda.empty_cache()
215
+
216
+ # all_latent.append(latent.detach())
217
+ # all_pred.append(pred.detach())
218
 
219
  return {
220
  'latent': latent,
221
+ # 'all_latent': all_latent,
222
+ # 'all_pred': all_pred
223
  }
224
 
225
  @torch.no_grad()