Spaces:
Runtime error
Runtime error
File size: 3,554 Bytes
ffe5abc 3b852e7 ffe5abc c5b8684 e5dead6 ffe5abc 1272d13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
from diffusers import StableDiffusionPipeline
from torch import autocast
import streamlit as st
from PIL import Image, ImageEnhance
import pandas as pd
import numpy as np
class StableDiffusionLoader:
"""
Stable Diffusion loader and generator class.
Utilises the stable diffusion models from the `Hugging Face`(https://huggingface.co/spaces/stabilityai/stable-diffusion) library
Attributes
----------
prompt : str
a text prompt to use to generate an associated image
pretrain_pipe : str
a pretrained image diffusion pipeline i.e. CompVis/stable-diffusion-v1-4
"""
def __init__(self,
prompt:str,
pretrain_pipe:str='lfernandopg/mach-5-model-v1'):
"""
Constructs all the necessary attributes for the diffusion class.
Parameters
----------
prompt : str
the prompt to generate the model
pretrain_pipe : str
the name of the pretrained pipeline
"""
self.prompt = prompt
self.pretrain_pipe = pretrain_pipe
self.device = "cuda" if torch.cuda.is_available() else "cpu"
assert isinstance(self.prompt, str), 'Please enter a string into the prompt field'
assert isinstance(self.pretrain_pipe, str), 'Please use value such as `CompVis/stable-diffusion-v1-4` for pretrained pipeline'
def generate_image_from_prompt(self, save_location='prompt.jpg', use_token=False,
verbose=False):
"""
Class method to generate images based on the prompt
Parameters
----------
save_location : str - defaults to prompt.jpg
the location where to save the image generated by the Diffusion Model
use_token : bool
boolean to see if Hugging Face token should be used
verbose : bool
boolean that defaults to False, otherwise message printed
"""
pipe = StableDiffusionPipeline.from_pretrained(
self.pretrain_pipe,
#torch_dtype=torch.float16,
use_auth_token=False
).to(self.device)
pipe = pipe.to(self.device)
with autocast(self.device):
image = pipe(self.prompt)[0][0]
image.save(save_location)
if verbose:
print(f'[INFO] saving image to {save_location}')
return image
def __str__(self) -> str:
return f'[INFO] Generating image for prompt: {self.prompt}'
def __len__(self):
return len(self.prompt)
SAVE_LOCATION = 'prompt.jpg'
# Create the page title
st.set_page_config(page_title='Diffusion Model generator')
# Create page layout
st.title('Image generator using Stable Diffusion')
st.caption('An app to generate images based on text prompts with a :blue[_Stable Diffusion_] model :sunglasses:')
# Create text prompt
prompt = st.text_input('Input the prompt desired')
if len(prompt) > 0:
st.markdown(f"""
This will show an image using **stable diffusion** of the desired {prompt} entered:
""")
print(prompt)
# Create a spinner to show the image is being generated
with st.spinner('Generating image based on prompt'):
sd = StableDiffusionLoader(prompt)
sd.generate_image_from_prompt(save_location=SAVE_LOCATION)
st.success('Generated stable diffusion model')
# Open and display the image on the site
image = Image.open(SAVE_LOCATION)
st.image(image) |