File size: 1,397 Bytes
21fcd0a
 
 
 
 
 
9577e92
 
21fcd0a
 
9577e92
21fcd0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from flask import Flask, send_file, request

import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from io import BytesIO
from flask_cors import CORS


app = Flask(__name__)
CORS(app)


model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"


pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    revision="fp16",
    use_auth_token=True,
)

pipe = pipe.to(device)


def serve_pil_image(pil_img):
    img_io = BytesIO()
    pil_img.save(img_io, "JPEG", quality=70)
    img_io.seek(0)
    return send_file(img_io, mimetype="image/jpeg")


@app.route("/")
def hello_world():
    return "<p>Hello, World!</p>"


@app.route("/generate-image")
def generate_image():
    prompt = request.args.get(
        "prompt",
        default="a photo of an astronaut riding a horse on mars",
        type=str,
    )
    steps = request.args.get("steps", default=15, type=int)
    seed = request.args.get("seed", default=1024, type=int)

    generator = torch.Generator("cuda").manual_seed(seed)
    with autocast(device):
        image = pipe(
            prompt,
            guidance_scale=7.5,
            num_inference_steps=steps,
            generator=generator,
        ).images[0]

    return serve_pil_image(image)


if __name__ == "__main__":
    app.run(
        host="0.0.0.0",
        port=5000,
        debug=True,
    )