import random import numpy as np from PIL import Image, ImageOps import base64 from io import BytesIO import torch import torchvision.transforms.functional as F import gradio as gr from transformers import BlipProcessor, BlipForConditionalGeneration from flask import Flask, request, jsonify, render_template_string, send_file from flask_cors import CORS import threading import hashlib import signal import sys import os # Load models processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda") # Pix2Pix model placeholder (Assume you have this model correctly implemented) class Pix2Pix_Turbo: def __init__(self, mode): pass def __call__(self, c_t, prompt, deterministic, r, noise_map): # Dummy image processing function for demonstration purposes return c_t pix2pix_model = Pix2Pix_Turbo("sketch_to_image_stochastic") # Flask application setup app = Flask(__name__) CORS(app) # Handle CORS issues # Global Constants and Configuration STYLE_LIST = [ {"name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"}, {"name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting"}, {"name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed"}, {"name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed"}, {"name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed"}, {"name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics"}, {"name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"}, {"name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"}, {"name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style"}, ] STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST} STYLE_NAMES = list(STYLES.keys()) DEFAULT_STYLE_NAME = "Fantasy art" MAX_SEED = np.iinfo(np.int32).max # Paths for storing sketches and outputs SKETCH_PATH = "sketch.png" OUTPUT_PATH = "output.png" # Image processing function def run(image, prompt, prompt_template, style_name, seed, val_r): if not prompt.strip(): prompt = "Generated by drawing tool" prompt = prompt_template.replace("{prompt}", prompt) image = image.convert("RGB") image_tensor = F.to_tensor(image) > 0.5 with torch.no_grad(): c_t = image_tensor.unsqueeze(0).to("cuda").float() torch.manual_seed(seed) noise = torch.randn((1, 4, c_t.shape[2] // 8, c_t.shape[3] // 8), device=c_t.device) output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) output_pil.save(OUTPUT_PATH) # Save the output image buffered = BytesIO() output_pil.save(buffered, format="PNG") output_data = base64.b64encode(buffered.getvalue()).decode("utf-8") return output_data # Flask route to handle image processing @app.route('/process-image', methods=['POST']) def process_image(): try: data = request.get_json() image_data = data.get("image", "").split(",")[1] image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB") # Process the image output_image_uri = run( image, data.get("prompt", ""), STYLES.get(data.get("style_name", DEFAULT_STYLE_NAME)), data.get("style_name", DEFAULT_STYLE_NAME), int(data.get("seed", random.randint(0, MAX_SEED))), float(data.get("val_r", 0.4)) ) return jsonify({"image": output_image_uri}) except Exception as e: return jsonify({"error": str(e)}), 500 # Flask route to serve the sketch image @app.route('/get_sketch', methods=['GET']) def get_sketch(): if os.path.exists(SKETCH_PATH): return send_file(SKETCH_PATH, mimetype='image/png') return jsonify({"status": "error", "message": "Sketch not found."}), 404 # Flask route to serve the output image @app.route('/get_output', methods=['GET']) def get_output(): if os.path.exists(OUTPUT_PATH): return send_file(OUTPUT_PATH, mimetype='image/png') return jsonify({"status": "error", "message": "Output not found."}), 404 # HTML page for drawing @app.route('/') def draw_page(): html_template = """