fffiloni commited on
Commit
f62d5e5
·
verified ·
1 Parent(s): fa48cf6
Files changed (1) hide show
  1. gradio_ui.py +109 -292
gradio_ui.py CHANGED
@@ -1,88 +1,107 @@
1
- import spaces
2
- import PIL
3
  import torch
4
- import subprocess
5
  import gradio as gr
6
- import os
7
 
8
  from typing import Optional
9
  from accelerate import Accelerator
10
  from diffusers import (
11
  AutoencoderKL,
12
- StableDiffusionXLControlNetPipeline,
13
  ControlNetModel,
14
  UNet2DConditionModel,
15
  )
16
  from transformers import (
17
- BlipProcessor, BlipForConditionalGeneration,
18
- VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
19
  )
20
- from huggingface_hub import hf_hub_download
21
  from safetensors.torch import load_file
22
- from clip_interrogator import Interrogator, Config, list_clip_models
23
 
24
- from huggingface_hub import snapshot_download
25
 
26
- # Download colorization models
 
 
 
27
  os.makedirs("sdxl_light_caption_output", exist_ok=True)
28
- os.makedirs("sdxl_light_custom_caption_output", exist_ok=True)
29
 
 
30
  snapshot_download(
31
- repo_id = 'nickpai/sdxl_light_caption_output',
32
- local_dir = 'sdxl_light_caption_output'
33
  )
34
 
35
- snapshot_download(
36
- repo_id = 'nickpai/sdxl_light_custom_caption_output',
37
- local_dir = 'sdxl_light_custom_caption_output'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
 
 
 
 
 
 
 
 
 
39
 
 
40
 
41
  def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image:
42
- # Convert input images to LAB color space
43
  image_lab = image.convert('LAB')
44
  color_map_lab = color_map.convert('LAB')
45
 
46
- # Split LAB channels
47
- l, a , b = image_lab.split()
48
  _, a_map, b_map = color_map_lab.split()
49
-
50
- # Merge LAB channels with color map
51
  merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map))
52
 
53
- # Convert merged LAB image back to RGB color space
54
- result_rgb = merged_lab.convert('RGB')
55
- return result_rgb
56
-
57
- def remove_unlikely_words(prompt: str) -> str:
58
- """
59
- Removes unlikely words from a prompt.
60
 
61
- Args:
62
- prompt: The text prompt to be cleaned.
63
 
64
- Returns:
65
- The cleaned prompt with unlikely words removed.
66
- """
67
  unlikely_words = []
68
 
69
- a1_list = [f'{i}s' for i in range(1900, 2000)]
70
- a2_list = [f'{i}' for i in range(1900, 2000)]
71
- a3_list = [f'year {i}' for i in range(1900, 2000)]
72
- a4_list = [f'circa {i}' for i in range(1900, 2000)]
73
- b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list]
74
- b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
75
- b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
76
- b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
 
77
 
78
- words_list = [
79
  "black and white,", "black and white", "black & white,", "black & white", "circa",
80
  "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,",
81
  "black - and - white photography,", "monochrome bw,", "black white,", "black an white,",
82
  "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo",
83
  "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy",
84
  "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w",
85
- "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo",
86
  "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,",
87
  "black-and-white photo,", "black-and-white photo", "black - and - white photography",
88
  "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic",
@@ -94,287 +113,85 @@ def remove_unlikely_words(prompt: str) -> str:
94
  "historical photo", "historical setting,",
95
  "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated",
96
  "taken in", "shot on leica", "shot on leica sl2", "sl2",
97
- "taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting",
98
  "overcast day", "overcast weather", "slight overcast", "overcast",
99
  "picture taken in", "photo taken in",
100
  ", photo", ", photo", ", photo", ", photo", ", photograph",
101
  ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,",
102
  ]
103
 
104
- unlikely_words.extend(a1_list)
105
- unlikely_words.extend(a2_list)
106
- unlikely_words.extend(a3_list)
107
- unlikely_words.extend(a4_list)
108
- unlikely_words.extend(b1_list)
109
- unlikely_words.extend(b2_list)
110
- unlikely_words.extend(b3_list)
111
- unlikely_words.extend(b4_list)
112
- unlikely_words.extend(words_list)
113
-
114
  for word in unlikely_words:
115
  prompt = prompt.replace(word, "")
116
  return prompt
117
 
118
- def blip_image_captioning(image: PIL.Image.Image,
119
- model_backbone: str,
120
- weight_dtype: type,
121
- device: str,
122
- conditional: bool) -> str:
123
- # https://huggingface.co/Salesforce/blip-image-captioning-large
124
- # https://huggingface.co/Salesforce/blip-image-captioning-base
125
- if weight_dtype == torch.bfloat16: # in case model might not accept bfloat16 data type
126
- weight_dtype = torch.float16
127
-
128
- processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}")
129
- model = BlipForConditionalGeneration.from_pretrained(
130
- f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device)
131
-
132
- valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"]
133
- if model_backbone not in valid_backbones:
134
- raise ValueError(f"Invalid model backbone '{model_backbone}'. \
135
- Valid options are: {', '.join(valid_backbones)}")
136
-
137
- if conditional:
138
- text = "a photography of"
139
- inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype)
140
- else:
141
- inputs = processor(image, return_tensors="pt").to(device)
142
- out = model.generate(**inputs)
143
- caption = processor.decode(out[0], skip_special_tokens=True)
144
- return caption
145
-
146
- # def vit_gpt2_image_captioning(image: PIL.Image.Image, device: str) -> str:
147
- # # https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
148
- # model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
149
- # feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
150
- # tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
151
-
152
- # max_length = 16
153
- # num_beams = 4
154
- # gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
155
-
156
- # pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
157
- # pixel_values = pixel_values.to(device)
158
-
159
- # output_ids = model.generate(pixel_values, **gen_kwargs)
160
-
161
- # preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
162
- # caption = [pred.strip() for pred in preds]
163
-
164
- # return caption[0]
165
-
166
- # def clip_image_captioning(image: PIL.Image.Image,
167
- # clip_model_name: str,
168
- # device: str) -> str:
169
- # # validate clip model name
170
- # models = list_clip_models()
171
- # if clip_model_name not in models:
172
- # raise ValueError(f"Could not find CLIP model {clip_model_name}! \
173
- # Available models: {models}")
174
- # config = Config(device=device, clip_model_name=clip_model_name)
175
- # config.apply_low_vram_defaults()
176
- # ci = Interrogator(config)
177
- # caption = ci.interrogate(image)
178
- # return caption
179
-
180
- # Define a function to process the image with the loaded model
181
  @spaces.GPU
182
- def process_image(image_path: str,
183
- controlnet_model_name_or_path: str,
184
- caption_model_name: str,
185
  positive_prompt: Optional[str],
186
  negative_prompt: Optional[str],
187
- seed: int,
188
- num_inference_steps: int,
189
- mixed_precision: str,
190
- pretrained_model_name_or_path: str,
191
- pretrained_vae_model_name_or_path: Optional[str],
192
- revision: Optional[str],
193
- variant: Optional[str],
194
- repo: str,
195
- ckpt: str,) -> PIL.Image.Image:
196
- # Seed
197
- generator = torch.manual_seed(seed)
198
-
199
- # Accelerator Setting
200
- accelerator = Accelerator(
201
- mixed_precision=mixed_precision,
202
- cpu=False
203
- )
204
-
205
- print(f"Accelerator device: {accelerator.device}")
206
-
207
- weight_dtype = torch.float32
208
- if accelerator.mixed_precision == "fp16":
209
- weight_dtype = torch.float16
210
- elif accelerator.mixed_precision == "bf16":
211
- weight_dtype = torch.bfloat16
212
-
213
- vae_path = (
214
- pretrained_model_name_or_path
215
- if pretrained_vae_model_name_or_path is None
216
- else pretrained_vae_model_name_or_path
217
- )
218
- vae = AutoencoderKL.from_pretrained(
219
- vae_path,
220
- subfolder="vae" if pretrained_vae_model_name_or_path is None else None,
221
- revision=revision,
222
- variant=variant,
223
- )
224
- unet = UNet2DConditionModel.from_config(
225
- pretrained_model_name_or_path,
226
- subfolder="unet",
227
- revision=revision,
228
- variant=variant,
229
- )
230
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
231
-
232
- # Move vae, unet and text_encoder to device and cast to weight_dtype
233
- # The VAE is in float32 to avoid NaN losses.
234
- if pretrained_vae_model_name_or_path is not None:
235
- vae.to(accelerator.device, dtype=weight_dtype)
236
- else:
237
- vae.to(accelerator.device, dtype=torch.float32)
238
- unet.to(accelerator.device, dtype=weight_dtype)
239
-
240
- controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype)
241
- pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
242
- pretrained_model_name_or_path,
243
- vae=vae,
244
- unet=unet,
245
- controlnet=controlnet,
246
- )
247
- pipe.to(accelerator.device, dtype=weight_dtype)
248
 
 
249
  image = PIL.Image.open(image_path)
250
-
251
- # Prepare everything with our `accelerator`.
252
- pipe, image = accelerator.prepare(pipe, image)
253
- pipe.safety_checker = None
254
-
255
- # Convert image into grayscale
256
  original_size = image.size
257
  control_image = image.convert("L").convert("RGB").resize((512, 512))
258
-
259
  # Image captioning
260
- if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base":
261
- caption = blip_image_captioning(control_image, caption_model_name,
262
- weight_dtype, accelerator.device, conditional=True)
263
- # elif caption_model_name == "ViT-L-14/openai" or "ViT-H-14/laion2b_s32b_b79k":
264
- # caption = clip_image_captioning(control_image, caption_model_name, accelerator.device)
265
- # elif caption_model_name == "vit-gpt2-image-captioning":
266
- # caption = vit_gpt2_image_captioning(control_image, accelerator.device)
267
  caption = remove_unlikely_words(caption)
268
 
269
- # Combine positive prompt and captioning result
270
- prompt = [positive_prompt + ", " + caption]
271
-
272
- # Image colorization
273
- image = pipe(prompt=prompt,
274
- negative_prompt=negative_prompt,
275
- num_inference_steps=num_inference_steps,
276
- generator=generator,
277
- image=control_image).images[0]
278
-
279
- # Apply color mapping
280
- result_image = apply_color(control_image, image)
281
- result_image = result_image.resize(original_size)
282
- return result_image, caption
283
-
284
- # Define the image gallery based on folder path
285
- def get_image_paths(folder_path):
286
- import os
287
- image_paths = []
288
- for filename in os.listdir(folder_path):
289
- if filename.endswith(".jpg") or filename.endswith(".png"):
290
- image_paths.append([os.path.join(folder_path, filename)])
291
- return image_paths
292
-
293
- # Create the Gradio interface
294
  def create_interface():
295
- controlnet_model_dict = {
296
- "sdxl-light-caption-30000": "sdxl_light_caption_output/checkpoint-30000/controlnet",
297
- "sdxl-light-custom-caption-30000": "sdxl_light_custom_caption_output/checkpoint-30000/controlnet",
298
- }
299
- images = get_image_paths("example/legacy_images") # Replace with your folder path
300
 
301
- interface = gr.Interface(
302
  fn=process_image,
303
  inputs=[
304
- gr.Image(label="Upload image",
305
- value="example/legacy_images/Hollywood-Sign.jpg",
306
- type='filepath'),
307
- gr.Dropdown(choices=[controlnet_model_dict[key] for key in controlnet_model_dict],
308
- value=controlnet_model_dict["sdxl-light-caption-30000"],
309
- label="Select ControlNet Model"),
310
- gr.Dropdown(choices=["blip-image-captioning-large",
311
- "blip-image-captioning-base",],
312
- value="blip-image-captioning-large",
313
- label="Select Image Captioning Model"),
314
- gr.Textbox(label="Positive Prompt", placeholder="Text for positive prompt"),
315
- gr.Textbox(value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate",
316
- label="Negative Prompt", placeholder="Text for negative prompt"),
317
  ],
318
  outputs=[
319
- gr.Image(label="Colorized image",
320
- value="example/UUColor_results/Hollywood-Sign.jpeg",
321
- format="jpeg"),
322
- gr.Textbox(label="Captioning Result", show_copy_button=True)
323
- ],
324
- examples=images,
325
- additional_inputs=[
326
- # gr.Radio(choices=["Original", "Square"], value="Original",
327
- # label="Output resolution"),
328
- # gr.Slider(minimum=128, maximum=512, value=256, step=128,
329
- # label="Height & Width",
330
- # info='Only effect if select "Square" output resolution'),
331
- gr.Slider(0, 1000, 123, label="Seed"),
332
- gr.Radio(choices=[1, 2, 4, 8],
333
- value=8,
334
- label="Inference Steps",
335
- info="1-step, 2-step, 4-step, or 8-step distilled models"),
336
- gr.Radio(choices=["no", "fp16", "bf16"],
337
- value="fp16",
338
- label="Mixed Precision",
339
- info="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16)."),
340
- gr.Dropdown(choices=["stabilityai/stable-diffusion-xl-base-1.0"],
341
- value="stabilityai/stable-diffusion-xl-base-1.0",
342
- label="Base Model",
343
- info="Path to pretrained model or model identifier from huggingface.co/models."),
344
- gr.Dropdown(choices=["None"],
345
- value=None,
346
- label="VAE Model",
347
- info="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038."),
348
- gr.Dropdown(choices=["None"],
349
- value=None,
350
- label="Varient",
351
- info="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16"),
352
- gr.Dropdown(choices=["None"],
353
- value=None,
354
- label="Revision",
355
- info="Revision of pretrained model identifier from huggingface.co/models."),
356
- gr.Dropdown(choices=["ByteDance/SDXL-Lightning"],
357
- value="ByteDance/SDXL-Lightning",
358
- label="Repository",
359
- info="Repository from huggingface.co"),
360
- gr.Dropdown(choices=["sdxl_lightning_1step_unet.safetensors",
361
- "sdxl_lightning_2step_unet.safetensors",
362
- "sdxl_lightning_4step_unet.safetensors",
363
- "sdxl_lightning_8step_unet.safetensors"],
364
- value="sdxl_lightning_8step_unet.safetensors",
365
- label="Checkpoint",
366
- info="Available checkpoints from the repository. Caution! Checkpoint's 'N'step must match with inference steps"),
367
  ],
 
 
368
  title="Text-Guided Image Colorization",
369
- description="Upload an image and select a model to colorize it.",
370
  cache_examples=False
371
  )
372
- return interface
373
 
374
  def main():
375
- # Launch the Gradio interface
376
  interface = create_interface()
377
  interface.launch(ssr_mode=False)
378
 
 
379
  if __name__ == "__main__":
380
- main()
 
1
+ import os
 
2
  import torch
3
+ import PIL
4
  import gradio as gr
 
5
 
6
  from typing import Optional
7
  from accelerate import Accelerator
8
  from diffusers import (
9
  AutoencoderKL,
10
+ StableDiffusionXLControlNetPipeline,
11
  ControlNetModel,
12
  UNet2DConditionModel,
13
  )
14
  from transformers import (
15
+ BlipProcessor, BlipForConditionalGeneration,
 
16
  )
 
17
  from safetensors.torch import load_file
18
+ from huggingface_hub import hf_hub_download, snapshot_download
19
 
20
+ import spaces
21
 
22
+
23
+ # ========== Initialization ==========
24
+
25
+ # Ensure required directories exist
26
  os.makedirs("sdxl_light_caption_output", exist_ok=True)
 
27
 
28
+ # Download controlnet model snapshot
29
  snapshot_download(
30
+ repo_id='nickpai/sdxl_light_caption_output',
31
+ local_dir='sdxl_light_caption_output'
32
  )
33
 
34
+ # Device and precision setup
35
+ accelerator = Accelerator(mixed_precision="fp16")
36
+ weight_dtype = torch.float16 if accelerator.mixed_precision == "fp16" else torch.float32
37
+ device = accelerator.device
38
+
39
+ print(f"[INFO] Accelerator device: {device}")
40
+
41
+ # ========== Models ==========
42
+
43
+ # Pretrained paths
44
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
45
+ safetensors_ckpt = "sdxl_lightning_8step_unet.safetensors"
46
+ controlnet_path = "sdxl_light_caption_output/checkpoint-30000/controlnet"
47
+
48
+ # Load diffusion components
49
+ vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae")
50
+ unet = UNet2DConditionModel.from_config(base_model_path, subfolder="unet")
51
+ unet.load_state_dict(load_file(hf_hub_download("ByteDance/SDXL-Lightning", safetensors_ckpt)))
52
+
53
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=weight_dtype)
54
+
55
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
56
+ base_model_path, vae=vae, unet=unet, controlnet=controlnet
57
  )
58
+ pipe.to(device, dtype=weight_dtype)
59
+ pipe.safety_checker = None
60
+
61
+ # Load BLIP captioning model
62
+ caption_model_name = "blip-image-captioning-large"
63
+ processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}")
64
+ caption_model = BlipForConditionalGeneration.from_pretrained(
65
+ f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype
66
+ ).to(device)
67
 
68
+ # ========== Utility Functions ==========
69
 
70
  def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image:
71
+ # Convert to LAB color space
72
  image_lab = image.convert('LAB')
73
  color_map_lab = color_map.convert('LAB')
74
 
75
+ # Extract and merge LAB channels
76
+ l, _, _ = image_lab.split()
77
  _, a_map, b_map = color_map_lab.split()
 
 
78
  merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map))
79
 
80
+ return merged_lab.convert('RGB')
 
 
 
 
 
 
81
 
 
 
82
 
83
+ def remove_unlikely_words(prompt: str) -> str:
84
+ """Removes predefined unlikely phrases from prompt text."""
 
85
  unlikely_words = []
86
 
87
+ a1 = [f'{i}s' for i in range(1900, 2000)]
88
+ a2 = [f'{i}' for i in range(1900, 2000)]
89
+ a3 = [f'year {i}' for i in range(1900, 2000)]
90
+ a4 = [f'circa {i}' for i in range(1900, 2000)]
91
+
92
+ b1 = [f"{y[0]} {y[1]} {y[2]} {y[3]} s" for y in a1]
93
+ b2 = [f"{y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
94
+ b3 = [f"year {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
95
+ b4 = [f"circa {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
96
 
97
+ manual = [ # same list as your original words_list
98
  "black and white,", "black and white", "black & white,", "black & white", "circa",
99
  "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,",
100
  "black - and - white photography,", "monochrome bw,", "black white,", "black an white,",
101
  "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo",
102
  "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy",
103
  "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w",
104
+ "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo",
105
  "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,",
106
  "black-and-white photo,", "black-and-white photo", "black - and - white photography",
107
  "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic",
 
113
  "historical photo", "historical setting,",
114
  "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated",
115
  "taken in", "shot on leica", "shot on leica sl2", "sl2",
116
+ "taken with a leica camera", "leica sl2", "leica", "setting",
117
  "overcast day", "overcast weather", "slight overcast", "overcast",
118
  "picture taken in", "photo taken in",
119
  ", photo", ", photo", ", photo", ", photo", ", photograph",
120
  ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,",
121
  ]
122
 
123
+ unlikely_words.extend(a1 + a2 + a3 + a4 + b1 + b2 + b3 + b4 + manual)
124
+
 
 
 
 
 
 
 
 
125
  for word in unlikely_words:
126
  prompt = prompt.replace(word, "")
127
  return prompt
128
 
129
+
130
+ def get_image_paths(folder_path: str) -> list:
131
+ return [[os.path.join(folder_path, f)] for f in os.listdir(folder_path)
132
+ if f.lower().endswith((".jpg", ".png"))]
133
+
134
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  @spaces.GPU
136
+ def process_image(image_path: str,
 
 
137
  positive_prompt: Optional[str],
138
  negative_prompt: Optional[str],
139
+ seed: int) -> tuple[PIL.Image.Image, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ torch.manual_seed(seed)
142
  image = PIL.Image.open(image_path)
 
 
 
 
 
 
143
  original_size = image.size
144
  control_image = image.convert("L").convert("RGB").resize((512, 512))
145
+
146
  # Image captioning
147
+ input_text = "a photography of"
148
+ inputs = processor(image, input_text, return_tensors="pt").to(device, dtype=weight_dtype)
149
+ caption_ids = caption_model.generate(**inputs)
150
+ caption = processor.decode(caption_ids[0], skip_special_tokens=True)
 
 
 
151
  caption = remove_unlikely_words(caption)
152
 
153
+ # Inference
154
+ final_prompt = [f"{positive_prompt}, {caption}"]
155
+ result = pipe(prompt=final_prompt,
156
+ negative_prompt=negative_prompt,
157
+ num_inference_steps=8,
158
+ generator=torch.manual_seed(seed),
159
+ image=control_image)
160
+
161
+ colorized = apply_color(control_image, result.images[0]).resize(original_size)
162
+ return colorized, caption
163
+
164
+
165
+ # ========== Gradio UI ==========
166
+
 
 
 
 
 
 
 
 
 
 
 
167
  def create_interface():
168
+ examples = get_image_paths("example/legacy_images")
 
 
 
 
169
 
170
+ return gr.Interface(
171
  fn=process_image,
172
  inputs=[
173
+ gr.Image(label="Upload Image", type='filepath',
174
+ value="example/legacy_images/Hollywood-Sign.jpg"),
175
+ gr.Textbox(label="Positive Prompt", placeholder="Enter details to enhance the caption"),
176
+ gr.Textbox(label="Negative Prompt", value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate"),
 
 
 
 
 
 
 
 
 
177
  ],
178
  outputs=[
179
+ gr.Image(label="Colorized Image", format="jpeg",
180
+ value="example/UUColor_results/Hollywood-Sign.jpeg"),
181
+ gr.Textbox(label="Caption", show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  ],
183
+ examples=examples,
184
+ additional_inputs=[gr.Slider(0, 1000, 123, label="Seed")],
185
  title="Text-Guided Image Colorization",
186
+ description="Upload a grayscale image and generate a color version guided by automatic captioning.",
187
  cache_examples=False
188
  )
189
+
190
 
191
  def main():
 
192
  interface = create_interface()
193
  interface.launch(ssr_mode=False)
194
 
195
+
196
  if __name__ == "__main__":
197
+ main()