Update app.py
Browse files
app.py
CHANGED
@@ -62,22 +62,20 @@ from ldm.models.diffusion.plms import PLMSSampler
|
|
62 |
from open_clip import tokenizer
|
63 |
import open_clip
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
model.eval()
|
80 |
-
return model
|
81 |
|
82 |
def load_safety_model(clip_model):
|
83 |
"""load the safety model"""
|
@@ -127,10 +125,8 @@ def is_unsafe(safety_model, embeddings, threshold=0.5):
|
|
127 |
x = np.array([e[0] for e in nsfw_values])
|
128 |
return True if x > threshold else False
|
129 |
|
130 |
-
|
131 |
-
model = load_model_from_config(config,model_path_e)
|
132 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
133 |
-
model = model.to(device)
|
134 |
|
135 |
#NSFW CLIP Filter
|
136 |
safety_model = load_safety_model("ViT-B/32")
|
|
|
62 |
from open_clip import tokenizer
|
63 |
import open_clip
|
64 |
|
65 |
+
config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
|
66 |
+
print(f"Loading model from {ckpt}")
|
67 |
+
pl_sd = torch.load(ckpt, map_location="cuda")
|
68 |
+
sd = pl_sd["state_dict"]
|
69 |
+
model = instantiate_from_config(config.model)
|
70 |
+
m, u = model.load_state_dict(sd, strict=False)
|
71 |
+
if len(m) > 0 and verbose:
|
72 |
+
print("missing keys:")
|
73 |
+
print(m)
|
74 |
+
if len(u) > 0 and verbose:
|
75 |
+
print("unexpected keys:")
|
76 |
+
print(u)
|
77 |
+
|
78 |
+
model.half().to("cuda")
|
|
|
|
|
79 |
|
80 |
def load_safety_model(clip_model):
|
81 |
"""load the safety model"""
|
|
|
125 |
x = np.array([e[0] for e in nsfw_values])
|
126 |
return True if x > threshold else False
|
127 |
|
128 |
+
# model = load_model_from_config(config,model_path_e)
|
|
|
129 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
130 |
|
131 |
#NSFW CLIP Filter
|
132 |
safety_model = load_safety_model("ViT-B/32")
|