File size: 3,554 Bytes
ffe5abc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b852e7
ffe5abc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5b8684
 
e5dead6
ffe5abc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1272d13
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from diffusers import StableDiffusionPipeline
from torch import autocast
import streamlit as st
from PIL import Image, ImageEnhance
import pandas as pd
import numpy as np

class StableDiffusionLoader:
    """
    Stable Diffusion loader and generator class. 
    Utilises the stable diffusion models from the `Hugging Face`(https://huggingface.co/spaces/stabilityai/stable-diffusion) library
    Attributes
    ----------
    prompt : str
        a text prompt to use to generate an associated image
    pretrain_pipe : str
        a pretrained image diffusion pipeline i.e. CompVis/stable-diffusion-v1-4
    """
    def __init__(self, 
                prompt:str, 
                pretrain_pipe:str='lfernandopg/mach-5-model-v1'):
        """
        Constructs all the necessary attributes for the diffusion class.
        Parameters
        ----------
            prompt : str
                the prompt to generate the model
            pretrain_pipe : str
                the name of the pretrained pipeline
        """
        self.prompt = prompt
        self.pretrain_pipe = pretrain_pipe
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        assert isinstance(self.prompt, str), 'Please enter a string into the prompt field'
        assert isinstance(self.pretrain_pipe, str), 'Please use value such as `CompVis/stable-diffusion-v1-4` for pretrained pipeline'


    def generate_image_from_prompt(self, save_location='prompt.jpg', use_token=False,
                                   verbose=False):
        """
        Class method to generate images based on the prompt
        Parameters
        ----------
            save_location : str - defaults to prompt.jpg
                the location where to save the image generated by the Diffusion Model
            use_token : bool
                boolean to see if Hugging Face token should be used
            verbose : bool
                boolean that defaults to False, otherwise message printed
        """

        pipe = StableDiffusionPipeline.from_pretrained(
            self.pretrain_pipe, 
            #torch_dtype=torch.float16, 
            use_auth_token=False
            ).to(self.device)
        pipe = pipe.to(self.device)
        with autocast(self.device):
            image = pipe(self.prompt)[0][0]
        image.save(save_location)
        if verbose: 
            print(f'[INFO] saving image to {save_location}')
        return image    

    def __str__(self) -> str:
        return f'[INFO] Generating image for prompt: {self.prompt}'

    def __len__(self):
        return len(self.prompt)

SAVE_LOCATION = 'prompt.jpg'
 # Create the page title 
st.set_page_config(page_title='Diffusion Model generator')
 # Create page layout

st.title('Image generator using Stable Diffusion')
st.caption('An app to generate images based on text prompts with a :blue[_Stable Diffusion_] model :sunglasses:')

# Create text prompt
prompt = st.text_input('Input the prompt desired')

if len(prompt) > 0:
    st.markdown(f"""
    This will show an image using **stable diffusion** of the desired {prompt} entered:
    """)
    print(prompt)
    # Create a spinner to show the image is being generated
    with st.spinner('Generating image based on prompt'):
        sd = StableDiffusionLoader(prompt)
        sd.generate_image_from_prompt(save_location=SAVE_LOCATION)
        st.success('Generated stable diffusion model')
        
        # Open and display the image on the site
        image = Image.open(SAVE_LOCATION)
        st.image(image)