import random import numpy as np from PIL import Image, ImageOps import base64 from io import BytesIO import torch import torchvision.transforms.functional as F from transformers import BlipProcessor, BlipForConditionalGeneration from src.pix2pix_turbo import Pix2Pix_Turbo import nltk from nltk import pos_tag from nltk.tokenize import word_tokenize import re import os import threading import hashlib from flask import Flask, request, send_file, jsonify, render_template_string from flask_cors import CORS import signal import sys import logging import json import gc from torch.cuda.amp import autocast # Set environment variable for better memory management os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # Function to clear CUDA cache and collect garbage def clear_memory(): torch.cuda.empty_cache() gc.collect() # Load the configuration from config.json with open('config.json', 'r') as config_file: config = json.load(config_file) # Setup logging as per config logging.basicConfig(level=config["logging"]["level"], format=config["logging"]["format"]) # Ensure NLTK resources are downloaded nltk.download('averaged_perceptron_tagger', quiet=True) nltk.download('punkt', quiet=True) # File paths for storing sketches and outputs SKETCH_PATH = config["file_paths"]["sketch_path"] OUTPUT_PATH = config["file_paths"]["output_path"] # Processing queue processing_queue = [] # Global Constants and Configuration STYLE_LIST = config["style_list"] STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST} DEFAULT_STYLE_NAME = config["default_style_name"] RANDOM_VALUES = config["random_values"] PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"] DEVICE = config["model_params"]["device"] DEFAULT_SEED = config["model_params"]["default_seed"] VAL_R_DEFAULT = config["model_params"]["val_r_default"] MAX_SEED = config["model_params"]["max_seed"] # Canvas configuration CANVAS_WIDTH = config["canvas"]["width"] CANVAS_HEIGHT = config["canvas"]["height"] BACKGROUND_COLOR = config["canvas"]["background_color"] DEFAULT_BRUSH_COLOR = config["canvas"]["default_brush_color"] DEFAULT_BRUSH_SIZE = config["canvas"]["default_brush_size"] ERASER_COLOR = config["canvas"]["eraser_color"] MAX_BRUSH_SIZE = config["canvas"]["max_brush_size"] MIN_BRUSH_SIZE = config["canvas"]["min_brush_size"] # Preload Models logging.debug("Loading BLIP and Pix2Pix models...") processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE).eval() # Set model to eval mode pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME).to(DEVICE).eval() # Set model to eval mode logging.debug("Models loaded.") style_list = [ { "name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", }, # Other styles... ] styles = {k["name"]: k["prompt"] for k in style_list} STYLE_NAMES = list(styles.keys()) DEFAULT_STYLE_NAME = "Fantasy art" MAX_SEED = np.iinfo(np.int32).max # Shared flag and thread for managing the current processing current_thread = None cancel_flag = threading.Event() def pil_image_to_data_uri(img: Image, format="PNG") -> str: """Converts a PIL image to a data URI.""" buffered = BytesIO() img.save(buffered, format=format) img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/{format.lower()};base64,{img_str}" def generate_prompt_from_sketch(image: Image) -> str: """Generates a text prompt based on a sketch using the BLIP model.""" logging.debug("Generating prompt from sketch...") image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS) inputs = processor(image, return_tensors="pt").to(DEVICE) with torch.no_grad(): out = blip_model.generate(**inputs, max_new_tokens=50) text_prompt = processor.decode(out[0], skip_special_tokens=True) logging.debug(f"Generated prompt: {text_prompt}") recognized_items = [extract_main_words(item) for item in text_prompt.split(', ') if item.strip()] random_prefix = random.choice(RANDOM_VALUES) prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}" logging.debug(f"Final prompt: {prompt}") return prompt def extract_main_words(item: str) -> str: """Extracts all nouns from a given text fragment and returns them as a space-separated string.""" words = word_tokenize(item.strip()) tagged = pos_tag(words) nouns = [word.capitalize() for word, tag in tagged if tag in ('NN', 'NNP', 'NNPS', 'NNS')] return ' '.join(nouns) def run(image, prompt, prompt_template, style_name, seed, val_r): """Runs the main image processing pipeline.""" logging.debug("Running model inference...") if image is None: blank_image = Image.new("L", (CANVAS_WIDTH, CANVAS_HEIGHT), 255) blank_image.save(SKETCH_PATH) # Save blank image as sketch logging.debug("No image provided. Saving blank image.") return "", "", "", "" if not prompt.strip(): prompt = generate_prompt_from_sketch(image) # Save the sketch to a file image.save(SKETCH_PATH) # Show the original prompt before processing original_prompt = f"Original Prompt: {prompt}" logging.debug(original_prompt) # Add the task to the processing queue processing_queue.append({"prompt": prompt, "status": "processing"}) prompt = prompt_template.replace("{prompt}", prompt) logging.debug(f"Processing with prompt: {prompt}") image = image.convert("RGB") image_tensor = F.to_tensor(image) * 2 - 1 # Normalize to [-1, 1] clear_memory() # Clear memory before running the model try: with torch.no_grad(): c_t = image_tensor.unsqueeze(0).to(DEVICE).float() torch.manual_seed(seed) B, C, H, W = c_t.shape noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device) logging.debug("Calling Pix2Pix model...") # Enable mixed precision with autocast(): if cancel_flag.is_set(): logging.debug("Processing canceled.") return "", "", "", original_prompt output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) logging.debug("Model inference completed.") except RuntimeError as e: if "CUDA out of memory" in str(e): logging.warning("CUDA out of memory error. Falling back to CPU.") with torch.no_grad(): c_t = c_t.cpu() noise = noise.cpu() pix2pix_model_cpu = pix2pix_model.cpu() # Move the model to CPU output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) else: raise e output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) output_pil.save(OUTPUT_PATH) logging.debug("Output image saved.") input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255 - np.array(image))) output_image_uri = pil_image_to_data_uri(output_pil) logging.debug(f"Generated output URI: {output_image_uri}") clear_memory() # Clear memory after running the model return output_image_uri, input_sketch_uri, output_image_uri, original_prompt def process_image_task(image, prompt, style_name, seed, val_r): try: global cancel_flag cancel_flag.clear() # Clear any previous cancellation flag output_image_uri, _, _, _ = run(image, prompt, STYLES.get(style_name, DEFAULT_STYLE_NAME), style_name, seed, val_r) logging.debug(f"Processed image URI: {output_image_uri}") return jsonify({"image": output_image_uri}) except Exception as e: logging.error(f"Error processing image: {e}") return jsonify({"error": str(e)}), 500 # Flask Server Setup for Preview and JSON endpoint app = Flask(__name__) CORS(app) # Enable CORS @app.route('/process-image', methods=['POST']) def process_image(): global current_thread, cancel_flag # Cancel any ongoing processing if current_thread is not None and current_thread.is_alive(): logging.debug("Cancelling previous processing...") cancel_flag.set() current_thread.join() # Wait for the thread to finish data = request.get_json() # Extract and decode the base64 image image_data = data.get("image", "").split(",")[1] image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB") prompt = data.get("prompt", "") style_name = data.get("style_name", DEFAULT_STYLE_NAME) seed = int(data.get("seed", DEFAULT_SEED)) val_r = float(data.get("val_r", VAL_R_DEFAULT)) # Start new processing in a separate thread current_thread = threading.Thread(target=process_image_task, args=(image, prompt, style_name, seed, val_r)) current_thread.start() return jsonify({"status": "processing_started"}) @app.route('/get_sketch', methods=['GET']) def get_sketch(): if os.path.exists(SKETCH_PATH): return send_file(SKETCH_PATH, mimetype='image/png') return jsonify({"status": "error", "message": "Sketch not found."}), 404 @app.route('/get_output', methods=['GET']) def get_output(): if os.path.exists(OUTPUT_PATH): return send_file(OUTPUT_PATH, mimetype='image/png') return jsonify({"status": "error", "message": "Output not found."}), 404 @app.route('/get_status', methods=['GET']) def get_status(): """Returns a JSON with the last image base64 encoded, its checksum, and the processing queue.""" if os.path.exists(OUTPUT_PATH): with open(OUTPUT_PATH, "rb") as f: img_data = f.read() base64_image = base64.b64encode(img_data).decode('utf-8') checksum = hashlib.sha256(img_data).hexdigest() else: base64_image = "" checksum = "" return jsonify({ "image_base64": base64_image, "checksum": checksum, "processing_queue": processing_queue }) @app.route('/') def index(): # HTML template for the preview page html_template = """ Preview Page """ return render_template_string(html_template) @app.route('/draw') def draw_page(): # HTML template for the drawing page at /draw html_template = """ Drawing Page
""" return render_template_string(html_template) @app.route('/clear_preview', methods=['POST']) def clear_preview(): if os.path.exists(OUTPUT_PATH): os.remove(OUTPUT_PATH) return jsonify({"status": "cleared"}) def start_flask_app(): app.run(host=config["server"]["host"], port=config["server"]["port"], threaded=True) def signal_handler(sig, frame): print("Ctrl+C pressed, shutting down.") sys.exit(0) # Register the signal handler for Ctrl+C signal.signal(signal.SIGINT, signal_handler) if __name__ == "__main__": start_flask_app()