File size: 7,811 Bytes
f62d5e5
691af46
f62d5e5
691af46
 
 
 
 
 
f62d5e5
691af46
 
 
 
f62d5e5
691af46
 
f62d5e5
691af46
f62d5e5
bd56df4
f62d5e5
 
 
 
bd56df4
 
f62d5e5
bd56df4
f62d5e5
 
bd56df4
 
f62d5e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a828900
f62d5e5
 
 
 
 
 
 
 
 
a828900
f62d5e5
a828900
691af46
f62d5e5
691af46
 
 
f62d5e5
 
691af46
 
 
f62d5e5
691af46
 
f62d5e5
 
691af46
 
f62d5e5
 
 
 
 
 
 
 
 
691af46
f62d5e5
691af46
 
 
 
 
 
f62d5e5
691af46
 
 
 
 
 
 
 
 
 
 
f62d5e5
691af46
 
 
 
 
 
f62d5e5
 
691af46
 
 
 
f62d5e5
 
 
 
 
 
2702839
f62d5e5
691af46
 
f62d5e5
691af46
f62d5e5
691af46
 
 
f62d5e5
691af46
f62d5e5
 
 
 
691af46
 
f62d5e5
 
 
 
 
 
 
 
 
 
 
 
 
 
691af46
f62d5e5
691af46
f62d5e5
691af46
 
f62d5e5
 
 
 
691af46
 
f62d5e5
 
 
691af46
f62d5e5
 
691af46
f62d5e5
45adfa2
691af46
f62d5e5
691af46
 
 
69f7ee8
691af46
f62d5e5
691af46
f62d5e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import torch
import PIL
import gradio as gr

from typing import Optional
from accelerate import Accelerator
from diffusers import (
    AutoencoderKL,
    StableDiffusionXLControlNetPipeline,
    ControlNetModel,
    UNet2DConditionModel,
)
from transformers import (
    BlipProcessor, BlipForConditionalGeneration,
)
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, snapshot_download

import spaces


# ========== Initialization ==========

# Ensure required directories exist
os.makedirs("sdxl_light_caption_output", exist_ok=True)

# Download controlnet model snapshot
snapshot_download(
    repo_id='nickpai/sdxl_light_caption_output',
    local_dir='sdxl_light_caption_output'
)

# Device and precision setup
accelerator = Accelerator(mixed_precision="fp16")
weight_dtype = torch.float16 if accelerator.mixed_precision == "fp16" else torch.float32
device = accelerator.device

print(f"[INFO] Accelerator device: {device}")

# ========== Models ==========

# Pretrained paths
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
safetensors_ckpt = "sdxl_lightning_8step_unet.safetensors"
controlnet_path = "sdxl_light_caption_output/checkpoint-30000/controlnet"

# Load diffusion components
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae")
unet = UNet2DConditionModel.from_config(base_model_path, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download("ByteDance/SDXL-Lightning", safetensors_ckpt)))

controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=weight_dtype)

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    base_model_path, vae=vae, unet=unet, controlnet=controlnet
)
pipe.to(device, dtype=weight_dtype)
pipe.safety_checker = None

# Load BLIP captioning model
caption_model_name = "blip-image-captioning-large"
processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}")
caption_model = BlipForConditionalGeneration.from_pretrained(
    f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype
).to(device)

# ========== Utility Functions ==========

def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image:
    # Convert to LAB color space
    image_lab = image.convert('LAB')
    color_map_lab = color_map.convert('LAB')

    # Extract and merge LAB channels
    l, _, _ = image_lab.split()
    _, a_map, b_map = color_map_lab.split()
    merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map))

    return merged_lab.convert('RGB')


def remove_unlikely_words(prompt: str) -> str:
    """Removes predefined unlikely phrases from prompt text."""
    unlikely_words = []

    a1 = [f'{i}s' for i in range(1900, 2000)]
    a2 = [f'{i}' for i in range(1900, 2000)]
    a3 = [f'year {i}' for i in range(1900, 2000)]
    a4 = [f'circa {i}' for i in range(1900, 2000)]

    b1 = [f"{y[0]} {y[1]} {y[2]} {y[3]} s" for y in a1]
    b2 = [f"{y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
    b3 = [f"year {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
    b4 = [f"circa {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]

    manual = [  # same list as your original words_list
        "black and white,", "black and white", "black & white,", "black & white", "circa", 
        "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", 
        "black - and - white photography,", "monochrome bw,", "black white,", "black an white,",
        "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo",
        "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy",
        "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w",
        "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo",
        "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy  photo,",
        "black-and-white photo,", "black-and-white photo", "black - and - white photography",
        "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic",
        "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo",
        "black - and - white photograph,", "black - and - white photograph", "black on white,",
        "black on white", "black-and-white", "historical image,", "historical picture,", 
        "historical photo,", "historical photograph,", "archival photo,", "taken in the early",
        "taken in the late", "taken in the", "historic photograph,", "restored,", "restored", 
        "historical photo", "historical setting,",
        "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", 
        "taken in", "shot on leica", "shot on leica sl2", "sl2",
        "taken with a leica camera", "leica sl2", "leica", "setting", 
        "overcast day", "overcast weather", "slight overcast", "overcast", 
        "picture taken in", "photo taken in", 
        ", photo", ",  photo", ",   photo", ",    photo", ", photograph",
        ",,", ",,,", ",,,,", " ,", "  ,", "   ,", "    ,", 
    ]

    unlikely_words.extend(a1 + a2 + a3 + a4 + b1 + b2 + b3 + b4 + manual)

    for word in unlikely_words:
        prompt = prompt.replace(word, "")
    return prompt


def get_image_paths(folder_path: str) -> list:
    return [[os.path.join(folder_path, f)] for f in os.listdir(folder_path)
            if f.lower().endswith((".jpg", ".png"))]


@spaces.GPU
def process_image(image_path: str,                 
                  positive_prompt: Optional[str],
                  negative_prompt: Optional[str],
                  seed: int) -> tuple[PIL.Image.Image, str]:

    torch.manual_seed(seed)
    image = PIL.Image.open(image_path)
    original_size = image.size
    control_image = image.convert("L").convert("RGB").resize((512, 512))

    # Image captioning
    input_text = "a photography of"
    inputs = processor(image, input_text, return_tensors="pt").to(device, dtype=weight_dtype)
    caption_ids = caption_model.generate(**inputs)
    caption = processor.decode(caption_ids[0], skip_special_tokens=True)
    caption = remove_unlikely_words(caption)

    # Inference
    final_prompt = [f"{positive_prompt}, {caption}"]
    result = pipe(prompt=final_prompt,
                  negative_prompt=negative_prompt,
                  num_inference_steps=8,
                  generator=torch.manual_seed(seed),
                  image=control_image)

    colorized = apply_color(control_image, result.images[0]).resize(original_size)
    return colorized, caption


# ========== Gradio UI ==========

def create_interface():
    examples = get_image_paths("example/legacy_images")

    return gr.Interface(
        fn=process_image,
        inputs=[
            gr.Image(label="Upload Image", type='filepath',
                     value="example/legacy_images/Hollywood-Sign.jpg"),
            gr.Textbox(label="Positive Prompt", placeholder="Enter details to enhance the caption"),
            gr.Textbox(label="Negative Prompt", value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate"),
        ],
        outputs=[
            gr.Image(label="Colorized Image", format="jpeg",
                     value="example/UUColor_results/Hollywood-Sign.jpeg"),
            gr.Textbox(label="Caption", show_copy_button=True)
        ],
        examples=examples,
        additional_inputs=[gr.Slider(0, 1000, 123, label="Seed")],
        title="Text-Guided Image Colorization",
        description="Upload a grayscale image and generate a color version guided by automatic captioning.",
        cache_examples=False
    )


def main():
    interface = create_interface()
    interface.launch(ssr_mode=False)


if __name__ == "__main__":
    main()