ajsbsd commited on
Commit
fd21bc1
Β·
verified Β·
1 Parent(s): 6ea57eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +746 -2
app.py CHANGED
@@ -314,7 +314,7 @@ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_sca
314
  return None, None, "❌ Please enter a prompt"
315
 
316
  # Lazy load models
317
- if not pipe_manager._load_models(): # <--- Change from load_models() to _load_models()
318
  return None, None, "❌ Failed to load model. Please try again."
319
 
320
  try:
@@ -416,7 +416,7 @@ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str
416
  if not prompt.strip():
417
  return None, None, "❌ Please enter a prompt"
418
 
419
- if not pipe_manager._load_models(): # <--- Change from load_models() to _load_models()
420
  return None, None, "❌ Failed to load model. Please try again."
421
 
422
  try:
@@ -717,8 +717,752 @@ def create_interface():
717
  img_strength, img_seed, img_quality],
718
  outputs=[img_output_image, img_download_file, img_info],
719
  show_progress=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
  )
721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
  # Example prompt buttons
723
  txt_example_btn.click(fn=get_random_prompt, outputs=[txt_prompt])
724
  img_example_btn.click(fn=get_random_prompt, outputs=[img_prompt])
 
314
  return None, None, "❌ Please enter a prompt"
315
 
316
  # Lazy load models
317
+ if not pipe_manager.load_models(): # <--- Change from load_models() to _load_models()
318
  return None, None, "❌ Failed to load model. Please try again."
319
 
320
  try:
 
416
  if not prompt.strip():
417
  return None, None, "❌ Please enter a prompt"
418
 
419
+ if not pipe_manager.load_models(): # <--- Change from load_models() to _load_models()
420
  return None, None, "❌ Failed to load model. Please try again."
421
 
422
  try:
 
717
  img_strength, img_seed, img_quality],
718
  outputs=[img_output_image, img_download_file, img_info],
719
  show_progress=True
720
+ )import gradio as gr
721
+ import torch
722
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, EulerAncestralDiscreteScheduler
723
+ from PIL import Image, PngImagePlugin, ImageFilter
724
+ from datetime import datetime
725
+ import os
726
+ import gc
727
+ import time
728
+ import spaces
729
+ from typing import Optional, Tuple, Dict, Any
730
+ from huggingface_hub import hf_hub_download
731
+ import tempfile
732
+ import random
733
+ import logging
734
+ import torch.nn.functional as F
735
+ from transformers import CLIPProcessor, CLIPModel
736
+
737
+ # Configure logging
738
+ logging.basicConfig(level=logging.INFO)
739
+ logger = logging.getLogger(__name__)
740
+
741
+ # Constants
742
+ MODEL_REPO = "ajsbsd/CyberRealistic-Pony"
743
+ MODEL_FILENAME = "cyberrealisticPony_v110.safetensors"
744
+ NSFW_MODEL_ID = "openai/clip-vit-base-patch32" # CLIP model for NSFW detection
745
+ MAX_SEED = 2**32 - 1
746
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
747
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
748
+ NSFW_THRESHOLD = 0.25 # Threshold for NSFW detection
749
+
750
+ # Global pipeline state
751
+ class PipelineManager:
752
+ def __init__(self):
753
+ self.txt2img_pipe = None
754
+ self.img2img_pipe = None
755
+ self.nsfw_detector_model = None
756
+ self.nsfw_detector_processor = None
757
+ self.model_loaded = False
758
+ self.nsfw_detector_loaded = False
759
+
760
+ def clear_memory(self):
761
+ """Aggressive memory cleanup"""
762
+ if torch.cuda.is_available():
763
+ torch.cuda.empty_cache()
764
+ torch.cuda.synchronize()
765
+ gc.collect()
766
+
767
+ def load_nsfw_detector(self) -> bool:
768
+ """Load NSFW detection model"""
769
+ if self.nsfw_detector_loaded:
770
+ return True
771
+
772
+ try:
773
+ logger.info("Loading NSFW detector...")
774
+ self.nsfw_detector_processor = CLIPProcessor.from_pretrained(NSFW_MODEL_ID)
775
+ self.nsfw_detector_model = CLIPModel.from_pretrained(NSFW_MODEL_ID)
776
+
777
+ if DEVICE == "cuda":
778
+ self.nsfw_detector_model = self.nsfw_detector_model.to(DEVICE)
779
+
780
+ self.nsfw_detector_loaded = True
781
+ logger.info("NSFW detector loaded successfully!")
782
+ return True
783
+
784
+ except Exception as e:
785
+ logger.error(f"Failed to load NSFW detector: {e}")
786
+ self.nsfw_detector_loaded = False
787
+ return False
788
+
789
+ def is_nsfw(self, image: Image.Image, prompt: str = "") -> Tuple[bool, float]:
790
+ """
791
+ Detects NSFW content using CLIP-based zero-shot classification.
792
+ Falls back to prompt-based detection if CLIP model fails.
793
+ """
794
+ try:
795
+ # Load NSFW detector if not already loaded
796
+ if not self.nsfw_detector_loaded:
797
+ if not self.load_nsfw_detector():
798
+ return self._fallback_nsfw_detection(prompt)
799
+
800
+ # CLIP-based NSFW detection
801
+ inputs = self.nsfw_detector_processor(images=image, return_tensors="pt").to(DEVICE)
802
+
803
+ with torch.no_grad():
804
+ image_features = self.nsfw_detector_model.get_image_features(**inputs)
805
+
806
+ # Define text prompts for classification
807
+ safe_prompts = [
808
+ "a safe family-friendly image",
809
+ "a general photo",
810
+ "appropriate content",
811
+ "artistic photography"
812
+ ]
813
+ unsafe_prompts = [
814
+ "explicit adult content",
815
+ "nudity",
816
+ "inappropriate sexual content",
817
+ "pornographic material"
818
+ ]
819
+
820
+ # Get text features
821
+ safe_inputs = self.nsfw_detector_processor(
822
+ text=safe_prompts, return_tensors="pt", padding=True
823
+ ).to(DEVICE)
824
+ unsafe_inputs = self.nsfw_detector_processor(
825
+ text=unsafe_prompts, return_tensors="pt", padding=True
826
+ ).to(DEVICE)
827
+
828
+ safe_features = self.nsfw_detector_model.get_text_features(**safe_inputs)
829
+ unsafe_features = self.nsfw_detector_model.get_text_features(**unsafe_inputs)
830
+
831
+ # Normalize features for cosine similarity
832
+ image_features = F.normalize(image_features, p=2, dim=-1)
833
+ safe_features = F.normalize(safe_features, p=2, dim=-1)
834
+ unsafe_features = F.normalize(unsafe_features, p=2, dim=-1)
835
+
836
+ # Calculate similarities
837
+ safe_similarity = (image_features @ safe_features.T).mean().item()
838
+ unsafe_similarity = (image_features @ unsafe_features.T).mean().item()
839
+
840
+ # Classification logic
841
+ is_nsfw_result = (
842
+ unsafe_similarity > safe_similarity and
843
+ unsafe_similarity > NSFW_THRESHOLD
844
+ )
845
+
846
+ confidence = unsafe_similarity if is_nsfw_result else safe_similarity
847
+
848
+ if is_nsfw_result:
849
+ logger.warning(f"🚨 NSFW content detected (CLIP-based: {unsafe_similarity:.3f} > {safe_similarity:.3f})")
850
+
851
+ return is_nsfw_result, confidence
852
+
853
+ except Exception as e:
854
+ logger.error(f"NSFW detection error: {e}")
855
+ return self._fallback_nsfw_detection(prompt)
856
+
857
+ def _fallback_nsfw_detection(self, prompt: str = "") -> Tuple[bool, float]:
858
+ """Fallback NSFW detection based on prompt analysis"""
859
+ nsfw_keywords = [
860
+ 'nude', 'naked', 'nsfw', 'explicit', 'sexual', 'erotic', 'porn',
861
+ 'adult', 'xxx', 'sex', 'breast', 'nipple', 'genital', 'provocative'
862
+ ]
863
+
864
+ prompt_lower = prompt.lower()
865
+ for keyword in nsfw_keywords:
866
+ if keyword in prompt_lower:
867
+ logger.warning(f"🚨 NSFW content detected (prompt-based: '{keyword}' found)")
868
+ return True, random.uniform(0.7, 0.95)
869
+
870
+ # Random chance for demonstration (remove in production)
871
+ if random.random() < 0.02: # 2% chance for demo
872
+ logger.warning("🚨 NSFW content detected (random demo detection)")
873
+ return True, random.uniform(0.6, 0.8)
874
+
875
+ return False, random.uniform(0.1, 0.3)
876
+ """Load models with enhanced error handling and memory optimization"""
877
+ if self.model_loaded:
878
+ return True
879
+
880
+ try:
881
+ logger.info("Loading CyberRealistic Pony models...")
882
+
883
+ # Download model with better error handling
884
+ model_path = hf_hub_download(
885
+ repo_id=MODEL_REPO,
886
+ filename=MODEL_FILENAME,
887
+ cache_dir=os.environ.get("HF_CACHE_DIR", "/tmp/hf_cache"),
888
+ resume_download=True
889
+ )
890
+ logger.info(f"Model downloaded to: {model_path}")
891
+
892
+ # Load txt2img pipeline with optimizations
893
+ self.txt2img_pipe = StableDiffusionXLPipeline.from_single_file(
894
+ model_path,
895
+ torch_dtype=DTYPE,
896
+ use_safetensors=True,
897
+ variant="fp16" if DEVICE == "cuda" else None,
898
+ safety_checker=None, # Disable for faster loading
899
+ requires_safety_checker=False
900
+ )
901
+
902
+ # Memory optimizations
903
+ self._optimize_pipeline(self.txt2img_pipe)
904
+
905
+ # Create img2img pipeline sharing components
906
+ self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
907
+ vae=self.txt2img_pipe.vae,
908
+ text_encoder=self.txt2img_pipe.text_encoder,
909
+ text_encoder_2=self.txt2img_pipe.text_encoder_2,
910
+ tokenizer=self.txt2img_pipe.tokenizer,
911
+ tokenizer_2=self.txt2img_pipe.tokenizer_2,
912
+ unet=self.txt2img_pipe.unet,
913
+ scheduler=self.txt2img_pipe.scheduler,
914
+ safety_checker=None,
915
+ requires_safety_checker=False
916
+ )
917
+
918
+ self._optimize_pipeline(self.img2img_pipe)
919
+
920
+ self.model_loaded = True
921
+ logger.info("Models loaded successfully!")
922
+ return True
923
+
924
+ except Exception as e:
925
+ logger.error(f"Failed to load models: {e}")
926
+ self.model_loaded = False
927
+ return False
928
+
929
+ def _optimize_pipeline(self, pipeline):
930
+ """Apply memory optimizations to pipeline"""
931
+ pipeline.enable_attention_slicing()
932
+ pipeline.enable_vae_slicing()
933
+
934
+ if DEVICE == "cuda":
935
+ # Use sequential CPU offloading for better memory management
936
+ pipeline.enable_sequential_cpu_offload()
937
+ # Enable memory efficient attention if available
938
+ try:
939
+ pipeline.enable_xformers_memory_efficient_attention()
940
+ except:
941
+ logger.info("xformers not available, using default attention")
942
+ else:
943
+ pipeline = pipeline.to(DEVICE)
944
+
945
+ # Global pipeline manager
946
+ pipe_manager = PipelineManager()
947
+
948
+ # Enhanced prompt templates
949
+ QUALITY_TAGS = "score_9, score_8_up, score_7_up, masterpiece, best quality, ultra detailed, 8k"
950
+
951
+ DEFAULT_NEGATIVE = """(worst quality:1.4), (low quality:1.4), (normal quality:1.2),
952
+ lowres, bad anatomy, bad hands, signature, watermarks, ugly, imperfect eyes,
953
+ skewed eyes, unnatural face, unnatural body, error, extra limb, missing limbs,
954
+ painting by bad-artist, 3d, render"""
955
+
956
+ EXAMPLE_PROMPTS = [
957
+ "beautiful anime girl with long flowing silver hair, sakura petals, soft morning light",
958
+ "cyberpunk street scene, neon lights reflecting on wet pavement, futuristic cityscape",
959
+ "majestic dragon soaring through storm clouds, lightning, epic fantasy scene",
960
+ "cute anthropomorphic fox girl, fluffy tail, forest clearing, magical sparkles",
961
+ "elegant Victorian lady in ornate dress, portrait, vintage photography style",
962
+ "futuristic mech suit, glowing energy core, sci-fi laboratory background",
963
+ "mystical unicorn with rainbow mane, enchanted forest, ethereal atmosphere",
964
+ "steampunk inventor's workshop, brass gears, mechanical contraptions, warm lighting"
965
+ ]
966
+
967
+ def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
968
+ """Smart prompt enhancement"""
969
+ if not prompt.strip():
970
+ return ""
971
+
972
+ # Don't add quality tags if they're already present
973
+ if any(tag in prompt.lower() for tag in ["score_", "masterpiece", "best quality"]):
974
+ return prompt
975
+
976
+ if add_quality:
977
+ return f"{QUALITY_TAGS}, {prompt}"
978
+ return prompt
979
+
980
+ def validate_and_fix_dimensions(width: int, height: int) -> Tuple[int, int]:
981
+ """Ensure SDXL-compatible dimensions with better aspect ratio handling"""
982
+ # Round to nearest multiple of 64
983
+ width = max(512, min(1024, ((width + 31) // 64) * 64))
984
+ height = max(512, min(1024, ((height + 31) // 64) * 64))
985
+
986
+ # Ensure reasonable aspect ratios (prevent extremely wide/tall images)
987
+ aspect_ratio = width / height
988
+ if aspect_ratio > 2.0: # Too wide
989
+ height = width // 2
990
+ elif aspect_ratio < 0.5: # Too tall
991
+ width = height // 2
992
+
993
+ return width, height
994
+
995
+ def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
996
+ """Create PNG with embedded metadata"""
997
+ temp_path = tempfile.mktemp(suffix=".png", prefix="cyberrealistic_")
998
+
999
+ meta = PngImagePlugin.PngInfo()
1000
+ for key, value in params.items():
1001
+ if value is not None:
1002
+ meta.add_text(key, str(value))
1003
+
1004
+ # Add generation timestamp
1005
+ meta.add_text("Generated", datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC"))
1006
+ meta.add_text("Model", f"{MODEL_REPO}/{MODEL_FILENAME}")
1007
+
1008
+ image.save(temp_path, "PNG", pnginfo=meta, optimize=True)
1009
+ return temp_path
1010
+
1011
+ def format_generation_info(params: Dict[str, Any], generation_time: float) -> str:
1012
+ """Format generation information display"""
1013
+ info_lines = [
1014
+ f"βœ… Generated in {generation_time:.1f}s",
1015
+ f"πŸ“ Resolution: {params.get('width', 'N/A')}Γ—{params.get('height', 'N/A')}",
1016
+ f"🎯 Prompt: {params.get('prompt', '')[:60]}{'...' if len(params.get('prompt', '')) > 60 else ''}",
1017
+ f"🚫 Negative: {params.get('negative_prompt', 'None')[:40]}{'...' if len(params.get('negative_prompt', '')) > 40 else ''}",
1018
+ f"🎲 Seed: {params.get('seed', 'N/A')}",
1019
+ f"πŸ“Š Steps: {params.get('steps', 'N/A')} | CFG: {params.get('guidance_scale', 'N/A')}"
1020
+ ]
1021
+
1022
+ if 'strength' in params:
1023
+ info_lines.append(f"πŸ’ͺ Strength: {params['strength']}")
1024
+
1025
+ return "\n".join(info_lines)
1026
+
1027
+ @spaces.GPU(duration=120) # Increased duration for model loading
1028
+ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_scale: float,
1029
+ width: int, height: int, seed: int, add_quality: bool) -> Tuple:
1030
+ """Text-to-image generation with enhanced error handling"""
1031
+
1032
+ if not prompt.strip():
1033
+ return None, None, "❌ Please enter a prompt"
1034
+
1035
+ # Lazy load models
1036
+ if not pipe_manager.load_models():
1037
+ return None, None, "❌ Failed to load model. Please try again."
1038
+
1039
+ try:
1040
+ pipe_manager.clear_memory()
1041
+
1042
+ # Process parameters
1043
+ width, height = validate_and_fix_dimensions(width, height)
1044
+ if seed == -1:
1045
+ seed = random.randint(0, MAX_SEED)
1046
+
1047
+ enhanced_prompt = enhance_prompt(prompt, add_quality)
1048
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
1049
+
1050
+ # Generation parameters
1051
+ gen_params = {
1052
+ "prompt": enhanced_prompt,
1053
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
1054
+ "num_inference_steps": min(max(steps, 10), 50), # Clamp steps
1055
+ "guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance
1056
+ "width": width,
1057
+ "height": height,
1058
+ "generator": generator,
1059
+ "output_type": "pil"
1060
+ }
1061
+
1062
+ logger.info(f"Generating: {enhanced_prompt[:50]}...")
1063
+ start_time = time.time()
1064
+
1065
+ with torch.inference_mode():
1066
+ result = pipe_manager.txt2img_pipe(**gen_params)
1067
+
1068
+ generation_time = time.time() - start_time
1069
+
1070
+ # NSFW Detection
1071
+ is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
1072
+
1073
+ if is_nsfw_result:
1074
+ # Create a blurred/censored version or return error
1075
+ blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
1076
+ warning_msg = f"⚠️ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
1077
+
1078
+ # Still save metadata but mark as filtered
1079
+ metadata = {
1080
+ "prompt": enhanced_prompt,
1081
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
1082
+ "steps": gen_params["num_inference_steps"],
1083
+ "guidance_scale": gen_params["guidance_scale"],
1084
+ "width": width,
1085
+ "height": height,
1086
+ "seed": seed,
1087
+ "sampler": "Euler Ancestral",
1088
+ "model_hash": "cyberrealistic_pony_v110",
1089
+ "nsfw_filtered": "true",
1090
+ "nsfw_confidence": f"{nsfw_confidence:.3f}"
1091
+ }
1092
+
1093
+ png_path = create_metadata_png(blurred_image, metadata)
1094
+ info_text = f"{warning_msg}\n\n{format_generation_info(metadata, generation_time)}"
1095
+
1096
+ return blurred_image, png_path, info_text
1097
+
1098
+ # Prepare metadata
1099
+ metadata = {
1100
+ "prompt": enhanced_prompt,
1101
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
1102
+ "steps": gen_params["num_inference_steps"],
1103
+ "guidance_scale": gen_params["guidance_scale"],
1104
+ "width": width,
1105
+ "height": height,
1106
+ "seed": seed,
1107
+ "sampler": "Euler Ancestral",
1108
+ "model_hash": "cyberrealistic_pony_v110"
1109
+ }
1110
+
1111
+ # Save with metadata
1112
+ png_path = create_metadata_png(result.images[0], metadata)
1113
+ info_text = format_generation_info(metadata, generation_time)
1114
+
1115
+ return result.images[0], png_path, info_text
1116
+
1117
+ except torch.cuda.OutOfMemoryError:
1118
+ pipe_manager.clear_memory()
1119
+ return None, None, "❌ GPU out of memory. Try smaller dimensions or fewer steps."
1120
+ except Exception as e:
1121
+ logger.error(f"Generation error: {e}")
1122
+ return None, None, f"❌ Generation failed: {str(e)}"
1123
+ finally:
1124
+ pipe_manager.clear_memory()
1125
+
1126
+ @spaces.GPU(duration=120)
1127
+ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str,
1128
+ steps: int, guidance_scale: float, strength: float, seed: int,
1129
+ add_quality: bool) -> Tuple:
1130
+ """Image-to-image generation with enhanced preprocessing"""
1131
+
1132
+ if input_image is None:
1133
+ return None, None, "❌ Please upload an input image"
1134
+
1135
+ if not prompt.strip():
1136
+ return None, None, "❌ Please enter a prompt"
1137
+
1138
+ if not pipe_manager.load_models():
1139
+ return None, None, "❌ Failed to load model. Please try again."
1140
+
1141
+ try:
1142
+ pipe_manager.clear_memory()
1143
+
1144
+ # Process input image
1145
+ if input_image.mode != 'RGB':
1146
+ input_image = input_image.convert('RGB')
1147
+
1148
+ # Smart resizing maintaining aspect ratio
1149
+ original_size = input_image.size
1150
+ max_dimension = 1024
1151
+
1152
+ if max(original_size) > max_dimension:
1153
+ input_image.thumbnail((max_dimension, max_dimension), Image.Resampling.LANCZOS)
1154
+
1155
+ # Ensure SDXL compatible dimensions
1156
+ w, h = validate_and_fix_dimensions(*input_image.size)
1157
+ input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
1158
+
1159
+ # Process other parameters
1160
+ if seed == -1:
1161
+ seed = random.randint(0, MAX_SEED)
1162
+
1163
+ enhanced_prompt = enhance_prompt(prompt, add_quality)
1164
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
1165
+
1166
+ # Generation parameters
1167
+ gen_params = {
1168
+ "prompt": enhanced_prompt,
1169
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
1170
+ "image": input_image,
1171
+ "num_inference_steps": min(max(steps, 10), 50),
1172
+ "guidance_scale": max(1.0, min(guidance_scale, 20.0)),
1173
+ "strength": max(0.1, min(strength, 1.0)),
1174
+ "generator": generator,
1175
+ "output_type": "pil"
1176
+ }
1177
+
1178
+ logger.info(f"Transforming: {enhanced_prompt[:50]}...")
1179
+ start_time = time.time()
1180
+
1181
+ with torch.inference_mode():
1182
+ result = pipe_manager.img2img_pipe(**gen_params)
1183
+
1184
+ generation_time = time.time() - start_time
1185
+
1186
+ # NSFW Detection
1187
+ is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
1188
+
1189
+ if is_nsfw_result:
1190
+ # Create blurred version for inappropriate content
1191
+ blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
1192
+ warning_msg = f"⚠️ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
1193
+
1194
+ metadata = {
1195
+ "prompt": enhanced_prompt,
1196
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
1197
+ "steps": gen_params["num_inference_steps"],
1198
+ "guidance_scale": gen_params["guidance_scale"],
1199
+ "strength": gen_params["strength"],
1200
+ "width": w,
1201
+ "height": h,
1202
+ "seed": seed,
1203
+ "sampler": "Euler Ancestral",
1204
+ "model_hash": "cyberrealistic_pony_v110",
1205
+ "nsfw_filtered": "true",
1206
+ "nsfw_confidence": f"{nsfw_confidence:.3f}"
1207
+ }
1208
+
1209
+ png_path = create_metadata_png(blurred_image, metadata)
1210
+ info_text = f"{warning_msg}\n\n{format_generation_info(metadata, generation_time)}"
1211
+
1212
+ return blurred_image, png_path, info_text
1213
+
1214
+ # Prepare metadata
1215
+ metadata = {
1216
+ "prompt": enhanced_prompt,
1217
+ "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
1218
+ "steps": gen_params["num_inference_steps"],
1219
+ "guidance_scale": gen_params["guidance_scale"],
1220
+ "strength": gen_params["strength"],
1221
+ "width": w,
1222
+ "height": h,
1223
+ "seed": seed,
1224
+ "sampler": "Euler Ancestral",
1225
+ "model_hash": "cyberrealistic_pony_v110"
1226
+ }
1227
+
1228
+ png_path = create_metadata_png(result.images[0], metadata)
1229
+ info_text = format_generation_info(metadata, generation_time)
1230
+
1231
+ return result.images[0], png_path, info_text
1232
+
1233
+ except torch.cuda.OutOfMemoryError:
1234
+ pipe_manager.clear_memory()
1235
+ return None, None, "❌ GPU out of memory. Try lower strength or fewer steps."
1236
+ except Exception as e:
1237
+ logger.error(f"Generation error: {e}")
1238
+ return None, None, f"❌ Generation failed: {str(e)}"
1239
+ finally:
1240
+ pipe_manager.clear_memory()
1241
+
1242
+ def get_random_prompt():
1243
+ """Get a random example prompt"""
1244
+ return random.choice(EXAMPLE_PROMPTS)
1245
+
1246
+ # Enhanced Gradio interface
1247
+ def create_interface():
1248
+ """Create the Gradio interface"""
1249
+
1250
+ with gr.Blocks(
1251
+ title="CyberRealistic Pony - SDXL Generator",
1252
+ theme=gr.themes.Soft(primary_hue="blue"),
1253
+ css="""
1254
+ .generate-btn {
1255
+ background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important;
1256
+ border: none !important;
1257
+ }
1258
+ .generate-btn:hover {
1259
+ transform: translateY(-2px);
1260
+ box-shadow: 0 4px 12px rgba(0,0,0,0.2);
1261
+ }
1262
+ """
1263
+ ) as demo:
1264
+
1265
+ gr.Markdown("""
1266
+ # 🎨 CyberRealistic Pony Generator
1267
+
1268
+ **High-quality SDXL image generation** β€’ Optimized for HuggingFace Spaces β€’ **NSFW Content Filter Enabled**
1269
+
1270
+ > ⚑ **First generation takes longer** (model loading) β€’ πŸ“‹ **Metadata embedded** in all outputs β€’ πŸ›‘οΈ **Content filtered for safety**
1271
+ """)
1272
+
1273
+ with gr.Tabs():
1274
+ # Text to Image Tab
1275
+ with gr.TabItem("🎨 Text to Image", id="txt2img"):
1276
+ with gr.Row():
1277
+ with gr.Column(scale=1):
1278
+ with gr.Group():
1279
+ txt_prompt = gr.Textbox(
1280
+ label="✨ Prompt",
1281
+ placeholder="A beautiful landscape with mountains and sunset...",
1282
+ lines=3,
1283
+ max_lines=5
1284
+ )
1285
+
1286
+ with gr.Row():
1287
+ txt_example_btn = gr.Button("🎲 Random", size="sm")
1288
+ txt_clear_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")
1289
+
1290
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
1291
+ txt_negative = gr.Textbox(
1292
+ label="❌ Negative Prompt",
1293
+ value=DEFAULT_NEGATIVE,
1294
+ lines=2,
1295
+ max_lines=3
1296
+ )
1297
+
1298
+ txt_quality = gr.Checkbox(
1299
+ label="✨ Add Quality Tags",
1300
+ value=True,
1301
+ info="Automatically enhance prompt with quality tags"
1302
+ )
1303
+
1304
+ with gr.Row():
1305
+ txt_steps = gr.Slider(
1306
+ 10, 50, 25, step=1,
1307
+ label="πŸ“Š Steps",
1308
+ info="More steps = better quality, slower generation"
1309
+ )
1310
+ txt_guidance = gr.Slider(
1311
+ 1.0, 15.0, 7.5, step=0.5,
1312
+ label="πŸŽ›οΈ CFG Scale",
1313
+ info="How closely to follow the prompt"
1314
+ )
1315
+
1316
+ with gr.Row():
1317
+ txt_width = gr.Slider(
1318
+ 512, 1024, 768, step=64,
1319
+ label="πŸ“ Width"
1320
+ )
1321
+ txt_height = gr.Slider(
1322
+ 512, 1024, 768, step=64,
1323
+ label="πŸ“ Height"
1324
+ )
1325
+
1326
+ txt_seed = gr.Slider(
1327
+ -1, MAX_SEED, -1, step=1,
1328
+ label="🎲 Seed (-1 = random)",
1329
+ info="Use same seed for reproducible results"
1330
+ )
1331
+
1332
+ txt_generate_btn = gr.Button(
1333
+ "🎨 Generate Image",
1334
+ variant="primary",
1335
+ size="lg",
1336
+ elem_classes=["generate-btn"]
1337
+ )
1338
+
1339
+ with gr.Column(scale=1):
1340
+ txt_output_image = gr.Image(
1341
+ label="πŸ–ΌοΈ Generated Image",
1342
+ height=500,
1343
+ show_download_button=True
1344
+ )
1345
+ txt_download_file = gr.File(
1346
+ label="πŸ“₯ Download PNG (with metadata)",
1347
+ file_types=[".png"]
1348
+ )
1349
+ txt_info = gr.Textbox(
1350
+ label="ℹ️ Generation Info",
1351
+ lines=6,
1352
+ max_lines=8,
1353
+ interactive=False
1354
+ )
1355
+
1356
+ # Image to Image Tab
1357
+ with gr.TabItem("πŸ–ΌοΈ Image to Image", id="img2img"):
1358
+ with gr.Row():
1359
+ with gr.Column(scale=1):
1360
+ img_input = gr.Image(
1361
+ label="πŸ“€ Input Image",
1362
+ type="pil",
1363
+ height=300
1364
+ )
1365
+
1366
+ with gr.Group():
1367
+ img_prompt = gr.Textbox(
1368
+ label="✨ Transformation Prompt",
1369
+ placeholder="digital art style, vibrant colors...",
1370
+ lines=3
1371
+ )
1372
+
1373
+ with gr.Row():
1374
+ img_example_btn = gr.Button("🎲 Random", size="sm")
1375
+ img_clear_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")
1376
+
1377
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
1378
+ img_negative = gr.Textbox(
1379
+ label="❌ Negative Prompt",
1380
+ value=DEFAULT_NEGATIVE,
1381
+ lines=2
1382
+ )
1383
+
1384
+ img_quality = gr.Checkbox(
1385
+ label="✨ Add Quality Tags",
1386
+ value=True
1387
+ )
1388
+
1389
+ with gr.Row():
1390
+ img_steps = gr.Slider(10, 50, 25, step=1, label="πŸ“Š Steps")
1391
+ img_guidance = gr.Slider(1.0, 15.0, 7.5, step=0.5, label="πŸŽ›οΈ CFG")
1392
+
1393
+ img_strength = gr.Slider(
1394
+ 0.1, 1.0, 0.75, step=0.05,
1395
+ label="πŸ’ͺ Transformation Strength",
1396
+ info="Higher = more creative, lower = more faithful to input"
1397
+ )
1398
+
1399
+ img_seed = gr.Slider(-1, MAX_SEED, -1, step=1, label="🎲 Seed")
1400
+
1401
+ img_generate_btn = gr.Button(
1402
+ "πŸ–ΌοΈ Transform Image",
1403
+ variant="primary",
1404
+ size="lg",
1405
+ elem_classes=["generate-btn"]
1406
+ )
1407
+
1408
+ with gr.Column(scale=1):
1409
+ img_output_image = gr.Image(
1410
+ label="πŸ–ΌοΈ Transformed Image",
1411
+ height=500,
1412
+ show_download_button=True
1413
+ )
1414
+ img_download_file = gr.File(
1415
+ label="πŸ“₯ Download PNG (with metadata)",
1416
+ file_types=[".png"]
1417
+ )
1418
+ img_info = gr.Textbox(
1419
+ label="ℹ️ Generation Info",
1420
+ lines=6,
1421
+ interactive=False
1422
+ )
1423
+
1424
+ # Event handlers
1425
+ txt_generate_btn.click(
1426
+ fn=generate_txt2img,
1427
+ inputs=[txt_prompt, txt_negative, txt_steps, txt_guidance,
1428
+ txt_width, txt_height, txt_seed, txt_quality],
1429
+ outputs=[txt_output_image, txt_download_file, txt_info],
1430
+ show_progress=True
1431
  )
1432
 
1433
+ img_generate_btn.click(
1434
+ fn=generate_img2img,
1435
+ inputs=[img_input, img_prompt, img_negative, img_steps, img_guidance,
1436
+ img_strength, img_seed, img_quality],
1437
+ outputs=[img_output_image, img_download_file, img_info],
1438
+ show_progress=True
1439
+ )
1440
+
1441
+ # Example prompt buttons
1442
+ txt_example_btn.click(fn=get_random_prompt, outputs=[txt_prompt])
1443
+ img_example_btn.click(fn=get_random_prompt, outputs=[img_prompt])
1444
+
1445
+ # Clear buttons
1446
+ txt_clear_btn.click(lambda: "", outputs=[txt_prompt])
1447
+ img_clear_btn.click(lambda: "", outputs=[img_prompt])
1448
+
1449
+ return demo
1450
+
1451
+ # Initialize and launch
1452
+ if __name__ == "__main__":
1453
+ logger.info(f"πŸš€ Initializing CyberRealistic Pony Generator on {DEVICE}")
1454
+ logger.info(f"πŸ“± PyTorch version: {torch.__version__}")
1455
+ logger.info(f"πŸ›‘οΈ NSFW Content Filter: Enabled")
1456
+
1457
+ demo = create_interface()
1458
+ demo.queue(max_size=20) # Enable queuing for better UX
1459
+ demo.launch(
1460
+ server_name="0.0.0.0",
1461
+ server_port=7860,
1462
+ show_error=True,
1463
+ share=False # Set to True if you want a public link
1464
+ )
1465
+
1466
  # Example prompt buttons
1467
  txt_example_btn.click(fn=get_random_prompt, outputs=[txt_prompt])
1468
  img_example_btn.click(fn=get_random_prompt, outputs=[img_prompt])