Spaces:
Runtime error
Runtime error
File size: 5,008 Bytes
3c6590b 690f094 9da79fd 5e2c7ed 3c6590b b85438c 3c6590b 13298a2 3c6590b 3b7350e 3c6590b 3b7350e 3c6590b 3b7350e d26a101 5e2c7ed c7f120b 690f094 bdf16c0 690f094 c7f120b 3c6590b 690f094 bdf16c0 3c6590b c7f120b bdf16c0 c7f120b bdf16c0 c7f120b 690f094 081cd9c 5e2c7ed c7f120b 5e2c7ed bdf16c0 5e2c7ed c7f120b 5e2c7ed 3c6590b bdf16c0 5e2c7ed c7f120b 5e2c7ed 081cd9c 690f094 bdf16c0 690f094 bdf16c0 f466dd9 c7f120b bdf16c0 c7f120b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import asyncio
from generate_prompts import generate_prompt
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import gradio as gr
import threading
from diffusers.schedulers import PNDMScheduler
# Custom Scheduler with ThreadLocal step index
class ThreadLocalStepScheduler:
def __init__(self, base_scheduler):
self.base_scheduler = base_scheduler
self._step = threading.local()
def _init_step_index(self):
self._step.step = 0
@property
def step(self):
if not hasattr(self._step, 'step'):
self._init_step_index()
return self._step.step
@step.setter
def step(self, value):
self._step.step = value
def step_process(self, *args, **kwargs):
if not hasattr(self._step, 'step'):
self._init_step_index()
self._step.step += 1
return self.base_scheduler.step(*args, **kwargs)
# Load the model once outside of the function
print("Loading the model...")
base_scheduler = PNDMScheduler.from_config("stabilityai/sdxl-turbo")
custom_scheduler = ThreadLocalStepScheduler(base_scheduler)
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", scheduler=custom_scheduler)
print("Model loaded successfully.")
def generate_image(prompt, prompt_name):
try:
print(f"Generating response for {prompt_name} with prompt: {prompt}")
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"Output for {prompt_name}: {output}")
# Check if the model returned images
if isinstance(output.images, list) and len(output.images) > 0:
image = output.images[0]
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
return image_bytes
except Exception as e:
print(f"Error saving image for {prompt_name}: {e}")
return None
else:
raise Exception(f"No images returned by the model for {prompt_name}.")
except Exception as e:
print(f"Error generating image for {prompt_name}: {e}")
return None
async def queue_api_calls(sentence_mapping, character_dict, selected_style):
print(f"queue_api_calls invoked with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
prompts = []
# Generate prompts for each paragraph
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
prompt = generate_prompt(combined_sentence, character_dict, selected_style) # Correct prompt generation
prompts.append((paragraph_number, prompt))
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
# Generate images for each prompt in parallel
loop = asyncio.get_running_loop()
tasks = [loop.run_in_executor(None, generate_image, prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
print("Tasks created for image generation.")
responses = await asyncio.gather(*tasks)
print("Responses received from image generation tasks.")
images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
print(f"Images generated: {images}")
return images
def process_prompt(sentence_mapping, character_dict, selected_style):
print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
try:
# See if there is a loop already running. If there is, reuse it.
loop = asyncio.get_running_loop()
except RuntimeError:
# Create new event loop if one is not running
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
print("Event loop created.")
# Initialize thread-local variables
custom_scheduler._init_step_index()
# This sends the prompts to function that sets up the async calls. Once all the calls to the API complete, it returns a list of the gr.Textbox with value= set.
cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
print(f"process_prompt completed with return value: {cmpt_return}")
return cmpt_return
# Gradio interface with high concurrency limit
gradio_interface = gr.Interface(
fn=process_prompt,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json"
)
if __name__ == "__main__":
print("Launching Gradio interface...")
gradio_interface.launch()
print("Gradio interface launched.")
|