Scaryplasmon96 commited on
Commit
457b619
·
verified ·
1 Parent(s): 8a2e6c2

😶‍🌫️

Browse files
Files changed (1) hide show
  1. app.py +221 -147
app.py CHANGED
@@ -1,154 +1,228 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Fix 1: Set Matplotlib backend ---
2
+ import matplotlib
3
+ matplotlib.use('Agg') # Set backend BEFORE importing pyplot or other conflicting libs
4
+ # --- End Fix 1 ---
5
 
6
+ import gradio as gr
 
7
  import torch
8
+ from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
9
+ from PIL import Image, ImageOps # Added ImageOps for inversion
10
+ import numpy as np
11
+ import os
12
+ import importlib
13
+ import traceback # For detailed error printing
14
+
15
+ # --- FidelityMLP Class (Ensure this is correct as provided by user) ---
16
+ class FidelityMLP(torch.nn.Module):
17
+ def __init__(self, hidden_size, output_size=None):
18
+ super().__init__()
19
+ self.hidden_size = hidden_size
20
+ self.output_size = output_size or hidden_size
21
+ self.net = torch.nn.Sequential(
22
+ torch.nn.Linear(1, 128), torch.nn.LayerNorm(128), torch.nn.SiLU(),
23
+ torch.nn.Linear(128, 256), torch.nn.LayerNorm(256), torch.nn.SiLU(),
24
+ torch.nn.Linear(256, hidden_size), torch.nn.LayerNorm(hidden_size), torch.nn.Tanh()
25
+ )
26
+ self.output_proj = torch.nn.Linear(hidden_size, self.output_size)
27
+ self.apply(self._init_weights)
28
+
29
+ def _init_weights(self, module):
30
+ if isinstance(module, torch.nn.Linear):
31
+ module.weight.data.normal_(mean=0.0, std=0.01)
32
+ if module.bias is not None: module.bias.data.zero_()
33
+
34
+ def forward(self, x, target_dim=None):
35
+ features = self.net(x)
36
+ outputs = self.output_proj(features)
37
+ if target_dim is not None and target_dim != self.output_size:
38
+ return self._adjust_dimension(outputs, target_dim)
39
+ return outputs
40
+
41
+ def _adjust_dimension(self, embeddings, target_dim):
42
+ current_dim = embeddings.shape[-1]
43
+ if target_dim > current_dim:
44
+ pad_size = target_dim - current_dim
45
+ padding = torch.zeros((*embeddings.shape[:-1], pad_size), device=embeddings.device, dtype=embeddings.dtype)
46
+ return torch.cat([embeddings, padding], dim=-1)
47
+ elif target_dim < current_dim:
48
+ return embeddings[..., :target_dim]
49
+ return embeddings
50
+
51
+ def save_pretrained(self, save_directory):
52
+ os.makedirs(save_directory, exist_ok=True)
53
+ config = {"hidden_size": self.hidden_size, "output_size": self.output_size}
54
+ torch.save(config, os.path.join(save_directory, "config.json"))
55
+ torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
56
+
57
+ @classmethod
58
+ def from_pretrained(cls, pretrained_model_path):
59
+ config_file = os.path.join(pretrained_model_path, "config.json")
60
+ model_file = os.path.join(pretrained_model_path, "pytorch_model.bin")
61
+ if not os.path.exists(config_file): raise FileNotFoundError(f"Config file not found at {config_file}")
62
+ if not os.path.exists(model_file): raise FileNotFoundError(f"Model file not found at {model_file}")
63
+ try:
64
+ config = torch.load(config_file, map_location=torch.device('cpu'))
65
+ if not isinstance(config, dict): raise TypeError(f"Expected config dict, got {type(config)}")
66
+ except Exception as e: print(f"Error loading config {config_file}: {e}"); raise
67
+ model = cls(hidden_size=config["hidden_size"], output_size=config.get("output_size", config["hidden_size"]))
68
+ try:
69
+ state_dict = torch.load(model_file, map_location=torch.device('cpu'))
70
+ model.load_state_dict(state_dict)
71
+ print(f"Successfully loaded FidelityMLP state dict from {model_file}")
72
+ except Exception as e: print(f"Error loading state dict {model_file}: {e}"); raise
73
+ return model
74
+
75
+ # --- Global Variables ---
76
+ pipeline = None
77
  device = "cuda" if torch.cuda.is_available() else "cpu"
78
+ model_id = "Scaryplasmon96/DoodlePixV1"
79
+
80
+ # --- Model Loading Function ---
81
+ def load_pipeline():
82
+ global pipeline
83
+ if pipeline is not None: return True
84
+ print(f"Loading model {model_id} onto {device}...")
85
+ try:
86
+ hf_cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
87
+ local_model_path = model_id # Let diffusers find/download
88
+
89
+ # Load Fidelity MLP if possible
90
+ fidelity_mlp_instance = None
91
+ try:
92
+ from huggingface_hub import snapshot_download, hf_hub_download
93
+ # Attempt to download config first to check existence
94
+ hf_hub_download(repo_id=model_id, filename="fidelity_mlp/config.json", cache_dir=hf_cache_dir)
95
+ # If config exists, download the whole subfolder
96
+ fidelity_mlp_path = snapshot_download(repo_id=model_id, allow_patterns="fidelity_mlp/*", local_dir_use_symlinks=False, cache_dir=hf_cache_dir)
97
+ fidelity_mlp_instance = FidelityMLP.from_pretrained(os.path.join(fidelity_mlp_path, "fidelity_mlp"))
98
+ fidelity_mlp_instance = fidelity_mlp_instance.to(device=device, dtype=torch.float16)
99
+ print("Fidelity MLP loaded successfully.")
100
+ except Exception as e:
101
+ print(f"Fidelity MLP not found or failed to load for {model_id}: {e}. Proceeding without MLP.")
102
+ fidelity_mlp_instance = None
103
+
104
+ scheduler = EulerAncestralDiscreteScheduler.from_pretrained(local_model_path, subfolder="scheduler")
105
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
106
+ local_model_path, torch_dtype=torch.float16, scheduler=scheduler, safety_checker=None
107
+ ).to(device)
108
+
109
+ if fidelity_mlp_instance:
110
+ pipeline.fidelity_mlp = fidelity_mlp_instance
111
+ print("Attached Fidelity MLP to pipeline.")
112
+
113
+ # Optimizations
114
+ if device == "cuda" and hasattr(pipeline, "enable_xformers_memory_efficient_attention"):
115
+ try: pipeline.enable_xformers_memory_efficient_attention(); print("Enabled xformers.")
116
+ except: print("Could not enable xformers. Using attention slicing."); pipeline.enable_attention_slicing()
117
+ else: pipeline.enable_attention_slicing(); print("Enabled attention slicing.")
118
+
119
+ print("Pipeline loaded successfully.")
120
+ return True
121
+ except Exception as e:
122
+ print(f"Error loading pipeline: {e}"); traceback.print_exc()
123
+ pipeline = None; raise gr.Error(f"Failed to load model: {e}")
124
+
125
+ # --- Image Generation Function (Corrected Input Handling) ---
126
+ def generate_image(drawing_input, prompt, fidelity_slider, steps, guidance, image_guidance, seed_val):
127
+ global pipeline
128
+ if pipeline is None:
129
+ if not load_pipeline(): return None, "Model not loaded. Check logs."
130
+
131
+ # --- Corrected Input Processing ---
132
+ print(f"DEBUG: Received drawing_input type: {type(drawing_input)}")
133
+ if isinstance(drawing_input, dict): print(f"DEBUG: Received drawing_input keys: {drawing_input.keys()}")
134
+
135
+ # Check if input is dict and get PIL image from 'composite' key
136
+ if isinstance(drawing_input, dict) and "composite" in drawing_input and isinstance(drawing_input["composite"], Image.Image):
137
+ input_image_pil = drawing_input["composite"].convert("RGB") # Get composite image
138
+ print("DEBUG: Using PIL Image from 'composite' key.")
139
+ else:
140
+ err_msg = "Drawing input format unexpected. Expected dict with PIL Image under 'composite' key."
141
+ print(f"ERROR: {err_msg} Input: {drawing_input}")
142
+ return None, err_msg
143
+ # --- End Corrected Input Processing ---
144
+
145
+ try:
146
+ # Invert the image: White bg -> Black bg, Black lines -> White lines
147
+ input_image_inverted = ImageOps.invert(input_image_pil)
148
+ #save the inverted image
149
+ input_image_inverted.save("input_image_inverted.png")
150
+
151
+ # Ensure image is 512x512
152
+ if input_image_inverted.size != (512, 512):
153
+ print(f"Resizing input image from {input_image_inverted.size} to (512, 512)")
154
+ input_image_inverted = input_image_inverted.resize((512, 512), Image.Resampling.LANCZOS)
155
+
156
+ # Prompt Construction
157
+ final_prompt = f"f{int(fidelity_slider)}, {prompt}"
158
+ if not final_prompt.endswith("background."): final_prompt += " background."
159
+
160
+ negative_prompt = "artifacts, blur, jpg, uncanny, deformed, glow, shadow, text, words, letters, signature, watermark"
161
+
162
+ # Generation
163
+ print(f"Generating with: Prompt='{final_prompt[:100]}...', Fidelity={int(fidelity_slider)}, Steps={steps}, Guidance={guidance}, ImageGuidance={image_guidance}, Seed={seed_val}")
164
+ seed_val = int(seed_val)
165
+ generator = torch.Generator(device=device).manual_seed(seed_val)
166
+
167
+ with torch.no_grad():
168
+ output = pipeline(
169
+ prompt=final_prompt, negative_prompt=negative_prompt, image=input_image_inverted,
170
+ num_inference_steps=int(steps), guidance_scale=float(guidance),
171
+ image_guidance_scale=float(image_guidance), generator=generator,
172
+ ).images[0]
173
+
174
+ print("Generation complete.")
175
+ return output, "Generation Complete"
176
+
177
+ except Exception as e:
178
+ print(f"Error during generation: {e}"); traceback.print_exc()
179
+ return None, f"Error during generation: {str(e)}"
180
+
181
+ # --- Gradio Interface ---
182
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", secondary_hue="blue")) as demo:
183
+ gr.Markdown("# DoodlePix Gradio App")
184
+ gr.Markdown(f"Using model: `{model_id}`.")
185
+ status_output = gr.Textbox(label="Status", interactive=False, value="App loading...")
186
+
187
+ with gr.Row():
188
+ with gr.Column(scale=1):
189
+ gr.Markdown("## 1. Draw Something (Black on White)")
190
+ # Keep type="pil" as it provides the composite key
191
+ drawing = gr.Sketchpad(
192
+ label="Drawing Canvas",
193
+ type="pil", # type="pil" gives dict output with 'composite' key
194
+ height=512, width=512,
195
+ brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=5),
196
+ show_label=True
197
  )
198
+ prompt_input = gr.Textbox(label="2. Enter Prompt", placeholder="Describe the image you want...")
199
+ fidelity = gr.Slider(0, 9, step=1, value=4, label="Fidelity (0=Creative, 9=Faithful)")
200
+ num_steps = gr.Slider(10, 50, step=1, value=25, label="Inference Steps")
201
+ guidance_scale = gr.Slider(1.0, 15.0, step=0.5, value=7.5, label="Guidance Scale (CFG)")
202
+ image_guidance_scale = gr.Slider(0.5, 5.0, step=0.1, value=1.5, label="Image Guidance Scale")
203
+ seed = gr.Number(label="Seed", value=42, precision=0)
204
+ generate_button = gr.Button("🚀 Generate Image!", variant="primary")
205
+
206
+ with gr.Column(scale=1):
207
+ gr.Markdown("## 3. Generated Image")
208
+ output_image = gr.Image(label="Result", type="pil", height=512, width=512, show_label=True)
209
+
210
+ generate_button.click(
211
+ fn=generate_image,
212
+ inputs=[drawing, prompt_input, fidelity, num_steps, guidance_scale, image_guidance_scale, seed],
213
+ outputs=[output_image, status_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  )
215
 
216
+ # --- Launch App ---
217
  if __name__ == "__main__":
218
+ initial_status = "App loading..."
219
+ print("Attempting to pre-load pipeline...")
220
+ try:
221
+ if load_pipeline(): initial_status = "Model pre-loaded successfully."
222
+ else: initial_status = "Model pre-loading failed. Will retry on first generation."
223
+ except Exception as e:
224
+ print(f"Pre-loading failed: {e}")
225
+ initial_status = f"Model pre-loading failed: {e}. Will retry on first generation."
226
+ print(f"Pre-loading status: {initial_status}")
227
+
228
+ demo.launch()