img2img-turbo / gradio_sketch2imagehd.py
Inmental's picture
Upload folder using huggingface_hub
d97b2a0 verified
import random
import numpy as np
from PIL import Image, ImageOps
import base64
from io import BytesIO
import torch
import torchvision.transforms.functional as F
from transformers import BlipProcessor, BlipForConditionalGeneration
from src.pix2pix_turbo import Pix2Pix_Turbo
import nltk
from nltk import pos_tag
from nltk.tokenize import word_tokenize
import re
import os
import json
import logging
import gc
import gradio as gr
from torch.cuda.amp import autocast
# Set environment variable for better memory management
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
# Function to clear CUDA cache and collect garbage
def clear_memory():
torch.cuda.empty_cache()
gc.collect()
# Load the configuration from config.json
with open('config.json', 'r') as config_file:
config = json.load(config_file)
# Setup logging as per config
logging.basicConfig(level=config["logging"]["level"], format=config["logging"]["format"])
# Ensure NLTK resources are downloaded
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
# File paths for storing sketches and outputs
SKETCH_PATH = config["file_paths"]["sketch_path"]
OUTPUT_PATH = config["file_paths"]["output_path"]
# Global Constants and Configuration
STYLE_LIST = config["style_list"]
STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
DEFAULT_STYLE_NAME = config["default_style_name"]
RANDOM_VALUES = config["random_values"]
PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"]
DEVICE = config["model_params"]["device"]
DEFAULT_SEED = config["model_params"]["default_seed"]
VAL_R_DEFAULT = config["model_params"]["val_r_default"]
MAX_SEED = config["model_params"]["max_seed"]
# Canvas configuration
CANVAS_WIDTH = config["canvas"]["width"]
CANVAS_HEIGHT = config["canvas"]["height"]
# Preload Models
logging.debug("Loading BLIP and Pix2Pix models...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE)
pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME)
logging.debug("Models loaded.")
def pil_image_to_data_uri(img: Image, format="PNG") -> str:
"""Converts a PIL image to a data URI."""
buffered = BytesIO()
img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/{format.lower()};base64,{img_str}"
def generate_prompt_from_sketch(image: Image) -> str:
"""Generates a text prompt based on a sketch using the BLIP model."""
logging.debug("Generating prompt from sketch...")
image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS)
inputs = processor(image, return_tensors="pt").to(DEVICE)
out = blip_model.generate(**inputs, max_new_tokens=50)
text_prompt = processor.decode(out[0], skip_special_tokens=True)
logging.debug(f"Generated prompt: {text_prompt}")
recognized_items = [extract_main_words(item) for item in text_prompt.split(', ') if item.strip()]
random_prefix = random.choice(RANDOM_VALUES)
prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}"
logging.debug(f"Final prompt: {prompt}")
return prompt
def extract_main_words(item: str) -> str:
"""Extracts all nouns from a given text fragment and returns them as a space-separated string."""
words = word_tokenize(item.strip())
tagged = pos_tag(words)
nouns = [word.capitalize() for word, tag in tagged if tag in ('NN', 'NNP', 'NNPS', 'NNS')]
return ' '.join(nouns)
def normalize_image(image, range_from=(-1, 1)):
"""
Normalize the input image to a specified range.
:param image: The PIL Image to be normalized.
:param range_from: The target range for normalization, typically (-1, 1) or (0, 1).
:return: Normalized image tensor.
"""
# Convert the image to a tensor
image_t = F.to_tensor(image)
if range_from == (-1, 1):
# Normalize from [0, 1] to [-1, 1]
image_t = image_t * 2 - 1
return image_t
def run(image, prompt, prompt_template, style_name, seed, val_r):
"""Runs the main image processing pipeline."""
logging.debug("Running model inference...")
if image is None:
blank_image = Image.new("L", (CANVAS_WIDTH, CANVAS_HEIGHT), 255)
blank_image.save(SKETCH_PATH) # Save blank image as sketch
logging.debug("No image provided. Saving blank image.")
return "", "", "", ""
if not prompt.strip():
prompt = generate_prompt_from_sketch(image)
# Save the sketch to a file
image.save(SKETCH_PATH)
# Show the original prompt before processing
original_prompt = f"Original Prompt: {prompt}"
logging.debug(original_prompt)
prompt = prompt_template.replace("{prompt}", prompt)
logging.debug(f"Processing with prompt: {prompt}")
image = image.convert("RGB")
image_tensor = F.to_tensor(image) * 2 - 1 # Normalize to [-1, 1]
clear_memory() # Clear memory before running the model
try:
with torch.no_grad():
c_t = image_tensor.unsqueeze(0).to(DEVICE).float()
torch.manual_seed(seed)
B, C, H, W = c_t.shape
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
logging.debug("Calling Pix2Pix model...")
# Enable mixed precision
with autocast():
output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
logging.debug("Model inference completed.")
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.warning("CUDA out of memory error. Falling back to CPU.")
with torch.no_grad():
c_t = c_t.cpu()
noise = noise.cpu()
pix2pix_model_cpu = pix2pix_model.cpu() # Move the model to CPU
output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
else:
raise e
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
output_pil.save(OUTPUT_PATH)
logging.debug("Output image saved.")
return output_pil
def gradio_interface(image, prompt, style_name, seed, val_r):
"""Gradio interface function to handle inputs and generate outputs."""
# Endpoint: `image` - Input image from user (Sketch Image)
# Endpoint: `prompt` - Text prompt (optional)
# Endpoint: `style_name` - Selected style from dropdown
# Endpoint: `seed` - Seed for reproducibility
# Endpoint: `val_r` - Sketch guidance value
prompt_template = STYLES.get(style_name, STYLES[DEFAULT_STYLE_NAME])
result_image = run(image, prompt, prompt_template, style_name, seed, val_r)
return result_image
# Create the Gradio Interface
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(source="upload", type="pil", label="Sketch Image"), # Endpoint: `image`
gr.Textbox(lines=2, placeholder="Enter a text prompt (optional)", label="Prompt"), # Endpoint: `prompt`
gr.Dropdown(choices=list(STYLES.keys()), value=DEFAULT_STYLE_NAME, label="Style"), # Endpoint: `style_name`
gr.Slider(minimum=0, maximum=MAX_SEED, step=1, default=DEFAULT_SEED, label="Seed"), # Endpoint: `seed`
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=VAL_R_DEFAULT, label="Sketch Guidance") # Endpoint: `val_r`
],
outputs=gr.Image(label="Generated Image"), # Output endpoint: `result_image`
title="Sketch to Image Generation",
description="Upload a sketch and generate an image based on a prompt and style."
)
if __name__ == "__main__":
# Launch the Gradio interface
interface.launch(share=True)