omer11a commited on
Commit
d157117
·
1 Parent(s): 431d204

Load float32 model

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -135,7 +135,7 @@ def inference(
135
  raise gr.Error("cuda is not available")
136
 
137
  device = torch.device("cuda")
138
- model.to(device)
139
 
140
  seed_everything(seed)
141
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
@@ -162,7 +162,7 @@ def inference(
162
  register_attention_editor_diffusers(model, editor)
163
  images = model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
164
  unregister_attention_editor_diffusers(model)
165
- model.to(torch.device("cpu"))
166
  return images
167
 
168
 
@@ -256,7 +256,7 @@ def main():
256
  nltk.download("averaged_perceptron_tagger")
257
 
258
  scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
259
- model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16)
260
  model.unet.set_attn_processor(AttnProcessor2_0())
261
  model.enable_sequential_cpu_offload()
262
 
 
135
  raise gr.Error("cuda is not available")
136
 
137
  device = torch.device("cuda")
138
+ model.to(device).half()
139
 
140
  seed_everything(seed)
141
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
 
162
  register_attention_editor_diffusers(model, editor)
163
  images = model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
164
  unregister_attention_editor_diffusers(model)
165
+ model.double().to(torch.device("cpu"))
166
  return images
167
 
168
 
 
256
  nltk.download("averaged_perceptron_tagger")
257
 
258
  scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
259
+ model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler)
260
  model.unet.set_attn_processor(AttnProcessor2_0())
261
  model.enable_sequential_cpu_offload()
262