robiro commited on
Commit
47fe778
·
verified ·
1 Parent(s): fa3a13f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -48
app.py CHANGED
@@ -2,59 +2,75 @@ import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionPipeline
4
  from PIL import Image
 
5
 
6
  # --- Globale Konfiguration und Modellladung ---
7
- MODEL_ID = "runwayml/stable-diffusion-v1-5"
8
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
9
  print(f"Verwende Gerät: {DEVICE}")
10
 
11
  # Lade das Modell nur einmal beim Start der App
12
- # Für GPU: torch_dtype=torch.float16 spart VRAM und ist schneller
13
- # Für CPU: torch_dtype=torch.float32 (float16 wird auf CPU nicht gut unterstützt)
14
- dtype = torch.float16 if DEVICE == "cuda" else torch.float32
15
 
16
- print(f"Lade Modell '{MODEL_ID}'... Dies kann einige Minuten dauern.")
 
 
 
17
  try:
18
  pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=dtype)
19
- pipe = pipe.to(DEVICE)
20
  print("Modell erfolgreich geladen!")
21
  except Exception as e:
22
  print(f"Fehler beim Laden des Modells: {e}")
23
- print("Stelle sicher, dass du eine Internetverbindung hast und der Modellname korrekt ist.")
24
- print("Wenn du wenig VRAM hast, versuche ein kleineres Modell oder Einstellungen zur Speicheroptimierung.")
25
- pipe = None # Signalisiert, dass das Modell nicht geladen werden konnte
 
 
26
 
27
  # --- Bildgenerierungsfunktion ---
28
  def generate_image(
29
  prompt: str,
30
  negative_prompt: str = "",
31
- num_inference_steps: int = 50,
32
  guidance_scale: float = 7.5,
33
- height: int = 512,
34
- width: int = 512,
35
  seed: int = -1 # -1 für zufälligen Seed
36
  ) -> Image.Image:
37
  """
38
  Generiert ein Bild basierend auf dem Prompt und anderen Parametern.
39
  """
40
  if pipe is None:
41
- raise gr.Error("Modell konnte nicht geladen werden. Bitte überprüfe die Konsolenausgabe.")
 
 
 
42
 
43
- print(f"Generiere Bild für Prompt: '{prompt}'")
44
  print(f" Negative Prompt: '{negative_prompt}'")
45
  print(f" Schritte: {num_inference_steps}, Guidance: {guidance_scale}")
46
  print(f" Dimensionen: {width}x{height}, Seed: {seed}")
 
 
 
47
 
48
  # Seed Handling
49
  generator = None
50
- if seed != -1:
51
- generator = torch.Generator(device=DEVICE).manual_seed(seed)
 
 
 
 
 
 
52
 
53
- # Bild generieren
54
- # safety_checker=None kann verwendet werden, um den NSFW-Filter zu deaktivieren,
55
- # sei dir aber der Implikationen bewusst. Standardmäßig ist er aktiv.
56
  try:
57
- with torch.inference_mode(): # Wichtig für geringeren Speicherverbrauch bei Inferenz
 
58
  result = pipe(
59
  prompt,
60
  negative_prompt=negative_prompt if negative_prompt else None,
@@ -65,26 +81,22 @@ def generate_image(
65
  generator=generator
66
  )
67
  image = result.images[0]
68
- print("Bild erfolgreich generiert.")
 
 
69
  return image
70
  except Exception as e:
71
  print(f"Fehler bei der Bildgenerierung: {e}")
72
- # Versuche, eine spezifischere Fehlermeldung für OOM-Fehler (Out Of Memory) zu geben
73
- if "CUDA out of memory" in str(e):
74
- raise gr.Error(
75
- "CUDA out of memory. Versuche, die Bildgröße zu verringern, "
76
- "weniger Inferenzschritte zu verwenden oder ein kleineres Modell zu laden."
77
- )
78
- raise gr.Error(f"Fehler bei der Bildgenerierung: {e}")
79
-
80
 
81
  # --- Gradio Interface Definition ---
82
  with gr.Blocks(theme=gr.themes.Soft()) as app:
83
  gr.Markdown(
84
  """
85
- # 🖼️ Bildgenerator mit Stable Diffusion
86
  Gib einen Text-Prompt ein, um ein Bild zu generieren.
87
- Das Laden des Modells beim ersten Start kann einige Minuten dauern.
 
88
  """
89
  )
90
 
@@ -97,42 +109,44 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
97
  )
98
  negative_prompt_input = gr.Textbox(
99
  label="Negativer Prompt (was vermieden werden soll)",
100
- placeholder="z.B. schlecht gezeichnet, unscharf, Text, Wasserzeichen",
101
  lines=2
102
  )
103
  with gr.Row():
 
104
  steps_slider = gr.Slider(
105
- minimum=10, maximum=150, value=50, step=1, label="Inferenzschritte"
 
106
  )
107
  guidance_slider = gr.Slider(
108
  minimum=1, maximum=20, value=7.5, step=0.1, label="Guidance Scale (CFG)"
109
  )
110
  with gr.Row():
 
111
  height_slider = gr.Slider(
112
- minimum=256, maximum=1024, value=512, step=64, label="Höhe"
113
  )
114
  width_slider = gr.Slider(
115
- minimum=256, maximum=1024, value=512, step=64, label="Breite"
116
  )
117
  seed_input = gr.Number(
118
- label="Seed (-1 für zufällig)", value=-1, precision=0
119
  )
120
- generate_button = gr.Button("Bild generieren", variant="primary")
121
 
122
  with gr.Column(scale=1):
123
  image_output = gr.Image(label="Generiertes Bild", type="pil")
124
  gr.Markdown("### Beispiel-Prompts:")
125
  gr.Examples(
126
  examples=[
127
- ["Ein Astronaut reitet ein Pferd auf dem Mond, digitale Kunst", "", 50, 7.5, 512, 512, -1],
128
- ["Ein impressionistisches Gemälde eines Sonnenuntergangs über einem Lavendelfeld", "Menschen, Gebäude", 40, 8.0, 512, 768, -1],
129
- ["Ein niedlicher Corgi-Hund als Pixel-Art-Charakter", "fotorealistisch", 30, 7.0, 512, 512, 12345],
130
- ["Eine surreale Landschaft mit schwebenden Inseln und Wasserfällen aus Licht", "dunkel, düster", 60, 9.0, 768, 512, -1],
131
  ],
132
  inputs=[prompt_input, negative_prompt_input, steps_slider, guidance_slider, height_slider, width_slider, seed_input],
133
  outputs=image_output,
134
- fn=generate_image, # Die Funktion, die bei Klick auf ein Beispiel ausgeführt wird
135
- cache_examples=False # Oder True, wenn du die Ergebnisse cachen willst
136
  )
137
 
138
  generate_button.click(
@@ -147,13 +161,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
147
  seed_input
148
  ],
149
  outputs=image_output,
150
- api_name="generate_image" # Für API-Zugriff
151
  )
152
 
153
  # --- App starten ---
154
  if __name__ == "__main__":
155
  if pipe is None:
156
- print("Das Modell konnte nicht geladen werden. Die Gradio-App wird nicht gestartet.")
157
- print("Bitte behebe die Fehler und versuche es erneut.")
158
  else:
159
- app.launch(share=False) # Setze share=True, um einen öffentlichen Link zu erhalten (erfordert `gradio-client`)
 
 
2
  import torch
3
  from diffusers import StableDiffusionPipeline
4
  from PIL import Image
5
+ import time # Um die Generierungszeit zu messen
6
 
7
  # --- Globale Konfiguration und Modellladung ---
8
+ # Verwende das Modell aus deinem Textausschnitt
9
+ MODEL_ID = "sd-legacy/stable-diffusion-v1-5"
10
+ DEVICE = "cpu" # Explizit CPU verwenden
11
  print(f"Verwende Gerät: {DEVICE}")
12
 
13
  # Lade das Modell nur einmal beim Start der App
14
+ # Für CPU: torch_dtype=torch.float32
15
+ dtype = torch.float32
 
16
 
17
+ print(f"Lade Modell '{MODEL_ID}' für CPU-Nutzung... Dies kann einige Minuten dauern und benötigt viel RAM.")
18
+ print("Stelle sicher, dass du eine stabile Internetverbindung hast.")
19
+
20
+ pipe = None # Initialisiere pipe als None
21
  try:
22
  pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=dtype)
23
+ pipe = pipe.to(DEVICE) # Auf CPU verschieben
24
  print("Modell erfolgreich geladen!")
25
  except Exception as e:
26
  print(f"Fehler beim Laden des Modells: {e}")
27
+ print("Mögliche Ursachen:")
28
+ print("- Keine Internetverbindung oder Hugging Face Hub nicht erreichbar.")
29
+ print("- Nicht genügend RAM verfügbar. Versuche, andere speicherintensive Anwendungen zu schließen.")
30
+ print("- Falsche Modell-ID (sollte hier aber korrekt sein).")
31
+ # pipe bleibt None, was in generate_image abgefangen wird
32
 
33
  # --- Bildgenerierungsfunktion ---
34
  def generate_image(
35
  prompt: str,
36
  negative_prompt: str = "",
37
+ num_inference_steps: int = 20, # Reduziert für schnellere CPU-Tests, erhöhe für bessere Qualität
38
  guidance_scale: float = 7.5,
39
+ height: int = 512, # Standardauflösung für SD v1.5
40
+ width: int = 512, # Standardauflösung für SD v1.5
41
  seed: int = -1 # -1 für zufälligen Seed
42
  ) -> Image.Image:
43
  """
44
  Generiert ein Bild basierend auf dem Prompt und anderen Parametern.
45
  """
46
  if pipe is None:
47
+ raise gr.Error(
48
+ "Modell konnte nicht geladen werden. Bitte überprüfe die Konsolenausgabe "
49
+ "beim Start der App und starte die App ggf. neu, nachdem die Probleme behoben wurden."
50
+ )
51
 
52
+ print(f"\nStarte Bildgenerierung auf CPU für Prompt: '{prompt}'")
53
  print(f" Negative Prompt: '{negative_prompt}'")
54
  print(f" Schritte: {num_inference_steps}, Guidance: {guidance_scale}")
55
  print(f" Dimensionen: {width}x{height}, Seed: {seed}")
56
+ print(" Dies kann auf der CPU einige Minuten dauern...")
57
+
58
+ start_time = time.time()
59
 
60
  # Seed Handling
61
  generator = None
62
+ if seed != -1 and seed is not None: # Stelle sicher, dass seed nicht None ist
63
+ generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
64
+ else: # Zufälliger Seed
65
+ # Generiere einen zufälligen Seed, um ihn später ggf. anzuzeigen oder zu verwenden
66
+ current_seed = torch.seed()
67
+ generator = torch.Generator(device=DEVICE).manual_seed(current_seed)
68
+ print(f" Verwende zufälligen Seed: {current_seed}")
69
+
70
 
 
 
 
71
  try:
72
+ # torch.inference_mode() ist gut für geringeren Speicherverbrauch und Geschwindigkeit
73
+ with torch.inference_mode():
74
  result = pipe(
75
  prompt,
76
  negative_prompt=negative_prompt if negative_prompt else None,
 
81
  generator=generator
82
  )
83
  image = result.images[0]
84
+ end_time = time.time()
85
+ duration = end_time - start_time
86
+ print(f"Bild erfolgreich generiert in {duration:.2f} Sekunden.")
87
  return image
88
  except Exception as e:
89
  print(f"Fehler bei der Bildgenerierung: {e}")
90
+ raise gr.Error(f"Fehler bei der Bildgenerierung auf CPU: {e}")
 
 
 
 
 
 
 
91
 
92
  # --- Gradio Interface Definition ---
93
  with gr.Blocks(theme=gr.themes.Soft()) as app:
94
  gr.Markdown(
95
  """
96
+ # 🖼️ CPU Bildgenerator mit Stable Diffusion v1.5
97
  Gib einen Text-Prompt ein, um ein Bild zu generieren.
98
+ **Achtung:** Die Generierung auf der **CPU ist langsam** und kann mehrere Minuten pro Bild dauern!
99
+ Das Laden des Modells beim ersten Start benötigt ebenfalls Zeit und RAM.
100
  """
101
  )
102
 
 
109
  )
110
  negative_prompt_input = gr.Textbox(
111
  label="Negativer Prompt (was vermieden werden soll)",
112
+ placeholder="z.B. schlecht gezeichnet, unscharf, Text, Wasserzeichen, mutierte Hände",
113
  lines=2
114
  )
115
  with gr.Row():
116
+ # Reduzierte Standard-Schritte für CPU, da es sonst zu lange dauert
117
  steps_slider = gr.Slider(
118
+ minimum=5, maximum=50, value=20, step=1,
119
+ label="Inferenzschritte (weniger = schneller, aber ggf. schlechtere Qualität)"
120
  )
121
  guidance_slider = gr.Slider(
122
  minimum=1, maximum=20, value=7.5, step=0.1, label="Guidance Scale (CFG)"
123
  )
124
  with gr.Row():
125
+ # Standardauflösung für v1.5 ist 512x512. Kleinere Auflösungen sind schneller auf CPU.
126
  height_slider = gr.Slider(
127
+ minimum=256, maximum=512, value=512, step=64, label="Höhe"
128
  )
129
  width_slider = gr.Slider(
130
+ minimum=256, maximum=512, value=512, step=64, label="Breite"
131
  )
132
  seed_input = gr.Number(
133
+ label="Seed (-1 oder leer für zufällig)", value=-1, precision=0
134
  )
135
+ generate_button = gr.Button("Bild generieren (langsam auf CPU!)", variant="primary")
136
 
137
  with gr.Column(scale=1):
138
  image_output = gr.Image(label="Generiertes Bild", type="pil")
139
  gr.Markdown("### Beispiel-Prompts:")
140
  gr.Examples(
141
  examples=[
142
+ ["Ein Astronaut reitet ein Pferd auf dem Mars, digitale Kunst", "", 20, 7.5, 512, 512, -1],
143
+ ["Ein impressionistisches Gemälde eines Sonnenuntergangs über einem Lavendelfeld", "Menschen, Gebäude", 15, 8.0, 512, 512, -1],
144
+ ["Ein niedlicher Corgi-Hund als Pixel-Art-Charakter", "fotorealistisch", 25, 7.0, 512, 512, 12345],
 
145
  ],
146
  inputs=[prompt_input, negative_prompt_input, steps_slider, guidance_slider, height_slider, width_slider, seed_input],
147
  outputs=image_output,
148
+ fn=generate_image,
149
+ cache_examples=False # CPU-Generierung ist zu langsam zum Cachen während des Tests
150
  )
151
 
152
  generate_button.click(
 
161
  seed_input
162
  ],
163
  outputs=image_output,
164
+ api_name="generate_image_cpu"
165
  )
166
 
167
  # --- App starten ---
168
  if __name__ == "__main__":
169
  if pipe is None:
170
+ print("\nDas Modell konnte nicht geladen werden. Die Gradio-App wird nicht gestartet.")
171
+ print("Bitte behebe die in der Konsole angezeigten Fehler und versuche es erneut.")
172
  else:
173
+ print("\nStarte Gradio App. Öffne die angezeigte URL in deinem Browser.")
174
+ app.launch(share=False) # share=False für lokale Nutzung