img2img-turbo-sketch / flask_sketch2imagehd.py
Inmental's picture
Upload 4 files
f59de63 verified
import random
import numpy as np
from PIL import Image, ImageOps
import base64
from io import BytesIO
import torch
import torchvision.transforms.functional as F
from transformers import BlipProcessor, BlipForConditionalGeneration
from src.pix2pix_turbo import Pix2Pix_Turbo
import nltk
from nltk import pos_tag
from nltk.tokenize import word_tokenize
import re
import os
import threading
import hashlib
from flask import Flask, request, send_file, jsonify, render_template_string
from flask_cors import CORS
import signal
import sys
import logging
import json
import gc
from torch.cuda.amp import autocast
# Set environment variable for better memory management
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
# Function to clear CUDA cache and collect garbage
def clear_memory():
torch.cuda.empty_cache()
gc.collect()
# Load the configuration from config.json
with open('config.json', 'r') as config_file:
config = json.load(config_file)
# Setup logging as per config
logging.basicConfig(level=config["logging"]["level"], format=config["logging"]["format"])
# Ensure NLTK resources are downloaded
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('punkt', quiet=True)
# File paths for storing sketches and outputs
SKETCH_PATH = config["file_paths"]["sketch_path"]
OUTPUT_PATH = config["file_paths"]["output_path"]
# Processing queue
processing_queue = []
# Global Constants and Configuration
STYLE_LIST = config["style_list"]
STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
DEFAULT_STYLE_NAME = config["default_style_name"]
RANDOM_VALUES = config["random_values"]
PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"]
DEVICE = config["model_params"]["device"]
DEFAULT_SEED = config["model_params"]["default_seed"]
VAL_R_DEFAULT = config["model_params"]["val_r_default"]
MAX_SEED = config["model_params"]["max_seed"]
# Canvas configuration
CANVAS_WIDTH = config["canvas"]["width"]
CANVAS_HEIGHT = config["canvas"]["height"]
BACKGROUND_COLOR = config["canvas"]["background_color"]
DEFAULT_BRUSH_COLOR = config["canvas"]["default_brush_color"]
DEFAULT_BRUSH_SIZE = config["canvas"]["default_brush_size"]
ERASER_COLOR = config["canvas"]["eraser_color"]
MAX_BRUSH_SIZE = config["canvas"]["max_brush_size"]
MIN_BRUSH_SIZE = config["canvas"]["min_brush_size"]
# Preload Models
logging.debug("Loading BLIP and Pix2Pix models...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE).eval() # Set model to eval mode
pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME).to(DEVICE).eval() # Set model to eval mode
logging.debug("Models loaded.")
style_list = [
{
"name": "Cinematic",
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
},
# Other styles...
]
styles = {k["name"]: k["prompt"] for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Fantasy art"
MAX_SEED = np.iinfo(np.int32).max
# Shared flag and thread for managing the current processing
current_thread = None
cancel_flag = threading.Event()
def pil_image_to_data_uri(img: Image, format="PNG") -> str:
"""Converts a PIL image to a data URI."""
buffered = BytesIO()
img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/{format.lower()};base64,{img_str}"
def generate_prompt_from_sketch(image: Image) -> str:
"""Generates a text prompt based on a sketch using the BLIP model."""
logging.debug("Generating prompt from sketch...")
image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS)
inputs = processor(image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = blip_model.generate(**inputs, max_new_tokens=50)
text_prompt = processor.decode(out[0], skip_special_tokens=True)
logging.debug(f"Generated prompt: {text_prompt}")
recognized_items = [extract_main_words(item) for item in text_prompt.split(', ') if item.strip()]
random_prefix = random.choice(RANDOM_VALUES)
prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}"
logging.debug(f"Final prompt: {prompt}")
return prompt
def extract_main_words(item: str) -> str:
"""Extracts all nouns from a given text fragment and returns them as a space-separated string."""
words = word_tokenize(item.strip())
tagged = pos_tag(words)
nouns = [word.capitalize() for word, tag in tagged if tag in ('NN', 'NNP', 'NNPS', 'NNS')]
return ' '.join(nouns)
def run(image, prompt, prompt_template, style_name, seed, val_r):
"""Runs the main image processing pipeline."""
logging.debug("Running model inference...")
if image is None:
blank_image = Image.new("L", (CANVAS_WIDTH, CANVAS_HEIGHT), 255)
blank_image.save(SKETCH_PATH) # Save blank image as sketch
logging.debug("No image provided. Saving blank image.")
return "", "", "", ""
if not prompt.strip():
prompt = generate_prompt_from_sketch(image)
# Save the sketch to a file
image.save(SKETCH_PATH)
# Show the original prompt before processing
original_prompt = f"Original Prompt: {prompt}"
logging.debug(original_prompt)
# Add the task to the processing queue
processing_queue.append({"prompt": prompt, "status": "processing"})
prompt = prompt_template.replace("{prompt}", prompt)
logging.debug(f"Processing with prompt: {prompt}")
image = image.convert("RGB")
image_tensor = F.to_tensor(image) * 2 - 1 # Normalize to [-1, 1]
clear_memory() # Clear memory before running the model
try:
with torch.no_grad():
c_t = image_tensor.unsqueeze(0).to(DEVICE).float()
torch.manual_seed(seed)
B, C, H, W = c_t.shape
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
logging.debug("Calling Pix2Pix model...")
# Enable mixed precision
with autocast():
if cancel_flag.is_set():
logging.debug("Processing canceled.")
return "", "", "", original_prompt
output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
logging.debug("Model inference completed.")
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.warning("CUDA out of memory error. Falling back to CPU.")
with torch.no_grad():
c_t = c_t.cpu()
noise = noise.cpu()
pix2pix_model_cpu = pix2pix_model.cpu() # Move the model to CPU
output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
else:
raise e
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
output_pil.save(OUTPUT_PATH)
logging.debug("Output image saved.")
input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255 - np.array(image)))
output_image_uri = pil_image_to_data_uri(output_pil)
logging.debug(f"Generated output URI: {output_image_uri}")
clear_memory() # Clear memory after running the model
return output_image_uri, input_sketch_uri, output_image_uri, original_prompt
def process_image_task(image, prompt, style_name, seed, val_r):
try:
global cancel_flag
cancel_flag.clear() # Clear any previous cancellation flag
output_image_uri, _, _, _ = run(image, prompt, STYLES.get(style_name, DEFAULT_STYLE_NAME), style_name, seed, val_r)
logging.debug(f"Processed image URI: {output_image_uri}")
return jsonify({"image": output_image_uri})
except Exception as e:
logging.error(f"Error processing image: {e}")
return jsonify({"error": str(e)}), 500
# Flask Server Setup for Preview and JSON endpoint
app = Flask(__name__)
CORS(app) # Enable CORS
@app.route('/process-image', methods=['POST'])
def process_image():
global current_thread, cancel_flag
# Cancel any ongoing processing
if current_thread is not None and current_thread.is_alive():
logging.debug("Cancelling previous processing...")
cancel_flag.set()
current_thread.join() # Wait for the thread to finish
data = request.get_json()
# Extract and decode the base64 image
image_data = data.get("image", "").split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
prompt = data.get("prompt", "")
style_name = data.get("style_name", DEFAULT_STYLE_NAME)
seed = int(data.get("seed", DEFAULT_SEED))
val_r = float(data.get("val_r", VAL_R_DEFAULT))
# Start new processing in a separate thread
current_thread = threading.Thread(target=process_image_task, args=(image, prompt, style_name, seed, val_r))
current_thread.start()
return jsonify({"status": "processing_started"})
@app.route('/get_sketch', methods=['GET'])
def get_sketch():
if os.path.exists(SKETCH_PATH):
return send_file(SKETCH_PATH, mimetype='image/png')
return jsonify({"status": "error", "message": "Sketch not found."}), 404
@app.route('/get_output', methods=['GET'])
def get_output():
if os.path.exists(OUTPUT_PATH):
return send_file(OUTPUT_PATH, mimetype='image/png')
return jsonify({"status": "error", "message": "Output not found."}), 404
@app.route('/get_status', methods=['GET'])
def get_status():
"""Returns a JSON with the last image base64 encoded, its checksum, and the processing queue."""
if os.path.exists(OUTPUT_PATH):
with open(OUTPUT_PATH, "rb") as f:
img_data = f.read()
base64_image = base64.b64encode(img_data).decode('utf-8')
checksum = hashlib.sha256(img_data).hexdigest()
else:
base64_image = ""
checksum = ""
return jsonify({
"image_base64": base64_image,
"checksum": checksum,
"processing_queue": processing_queue
})
@app.route('/')
def index():
# HTML template for the preview page
html_template = """
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Preview Page</title>
<style>
body, html {
margin: 0;
padding: 0;
height: 100%;
background-color: black;
}
.full-screen-image {
width: 100%;
height: 100%;
object-fit: contain;
}
</style>
<script>
function refreshImage() {
var img = document.getElementById("output-image");
img.src = "/get_output?" + new Date().getTime();
}
// Auto-refresh every 2 seconds to show the latest image
setInterval(refreshImage, 2000);
</script>
</head>
<body>
<img id="output-image" src="/get_output" class="full-screen-image">
</body>
</html>
"""
return render_template_string(html_template)
@app.route('/draw')
def draw_page():
# HTML template for the drawing page at /draw
html_template = """
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Drawing Page</title>
<style>
body, html {
margin: 0;
padding: 0;
height: 100%;
display: flex;
justify-content: center;
align-items: center;
background-color: #f0f0f0;
}
.canvas-container {
border: none;
position: relative;
}
.toolbar {
display: flex;
justify-content: center;
margin-bottom: 10px;
}
button {
margin-right: 5px;
}
canvas {
cursor: crosshair;
}
</style>
</head>
<body>
<div style="position: fixed;
bottom: 0;
width: 100%;">
<div class="toolbar">
<button id="brush" onclick="setTool('brush')">Brush</button>
<button id="line" onclick="setTool('line')">Line</button>
<button id="eraser" onclick="setTool('eraser')">Eraser</button>
<button id="clear" onclick="clearCanvas()">Clear</button>
<input type="color" id="colorPicker" value="#000000">
<input type="range" id="brushSize" min="1" max="20" value="4">
</div>
</div>
<div class="canvas-container">
<canvas id="drawingCanvas" width="512" height="512"></canvas>
</div>
<script>
let canvas = document.getElementById('drawingCanvas');
let ctx = canvas.getContext('2d');
let drawing = false;
let tool = 'brush';
let lastX = 0, lastY = 0;
// Fill the canvas with white background
ctx.fillStyle = "#ffffff";
ctx.fillRect(0, 0, canvas.width, canvas.height);
canvas.addEventListener('mousedown', (e) => {
drawing = true;
[lastX, lastY] = [e.offsetX, e.offsetY];
});
canvas.addEventListener('mousemove', draw);
canvas.addEventListener('mouseup', () => {
drawing = false;
sendDrawingToBackend();
});
canvas.addEventListener('mouseout', () => drawing = false);
function draw(e) {
if (!drawing) return;
ctx.strokeStyle = document.getElementById('colorPicker').value;
ctx.lineWidth = document.getElementById('brushSize').value;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.beginPath();
ctx.moveTo(lastX, lastY);
ctx.lineTo(e.offsetX, e.offsetY);
ctx.stroke();
[lastX, lastY] = [e.offsetX, e.offsetY];
}
function setTool(selectedTool) {
tool = selectedTool;
if (tool === 'eraser') {
ctx.strokeStyle = "#ffffff"; // Use white color for eraser
} else {
ctx.strokeStyle = document.getElementById('colorPicker').value;
}
ctx.globalCompositeOperation = 'source-over';
}
function clearCanvas() {
ctx.fillStyle = "#ffffff";
ctx.fillRect(0, 0, canvas.width, canvas.height);
fetch('/clear_preview', { method: 'POST' })
.then(response => response.json())
.then(data => console.log('Cleared preview', data))
.catch(error => console.error('Error clearing preview:', error));
}
function sendDrawingToBackend() {
let dataURL = canvas.toDataURL('image/png');
fetch('/process-image', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ image: dataURL }),
})
.then(response => response.json())
.then(data => console.log('Image processed', data))
.catch(error => console.error('Error processing image:', error));
}
</script>
</body>
</html>
"""
return render_template_string(html_template)
@app.route('/clear_preview', methods=['POST'])
def clear_preview():
if os.path.exists(OUTPUT_PATH):
os.remove(OUTPUT_PATH)
return jsonify({"status": "cleared"})
def start_flask_app():
app.run(host=config["server"]["host"], port=config["server"]["port"], threaded=True)
def signal_handler(sig, frame):
print("Ctrl+C pressed, shutting down.")
sys.exit(0)
# Register the signal handler for Ctrl+C
signal.signal(signal.SIGINT, signal_handler)
if __name__ == "__main__":
start_flask_app()