multimodalart HF Staff commited on
Commit
9bc2183
·
verified ·
1 Parent(s): 20b964d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -19
app.py CHANGED
@@ -62,22 +62,20 @@ from ldm.models.diffusion.plms import PLMSSampler
62
  from open_clip import tokenizer
63
  import open_clip
64
 
65
- def load_model_from_config(config, ckpt, verbose=False):
66
- print(f"Loading model from {ckpt}")
67
- pl_sd = torch.load(ckpt, map_location="cuda")
68
- sd = pl_sd["state_dict"]
69
- model = instantiate_from_config(config.model)
70
- m, u = model.load_state_dict(sd, strict=False)
71
- if len(m) > 0 and verbose:
72
- print("missing keys:")
73
- print(m)
74
- if len(u) > 0 and verbose:
75
- print("unexpected keys:")
76
- print(u)
77
-
78
- model = model.half().cuda()
79
- model.eval()
80
- return model
81
 
82
  def load_safety_model(clip_model):
83
  """load the safety model"""
@@ -127,10 +125,8 @@ def is_unsafe(safety_model, embeddings, threshold=0.5):
127
  x = np.array([e[0] for e in nsfw_values])
128
  return True if x > threshold else False
129
 
130
- config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
131
- model = load_model_from_config(config,model_path_e)
132
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
133
- model = model.to(device)
134
 
135
  #NSFW CLIP Filter
136
  safety_model = load_safety_model("ViT-B/32")
 
62
  from open_clip import tokenizer
63
  import open_clip
64
 
65
+ config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
66
+ print(f"Loading model from {ckpt}")
67
+ pl_sd = torch.load(ckpt, map_location="cuda")
68
+ sd = pl_sd["state_dict"]
69
+ model = instantiate_from_config(config.model)
70
+ m, u = model.load_state_dict(sd, strict=False)
71
+ if len(m) > 0 and verbose:
72
+ print("missing keys:")
73
+ print(m)
74
+ if len(u) > 0 and verbose:
75
+ print("unexpected keys:")
76
+ print(u)
77
+
78
+ model.half().to("cuda")
 
 
79
 
80
  def load_safety_model(clip_model):
81
  """load the safety model"""
 
125
  x = np.array([e[0] for e in nsfw_values])
126
  return True if x > threshold else False
127
 
128
+ # model = load_model_from_config(config,model_path_e)
 
129
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
 
130
 
131
  #NSFW CLIP Filter
132
  safety_model = load_safety_model("ViT-B/32")