RanM commited on
Commit
422af54
·
verified ·
1 Parent(s): caa13d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -50
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
- import asyncio
3
- from concurrent.futures import ProcessPoolExecutor
4
  from io import BytesIO
5
  from PIL import Image
6
- from diffusers import AutoPipelineForText2Image
7
  import gradio as gr
8
  from generate_prompts import generate_prompt
 
9
 
10
  # Load the model once at the start
11
  print("Loading the Stable Diffusion model...")
@@ -16,79 +15,65 @@ except Exception as e:
16
  print(f"Error loading model: {e}")
17
  model = None
18
 
19
- def generate_image(prompt, prompt_name):
20
  try:
21
  if model is None:
22
  raise ValueError("Model not loaded properly.")
23
 
24
- print(f"Generating image for {prompt_name} with prompt: {prompt}")
25
  output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
26
- print(f"Model output for {prompt_name}: {output}")
27
 
28
  if output is None:
29
- raise ValueError(f"Model returned None for {prompt_name}")
30
 
31
  if hasattr(output, 'images') and output.images:
32
- print(f"Image generated for {prompt_name}")
33
  image = output.images[0]
34
  buffered = BytesIO()
35
  image.save(buffered, format="JPEG")
36
  image_bytes = buffered.getvalue()
37
- return image_bytes
 
38
  else:
39
- print(f"No images found in model output for {prompt_name}")
40
- raise ValueError(f"No images found in model output for {prompt_name}")
41
  except Exception as e:
42
- print(f"An error occurred while generating image for {prompt_name}: {e}")
43
- return None
44
 
45
- async def queue_api_calls(sentence_mapping, character_dict, selected_style):
46
- print("Starting to queue API calls...")
47
- prompts = []
48
- for paragraph_number, sentences in sentence_mapping.items():
49
- combined_sentence = " ".join(sentences)
50
- prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
51
- prompts.append((paragraph_number, prompt))
52
- print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
53
-
54
- loop = asyncio.get_running_loop()
55
- with ProcessPoolExecutor() as pool:
56
- tasks = [
57
- loop.run_in_executor(pool, generate_image, prompt, f"Prompt {paragraph_number}")
58
- for paragraph_number, prompt in prompts
59
- ]
60
- responses = await asyncio.gather(*tasks)
61
-
62
- images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
63
- print("Finished queuing API calls. Generated images: ", images)
64
- return images
65
-
66
- def process_prompt(sentence_mapping, character_dict, selected_style):
67
- print("Processing prompt...")
68
- print(f"Sentence Mapping: {sentence_mapping}")
69
- print(f"Character Dict: {character_dict}")
70
- print(f"Selected Style: {selected_style}")
71
  try:
72
- loop = asyncio.get_running_loop()
73
- print("Using existing event loop.")
74
- except RuntimeError:
75
- loop = asyncio.new_event_loop()
76
- asyncio.set_event_loop(loop)
77
- print("Created new event loop.")
78
 
79
- cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
80
- print("Prompt processing complete. Generated images: ", cmpt_return)
81
- return cmpt_return
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  gradio_interface = gr.Interface(
84
- fn=process_prompt,
85
  inputs=[
86
  gr.JSON(label="Sentence Mapping"),
87
  gr.JSON(label="Character Dict"),
88
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
89
  ],
90
  outputs="json"
91
- ).queue(default_concurrency_limit=20) # Set concurrency limit if needed
92
 
93
  if __name__ == "__main__":
94
  print("Launching Gradio interface...")
 
1
  import os
 
 
2
  from io import BytesIO
3
  from PIL import Image
4
+ from transformers import AutoPipelineForText2Image
5
  import gradio as gr
6
  from generate_prompts import generate_prompt
7
+ import base64
8
 
9
  # Load the model once at the start
10
  print("Loading the Stable Diffusion model...")
 
15
  print(f"Error loading model: {e}")
16
  model = None
17
 
18
+ def generate_image(prompt):
19
  try:
20
  if model is None:
21
  raise ValueError("Model not loaded properly.")
22
 
23
+ print(f"Generating image with prompt: {prompt}")
24
  output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
25
+ print(f"Model output: {output}")
26
 
27
  if output is None:
28
+ raise ValueError("Model returned None")
29
 
30
  if hasattr(output, 'images') and output.images:
31
+ print(f"Image generated")
32
  image = output.images[0]
33
  buffered = BytesIO()
34
  image.save(buffered, format="JPEG")
35
  image_bytes = buffered.getvalue()
36
+ img_str = base64.b64encode(image_bytes).decode("utf-8")
37
+ return img_str, None
38
  else:
39
+ print(f"No images found in model output")
40
+ raise ValueError("No images found in model output")
41
  except Exception as e:
42
+ print(f"An error occurred while generating image: {e}")
43
+ return None, str(e)
44
 
45
+ def inference(sentence_mapping, character_dict, selected_style):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  try:
47
+ # Debugging statements
48
+ print(f"Received sentence_mapping: {sentence_mapping}")
49
+ print(f"Received character_dict: {character_dict}")
50
+ print(f"Received selected_style: {selected_style}")
 
 
51
 
52
+ if sentence_mapping is None or character_dict is None or selected_style is None:
53
+ return {"error": "One or more inputs are None"}
54
+
55
+ images = {}
56
+ for paragraph_number, sentences in sentence_mapping.items():
57
+ combined_sentence = " ".join(sentences)
58
+ prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
59
+ img_str, error = generate_image(prompt)
60
+ if error:
61
+ images[paragraph_number] = f"Error: {error}"
62
+ else:
63
+ images[paragraph_number] = img_str
64
+ return images
65
+ except Exception as e:
66
+ return {"error": str(e)}
67
 
68
  gradio_interface = gr.Interface(
69
+ fn=inference,
70
  inputs=[
71
  gr.JSON(label="Sentence Mapping"),
72
  gr.JSON(label="Character Dict"),
73
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
74
  ],
75
  outputs="json"
76
+ )
77
 
78
  if __name__ == "__main__":
79
  print("Launching Gradio interface...")