Spaces:
Runtime error
Runtime error
import torch | |
from diffusers import StableDiffusionPipeline | |
from torch import autocast | |
import streamlit as st | |
from PIL import Image, ImageEnhance | |
import pandas as pd | |
import numpy as np | |
class StableDiffusionLoader: | |
""" | |
Stable Diffusion loader and generator class. | |
Utilises the stable diffusion models from the `Hugging Face`(https://huggingface.co/spaces/stabilityai/stable-diffusion) library | |
Attributes | |
---------- | |
prompt : str | |
a text prompt to use to generate an associated image | |
pretrain_pipe : str | |
a pretrained image diffusion pipeline i.e. CompVis/stable-diffusion-v1-4 | |
""" | |
def __init__(self, | |
prompt:str, | |
pretrain_pipe:str='lfernandopg/mach-5-model-v1'): | |
""" | |
Constructs all the necessary attributes for the diffusion class. | |
Parameters | |
---------- | |
prompt : str | |
the prompt to generate the model | |
pretrain_pipe : str | |
the name of the pretrained pipeline | |
""" | |
self.prompt = prompt | |
self.pretrain_pipe = pretrain_pipe | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
assert isinstance(self.prompt, str), 'Please enter a string into the prompt field' | |
assert isinstance(self.pretrain_pipe, str), 'Please use value such as `CompVis/stable-diffusion-v1-4` for pretrained pipeline' | |
def generate_image_from_prompt(self, save_location='prompt.jpg', use_token=False, | |
verbose=False): | |
""" | |
Class method to generate images based on the prompt | |
Parameters | |
---------- | |
save_location : str - defaults to prompt.jpg | |
the location where to save the image generated by the Diffusion Model | |
use_token : bool | |
boolean to see if Hugging Face token should be used | |
verbose : bool | |
boolean that defaults to False, otherwise message printed | |
""" | |
pipe = StableDiffusionPipeline.from_pretrained( | |
self.pretrain_pipe, | |
#torch_dtype=torch.float16, | |
use_auth_token=False | |
).to(self.device) | |
pipe = pipe.to(self.device) | |
with autocast(self.device): | |
image = pipe(self.prompt)[0][0] | |
image.save(save_location) | |
if verbose: | |
print(f'[INFO] saving image to {save_location}') | |
return image | |
def __str__(self) -> str: | |
return f'[INFO] Generating image for prompt: {self.prompt}' | |
def __len__(self): | |
return len(self.prompt) | |
SAVE_LOCATION = 'prompt.jpg' | |
# Create the page title | |
st.set_page_config(page_title='Diffusion Model generator') | |
# Create page layout | |
st.title('Image generator using Stable Diffusion') | |
st.caption('An app to generate images based on text prompts with a :blue[_Stable Diffusion_] model :sunglasses:') | |
# Create text prompt | |
prompt = st.text_input('Input the prompt desired') | |
if len(prompt) > 0: | |
st.markdown(f""" | |
This will show an image using **stable diffusion** of the desired {prompt} entered: | |
""") | |
print(prompt) | |
# Create a spinner to show the image is being generated | |
with st.spinner('Generating image based on prompt'): | |
sd = StableDiffusionLoader(prompt) | |
sd.generate_image_from_prompt(save_location=SAVE_LOCATION) | |
st.success('Generated stable diffusion model') | |
# Open and display the image on the site | |
image = Image.open(SAVE_LOCATION) | |
st.image(image) |