Spaces:
Runtime error
Runtime error
Commit
·
ffe5abc
1
Parent(s):
9681608
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import StableDiffusionPipeline
|
3 |
+
from torch import autocast
|
4 |
+
import streamlit as st
|
5 |
+
from PIL import Image, ImageEnhance
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class StableDiffusionLoader:
|
10 |
+
"""
|
11 |
+
Stable Diffusion loader and generator class.
|
12 |
+
Utilises the stable diffusion models from the `Hugging Face`(https://huggingface.co/spaces/stabilityai/stable-diffusion) library
|
13 |
+
Attributes
|
14 |
+
----------
|
15 |
+
prompt : str
|
16 |
+
a text prompt to use to generate an associated image
|
17 |
+
pretrain_pipe : str
|
18 |
+
a pretrained image diffusion pipeline i.e. CompVis/stable-diffusion-v1-4
|
19 |
+
"""
|
20 |
+
def __init__(self,
|
21 |
+
prompt:str,
|
22 |
+
pretrain_pipe:str='lfernandopg/mach-5-model-v1'):
|
23 |
+
"""
|
24 |
+
Constructs all the necessary attributes for the diffusion class.
|
25 |
+
Parameters
|
26 |
+
----------
|
27 |
+
prompt : str
|
28 |
+
the prompt to generate the model
|
29 |
+
pretrain_pipe : str
|
30 |
+
the name of the pretrained pipeline
|
31 |
+
"""
|
32 |
+
self.prompt = prompt
|
33 |
+
self.pretrain_pipe = pretrain_pipe
|
34 |
+
|
35 |
+
assert isinstance(self.prompt, str), 'Please enter a string into the prompt field'
|
36 |
+
assert isinstance(self.pretrain_pipe, str), 'Please use value such as `CompVis/stable-diffusion-v1-4` for pretrained pipeline'
|
37 |
+
|
38 |
+
|
39 |
+
def generate_image_from_prompt(self, save_location='prompt.jpg', use_token=False,
|
40 |
+
verbose=False):
|
41 |
+
"""
|
42 |
+
Class method to generate images based on the prompt
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
save_location : str - defaults to prompt.jpg
|
46 |
+
the location where to save the image generated by the Diffusion Model
|
47 |
+
use_token : bool
|
48 |
+
boolean to see if Hugging Face token should be used
|
49 |
+
verbose : bool
|
50 |
+
boolean that defaults to False, otherwise message printed
|
51 |
+
"""
|
52 |
+
|
53 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
54 |
+
self.pretrain_pipe,
|
55 |
+
torch_dtype=torch.float16,
|
56 |
+
)
|
57 |
+
pipe = pipe.to(self.device)
|
58 |
+
with autocast(self.device):
|
59 |
+
image = pipe(self.prompt)[0][0]
|
60 |
+
image.save(save_location)
|
61 |
+
if verbose:
|
62 |
+
print(f'[INFO] saving image to {save_location}')
|
63 |
+
return image
|
64 |
+
|
65 |
+
def __str__(self) -> str:
|
66 |
+
return f'[INFO] Generating image for prompt: {self.prompt}'
|
67 |
+
|
68 |
+
def __len__(self):
|
69 |
+
return len(self.prompt)
|
70 |
+
|
71 |
+
SAVE_LOCATION = 'prompt.jpg'
|
72 |
+
# Create the page title
|
73 |
+
st.set_page_config(page_title='Diffusion Model generator')
|
74 |
+
# Create page layout
|
75 |
+
|
76 |
+
st.title('Image generator using Stable Diffusion')
|
77 |
+
st.caption('An app to generate images based on text prompts with a :blue[_Stable Diffusion_] model :sunglasses:')
|
78 |
+
|
79 |
+
# Create text prompt
|
80 |
+
prompt = st.text_input('Input the prompt desired')
|
81 |
+
|
82 |
+
if len(prompt) > 0:
|
83 |
+
st.markdown(f"""
|
84 |
+
This will show an image using **stable diffusion** of the desired {prompt} entered:
|
85 |
+
""")
|
86 |
+
print(prompt)
|
87 |
+
# Create a spinner to show the image is being generated
|
88 |
+
with st.spinner('Generating image based on prompt'):
|
89 |
+
sd = StableDiffusionLoader(prompt)
|
90 |
+
sd.generate_image_from_prompt(save_location=SAVE_LOCATION)
|
91 |
+
st.success('Generated stable diffusion model')
|
92 |
+
|
93 |
+
# Open and display the image on the site
|
94 |
+
image = Image.open(SAVE_LOCATION)
|
95 |
+
st.image(image)
|