Spaces:
Runtime error
Runtime error
import random | |
import numpy as np | |
from PIL import Image, ImageOps | |
import base64 | |
from io import BytesIO | |
import torch | |
import torchvision.transforms.functional as F | |
import gradio as gr | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from flask import Flask, request, jsonify, render_template_string, send_file | |
from flask_cors import CORS | |
import threading | |
import hashlib | |
import signal | |
import sys | |
import os | |
# Load models | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda") | |
# Pix2Pix model placeholder (Assume you have this model correctly implemented) | |
class Pix2Pix_Turbo: | |
def __init__(self, mode): | |
pass | |
def __call__(self, c_t, prompt, deterministic, r, noise_map): | |
# Dummy image processing function for demonstration purposes | |
return c_t | |
pix2pix_model = Pix2Pix_Turbo("sketch_to_image_stochastic") | |
# Flask application setup | |
app = Flask(__name__) | |
CORS(app) # Handle CORS issues | |
# Global Constants and Configuration | |
STYLE_LIST = [ | |
{"name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"}, | |
{"name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting"}, | |
{"name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed"}, | |
{"name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed"}, | |
{"name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed"}, | |
{"name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics"}, | |
{"name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"}, | |
{"name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"}, | |
{"name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style"}, | |
] | |
STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST} | |
STYLE_NAMES = list(STYLES.keys()) | |
DEFAULT_STYLE_NAME = "Fantasy art" | |
MAX_SEED = np.iinfo(np.int32).max | |
# Paths for storing sketches and outputs | |
SKETCH_PATH = "sketch.png" | |
OUTPUT_PATH = "output.png" | |
# Image processing function | |
def run(image, prompt, prompt_template, style_name, seed, val_r): | |
if not prompt.strip(): | |
prompt = "Generated by drawing tool" | |
prompt = prompt_template.replace("{prompt}", prompt) | |
image = image.convert("RGB") | |
image_tensor = F.to_tensor(image) > 0.5 | |
with torch.no_grad(): | |
c_t = image_tensor.unsqueeze(0).to("cuda").float() | |
torch.manual_seed(seed) | |
noise = torch.randn((1, 4, c_t.shape[2] // 8, c_t.shape[3] // 8), device=c_t.device) | |
output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) | |
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) | |
output_pil.save(OUTPUT_PATH) # Save the output image | |
buffered = BytesIO() | |
output_pil.save(buffered, format="PNG") | |
output_data = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return output_data | |
# Flask route to handle image processing | |
def process_image(): | |
try: | |
data = request.get_json() | |
image_data = data.get("image", "").split(",")[1] | |
image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB") | |
# Process the image | |
output_image_uri = run( | |
image, | |
data.get("prompt", ""), | |
STYLES.get(data.get("style_name", DEFAULT_STYLE_NAME)), | |
data.get("style_name", DEFAULT_STYLE_NAME), | |
int(data.get("seed", random.randint(0, MAX_SEED))), | |
float(data.get("val_r", 0.4)) | |
) | |
return jsonify({"image": output_image_uri}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
# Flask route to serve the sketch image | |
def get_sketch(): | |
if os.path.exists(SKETCH_PATH): | |
return send_file(SKETCH_PATH, mimetype='image/png') | |
return jsonify({"status": "error", "message": "Sketch not found."}), 404 | |
# Flask route to serve the output image | |
def get_output(): | |
if os.path.exists(OUTPUT_PATH): | |
return send_file(OUTPUT_PATH, mimetype='image/png') | |
return jsonify({"status": "error", "message": "Output not found."}), 404 | |
# HTML page for drawing | |
def draw_page(): | |
html_template = """ | |
<!doctype html> | |
<html lang="en"> | |
<head> | |
<meta charset="utf-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Drawing Page</title> | |
<style> | |
body, html { | |
margin: 0; | |
padding: 0; | |
height: 100%; | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
background-color: #f0f0f0; | |
} | |
.canvas-container { | |
border: 2px solid black; | |
position: relative; | |
} | |
.toolbar { | |
display: flex; | |
justify-content: center; | |
margin-bottom: 10px; | |
} | |
button { | |
margin-right: 5px; | |
} | |
canvas { | |
cursor: crosshair; | |
} | |
</style> | |
</head> | |
<body> | |
<div class="toolbar"> | |
<button id="brush" onclick="setTool('brush')">Brush</button> | |
<button id="line" onclick="setTool('line')">Line</button> | |
<button id="eraser" onclick="setTool('eraser')">Eraser</button> | |
<button id="clear" onclick="clearCanvas()">Clear</button> | |
<input type="color" id="colorPicker" value="#000000"> | |
<input type="range" id="brushSize" min="1" max="20" value="4"> | |
</div> | |
<div class="canvas-container"> | |
<canvas id="drawingCanvas" width="800" height="600"></canvas> | |
</div> | |
<script> | |
let canvas = document.getElementById('drawingCanvas'); | |
let ctx = canvas.getContext('2d'); | |
let drawing = false; | |
let tool = 'brush'; | |
let lastX = 0, lastY = 0; | |
canvas.addEventListener('mousedown', (e) => { | |
drawing = true; | |
[lastX, lastY] = [e.offsetX, e.offsetY]; | |
}); | |
canvas.addEventListener('mousemove', draw); | |
canvas.addEventListener('mouseup', () => { | |
drawing = false; | |
sendDrawingToBackend(); | |
}); | |
canvas.addEventListener('mouseout', () => drawing = false); | |
function draw(e) { | |
if (!drawing) return; | |
ctx.strokeStyle = document.getElementById('colorPicker').value; | |
ctx.lineWidth = document.getElementById('brushSize').value; | |
ctx.lineJoin = 'round'; | |
ctx.lineCap = 'round'; | |
ctx.beginPath(); | |
ctx.moveTo(lastX, lastY); | |
ctx.lineTo(e.offsetX, e.offsetY); | |
ctx.stroke(); | |
[lastX, lastY] = [e.offsetX, e.offsetY]; | |
} | |
function setTool(selectedTool) { | |
tool = selectedTool; | |
ctx.globalCompositeOperation = (tool === 'eraser') ? 'destination-out' : 'source-over'; | |
} | |
function clearCanvas() { | |
ctx.clearRect(0, 0, canvas.width, canvas.height); | |
} | |
function sendDrawingToBackend() { | |
let dataURL = canvas.toDataURL('image/png'); | |
fetch('/process-image', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json', | |
}, | |
body: JSON.stringify({ image: dataURL }), | |
}) | |
.then(response => response.json()) | |
.then(data => console.log('Image processed', data)) | |
.catch(error => console.error('Error processing image:', error)); | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
return render_template_string(html_template) | |
# HTML page for previewing the processed image | |
def preview_page(): | |
html_template = """ | |
<!doctype html> | |
<html lang="en"> | |
<head> | |
<meta charset="utf-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Preview Page</title> | |
<style> | |
body, html { | |
margin: 0; | |
padding: 0; | |
height: 100%; | |
background-color: black; | |
} | |
.full-screen-image { | |
width: 100%; | |
height: 100%; | |
object-fit: contain; | |
} | |
</style> | |
<script> | |
function refreshImage() { | |
var img = document.getElementById("output-image"); | |
img.src = "/get_output?" + new Date().getTime(); | |
} | |
// Auto-refresh every 2 seconds to show the latest image | |
setInterval(refreshImage, 2000); | |
</script> | |
</head> | |
<body> | |
<img id="output-image" src="/get_output" class="full-screen-image"> | |
</body> | |
</html> | |
""" | |
return render_template_string(html_template) | |
def signal_handler(sig, frame): | |
print("Ctrl+C pressed, shutting down.") | |
sys.exit(0) | |
# Register the signal handler for Ctrl+C | |
signal.signal(signal.SIGINT, signal_handler) | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=2073) | |