Spaces:
Running
on
Zero
Running
on
Zero
acc
Browse files- app.py +2 -1
- pl_trainer/inference/inference.py +11 -6
app.py
CHANGED
@@ -243,7 +243,8 @@ def dummy_process(input_fg, input_bg):
|
|
243 |
TEXT_CFG = 7.5
|
244 |
text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768)
|
245 |
text_uncond = diffusion_model.encode_text([''])
|
246 |
-
# to float16
|
|
|
247 |
init_latent, text_cond, text_uncond, cond_tensor = (
|
248 |
init_latent.to(dtype=torch.float16),
|
249 |
text_cond.to(dtype=torch.float16),
|
|
|
243 |
TEXT_CFG = 7.5
|
244 |
text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768)
|
245 |
text_uncond = diffusion_model.encode_text([''])
|
246 |
+
# to float16
|
247 |
+
print('------------to float 16----------------')
|
248 |
init_latent, text_cond, text_uncond, cond_tensor = (
|
249 |
init_latent.to(dtype=torch.float16),
|
250 |
text_cond.to(dtype=torch.float16),
|
pl_trainer/inference/inference.py
CHANGED
@@ -176,8 +176,8 @@ class InferenceIP2PVideo(Inference):
|
|
176 |
text x x v
|
177 |
img x v v
|
178 |
'''
|
179 |
-
all_latent = []
|
180 |
-
all_pred = [] # x0_hat
|
181 |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
|
182 |
t = int(t)
|
183 |
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
|
@@ -208,13 +208,18 @@ class InferenceIP2PVideo(Inference):
|
|
208 |
pred_samples = self.scheduler.step(noise_pred, t, latent)
|
209 |
latent = pred_samples.prev_sample
|
210 |
pred = pred_samples.pred_original_sample
|
211 |
-
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
return {
|
215 |
'latent': latent,
|
216 |
-
'all_latent': all_latent,
|
217 |
-
'all_pred': all_pred
|
218 |
}
|
219 |
|
220 |
@torch.no_grad()
|
|
|
176 |
text x x v
|
177 |
img x v v
|
178 |
'''
|
179 |
+
# all_latent = []
|
180 |
+
# all_pred = [] # x0_hat
|
181 |
for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])):
|
182 |
t = int(t)
|
183 |
latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2)
|
|
|
208 |
pred_samples = self.scheduler.step(noise_pred, t, latent)
|
209 |
latent = pred_samples.prev_sample
|
210 |
pred = pred_samples.pred_original_sample
|
211 |
+
|
212 |
+
del noise_pred, noise_pred1, noise_pred2, noise_pred3, pred_samples
|
213 |
+
del latent_input, context_input
|
214 |
+
torch.cuda.empty_cache()
|
215 |
+
|
216 |
+
# all_latent.append(latent.detach())
|
217 |
+
# all_pred.append(pred.detach())
|
218 |
|
219 |
return {
|
220 |
'latent': latent,
|
221 |
+
# 'all_latent': all_latent,
|
222 |
+
# 'all_pred': all_pred
|
223 |
}
|
224 |
|
225 |
@torch.no_grad()
|