File size: 3,358 Bytes
f6b8b7e
13298a2
b0bcf89
 
c1282a1
216a041
b0bcf89
 
b85438c
b0bcf89
13298a2
 
 
b0bcf89
 
 
 
 
 
 
 
 
 
 
66e43e3
 
 
13298a2
b0bcf89
 
66e43e3
b0bcf89
66e43e3
13298a2
b0bcf89
66e43e3
 
6d1d03a
b0bcf89
216a041
f466dd9
b0bcf89
 
 
216a041
 
 
13298a2
b0bcf89
 
d26a101
b0bcf89
 
d26a101
f6b8b7e
c301a62
6449f8f
f466dd9
b5ad13a
b0bcf89
 
d26a101
b0bcf89
216a041
6035350
b0bcf89
e0ec116
b0bcf89
 
e0ec116
 
6449f8f
6035350
b0bcf89
 
 
216a041
 
b0bcf89
 
216a041
b0bcf89
6035350
f466dd9
b0bcf89
f466dd9
 
d05fa5e
1adc78a
 
e0ec116
d05fa5e
6035350
58f74fc
f466dd9
b0bcf89
f466dd9
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from diffusers import AutoPipelineForText2Image
from diffusers.schedulers import DPMSolverMultistepScheduler
from generate_propmts import generate_prompt  # Assuming you have this module
from PIL import Image
import asyncio
import threading
import traceback

# Define the SchedulerWrapper class
class SchedulerWrapper:
    def __init__(self, scheduler):
        self.scheduler = scheduler
        self._step = threading.local()
        self._step.step = 0

    def step(self, *args, **kwargs):
        try:
            self._step.step += 1
            return self.scheduler.step(*args, **kwargs)
        except IndexError:
            self._step.step = 0
            return self.scheduler.step(*args, **kwargs)

    @property
    def timesteps(self):
        return self.scheduler.timesteps

    def set_timesteps(self, *args, **kwargs):
        return self.scheduler.set_timesteps(*args, **kwargs)

# Load the model and wrap the scheduler
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

scheduler = DPMSolverMultistepScheduler.from_config(model.scheduler.config)
wrapped_scheduler = SchedulerWrapper(scheduler)
model.scheduler = wrapped_scheduler

# Define the image generation function
async def generate_image(prompt):
    try:
        num_inference_steps = 5  # Adjust this value as needed

        # Use the model to generate an image
        output = await asyncio.to_thread(
            model,
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=0.0,  # Typical value for guidance scale in image generation
            output_type="pil"  # Directly get PIL Image objects
        )

        # Check for output validity and return
        if output.images:
            return output.images[0]
        else:
            raise Exception("No images returned by the model.")
    except Exception as e:
        print(f"Error generating image: {e}")
        traceback.print_exc()
        return None  # Return None on error to handle it gracefully in the UI

# Define the inference function
async def inference(sentence_mapping, character_dict, selected_style):
    images = []
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append(prompt)
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    # Use asyncio.gather to run generate_image in parallel
    tasks = [generate_image(prompt) for prompt in prompts]
    images = await asyncio.gather(*tasks)

    # Filter out None values
    images = [image for image in images if image is not None]

    return images

# Define the Gradio interface
gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs=gr.Gallery(label="Generated Images")
)

# Run the Gradio app
if __name__ == "__main__":
    gradio_interface.launch()