ajsbsd commited on
Commit
5cbe56c
·
verified ·
1 Parent(s): a3f672f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -319
app.py CHANGED
@@ -5,9 +5,9 @@ from PIL import Image
5
  import os
6
  import gc
7
  import time
 
8
  from typing import Optional, Tuple
9
  from huggingface_hub import hf_hub_download
10
- import requests
11
 
12
  # Global pipeline variables
13
  txt2img_pipe = None
@@ -17,7 +17,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
17
  # Hugging Face model configuration
18
  MODEL_REPO = "ajsbsd/CyberRealistic-Pony"
19
  MODEL_FILENAME = "cyberrealisticPony_v110.safetensors"
20
- LOCAL_MODEL_PATH = "./models/cyberrealisticPony_v110.safetensors"
21
 
22
  def clear_memory():
23
  """Clear GPU memory"""
@@ -25,130 +24,67 @@ def clear_memory():
25
  torch.cuda.empty_cache()
26
  gc.collect()
27
 
28
- def download_model():
29
- """Download model from Hugging Face if not already cached"""
30
- try:
31
- # Create models directory if it doesn't exist
32
- os.makedirs("./models", exist_ok=True)
33
-
34
- # Check if model already exists locally
35
- if os.path.exists(LOCAL_MODEL_PATH):
36
- print(f"Model already exists at {LOCAL_MODEL_PATH}")
37
- return LOCAL_MODEL_PATH
38
-
39
- print(f"Downloading model from {MODEL_REPO}/{MODEL_FILENAME}...")
40
- print("This may take a while on first run...")
41
-
42
- # Download the model file
43
- model_path = hf_hub_download(
44
- repo_id=MODEL_REPO,
45
- filename=MODEL_FILENAME,
46
- local_dir="./models",
47
- local_dir_use_symlinks=False,
48
- resume_download=True
49
- )
50
-
51
- print(f"Model downloaded successfully to {model_path}")
52
- return model_path
53
-
54
- except Exception as e:
55
- print(f"Error downloading model: {e}")
56
- print("Attempting to use cached version or fallback...")
57
-
58
- # Try to use Hugging Face cache directly
59
- try:
60
- cached_path = hf_hub_download(
61
- repo_id=MODEL_REPO,
62
- filename=MODEL_FILENAME,
63
- resume_download=True
64
- )
65
- print(f"Using cached model at {cached_path}")
66
- return cached_path
67
- except Exception as cache_error:
68
- print(f"Cache fallback failed: {cache_error}")
69
- return None
70
-
71
  def load_models():
72
- """Load both text2img and img2img pipelines with Hugging Face integration"""
73
  global txt2img_pipe, img2img_pipe
74
 
75
- # Download model if needed
76
- model_path = download_model()
77
-
78
- if model_path is None:
79
- print("Failed to download or locate model file")
80
- return None, None
81
-
82
- if not os.path.exists(model_path):
83
- print(f"Model file not found after download: {model_path}")
84
- return None, None
85
-
86
- if txt2img_pipe is None:
87
- try:
88
- print("Loading CyberRealistic Pony Text2Img model...")
89
  txt2img_pipe = StableDiffusionXLPipeline.from_single_file(
90
- model_path,
91
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
92
  use_safetensors=True,
93
  variant="fp16" if device == "cuda" else None
94
  )
95
 
96
- # Memory optimizations
97
  txt2img_pipe.enable_attention_slicing()
 
98
 
99
  if device == "cuda":
100
- try:
101
- txt2img_pipe.enable_model_cpu_offload()
102
- print("Text2Img CPU offload enabled")
103
- except Exception as e:
104
- print(f"Text2Img CPU offload failed: {e}")
105
- txt2img_pipe = txt2img_pipe.to(device)
106
  else:
107
  txt2img_pipe = txt2img_pipe.to(device)
108
-
109
- print("Text2Img model loaded successfully!")
110
-
111
- except Exception as e:
112
- print(f"Error loading Text2Img model: {e}")
113
- return None, None
114
-
115
- if img2img_pipe is None:
116
- try:
117
- print("Loading CyberRealistic Pony Img2Img model...")
118
- img2img_pipe = StableDiffusionXLImg2ImgPipeline.from_single_file(
119
- model_path,
120
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
121
- use_safetensors=True,
122
- variant="fp16" if device == "cuda" else None
123
  )
124
 
125
- # Memory optimizations
126
  img2img_pipe.enable_attention_slicing()
 
127
 
128
  if device == "cuda":
129
- try:
130
- img2img_pipe.enable_model_cpu_offload()
131
- print("Img2Img CPU offload enabled")
132
- except Exception as e:
133
- print(f"Img2Img CPU offload failed: {e}")
134
- img2img_pipe = img2img_pipe.to(device)
135
- else:
136
- img2img_pipe = img2img_pipe.to(device)
137
-
138
- print("Img2Img model loaded successfully!")
139
-
140
- except Exception as e:
141
- print(f"Error loading Img2Img model: {e}")
142
- return txt2img_pipe, None
143
-
144
- return txt2img_pipe, img2img_pipe
145
 
146
  def enhance_prompt(prompt: str, add_quality_tags: bool = True) -> str:
147
  """Enhance prompt with Pony-style tags"""
148
  if not prompt.strip():
149
  return prompt
150
 
151
- # Don't add tags if already present
152
  if prompt.startswith("score_") or not add_quality_tags:
153
  return prompt
154
 
@@ -157,37 +93,35 @@ def enhance_prompt(prompt: str, add_quality_tags: bool = True) -> str:
157
 
158
  def validate_dimensions(width: int, height: int) -> Tuple[int, int]:
159
  """Ensure dimensions are valid for SDXL"""
160
- # SDXL works best with dimensions divisible by 64
161
  width = ((width + 63) // 64) * 64
162
  height = ((height + 63) // 64) * 64
163
 
164
- # Ensure reasonable limits
165
- width = max(512, min(1536, width))
166
- height = max(512, min(1536, height))
167
 
168
  return width, height
169
 
 
170
  def generate_txt2img(prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, add_quality_tags):
171
- """Generate image from text prompt with enhanced error handling"""
172
  global txt2img_pipe
173
 
174
  if not prompt.strip():
175
  return None, "Please enter a prompt"
176
-
177
- # Load models if not already loaded
178
  if txt2img_pipe is None:
179
- txt2img_pipe, _ = load_models()
180
- if txt2img_pipe is None:
181
- return None, "Failed to load Text2Img model. Please check your internet connection and try again."
182
 
183
  try:
184
- # Clear memory before generation
185
  clear_memory()
186
 
187
- # Validate and fix dimensions
188
  width, height = validate_dimensions(width, height)
189
 
190
- # Set seed for reproducibility
191
  generator = None
192
  if seed != -1:
193
  generator = torch.Generator(device=device).manual_seed(int(seed))
@@ -195,15 +129,15 @@ def generate_txt2img(prompt, negative_prompt, num_steps, guidance_scale, width,
195
  # Enhance prompt
196
  enhanced_prompt = enhance_prompt(prompt, add_quality_tags)
197
 
198
- print(f"Generating with prompt: {enhanced_prompt[:100]}...")
199
  start_time = time.time()
200
 
201
- # Generate image
202
  with torch.no_grad():
203
  result = txt2img_pipe(
204
  prompt=enhanced_prompt,
205
  negative_prompt=negative_prompt or "",
206
- num_inference_steps=int(num_steps),
207
  guidance_scale=float(guidance_scale),
208
  width=width,
209
  height=height,
@@ -211,38 +145,35 @@ def generate_txt2img(prompt, negative_prompt, num_steps, guidance_scale, width,
211
  )
212
 
213
  generation_time = time.time() - start_time
214
- status = f"Text2Img: Generated successfully in {generation_time:.1f}s (Size: {width}x{height})"
215
 
216
  return result.images[0], status
217
 
218
  except Exception as e:
219
- error_msg = f"Text2Img generation failed: {str(e)}"
220
- print(error_msg)
221
- return None, error_msg
222
  finally:
223
  clear_memory()
224
 
 
225
  def generate_img2img(input_image, prompt, negative_prompt, num_steps, guidance_scale, strength, seed, add_quality_tags):
226
- """Generate image from input image + text prompt with enhanced error handling"""
227
  global img2img_pipe
228
 
229
  if input_image is None:
230
- return None, "Please upload an input image for Img2Img"
231
 
232
  if not prompt.strip():
233
  return None, "Please enter a prompt"
234
-
235
- # Load models if not already loaded
236
  if img2img_pipe is None:
237
- _, img2img_pipe = load_models()
238
- if img2img_pipe is None:
239
- return None, "Failed to load Img2Img model. Please check your internet connection and try again."
240
 
241
  try:
242
- # Clear memory before generation
243
  clear_memory()
244
 
245
- # Set seed for reproducibility
246
  generator = None
247
  if seed != -1:
248
  generator = torch.Generator(device=device).manual_seed(int(seed))
@@ -252,230 +183,136 @@ def generate_img2img(input_image, prompt, negative_prompt, num_steps, guidance_s
252
 
253
  # Process input image
254
  if isinstance(input_image, Image.Image):
255
- # Ensure RGB format
256
  if input_image.mode != 'RGB':
257
  input_image = input_image.convert('RGB')
258
-
259
- # Resize to reasonable dimensions while maintaining aspect ratio
260
- original_size = input_image.size
261
- max_size = 1024
262
  input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
263
 
264
- # Ensure dimensions are divisible by 64
265
  w, h = input_image.size
266
  w, h = validate_dimensions(w, h)
267
  input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
268
 
269
- print(f"Generating with prompt: {enhanced_prompt[:100]}...")
270
  start_time = time.time()
271
 
272
- # Generate image
273
  with torch.no_grad():
274
  result = img2img_pipe(
275
  prompt=enhanced_prompt,
276
  negative_prompt=negative_prompt or "",
277
  image=input_image,
278
- num_inference_steps=int(num_steps),
279
  guidance_scale=float(guidance_scale),
280
  strength=float(strength),
281
  generator=generator
282
  )
283
 
284
  generation_time = time.time() - start_time
285
- status = f"Img2Img: Generated successfully in {generation_time:.1f}s (Strength: {strength})"
286
 
287
  return result.images[0], status
288
 
289
  except Exception as e:
290
- error_msg = f"Img2Img generation failed: {str(e)}"
291
- print(error_msg)
292
- return None, error_msg
293
  finally:
294
  clear_memory()
295
 
296
- # Default negative prompt (improved)
297
  DEFAULT_NEGATIVE = """
298
- (low quality:1.4), (worst quality:1.4), (bad quality:1.3), (normal quality:1.2), lowres, jpeg artifacts, blurry, noisy, ugly, deformed, disfigured, malformed, poorly drawn, bad art, amateur, render, 3D, cgi,
299
- (text, signature, watermark, username, copyright:1.5),
300
- (extra limbs:1.5), (missing limbs:1.5), (extra fingers:1.5), (missing fingers:1.5), (mutated hands:1.5), (bad hands:1.4), (poorly drawn hands:1.3), (ugly hands:1.2),
301
- (bad anatomy:1.4), (deformed body:1.3), (unnatural body:1.2), (cross-eyed:1.3), (skewed eyes:1.3), (imperfect eyes:1.2), (ugly eyes:1.2), (asymmetrical face:1.2), (unnatural face:1.2),
302
- (blush:1.1), (shadow on skin:1.1), (shaded skin:1.1), (dark skin:1.1),
303
- abstract, simplified, unrealistic, impressionistic, cartoon, anime, drawing, sketch, illustration, painting, censored, grayscale, monochrome, out of frame, cropped, distorted.
304
  """
305
 
306
- # Create Gradio interface with enhanced styling
307
  with gr.Blocks(
308
- title="CyberRealistic Pony Image Generator",
309
- theme=gr.themes.Soft(),
310
- css="""
311
- .gradio-container {
312
- max-width: 1200px !important;
313
- }
314
- .tab-nav button {
315
- font-size: 16px;
316
- font-weight: bold;
317
- }
318
- """
319
  ) as demo:
320
  gr.Markdown("""
321
- # 🎨 CyberRealistic Pony Image Generator (Hugging Face Edition)
322
 
323
- Generate high-quality images using the CyberRealistic Pony SDXL model from Hugging Face.
324
 
325
- **Features:**
326
- - 🎨 Text-to-Image generation
327
- - 🖼️ Image-to-Image transformation
328
- - 🎯 Automatic quality tag enhancement
329
- - ⚡ Memory optimized for better performance
330
- - 🤗 Auto-downloads model from Hugging Face
331
-
332
- **Note:** On first run, the model will be downloaded from Hugging Face (this may take a few minutes).
333
  """)
334
 
335
  with gr.Tabs():
336
- # Text2Image Tab
337
  with gr.TabItem("🎨 Text to Image"):
338
  with gr.Row():
339
- with gr.Column(scale=1):
340
- # Input controls for Text2Img
341
  txt2img_prompt = gr.Textbox(
342
  label="Prompt",
343
- placeholder="Enter your image description...",
344
- value="beautiful landscape with mountains and lake at sunset",
345
- lines=3
346
- )
347
-
348
- txt2img_negative = gr.Textbox(
349
- label="Negative Prompt",
350
- value=DEFAULT_NEGATIVE,
351
- lines=3
352
- )
353
-
354
- txt2img_quality_tags = gr.Checkbox(
355
- label="Add Quality Tags",
356
- value=True
357
  )
358
 
359
- with gr.Row():
360
- txt2img_steps = gr.Slider(
361
- minimum=10,
362
- maximum=50,
363
- value=25,
364
- step=1,
365
- label="Inference Steps"
366
  )
367
 
368
- txt2img_guidance = gr.Slider(
369
- minimum=1.0,
370
- maximum=20.0,
371
- value=7.5,
372
- step=0.5,
373
- label="Guidance Scale"
374
- )
375
-
376
- with gr.Row():
377
- txt2img_width = gr.Slider(
378
- minimum=512,
379
- maximum=1536,
380
- value=1024,
381
- step=64,
382
- label="Width"
383
  )
384
 
385
- txt2img_height = gr.Slider(
386
- minimum=512,
387
- maximum=1536,
388
- value=1024,
389
- step=64,
390
- label="Height"
391
- )
392
-
393
- txt2img_seed = gr.Number(
394
- label="Seed (-1 for random)",
395
- value=-1,
396
- precision=0
397
- )
398
-
399
- txt2img_btn = gr.Button("🎨 Generate Image", variant="primary")
400
 
401
- with gr.Column(scale=2):
402
- # Output for Text2Img
403
- txt2img_output = gr.Image(
404
- label="Generated Image",
405
- type="pil",
406
- height=600
407
- )
408
  txt2img_status = gr.Textbox(label="Status", interactive=False)
409
 
410
- # Image2Image Tab
411
  with gr.TabItem("🖼️ Image to Image"):
412
  with gr.Row():
413
- with gr.Column(scale=1):
414
- # Input controls for Img2Img
415
- img2img_input = gr.Image(
416
- label="Input Image",
417
- type="pil",
418
- height=300
419
- )
420
 
421
  img2img_prompt = gr.Textbox(
422
  label="Prompt",
423
- placeholder="Describe how to modify the image...",
424
- value="in the style of a digital painting, vibrant colors",
425
- lines=3
426
  )
427
 
428
- img2img_negative = gr.Textbox(
429
- label="Negative Prompt",
430
- value=DEFAULT_NEGATIVE,
431
- lines=3
432
- )
433
-
434
- img2img_quality_tags = gr.Checkbox(
435
- label="Add Quality Tags",
436
- value=True
437
- )
438
-
439
- with gr.Row():
440
- img2img_steps = gr.Slider(
441
- minimum=10,
442
- maximum=50,
443
- value=25,
444
- step=1,
445
- label="Inference Steps"
446
  )
447
 
448
- img2img_guidance = gr.Slider(
449
- minimum=1.0,
450
- maximum=20.0,
451
- value=7.5,
452
- step=0.5,
453
- label="Guidance Scale"
454
  )
 
 
 
 
 
 
 
 
 
 
 
455
 
456
- img2img_strength = gr.Slider(
457
- minimum=0.1,
458
- maximum=1.0,
459
- value=0.75,
460
- step=0.05,
461
- label="Denoising Strength (Lower = more like input, Higher = more creative)"
462
- )
463
-
464
- img2img_seed = gr.Number(
465
- label="Seed (-1 for random)",
466
- value=-1,
467
- precision=0
468
- )
469
-
470
- img2img_btn = gr.Button("🖼️ Transform Image", variant="primary")
471
-
472
- with gr.Column(scale=2):
473
- # Output for Img2Img
474
- img2img_output = gr.Image(
475
- label="Generated Image",
476
- type="pil",
477
- height=600
478
- )
479
  img2img_status = gr.Textbox(label="Status", interactive=False)
480
 
481
  # Event handlers
@@ -488,37 +325,12 @@ with gr.Blocks(
488
 
489
  img2img_btn.click(
490
  fn=generate_img2img,
491
- inputs=[img2img_input, img2img_prompt, img2img_negative, txt2img_steps, img2img_guidance,
492
  img2img_strength, img2img_seed, img2img_quality_tags],
493
  outputs=[img2img_output, img2img_status]
494
  )
495
 
496
- # Load models on startup
497
- print("Initializing CyberRealistic Pony Generator (Hugging Face Edition)...")
498
- print(f"Device: {device}")
499
- print(f"Model Repository: {MODEL_REPO}")
500
- print(f"Model File: {MODEL_FILENAME}")
501
-
502
- # Pre-load models in a separate thread to avoid blocking startup
503
- import threading
504
-
505
- def preload_models():
506
- """Pre-load models in background"""
507
- try:
508
- print("Starting background model loading...")
509
- load_models()
510
- print("Background model loading completed!")
511
- except Exception as e:
512
- print(f"Background model loading failed: {e}")
513
-
514
- # Start background loading
515
- loading_thread = threading.Thread(target=preload_models, daemon=True)
516
- loading_thread.start()
517
 
518
  if __name__ == "__main__":
519
- demo.launch(
520
- server_name="0.0.0.0",
521
- server_port=7860,
522
- share=False,
523
- show_error=True
524
- )
 
5
  import os
6
  import gc
7
  import time
8
+ import spaces
9
  from typing import Optional, Tuple
10
  from huggingface_hub import hf_hub_download
 
11
 
12
  # Global pipeline variables
13
  txt2img_pipe = None
 
17
  # Hugging Face model configuration
18
  MODEL_REPO = "ajsbsd/CyberRealistic-Pony"
19
  MODEL_FILENAME = "cyberrealisticPony_v110.safetensors"
 
20
 
21
  def clear_memory():
22
  """Clear GPU memory"""
 
24
  torch.cuda.empty_cache()
25
  gc.collect()
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def load_models():
28
+ """Load both text2img and img2img pipelines optimized for Spaces"""
29
  global txt2img_pipe, img2img_pipe
30
 
31
+ try:
32
+ print("Loading CyberRealistic Pony models...")
33
+
34
+ # Use Hugging Face Hub download with minimal local storage
35
+ print(f"Accessing model from {MODEL_REPO}...")
36
+
37
+ # Load Text2Img pipeline
38
+ if txt2img_pipe is None:
 
 
 
 
 
 
39
  txt2img_pipe = StableDiffusionXLPipeline.from_single_file(
40
+ f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILENAME}",
41
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
42
  use_safetensors=True,
43
  variant="fp16" if device == "cuda" else None
44
  )
45
 
46
+ # Aggressive memory optimizations for Spaces
47
  txt2img_pipe.enable_attention_slicing()
48
+ txt2img_pipe.enable_vae_slicing()
49
 
50
  if device == "cuda":
51
+ txt2img_pipe.enable_model_cpu_offload()
52
+ txt2img_pipe.enable_sequential_cpu_offload()
 
 
 
 
53
  else:
54
  txt2img_pipe = txt2img_pipe.to(device)
55
+
56
+ # Share components for Img2Img to save memory
57
+ if img2img_pipe is None:
58
+ img2img_pipe = StableDiffusionXLImg2ImgPipeline(
59
+ vae=txt2img_pipe.vae,
60
+ text_encoder=txt2img_pipe.text_encoder,
61
+ text_encoder_2=txt2img_pipe.text_encoder_2,
62
+ tokenizer=txt2img_pipe.tokenizer,
63
+ tokenizer_2=txt2img_pipe.tokenizer_2,
64
+ unet=txt2img_pipe.unet,
65
+ scheduler=txt2img_pipe.scheduler,
 
 
 
 
66
  )
67
 
68
+ # Same optimizations
69
  img2img_pipe.enable_attention_slicing()
70
+ img2img_pipe.enable_vae_slicing()
71
 
72
  if device == "cuda":
73
+ img2img_pipe.enable_model_cpu_offload()
74
+ img2img_pipe.enable_sequential_cpu_offload()
75
+
76
+ print("Models loaded successfully!")
77
+ return True
78
+
79
+ except Exception as e:
80
+ print(f"Error loading models: {e}")
81
+ return False
 
 
 
 
 
 
 
82
 
83
  def enhance_prompt(prompt: str, add_quality_tags: bool = True) -> str:
84
  """Enhance prompt with Pony-style tags"""
85
  if not prompt.strip():
86
  return prompt
87
 
 
88
  if prompt.startswith("score_") or not add_quality_tags:
89
  return prompt
90
 
 
93
 
94
  def validate_dimensions(width: int, height: int) -> Tuple[int, int]:
95
  """Ensure dimensions are valid for SDXL"""
 
96
  width = ((width + 63) // 64) * 64
97
  height = ((height + 63) // 64) * 64
98
 
99
+ # More conservative limits for Spaces
100
+ width = max(512, min(1024, width))
101
+ height = max(512, min(1024, height))
102
 
103
  return width, height
104
 
105
+ @spaces.GPU(duration=60) # GPU decorator for Spaces
106
  def generate_txt2img(prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, add_quality_tags):
107
+ """Generate image from text prompt with Spaces GPU support"""
108
  global txt2img_pipe
109
 
110
  if not prompt.strip():
111
  return None, "Please enter a prompt"
112
+
113
+ # Lazy load models
114
  if txt2img_pipe is None:
115
+ if not load_models():
116
+ return None, "Failed to load models. Please try again."
 
117
 
118
  try:
 
119
  clear_memory()
120
 
121
+ # Validate dimensions
122
  width, height = validate_dimensions(width, height)
123
 
124
+ # Set seed
125
  generator = None
126
  if seed != -1:
127
  generator = torch.Generator(device=device).manual_seed(int(seed))
 
129
  # Enhance prompt
130
  enhanced_prompt = enhance_prompt(prompt, add_quality_tags)
131
 
132
+ print(f"Generating: {enhanced_prompt[:100]}...")
133
  start_time = time.time()
134
 
135
+ # Generate with lower memory usage
136
  with torch.no_grad():
137
  result = txt2img_pipe(
138
  prompt=enhanced_prompt,
139
  negative_prompt=negative_prompt or "",
140
+ num_inference_steps=min(int(num_steps), 30), # Limit steps for Spaces
141
  guidance_scale=float(guidance_scale),
142
  width=width,
143
  height=height,
 
145
  )
146
 
147
  generation_time = time.time() - start_time
148
+ status = f"Generated in {generation_time:.1f}s ({width}x{height})"
149
 
150
  return result.images[0], status
151
 
152
  except Exception as e:
153
+ return None, f"Generation failed: {str(e)}"
 
 
154
  finally:
155
  clear_memory()
156
 
157
+ @spaces.GPU(duration=60) # GPU decorator for Spaces
158
  def generate_img2img(input_image, prompt, negative_prompt, num_steps, guidance_scale, strength, seed, add_quality_tags):
159
+ """Generate image from input image + text prompt with Spaces GPU support"""
160
  global img2img_pipe
161
 
162
  if input_image is None:
163
+ return None, "Please upload an input image"
164
 
165
  if not prompt.strip():
166
  return None, "Please enter a prompt"
167
+
168
+ # Lazy load models
169
  if img2img_pipe is None:
170
+ if not load_models():
171
+ return None, "Failed to load models. Please try again."
 
172
 
173
  try:
 
174
  clear_memory()
175
 
176
+ # Set seed
177
  generator = None
178
  if seed != -1:
179
  generator = torch.Generator(device=device).manual_seed(int(seed))
 
183
 
184
  # Process input image
185
  if isinstance(input_image, Image.Image):
 
186
  if input_image.mode != 'RGB':
187
  input_image = input_image.convert('RGB')
188
+
189
+ # Conservative resize for Spaces
190
+ max_size = 768
 
191
  input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
192
 
 
193
  w, h = input_image.size
194
  w, h = validate_dimensions(w, h)
195
  input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
196
 
197
+ print(f"Transforming: {enhanced_prompt[:100]}...")
198
  start_time = time.time()
199
 
 
200
  with torch.no_grad():
201
  result = img2img_pipe(
202
  prompt=enhanced_prompt,
203
  negative_prompt=negative_prompt or "",
204
  image=input_image,
205
+ num_inference_steps=min(int(num_steps), 30), # Limit steps
206
  guidance_scale=float(guidance_scale),
207
  strength=float(strength),
208
  generator=generator
209
  )
210
 
211
  generation_time = time.time() - start_time
212
+ status = f"Transformed in {generation_time:.1f}s (Strength: {strength})"
213
 
214
  return result.images[0], status
215
 
216
  except Exception as e:
217
+ return None, f"Transformation failed: {str(e)}"
 
 
218
  finally:
219
  clear_memory()
220
 
221
+ # Simplified negative prompt for better performance
222
  DEFAULT_NEGATIVE = """
223
+ (low quality:1.3), (worst quality:1.3), (bad quality:1.2), blurry, noisy, ugly, deformed,
224
+ (text, watermark:1.4), (extra limbs:1.3), (bad hands:1.3), (bad anatomy:1.2)
 
 
 
 
225
  """
226
 
227
+ # Gradio interface optimized for Spaces
228
  with gr.Blocks(
229
+ title="CyberRealistic Pony Generator",
230
+ theme=gr.themes.Soft()
 
 
 
 
 
 
 
 
 
231
  ) as demo:
232
  gr.Markdown("""
233
+ # 🎨 CyberRealistic Pony Image Generator
234
 
235
+ Generate high-quality images using the CyberRealistic Pony SDXL model.
236
 
237
+ ⚠️ **Note**: First generation may take longer as the model loads. GPU time is limited on Spaces.
 
 
 
 
 
 
 
238
  """)
239
 
240
  with gr.Tabs():
 
241
  with gr.TabItem("🎨 Text to Image"):
242
  with gr.Row():
243
+ with gr.Column():
 
244
  txt2img_prompt = gr.Textbox(
245
  label="Prompt",
246
+ placeholder="beautiful landscape, mountains, sunset",
247
+ lines=2
 
 
 
 
 
 
 
 
 
 
 
 
248
  )
249
 
250
+ with gr.Accordion("Advanced Settings", open=False):
251
+ txt2img_negative = gr.Textbox(
252
+ label="Negative Prompt",
253
+ value=DEFAULT_NEGATIVE,
254
+ lines=2
 
 
255
  )
256
 
257
+ txt2img_quality_tags = gr.Checkbox(
258
+ label="Add Quality Tags",
259
+ value=True
 
 
 
 
 
 
 
 
 
 
 
 
260
  )
261
 
262
+ with gr.Row():
263
+ txt2img_steps = gr.Slider(10, 30, 20, step=1, label="Steps")
264
+ txt2img_guidance = gr.Slider(1.0, 15.0, 7.5, step=0.5, label="Guidance")
265
+
266
+ with gr.Row():
267
+ txt2img_width = gr.Slider(512, 1024, 768, step=64, label="Width")
268
+ txt2img_height = gr.Slider(512, 1024, 768, step=64, label="Height")
269
+
270
+ txt2img_seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
 
 
 
 
 
 
271
 
272
+ txt2img_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
273
+
274
+ with gr.Column():
275
+ txt2img_output = gr.Image(label="Generated Image", height=400)
 
 
 
276
  txt2img_status = gr.Textbox(label="Status", interactive=False)
277
 
 
278
  with gr.TabItem("🖼️ Image to Image"):
279
  with gr.Row():
280
+ with gr.Column():
281
+ img2img_input = gr.Image(label="Input Image", type="pil", height=250)
 
 
 
 
 
282
 
283
  img2img_prompt = gr.Textbox(
284
  label="Prompt",
285
+ placeholder="digital painting style, vibrant colors",
286
+ lines=2
 
287
  )
288
 
289
+ with gr.Accordion("Advanced Settings", open=False):
290
+ img2img_negative = gr.Textbox(
291
+ label="Negative Prompt",
292
+ value=DEFAULT_NEGATIVE,
293
+ lines=2
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  )
295
 
296
+ img2img_quality_tags = gr.Checkbox(
297
+ label="Add Quality Tags",
298
+ value=True
 
 
 
299
  )
300
+
301
+ with gr.Row():
302
+ img2img_steps = gr.Slider(10, 30, 20, step=1, label="Steps")
303
+ img2img_guidance = gr.Slider(1.0, 15.0, 7.5, step=0.5, label="Guidance")
304
+
305
+ img2img_strength = gr.Slider(
306
+ 0.1, 1.0, 0.75, step=0.05,
307
+ label="Strength (Higher = more creative)"
308
+ )
309
+
310
+ img2img_seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
311
 
312
+ img2img_btn = gr.Button("🖼️ Transform", variant="primary", size="lg")
313
+
314
+ with gr.Column():
315
+ img2img_output = gr.Image(label="Generated Image", height=400)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  img2img_status = gr.Textbox(label="Status", interactive=False)
317
 
318
  # Event handlers
 
325
 
326
  img2img_btn.click(
327
  fn=generate_img2img,
328
+ inputs=[img2img_input, img2img_prompt, img2img_negative, img2img_steps, img2img_guidance,
329
  img2img_strength, img2img_seed, img2img_quality_tags],
330
  outputs=[img2img_output, img2img_status]
331
  )
332
 
333
+ print(f"🚀 CyberRealistic Pony Generator initialized on {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
  if __name__ == "__main__":
336
+ demo.launch()