img2img-turbo / draw.py
Inmental's picture
Upload folder using huggingface_hub
343e5a8 verified
import random
import numpy as np
from PIL import Image, ImageOps
import base64
from io import BytesIO
import torch
import torchvision.transforms.functional as F
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration
from flask import Flask, request, jsonify, render_template_string, send_file
from flask_cors import CORS
import threading
import hashlib
import signal
import sys
import os
# Load models
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda")
# Pix2Pix model placeholder (Assume you have this model correctly implemented)
class Pix2Pix_Turbo:
def __init__(self, mode):
pass
def __call__(self, c_t, prompt, deterministic, r, noise_map):
# Dummy image processing function for demonstration purposes
return c_t
pix2pix_model = Pix2Pix_Turbo("sketch_to_image_stochastic")
# Flask application setup
app = Flask(__name__)
CORS(app) # Handle CORS issues
# Global Constants and Configuration
STYLE_LIST = [
{"name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"},
{"name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting"},
{"name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed"},
{"name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed"},
{"name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed"},
{"name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics"},
{"name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"},
{"name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"},
{"name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style"},
]
STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
STYLE_NAMES = list(STYLES.keys())
DEFAULT_STYLE_NAME = "Fantasy art"
MAX_SEED = np.iinfo(np.int32).max
# Paths for storing sketches and outputs
SKETCH_PATH = "sketch.png"
OUTPUT_PATH = "output.png"
# Image processing function
def run(image, prompt, prompt_template, style_name, seed, val_r):
if not prompt.strip():
prompt = "Generated by drawing tool"
prompt = prompt_template.replace("{prompt}", prompt)
image = image.convert("RGB")
image_tensor = F.to_tensor(image) > 0.5
with torch.no_grad():
c_t = image_tensor.unsqueeze(0).to("cuda").float()
torch.manual_seed(seed)
noise = torch.randn((1, 4, c_t.shape[2] // 8, c_t.shape[3] // 8), device=c_t.device)
output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
output_pil.save(OUTPUT_PATH) # Save the output image
buffered = BytesIO()
output_pil.save(buffered, format="PNG")
output_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
return output_data
# Flask route to handle image processing
@app.route('/process-image', methods=['POST'])
def process_image():
try:
data = request.get_json()
image_data = data.get("image", "").split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
# Process the image
output_image_uri = run(
image,
data.get("prompt", ""),
STYLES.get(data.get("style_name", DEFAULT_STYLE_NAME)),
data.get("style_name", DEFAULT_STYLE_NAME),
int(data.get("seed", random.randint(0, MAX_SEED))),
float(data.get("val_r", 0.4))
)
return jsonify({"image": output_image_uri})
except Exception as e:
return jsonify({"error": str(e)}), 500
# Flask route to serve the sketch image
@app.route('/get_sketch', methods=['GET'])
def get_sketch():
if os.path.exists(SKETCH_PATH):
return send_file(SKETCH_PATH, mimetype='image/png')
return jsonify({"status": "error", "message": "Sketch not found."}), 404
# Flask route to serve the output image
@app.route('/get_output', methods=['GET'])
def get_output():
if os.path.exists(OUTPUT_PATH):
return send_file(OUTPUT_PATH, mimetype='image/png')
return jsonify({"status": "error", "message": "Output not found."}), 404
# HTML page for drawing
@app.route('/')
def draw_page():
html_template = """
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Drawing Page</title>
<style>
body, html {
margin: 0;
padding: 0;
height: 100%;
display: flex;
justify-content: center;
align-items: center;
background-color: #f0f0f0;
}
.canvas-container {
border: 2px solid black;
position: relative;
}
.toolbar {
display: flex;
justify-content: center;
margin-bottom: 10px;
}
button {
margin-right: 5px;
}
canvas {
cursor: crosshair;
}
</style>
</head>
<body>
<div class="toolbar">
<button id="brush" onclick="setTool('brush')">Brush</button>
<button id="line" onclick="setTool('line')">Line</button>
<button id="eraser" onclick="setTool('eraser')">Eraser</button>
<button id="clear" onclick="clearCanvas()">Clear</button>
<input type="color" id="colorPicker" value="#000000">
<input type="range" id="brushSize" min="1" max="20" value="4">
</div>
<div class="canvas-container">
<canvas id="drawingCanvas" width="800" height="600"></canvas>
</div>
<script>
let canvas = document.getElementById('drawingCanvas');
let ctx = canvas.getContext('2d');
let drawing = false;
let tool = 'brush';
let lastX = 0, lastY = 0;
canvas.addEventListener('mousedown', (e) => {
drawing = true;
[lastX, lastY] = [e.offsetX, e.offsetY];
});
canvas.addEventListener('mousemove', draw);
canvas.addEventListener('mouseup', () => {
drawing = false;
sendDrawingToBackend();
});
canvas.addEventListener('mouseout', () => drawing = false);
function draw(e) {
if (!drawing) return;
ctx.strokeStyle = document.getElementById('colorPicker').value;
ctx.lineWidth = document.getElementById('brushSize').value;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.beginPath();
ctx.moveTo(lastX, lastY);
ctx.lineTo(e.offsetX, e.offsetY);
ctx.stroke();
[lastX, lastY] = [e.offsetX, e.offsetY];
}
function setTool(selectedTool) {
tool = selectedTool;
ctx.globalCompositeOperation = (tool === 'eraser') ? 'destination-out' : 'source-over';
}
function clearCanvas() {
ctx.clearRect(0, 0, canvas.width, canvas.height);
}
function sendDrawingToBackend() {
let dataURL = canvas.toDataURL('image/png');
fetch('/process-image', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ image: dataURL }),
})
.then(response => response.json())
.then(data => console.log('Image processed', data))
.catch(error => console.error('Error processing image:', error));
}
</script>
</body>
</html>
"""
return render_template_string(html_template)
# HTML page for previewing the processed image
@app.route('/preview')
def preview_page():
html_template = """
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Preview Page</title>
<style>
body, html {
margin: 0;
padding: 0;
height: 100%;
background-color: black;
}
.full-screen-image {
width: 100%;
height: 100%;
object-fit: contain;
}
</style>
<script>
function refreshImage() {
var img = document.getElementById("output-image");
img.src = "/get_output?" + new Date().getTime();
}
// Auto-refresh every 2 seconds to show the latest image
setInterval(refreshImage, 2000);
</script>
</head>
<body>
<img id="output-image" src="/get_output" class="full-screen-image">
</body>
</html>
"""
return render_template_string(html_template)
def signal_handler(sig, frame):
print("Ctrl+C pressed, shutting down.")
sys.exit(0)
# Register the signal handler for Ctrl+C
signal.signal(signal.SIGINT, signal_handler)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=2073)