ajsbsd commited on
Commit
1ebd84a
·
verified ·
1 Parent(s): 1ed8393

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +524 -0
app.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
4
+ from PIL import Image
5
+ import os
6
+ import gc
7
+ import time
8
+ from typing import Optional, Tuple
9
+ from huggingface_hub import hf_hub_download
10
+ import requests
11
+
12
+ # Global pipeline variables
13
+ txt2img_pipe = None
14
+ img2img_pipe = None
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # Hugging Face model configuration
18
+ MODEL_REPO = "ajsbsd/CyberRealistic-Pony"
19
+ MODEL_FILENAME = "cyberrealisticPony_v110.safetensors"
20
+ LOCAL_MODEL_PATH = "./models/cyberrealisticPony_v110.safetensors"
21
+
22
+ def clear_memory():
23
+ """Clear GPU memory"""
24
+ if torch.cuda.is_available():
25
+ torch.cuda.empty_cache()
26
+ gc.collect()
27
+
28
+ def download_model():
29
+ """Download model from Hugging Face if not already cached"""
30
+ try:
31
+ # Create models directory if it doesn't exist
32
+ os.makedirs("./models", exist_ok=True)
33
+
34
+ # Check if model already exists locally
35
+ if os.path.exists(LOCAL_MODEL_PATH):
36
+ print(f"Model already exists at {LOCAL_MODEL_PATH}")
37
+ return LOCAL_MODEL_PATH
38
+
39
+ print(f"Downloading model from {MODEL_REPO}/{MODEL_FILENAME}...")
40
+ print("This may take a while on first run...")
41
+
42
+ # Download the model file
43
+ model_path = hf_hub_download(
44
+ repo_id=MODEL_REPO,
45
+ filename=MODEL_FILENAME,
46
+ local_dir="./models",
47
+ local_dir_use_symlinks=False,
48
+ resume_download=True
49
+ )
50
+
51
+ print(f"Model downloaded successfully to {model_path}")
52
+ return model_path
53
+
54
+ except Exception as e:
55
+ print(f"Error downloading model: {e}")
56
+ print("Attempting to use cached version or fallback...")
57
+
58
+ # Try to use Hugging Face cache directly
59
+ try:
60
+ cached_path = hf_hub_download(
61
+ repo_id=MODEL_REPO,
62
+ filename=MODEL_FILENAME,
63
+ resume_download=True
64
+ )
65
+ print(f"Using cached model at {cached_path}")
66
+ return cached_path
67
+ except Exception as cache_error:
68
+ print(f"Cache fallback failed: {cache_error}")
69
+ return None
70
+
71
+ def load_models():
72
+ """Load both text2img and img2img pipelines with Hugging Face integration"""
73
+ global txt2img_pipe, img2img_pipe
74
+
75
+ # Download model if needed
76
+ model_path = download_model()
77
+
78
+ if model_path is None:
79
+ print("Failed to download or locate model file")
80
+ return None, None
81
+
82
+ if not os.path.exists(model_path):
83
+ print(f"Model file not found after download: {model_path}")
84
+ return None, None
85
+
86
+ if txt2img_pipe is None:
87
+ try:
88
+ print("Loading CyberRealistic Pony Text2Img model...")
89
+ txt2img_pipe = StableDiffusionXLPipeline.from_single_file(
90
+ model_path,
91
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
92
+ use_safetensors=True,
93
+ variant="fp16" if device == "cuda" else None
94
+ )
95
+
96
+ # Memory optimizations
97
+ txt2img_pipe.enable_attention_slicing()
98
+
99
+ if device == "cuda":
100
+ try:
101
+ txt2img_pipe.enable_model_cpu_offload()
102
+ print("Text2Img CPU offload enabled")
103
+ except Exception as e:
104
+ print(f"Text2Img CPU offload failed: {e}")
105
+ txt2img_pipe = txt2img_pipe.to(device)
106
+ else:
107
+ txt2img_pipe = txt2img_pipe.to(device)
108
+
109
+ print("Text2Img model loaded successfully!")
110
+
111
+ except Exception as e:
112
+ print(f"Error loading Text2Img model: {e}")
113
+ return None, None
114
+
115
+ if img2img_pipe is None:
116
+ try:
117
+ print("Loading CyberRealistic Pony Img2Img model...")
118
+ img2img_pipe = StableDiffusionXLImg2ImgPipeline.from_single_file(
119
+ model_path,
120
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
121
+ use_safetensors=True,
122
+ variant="fp16" if device == "cuda" else None
123
+ )
124
+
125
+ # Memory optimizations
126
+ img2img_pipe.enable_attention_slicing()
127
+
128
+ if device == "cuda":
129
+ try:
130
+ img2img_pipe.enable_model_cpu_offload()
131
+ print("Img2Img CPU offload enabled")
132
+ except Exception as e:
133
+ print(f"Img2Img CPU offload failed: {e}")
134
+ img2img_pipe = img2img_pipe.to(device)
135
+ else:
136
+ img2img_pipe = img2img_pipe.to(device)
137
+
138
+ print("Img2Img model loaded successfully!")
139
+
140
+ except Exception as e:
141
+ print(f"Error loading Img2Img model: {e}")
142
+ return txt2img_pipe, None
143
+
144
+ return txt2img_pipe, img2img_pipe
145
+
146
+ def enhance_prompt(prompt: str, add_quality_tags: bool = True) -> str:
147
+ """Enhance prompt with Pony-style tags"""
148
+ if not prompt.strip():
149
+ return prompt
150
+
151
+ # Don't add tags if already present
152
+ if prompt.startswith("score_") or not add_quality_tags:
153
+ return prompt
154
+
155
+ quality_tags = "score_9, score_8_up, score_7_up, masterpiece, best quality, highly detailed"
156
+ return f"{quality_tags}, {prompt}"
157
+
158
+ def validate_dimensions(width: int, height: int) -> Tuple[int, int]:
159
+ """Ensure dimensions are valid for SDXL"""
160
+ # SDXL works best with dimensions divisible by 64
161
+ width = ((width + 63) // 64) * 64
162
+ height = ((height + 63) // 64) * 64
163
+
164
+ # Ensure reasonable limits
165
+ width = max(512, min(1536, width))
166
+ height = max(512, min(1536, height))
167
+
168
+ return width, height
169
+
170
+ def generate_txt2img(prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, add_quality_tags):
171
+ """Generate image from text prompt with enhanced error handling"""
172
+ global txt2img_pipe
173
+
174
+ if not prompt.strip():
175
+ return None, "Please enter a prompt"
176
+
177
+ # Load models if not already loaded
178
+ if txt2img_pipe is None:
179
+ txt2img_pipe, _ = load_models()
180
+ if txt2img_pipe is None:
181
+ return None, "Failed to load Text2Img model. Please check your internet connection and try again."
182
+
183
+ try:
184
+ # Clear memory before generation
185
+ clear_memory()
186
+
187
+ # Validate and fix dimensions
188
+ width, height = validate_dimensions(width, height)
189
+
190
+ # Set seed for reproducibility
191
+ generator = None
192
+ if seed != -1:
193
+ generator = torch.Generator(device=device).manual_seed(int(seed))
194
+
195
+ # Enhance prompt
196
+ enhanced_prompt = enhance_prompt(prompt, add_quality_tags)
197
+
198
+ print(f"Generating with prompt: {enhanced_prompt[:100]}...")
199
+ start_time = time.time()
200
+
201
+ # Generate image
202
+ with torch.no_grad():
203
+ result = txt2img_pipe(
204
+ prompt=enhanced_prompt,
205
+ negative_prompt=negative_prompt or "",
206
+ num_inference_steps=int(num_steps),
207
+ guidance_scale=float(guidance_scale),
208
+ width=width,
209
+ height=height,
210
+ generator=generator
211
+ )
212
+
213
+ generation_time = time.time() - start_time
214
+ status = f"Text2Img: Generated successfully in {generation_time:.1f}s (Size: {width}x{height})"
215
+
216
+ return result.images[0], status
217
+
218
+ except Exception as e:
219
+ error_msg = f"Text2Img generation failed: {str(e)}"
220
+ print(error_msg)
221
+ return None, error_msg
222
+ finally:
223
+ clear_memory()
224
+
225
+ def generate_img2img(input_image, prompt, negative_prompt, num_steps, guidance_scale, strength, seed, add_quality_tags):
226
+ """Generate image from input image + text prompt with enhanced error handling"""
227
+ global img2img_pipe
228
+
229
+ if input_image is None:
230
+ return None, "Please upload an input image for Img2Img"
231
+
232
+ if not prompt.strip():
233
+ return None, "Please enter a prompt"
234
+
235
+ # Load models if not already loaded
236
+ if img2img_pipe is None:
237
+ _, img2img_pipe = load_models()
238
+ if img2img_pipe is None:
239
+ return None, "Failed to load Img2Img model. Please check your internet connection and try again."
240
+
241
+ try:
242
+ # Clear memory before generation
243
+ clear_memory()
244
+
245
+ # Set seed for reproducibility
246
+ generator = None
247
+ if seed != -1:
248
+ generator = torch.Generator(device=device).manual_seed(int(seed))
249
+
250
+ # Enhance prompt
251
+ enhanced_prompt = enhance_prompt(prompt, add_quality_tags)
252
+
253
+ # Process input image
254
+ if isinstance(input_image, Image.Image):
255
+ # Ensure RGB format
256
+ if input_image.mode != 'RGB':
257
+ input_image = input_image.convert('RGB')
258
+
259
+ # Resize to reasonable dimensions while maintaining aspect ratio
260
+ original_size = input_image.size
261
+ max_size = 1024
262
+ input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
263
+
264
+ # Ensure dimensions are divisible by 64
265
+ w, h = input_image.size
266
+ w, h = validate_dimensions(w, h)
267
+ input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
268
+
269
+ print(f"Generating with prompt: {enhanced_prompt[:100]}...")
270
+ start_time = time.time()
271
+
272
+ # Generate image
273
+ with torch.no_grad():
274
+ result = img2img_pipe(
275
+ prompt=enhanced_prompt,
276
+ negative_prompt=negative_prompt or "",
277
+ image=input_image,
278
+ num_inference_steps=int(num_steps),
279
+ guidance_scale=float(guidance_scale),
280
+ strength=float(strength),
281
+ generator=generator
282
+ )
283
+
284
+ generation_time = time.time() - start_time
285
+ status = f"Img2Img: Generated successfully in {generation_time:.1f}s (Strength: {strength})"
286
+
287
+ return result.images[0], status
288
+
289
+ except Exception as e:
290
+ error_msg = f"Img2Img generation failed: {str(e)}"
291
+ print(error_msg)
292
+ return None, error_msg
293
+ finally:
294
+ clear_memory()
295
+
296
+ # Default negative prompt (improved)
297
+ DEFAULT_NEGATIVE = """
298
+ (low quality:1.4), (worst quality:1.4), (bad quality:1.3), (normal quality:1.2), lowres, jpeg artifacts, blurry, noisy, ugly, deformed, disfigured, malformed, poorly drawn, bad art, amateur, render, 3D, cgi,
299
+ (text, signature, watermark, username, copyright:1.5),
300
+ (extra limbs:1.5), (missing limbs:1.5), (extra fingers:1.5), (missing fingers:1.5), (mutated hands:1.5), (bad hands:1.4), (poorly drawn hands:1.3), (ugly hands:1.2),
301
+ (bad anatomy:1.4), (deformed body:1.3), (unnatural body:1.2), (cross-eyed:1.3), (skewed eyes:1.3), (imperfect eyes:1.2), (ugly eyes:1.2), (asymmetrical face:1.2), (unnatural face:1.2),
302
+ (blush:1.1), (shadow on skin:1.1), (shaded skin:1.1), (dark skin:1.1),
303
+ abstract, simplified, unrealistic, impressionistic, cartoon, anime, drawing, sketch, illustration, painting, censored, grayscale, monochrome, out of frame, cropped, distorted.
304
+ """
305
+
306
+ # Create Gradio interface with enhanced styling
307
+ with gr.Blocks(
308
+ title="CyberRealistic Pony Image Generator",
309
+ theme=gr.themes.Soft(),
310
+ css="""
311
+ .gradio-container {
312
+ max-width: 1200px !important;
313
+ }
314
+ .tab-nav button {
315
+ font-size: 16px;
316
+ font-weight: bold;
317
+ }
318
+ """
319
+ ) as demo:
320
+ gr.Markdown("""
321
+ # 🎨 CyberRealistic Pony Image Generator (Hugging Face Edition)
322
+
323
+ Generate high-quality images using the CyberRealistic Pony SDXL model from Hugging Face.
324
+
325
+ **Features:**
326
+ - 🎨 Text-to-Image generation
327
+ - 🖼️ Image-to-Image transformation
328
+ - 🎯 Automatic quality tag enhancement
329
+ - ⚡ Memory optimized for better performance
330
+ - 🤗 Auto-downloads model from Hugging Face
331
+
332
+ **Note:** On first run, the model will be downloaded from Hugging Face (this may take a few minutes).
333
+ """)
334
+
335
+ with gr.Tabs():
336
+ # Text2Image Tab
337
+ with gr.TabItem("🎨 Text to Image"):
338
+ with gr.Row():
339
+ with gr.Column(scale=1):
340
+ # Input controls for Text2Img
341
+ txt2img_prompt = gr.Textbox(
342
+ label="Prompt",
343
+ placeholder="Enter your image description...",
344
+ value="beautiful landscape with mountains and lake at sunset",
345
+ lines=3
346
+ )
347
+
348
+ txt2img_negative = gr.Textbox(
349
+ label="Negative Prompt",
350
+ value=DEFAULT_NEGATIVE,
351
+ lines=3
352
+ )
353
+
354
+ txt2img_quality_tags = gr.Checkbox(
355
+ label="Add Quality Tags",
356
+ value=True
357
+ )
358
+
359
+ with gr.Row():
360
+ txt2img_steps = gr.Slider(
361
+ minimum=10,
362
+ maximum=50,
363
+ value=25,
364
+ step=1,
365
+ label="Inference Steps"
366
+ )
367
+
368
+ txt2img_guidance = gr.Slider(
369
+ minimum=1.0,
370
+ maximum=20.0,
371
+ value=7.5,
372
+ step=0.5,
373
+ label="Guidance Scale"
374
+ )
375
+
376
+ with gr.Row():
377
+ txt2img_width = gr.Slider(
378
+ minimum=512,
379
+ maximum=1536,
380
+ value=1024,
381
+ step=64,
382
+ label="Width"
383
+ )
384
+
385
+ txt2img_height = gr.Slider(
386
+ minimum=512,
387
+ maximum=1536,
388
+ value=1024,
389
+ step=64,
390
+ label="Height"
391
+ )
392
+
393
+ txt2img_seed = gr.Number(
394
+ label="Seed (-1 for random)",
395
+ value=-1,
396
+ precision=0
397
+ )
398
+
399
+ txt2img_btn = gr.Button("🎨 Generate Image", variant="primary")
400
+
401
+ with gr.Column(scale=2):
402
+ # Output for Text2Img
403
+ txt2img_output = gr.Image(
404
+ label="Generated Image",
405
+ type="pil",
406
+ height=600
407
+ )
408
+ txt2img_status = gr.Textbox(label="Status", interactive=False)
409
+
410
+ # Image2Image Tab
411
+ with gr.TabItem("🖼️ Image to Image"):
412
+ with gr.Row():
413
+ with gr.Column(scale=1):
414
+ # Input controls for Img2Img
415
+ img2img_input = gr.Image(
416
+ label="Input Image",
417
+ type="pil",
418
+ height=300
419
+ )
420
+
421
+ img2img_prompt = gr.Textbox(
422
+ label="Prompt",
423
+ placeholder="Describe how to modify the image...",
424
+ value="in the style of a digital painting, vibrant colors",
425
+ lines=3
426
+ )
427
+
428
+ img2img_negative = gr.Textbox(
429
+ label="Negative Prompt",
430
+ value=DEFAULT_NEGATIVE,
431
+ lines=3
432
+ )
433
+
434
+ img2img_quality_tags = gr.Checkbox(
435
+ label="Add Quality Tags",
436
+ value=True
437
+ )
438
+
439
+ with gr.Row():
440
+ img2img_steps = gr.Slider(
441
+ minimum=10,
442
+ maximum=50,
443
+ value=25,
444
+ step=1,
445
+ label="Inference Steps"
446
+ )
447
+
448
+ img2img_guidance = gr.Slider(
449
+ minimum=1.0,
450
+ maximum=20.0,
451
+ value=7.5,
452
+ step=0.5,
453
+ label="Guidance Scale"
454
+ )
455
+
456
+ img2img_strength = gr.Slider(
457
+ minimum=0.1,
458
+ maximum=1.0,
459
+ value=0.75,
460
+ step=0.05,
461
+ label="Denoising Strength (Lower = more like input, Higher = more creative)"
462
+ )
463
+
464
+ img2img_seed = gr.Number(
465
+ label="Seed (-1 for random)",
466
+ value=-1,
467
+ precision=0
468
+ )
469
+
470
+ img2img_btn = gr.Button("🖼️ Transform Image", variant="primary")
471
+
472
+ with gr.Column(scale=2):
473
+ # Output for Img2Img
474
+ img2img_output = gr.Image(
475
+ label="Generated Image",
476
+ type="pil",
477
+ height=600
478
+ )
479
+ img2img_status = gr.Textbox(label="Status", interactive=False)
480
+
481
+ # Event handlers
482
+ txt2img_btn.click(
483
+ fn=generate_txt2img,
484
+ inputs=[txt2img_prompt, txt2img_negative, txt2img_steps, txt2img_guidance,
485
+ txt2img_width, txt2img_height, txt2img_seed, txt2img_quality_tags],
486
+ outputs=[txt2img_output, txt2img_status]
487
+ )
488
+
489
+ img2img_btn.click(
490
+ fn=generate_img2img,
491
+ inputs=[img2img_input, img2img_prompt, img2img_negative, txt2img_steps, img2img_guidance,
492
+ img2img_strength, img2img_seed, img2img_quality_tags],
493
+ outputs=[img2img_output, img2img_status]
494
+ )
495
+
496
+ # Load models on startup
497
+ print("Initializing CyberRealistic Pony Generator (Hugging Face Edition)...")
498
+ print(f"Device: {device}")
499
+ print(f"Model Repository: {MODEL_REPO}")
500
+ print(f"Model File: {MODEL_FILENAME}")
501
+
502
+ # Pre-load models in a separate thread to avoid blocking startup
503
+ import threading
504
+
505
+ def preload_models():
506
+ """Pre-load models in background"""
507
+ try:
508
+ print("Starting background model loading...")
509
+ load_models()
510
+ print("Background model loading completed!")
511
+ except Exception as e:
512
+ print(f"Background model loading failed: {e}")
513
+
514
+ # Start background loading
515
+ loading_thread = threading.Thread(target=preload_models, daemon=True)
516
+ loading_thread.start()
517
+
518
+ if __name__ == "__main__":
519
+ demo.launch(
520
+ server_name="0.0.0.0",
521
+ server_port=7860,
522
+ share=False,
523
+ show_error=True
524
+ )