Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,811 Bytes
f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 bd56df4 f62d5e5 bd56df4 f62d5e5 bd56df4 f62d5e5 bd56df4 f62d5e5 a828900 f62d5e5 a828900 f62d5e5 a828900 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 2702839 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 691af46 f62d5e5 45adfa2 691af46 f62d5e5 691af46 69f7ee8 691af46 f62d5e5 691af46 f62d5e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import os
import torch
import PIL
import gradio as gr
from typing import Optional
from accelerate import Accelerator
from diffusers import (
AutoencoderKL,
StableDiffusionXLControlNetPipeline,
ControlNetModel,
UNet2DConditionModel,
)
from transformers import (
BlipProcessor, BlipForConditionalGeneration,
)
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, snapshot_download
import spaces
# ========== Initialization ==========
# Ensure required directories exist
os.makedirs("sdxl_light_caption_output", exist_ok=True)
# Download controlnet model snapshot
snapshot_download(
repo_id='nickpai/sdxl_light_caption_output',
local_dir='sdxl_light_caption_output'
)
# Device and precision setup
accelerator = Accelerator(mixed_precision="fp16")
weight_dtype = torch.float16 if accelerator.mixed_precision == "fp16" else torch.float32
device = accelerator.device
print(f"[INFO] Accelerator device: {device}")
# ========== Models ==========
# Pretrained paths
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
safetensors_ckpt = "sdxl_lightning_8step_unet.safetensors"
controlnet_path = "sdxl_light_caption_output/checkpoint-30000/controlnet"
# Load diffusion components
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae")
unet = UNet2DConditionModel.from_config(base_model_path, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download("ByteDance/SDXL-Lightning", safetensors_ckpt)))
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=weight_dtype)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path, vae=vae, unet=unet, controlnet=controlnet
)
pipe.to(device, dtype=weight_dtype)
pipe.safety_checker = None
# Load BLIP captioning model
caption_model_name = "blip-image-captioning-large"
processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}")
caption_model = BlipForConditionalGeneration.from_pretrained(
f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype
).to(device)
# ========== Utility Functions ==========
def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image:
# Convert to LAB color space
image_lab = image.convert('LAB')
color_map_lab = color_map.convert('LAB')
# Extract and merge LAB channels
l, _, _ = image_lab.split()
_, a_map, b_map = color_map_lab.split()
merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map))
return merged_lab.convert('RGB')
def remove_unlikely_words(prompt: str) -> str:
"""Removes predefined unlikely phrases from prompt text."""
unlikely_words = []
a1 = [f'{i}s' for i in range(1900, 2000)]
a2 = [f'{i}' for i in range(1900, 2000)]
a3 = [f'year {i}' for i in range(1900, 2000)]
a4 = [f'circa {i}' for i in range(1900, 2000)]
b1 = [f"{y[0]} {y[1]} {y[2]} {y[3]} s" for y in a1]
b2 = [f"{y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
b3 = [f"year {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
b4 = [f"circa {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
manual = [ # same list as your original words_list
"black and white,", "black and white", "black & white,", "black & white", "circa",
"balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,",
"black - and - white photography,", "monochrome bw,", "black white,", "black an white,",
"grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo",
"back and white", "back and white,", "monochrome contrast", "monochrome", "grainy",
"grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w",
"grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo",
"b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,",
"black-and-white photo,", "black-and-white photo", "black - and - white photography",
"b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic",
"blurry photo,", "blurry,", "blurry photography,", "monochromatic photo",
"black - and - white photograph,", "black - and - white photograph", "black on white,",
"black on white", "black-and-white", "historical image,", "historical picture,",
"historical photo,", "historical photograph,", "archival photo,", "taken in the early",
"taken in the late", "taken in the", "historic photograph,", "restored,", "restored",
"historical photo", "historical setting,",
"historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated",
"taken in", "shot on leica", "shot on leica sl2", "sl2",
"taken with a leica camera", "leica sl2", "leica", "setting",
"overcast day", "overcast weather", "slight overcast", "overcast",
"picture taken in", "photo taken in",
", photo", ", photo", ", photo", ", photo", ", photograph",
",,", ",,,", ",,,,", " ,", " ,", " ,", " ,",
]
unlikely_words.extend(a1 + a2 + a3 + a4 + b1 + b2 + b3 + b4 + manual)
for word in unlikely_words:
prompt = prompt.replace(word, "")
return prompt
def get_image_paths(folder_path: str) -> list:
return [[os.path.join(folder_path, f)] for f in os.listdir(folder_path)
if f.lower().endswith((".jpg", ".png"))]
@spaces.GPU
def process_image(image_path: str,
positive_prompt: Optional[str],
negative_prompt: Optional[str],
seed: int) -> tuple[PIL.Image.Image, str]:
torch.manual_seed(seed)
image = PIL.Image.open(image_path)
original_size = image.size
control_image = image.convert("L").convert("RGB").resize((512, 512))
# Image captioning
input_text = "a photography of"
inputs = processor(image, input_text, return_tensors="pt").to(device, dtype=weight_dtype)
caption_ids = caption_model.generate(**inputs)
caption = processor.decode(caption_ids[0], skip_special_tokens=True)
caption = remove_unlikely_words(caption)
# Inference
final_prompt = [f"{positive_prompt}, {caption}"]
result = pipe(prompt=final_prompt,
negative_prompt=negative_prompt,
num_inference_steps=8,
generator=torch.manual_seed(seed),
image=control_image)
colorized = apply_color(control_image, result.images[0]).resize(original_size)
return colorized, caption
# ========== Gradio UI ==========
def create_interface():
examples = get_image_paths("example/legacy_images")
return gr.Interface(
fn=process_image,
inputs=[
gr.Image(label="Upload Image", type='filepath',
value="example/legacy_images/Hollywood-Sign.jpg"),
gr.Textbox(label="Positive Prompt", placeholder="Enter details to enhance the caption"),
gr.Textbox(label="Negative Prompt", value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate"),
],
outputs=[
gr.Image(label="Colorized Image", format="jpeg",
value="example/UUColor_results/Hollywood-Sign.jpeg"),
gr.Textbox(label="Caption", show_copy_button=True)
],
examples=examples,
additional_inputs=[gr.Slider(0, 1000, 123, label="Seed")],
title="Text-Guided Image Colorization",
description="Upload a grayscale image and generate a color version guided by automatic captioning.",
cache_examples=False
)
def main():
interface = create_interface()
interface.launch(ssr_mode=False)
if __name__ == "__main__":
main()
|