Spaces:
Runtime error
Runtime error
Commit
Β·
05f00dc
1
Parent(s):
8728fbf
Upload 6 files
Browse files- README.md +32 -6
- app.py +255 -0
- model-card.md +50 -0
- requirements.txt +5 -0
- server.py +175 -0
- setup.py +15 -0
README.md
CHANGED
@@ -1,12 +1,38 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.15.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Glide Text2im
|
3 |
+
emoji: π
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: gray
|
6 |
sdk: gradio
|
|
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
+
duplicated_from: valhalla/glide-text2im
|
10 |
---
|
11 |
|
12 |
+
# Configuration
|
13 |
+
|
14 |
+
`title`: _string_
|
15 |
+
Display title for the Space
|
16 |
+
|
17 |
+
`emoji`: _string_
|
18 |
+
Space emoji (emoji-only character allowed)
|
19 |
+
|
20 |
+
`colorFrom`: _string_
|
21 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
22 |
+
|
23 |
+
`colorTo`: _string_
|
24 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
25 |
+
|
26 |
+
`sdk`: _string_
|
27 |
+
Can be either `gradio` or `streamlit`
|
28 |
+
|
29 |
+
`sdk_version` : _string_
|
30 |
+
Only applicable for `streamlit` SDK.
|
31 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
32 |
+
|
33 |
+
`app_file`: _string_
|
34 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
35 |
+
Path is relative to the root of the repository.
|
36 |
+
|
37 |
+
`pinned`: _boolean_
|
38 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
os.system('pip install -e .')
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
import base64
|
7 |
+
from io import BytesIO
|
8 |
+
# from fastapi import FastAPI
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
import torch as th
|
12 |
+
|
13 |
+
from glide_text2im.download import load_checkpoint
|
14 |
+
from glide_text2im.model_creation import (
|
15 |
+
create_model_and_diffusion,
|
16 |
+
model_and_diffusion_defaults,
|
17 |
+
model_and_diffusion_defaults_upsampler
|
18 |
+
)
|
19 |
+
|
20 |
+
# print("Loading models...")
|
21 |
+
# app = FastAPI()
|
22 |
+
|
23 |
+
# This notebook supports both CPU and GPU.
|
24 |
+
# On CPU, generating one sample may take on the order of 20 minutes.
|
25 |
+
# On a GPU, it should be under a minute.
|
26 |
+
|
27 |
+
has_cuda = th.cuda.is_available()
|
28 |
+
device = th.device('cpu' if not has_cuda else 'cuda')
|
29 |
+
|
30 |
+
# # Create base model.
|
31 |
+
# options = model_and_diffusion_defaults()
|
32 |
+
# options['use_fp16'] = has_cuda
|
33 |
+
# options['timestep_respacing'] = '40' # use 100 diffusion steps for fast sampling (Previous it was 100)
|
34 |
+
# model, diffusion = create_model_and_diffusion(**options)
|
35 |
+
# model.eval()
|
36 |
+
# if has_cuda:
|
37 |
+
# model.convert_to_fp16()
|
38 |
+
# model.to(device)
|
39 |
+
# # model.load_state_dict(load_checkpoint('base', device))
|
40 |
+
# model.load_state_dict(th.load("base.pt", map_location=device))
|
41 |
+
# print('total base parameters', sum(x.numel() for x in model.parameters()))
|
42 |
+
|
43 |
+
# # Create upsampler model.
|
44 |
+
# options_up = model_and_diffusion_defaults_upsampler()
|
45 |
+
# options_up['use_fp16'] = has_cuda
|
46 |
+
# options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
|
47 |
+
# model_up, diffusion_up = create_model_and_diffusion(**options_up)
|
48 |
+
# model_up.eval()
|
49 |
+
# if has_cuda:
|
50 |
+
# model_up.convert_to_fp16()
|
51 |
+
# model_up.to(device)
|
52 |
+
# # model_up.load_state_dict(load_checkpoint('upsample', device))
|
53 |
+
# model.load_state_dict(th.load("upsample.pt", map_location=device))
|
54 |
+
# print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))
|
55 |
+
|
56 |
+
base_timestep_respacing = '40' #@param {type:"string"}
|
57 |
+
|
58 |
+
sr_timestep_respacing = 'fast27'
|
59 |
+
|
60 |
+
#@title Create base model.
|
61 |
+
glide_path = 'base.pt' #@param {type:"string"}
|
62 |
+
import os
|
63 |
+
options = model_and_diffusion_defaults()
|
64 |
+
options['use_fp16'] = has_cuda
|
65 |
+
options['timestep_respacing'] = base_timestep_respacing # use 100 diffusion steps for fast sampling
|
66 |
+
model, diffusion = create_model_and_diffusion(**options)
|
67 |
+
|
68 |
+
if len(glide_path) > 0:
|
69 |
+
assert os.path.exists(
|
70 |
+
glide_path
|
71 |
+
), f"Failed to resume from {glide_path}, file does not exist."
|
72 |
+
weights = th.load(glide_path, map_location="cpu")
|
73 |
+
model, diffusion = create_model_and_diffusion(**options)
|
74 |
+
model.load_state_dict(weights)
|
75 |
+
print(f"Resumed from {glide_path} successfully.")
|
76 |
+
else:
|
77 |
+
model, diffusion = create_model_and_diffusion(**options)
|
78 |
+
model.load_state_dict(load_checkpoint("base", device))
|
79 |
+
model.eval()
|
80 |
+
if has_cuda:
|
81 |
+
model.convert_to_fp16()
|
82 |
+
model.to(device)
|
83 |
+
print('total base parameters', sum(x.numel() for x in model.parameters()))
|
84 |
+
|
85 |
+
|
86 |
+
#@title Create upsampler model.
|
87 |
+
sr_glide_path = "upsample.pt" #@param {type:"string"}
|
88 |
+
|
89 |
+
|
90 |
+
options_up = model_and_diffusion_defaults_upsampler()
|
91 |
+
options_up['use_fp16'] = has_cuda
|
92 |
+
options_up['timestep_respacing'] = sr_timestep_respacing # use 27 diffusion steps for very fast sampling
|
93 |
+
|
94 |
+
if len(sr_glide_path) > 0:
|
95 |
+
assert os.path.exists(
|
96 |
+
sr_glide_path
|
97 |
+
), f"Failed to resume from {sr_glide_path}, file does not exist."
|
98 |
+
weights = th.load(sr_glide_path, map_location="cpu")
|
99 |
+
model_up, diffusion_up = create_model_and_diffusion(**options_up)
|
100 |
+
model_up.load_state_dict(weights)
|
101 |
+
print(f"Resumed from {sr_glide_path} successfully.")
|
102 |
+
else:
|
103 |
+
model_up, diffusion_up = create_model_and_diffusion(**options)
|
104 |
+
model_up.load_state_dict(load_checkpoint("upsample", device))
|
105 |
+
|
106 |
+
if has_cuda:
|
107 |
+
model_up.convert_to_fp16()
|
108 |
+
model_up.to(device)
|
109 |
+
print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def get_images(batch: th.Tensor):
|
115 |
+
""" Display a batch of images inline. """
|
116 |
+
scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
|
117 |
+
reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
|
118 |
+
return Image.fromarray(reshaped.numpy())
|
119 |
+
|
120 |
+
|
121 |
+
# Create a classifier-free guidance sampling function
|
122 |
+
guidance_scale = 3.0
|
123 |
+
|
124 |
+
def model_fn(x_t, ts, **kwargs):
|
125 |
+
half = x_t[: len(x_t) // 2]
|
126 |
+
combined = th.cat([half, half], dim=0)
|
127 |
+
model_out = model(combined, ts, **kwargs)
|
128 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
129 |
+
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
|
130 |
+
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
131 |
+
eps = th.cat([half_eps, half_eps], dim=0)
|
132 |
+
return th.cat([eps, rest], dim=1)
|
133 |
+
|
134 |
+
|
135 |
+
# @app.get("/")
|
136 |
+
def read_root():
|
137 |
+
return {"glide!"}
|
138 |
+
|
139 |
+
# @app.get("/{generate}")
|
140 |
+
def sample(prompt):
|
141 |
+
# Sampling parameters
|
142 |
+
batch_size = 1
|
143 |
+
|
144 |
+
# Tune this parameter to control the sharpness of 256x256 images.
|
145 |
+
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
146 |
+
upsample_temp = 0.997
|
147 |
+
|
148 |
+
##############################
|
149 |
+
# Sample from the base model #
|
150 |
+
##############################
|
151 |
+
|
152 |
+
# Create the text tokens to feed to the model.
|
153 |
+
tokens = model.tokenizer.encode(prompt)
|
154 |
+
tokens, mask = model.tokenizer.padded_tokens_and_mask(
|
155 |
+
tokens, options['text_ctx']
|
156 |
+
)
|
157 |
+
|
158 |
+
# Create the classifier-free guidance tokens (empty)
|
159 |
+
full_batch_size = batch_size * 2
|
160 |
+
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
|
161 |
+
[], options['text_ctx']
|
162 |
+
)
|
163 |
+
|
164 |
+
# Pack the tokens together into model kwargs.
|
165 |
+
model_kwargs = dict(
|
166 |
+
tokens=th.tensor(
|
167 |
+
[tokens] * batch_size + [uncond_tokens] * batch_size, device=device
|
168 |
+
),
|
169 |
+
mask=th.tensor(
|
170 |
+
[mask] * batch_size + [uncond_mask] * batch_size,
|
171 |
+
dtype=th.bool,
|
172 |
+
device=device,
|
173 |
+
),
|
174 |
+
)
|
175 |
+
|
176 |
+
# Sample from the base model.
|
177 |
+
model.del_cache()
|
178 |
+
samples = diffusion.p_sample_loop(
|
179 |
+
model_fn,
|
180 |
+
(full_batch_size, 3, options["image_size"], options["image_size"]),
|
181 |
+
device=device,
|
182 |
+
clip_denoised=True,
|
183 |
+
progress=True,
|
184 |
+
model_kwargs=model_kwargs,
|
185 |
+
cond_fn=None,
|
186 |
+
)[:batch_size]
|
187 |
+
model.del_cache()
|
188 |
+
|
189 |
+
|
190 |
+
##############################
|
191 |
+
# Upsample the 64x64 samples #
|
192 |
+
##############################
|
193 |
+
|
194 |
+
tokens = model_up.tokenizer.encode(prompt)
|
195 |
+
tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
|
196 |
+
tokens, options_up['text_ctx']
|
197 |
+
)
|
198 |
+
|
199 |
+
# Create the model conditioning dict.
|
200 |
+
model_kwargs = dict(
|
201 |
+
# Low-res image to upsample.
|
202 |
+
low_res=((samples+1)*127.5).round()/127.5 - 1,
|
203 |
+
|
204 |
+
# Text tokens
|
205 |
+
tokens=th.tensor(
|
206 |
+
[tokens] * batch_size, device=device
|
207 |
+
),
|
208 |
+
mask=th.tensor(
|
209 |
+
[mask] * batch_size,
|
210 |
+
dtype=th.bool,
|
211 |
+
device=device,
|
212 |
+
),
|
213 |
+
)
|
214 |
+
|
215 |
+
# Sample from the base model.
|
216 |
+
model_up.del_cache()
|
217 |
+
up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
|
218 |
+
up_samples = diffusion_up.ddim_sample_loop(
|
219 |
+
model_up,
|
220 |
+
up_shape,
|
221 |
+
noise=th.randn(up_shape, device=device) * upsample_temp,
|
222 |
+
device=device,
|
223 |
+
clip_denoised=True,
|
224 |
+
progress=True,
|
225 |
+
model_kwargs=model_kwargs,
|
226 |
+
cond_fn=None,
|
227 |
+
)[:batch_size]
|
228 |
+
model_up.del_cache()
|
229 |
+
|
230 |
+
# Show the output
|
231 |
+
image = get_images(up_samples)
|
232 |
+
# image = to_base64(image)
|
233 |
+
# return {"image": image}
|
234 |
+
return image
|
235 |
+
|
236 |
+
|
237 |
+
def to_base64(pil_image):
|
238 |
+
buffered = BytesIO()
|
239 |
+
pil_image.save(buffered, format="JPEG")
|
240 |
+
return base64.b64encode(buffered.getvalue())
|
241 |
+
|
242 |
+
title = "Interactive demo: glide-text2im"
|
243 |
+
description = "Demo for OpenAI's GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models."
|
244 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10741'>GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models</a> | <a href='https://github.com/openai/glide-text2im/'>Official Repo</a></p>"
|
245 |
+
examples =["an oil painting of a corgi"]
|
246 |
+
|
247 |
+
iface = gr.Interface(fn=sample,
|
248 |
+
inputs=gr.inputs.Textbox(label='What would you like to see?'),
|
249 |
+
outputs=gr.outputs.Image(type="pil", label="Model input + completions"),
|
250 |
+
title=title,
|
251 |
+
description=description,
|
252 |
+
article=article,
|
253 |
+
examples=examples,
|
254 |
+
enable_queue=True)
|
255 |
+
iface.launch(debug=True)
|
model-card.md
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Overview
|
2 |
+
|
3 |
+
This card describes the diffusion model GLIDE (filtered) and noised CLIP model described in the paper [GLIDE: Towards
|
4 |
+
Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/abs/2112.10741)
|
5 |
+
|
6 |
+
# Datasets
|
7 |
+
|
8 |
+
GLIDE (filtered) was trained on a filtered version of a dataset comprised of several hundred million text-image pairs
|
9 |
+
collected from the internet. We constructed a set of filters intended to remove all images of people, violent objects, and some
|
10 |
+
and hate symbols (see Appendix F of the paper for details). The size of the dataset after filtering was approximately
|
11 |
+
67M text-image pairs.
|
12 |
+
|
13 |
+
Our noised CLIP model which was trained on the dataset described above, augmented with a filtered version of the dataset used
|
14 |
+
to train the [original CLIP models](https://github.com/openai/clip). The total size of this augmented dataset is approximately 137M pairs.
|
15 |
+
|
16 |
+
# Performance
|
17 |
+
|
18 |
+
Qualitatively, we find that the generated images from GLIDE (filtered) often look semi-realistic, but the small size of the model hinders
|
19 |
+
its ability to bind attributes to objects and perform compositional tasks. Because the dataset used to train GLIDE
|
20 |
+
(filtered) has been preprocessed to remove images of people, this also limits its world knowledge, especially in regard
|
21 |
+
to concepts that involve people.
|
22 |
+
Finally, due to the dataset used to train GLIDE (filtered), the model has reduced capabilities to compose multiple objects in complex ways compared to models of a similar size trained on our internal dataset.
|
23 |
+
|
24 |
+
We do not directly measure quantitative metrics for GLIDE (filtered). In particular, most of the evaluations we report for our other models are biased against GLIDE (filtered), since they use prompts that often require generations of people. Evaluating people-free models remains an open area of research.
|
25 |
+
|
26 |
+
# Intended Use
|
27 |
+
|
28 |
+
We release these models to help advance research in generative modeling. Due to the limitations and biases of GLIDE (filtered), we do not currently recommend it for commercial use.
|
29 |
+
|
30 |
+
Functionally, these models are intended to be able to perform the following tasks for research purposes:
|
31 |
+
* Generate images from natural language prompts
|
32 |
+
* Iteratively edit and refine images using inpainting
|
33 |
+
|
34 |
+
These models are explicitly not intended to generate images of people or other subjects we filtered for (see Appendix F of the paper for details).
|
35 |
+
|
36 |
+
# Limitations
|
37 |
+
|
38 |
+
Despite the dataset filtering applied before training, GLIDE (filtered) continues to exhibit biases that extend beyond those found in images of people.
|
39 |
+
We explore some of these biases in our paper. For example:
|
40 |
+
|
41 |
+
* It produces different outputs when asked to generate toys for boys and toys for girls.
|
42 |
+
* It gravitates toward generating images of churches when asked to generate "a religious place",
|
43 |
+
and this bias is amplified by classifier-free guidance.
|
44 |
+
* It may have a greater propensity for generating hate symbols other than swastikas and confederate flags. Our filter
|
45 |
+
for hate symbols focused specifically on these two cases, as we found few relevant images of hate symbols in our
|
46 |
+
dataset. However, we also found that the model has diminished capabilities across a wider set of symbols.
|
47 |
+
|
48 |
+
GLIDE (filtered) can fail to produce realistic outputs for complex prompts or for prompts that involve concepts that are
|
49 |
+
not well-represented in its training data. While the data for the model was filtered to remove certain types of images,
|
50 |
+
the data still exhibits biases toward Western-centric concepts.
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/openai/glide-text2im.git
|
2 |
+
fastapi
|
3 |
+
uvicorn
|
4 |
+
regex
|
5 |
+
git-lfs
|
server.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from io import BytesIO
|
3 |
+
from fastapi import FastAPI
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
import torch as th
|
7 |
+
|
8 |
+
from glide_text2im.download import load_checkpoint
|
9 |
+
from glide_text2im.model_creation import (
|
10 |
+
create_model_and_diffusion,
|
11 |
+
model_and_diffusion_defaults,
|
12 |
+
model_and_diffusion_defaults_upsampler
|
13 |
+
)
|
14 |
+
|
15 |
+
print("Loading models...")
|
16 |
+
app = FastAPI()
|
17 |
+
|
18 |
+
# This notebook supports both CPU and GPU.
|
19 |
+
# On CPU, generating one sample may take on the order of 20 minutes.
|
20 |
+
# On a GPU, it should be under a minute.
|
21 |
+
|
22 |
+
has_cuda = th.cuda.is_available()
|
23 |
+
device = th.device('cpu' if not has_cuda else 'cuda')
|
24 |
+
|
25 |
+
# Create base model.
|
26 |
+
options = model_and_diffusion_defaults()
|
27 |
+
options['use_fp16'] = has_cuda
|
28 |
+
options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
|
29 |
+
model, diffusion = create_model_and_diffusion(**options)
|
30 |
+
model.eval()
|
31 |
+
if has_cuda:
|
32 |
+
model.convert_to_fp16()
|
33 |
+
model.to(device)
|
34 |
+
model.load_state_dict(load_checkpoint('base', device))
|
35 |
+
print('total base parameters', sum(x.numel() for x in model.parameters()))
|
36 |
+
|
37 |
+
# Create upsampler model.
|
38 |
+
options_up = model_and_diffusion_defaults_upsampler()
|
39 |
+
options_up['use_fp16'] = has_cuda
|
40 |
+
options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
|
41 |
+
model_up, diffusion_up = create_model_and_diffusion(**options_up)
|
42 |
+
model_up.eval()
|
43 |
+
if has_cuda:
|
44 |
+
model_up.convert_to_fp16()
|
45 |
+
model_up.to(device)
|
46 |
+
model_up.load_state_dict(load_checkpoint('upsample', device))
|
47 |
+
print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))
|
48 |
+
|
49 |
+
|
50 |
+
def get_images(batch: th.Tensor):
|
51 |
+
""" Display a batch of images inline. """
|
52 |
+
scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
|
53 |
+
reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
|
54 |
+
Image.fromarray(reshaped.numpy())
|
55 |
+
|
56 |
+
|
57 |
+
# Create a classifier-free guidance sampling function
|
58 |
+
guidance_scale = 3.0
|
59 |
+
|
60 |
+
def model_fn(x_t, ts, **kwargs):
|
61 |
+
half = x_t[: len(x_t) // 2]
|
62 |
+
combined = th.cat([half, half], dim=0)
|
63 |
+
model_out = model(combined, ts, **kwargs)
|
64 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
65 |
+
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
|
66 |
+
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
67 |
+
eps = th.cat([half_eps, half_eps], dim=0)
|
68 |
+
return th.cat([eps, rest], dim=1)
|
69 |
+
|
70 |
+
|
71 |
+
@app.get("/")
|
72 |
+
def read_root():
|
73 |
+
return {"glide!"}
|
74 |
+
|
75 |
+
@app.get("/{generate}")
|
76 |
+
def sample(prompt):
|
77 |
+
# Sampling parameters
|
78 |
+
batch_size = 1
|
79 |
+
|
80 |
+
# Tune this parameter to control the sharpness of 256x256 images.
|
81 |
+
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
82 |
+
upsample_temp = 0.997
|
83 |
+
|
84 |
+
##############################
|
85 |
+
# Sample from the base model #
|
86 |
+
##############################
|
87 |
+
|
88 |
+
# Create the text tokens to feed to the model.
|
89 |
+
tokens = model.tokenizer.encode(prompt)
|
90 |
+
tokens, mask = model.tokenizer.padded_tokens_and_mask(
|
91 |
+
tokens, options['text_ctx']
|
92 |
+
)
|
93 |
+
|
94 |
+
# Create the classifier-free guidance tokens (empty)
|
95 |
+
full_batch_size = batch_size * 2
|
96 |
+
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
|
97 |
+
[], options['text_ctx']
|
98 |
+
)
|
99 |
+
|
100 |
+
# Pack the tokens together into model kwargs.
|
101 |
+
model_kwargs = dict(
|
102 |
+
tokens=th.tensor(
|
103 |
+
[tokens] * batch_size + [uncond_tokens] * batch_size, device=device
|
104 |
+
),
|
105 |
+
mask=th.tensor(
|
106 |
+
[mask] * batch_size + [uncond_mask] * batch_size,
|
107 |
+
dtype=th.bool,
|
108 |
+
device=device,
|
109 |
+
),
|
110 |
+
)
|
111 |
+
|
112 |
+
# Sample from the base model.
|
113 |
+
model.del_cache()
|
114 |
+
samples = diffusion.p_sample_loop(
|
115 |
+
model_fn,
|
116 |
+
(full_batch_size, 3, options["image_size"], options["image_size"]),
|
117 |
+
device=device,
|
118 |
+
clip_denoised=True,
|
119 |
+
progress=True,
|
120 |
+
model_kwargs=model_kwargs,
|
121 |
+
cond_fn=None,
|
122 |
+
)[:batch_size]
|
123 |
+
model.del_cache()
|
124 |
+
|
125 |
+
|
126 |
+
##############################
|
127 |
+
# Upsample the 64x64 samples #
|
128 |
+
##############################
|
129 |
+
|
130 |
+
tokens = model_up.tokenizer.encode(prompt)
|
131 |
+
tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
|
132 |
+
tokens, options_up['text_ctx']
|
133 |
+
)
|
134 |
+
|
135 |
+
# Create the model conditioning dict.
|
136 |
+
model_kwargs = dict(
|
137 |
+
# Low-res image to upsample.
|
138 |
+
low_res=((samples+1)*127.5).round()/127.5 - 1,
|
139 |
+
|
140 |
+
# Text tokens
|
141 |
+
tokens=th.tensor(
|
142 |
+
[tokens] * batch_size, device=device
|
143 |
+
),
|
144 |
+
mask=th.tensor(
|
145 |
+
[mask] * batch_size,
|
146 |
+
dtype=th.bool,
|
147 |
+
device=device,
|
148 |
+
),
|
149 |
+
)
|
150 |
+
|
151 |
+
# Sample from the base model.
|
152 |
+
model_up.del_cache()
|
153 |
+
up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
|
154 |
+
up_samples = diffusion_up.ddim_sample_loop(
|
155 |
+
model_up,
|
156 |
+
up_shape,
|
157 |
+
noise=th.randn(up_shape, device=device) * upsample_temp,
|
158 |
+
device=device,
|
159 |
+
clip_denoised=True,
|
160 |
+
progress=True,
|
161 |
+
model_kwargs=model_kwargs,
|
162 |
+
cond_fn=None,
|
163 |
+
)[:batch_size]
|
164 |
+
model_up.del_cache()
|
165 |
+
|
166 |
+
# Show the output
|
167 |
+
image = get_images(up_samples)
|
168 |
+
image = to_base64(image)
|
169 |
+
return {"image": image}
|
170 |
+
|
171 |
+
|
172 |
+
def to_base64(pil_image):
|
173 |
+
buffered = BytesIO()
|
174 |
+
pil_image.save(buffered, format="JPEG")
|
175 |
+
return base64.b64encode(buffered.getvalue())
|
setup.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="glide-text2im",
|
5 |
+
packages=["glide_text2im"],
|
6 |
+
install_requires=[
|
7 |
+
"Pillow",
|
8 |
+
"attrs",
|
9 |
+
"torch",
|
10 |
+
"filelock",
|
11 |
+
"requests",
|
12 |
+
"tqdm",
|
13 |
+
],
|
14 |
+
author="OpenAI",
|
15 |
+
)
|