Inmental commited on
Commit
f59de63
verified
1 Parent(s): 6e73b66

Upload 4 files

Browse files
Files changed (4) hide show
  1. draw.py +277 -0
  2. flask_sketch2imagehd.py +462 -0
  3. gradio_sketch2imagehd.py +222 -0
  4. preview_server.py +19 -0
draw.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from PIL import Image, ImageOps
4
+ import base64
5
+ from io import BytesIO
6
+ import torch
7
+ import torchvision.transforms.functional as F
8
+ import gradio as gr
9
+ from transformers import BlipProcessor, BlipForConditionalGeneration
10
+ from flask import Flask, request, jsonify, render_template_string, send_file
11
+ from flask_cors import CORS
12
+ import threading
13
+ import hashlib
14
+ import signal
15
+ import sys
16
+ import os
17
+
18
+ # Load models
19
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
20
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda")
21
+
22
+ # Pix2Pix model placeholder (Assume you have this model correctly implemented)
23
+ class Pix2Pix_Turbo:
24
+ def __init__(self, mode):
25
+ pass
26
+
27
+ def __call__(self, c_t, prompt, deterministic, r, noise_map):
28
+ # Dummy image processing function for demonstration purposes
29
+ return c_t
30
+
31
+ pix2pix_model = Pix2Pix_Turbo("sketch_to_image_stochastic")
32
+
33
+ # Flask application setup
34
+ app = Flask(__name__)
35
+ CORS(app) # Handle CORS issues
36
+
37
+ # Global Constants and Configuration
38
+ STYLE_LIST = [
39
+ {"name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"},
40
+ {"name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting"},
41
+ {"name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed"},
42
+ {"name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed"},
43
+ {"name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed"},
44
+ {"name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics"},
45
+ {"name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"},
46
+ {"name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"},
47
+ {"name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style"},
48
+ ]
49
+
50
+ STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
51
+ STYLE_NAMES = list(STYLES.keys())
52
+ DEFAULT_STYLE_NAME = "Fantasy art"
53
+ MAX_SEED = np.iinfo(np.int32).max
54
+
55
+ # Paths for storing sketches and outputs
56
+ SKETCH_PATH = "sketch.png"
57
+ OUTPUT_PATH = "output.png"
58
+
59
+ # Image processing function
60
+ def run(image, prompt, prompt_template, style_name, seed, val_r):
61
+ if not prompt.strip():
62
+ prompt = "Generated by drawing tool"
63
+
64
+ prompt = prompt_template.replace("{prompt}", prompt)
65
+ image = image.convert("RGB")
66
+ image_tensor = F.to_tensor(image) > 0.5
67
+
68
+ with torch.no_grad():
69
+ c_t = image_tensor.unsqueeze(0).to("cuda").float()
70
+ torch.manual_seed(seed)
71
+ noise = torch.randn((1, 4, c_t.shape[2] // 8, c_t.shape[3] // 8), device=c_t.device)
72
+ output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
73
+
74
+ output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
75
+ output_pil.save(OUTPUT_PATH) # Save the output image
76
+
77
+ buffered = BytesIO()
78
+ output_pil.save(buffered, format="PNG")
79
+ output_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
80
+ return output_data
81
+
82
+ # Flask route to handle image processing
83
+ @app.route('/process-image', methods=['POST'])
84
+ def process_image():
85
+ try:
86
+ data = request.get_json()
87
+ image_data = data.get("image", "").split(",")[1]
88
+ image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
89
+
90
+ # Process the image
91
+ output_image_uri = run(
92
+ image,
93
+ data.get("prompt", ""),
94
+ STYLES.get(data.get("style_name", DEFAULT_STYLE_NAME)),
95
+ data.get("style_name", DEFAULT_STYLE_NAME),
96
+ int(data.get("seed", random.randint(0, MAX_SEED))),
97
+ float(data.get("val_r", 0.4))
98
+ )
99
+ return jsonify({"image": output_image_uri})
100
+
101
+ except Exception as e:
102
+ return jsonify({"error": str(e)}), 500
103
+
104
+ # Flask route to serve the sketch image
105
+ @app.route('/get_sketch', methods=['GET'])
106
+ def get_sketch():
107
+ if os.path.exists(SKETCH_PATH):
108
+ return send_file(SKETCH_PATH, mimetype='image/png')
109
+ return jsonify({"status": "error", "message": "Sketch not found."}), 404
110
+
111
+ # Flask route to serve the output image
112
+ @app.route('/get_output', methods=['GET'])
113
+ def get_output():
114
+ if os.path.exists(OUTPUT_PATH):
115
+ return send_file(OUTPUT_PATH, mimetype='image/png')
116
+ return jsonify({"status": "error", "message": "Output not found."}), 404
117
+
118
+ # HTML page for drawing
119
+ @app.route('/')
120
+ def draw_page():
121
+ html_template = """
122
+ <!doctype html>
123
+ <html lang="en">
124
+ <head>
125
+ <meta charset="utf-8">
126
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
127
+ <title>Drawing Page</title>
128
+ <style>
129
+ body, html {
130
+ margin: 0;
131
+ padding: 0;
132
+ height: 100%;
133
+ display: flex;
134
+ justify-content: center;
135
+ align-items: center;
136
+ background-color: #f0f0f0;
137
+ }
138
+ .canvas-container {
139
+ border: 2px solid black;
140
+ position: relative;
141
+ }
142
+ .toolbar {
143
+ display: flex;
144
+ justify-content: center;
145
+ margin-bottom: 10px;
146
+ }
147
+ button {
148
+ margin-right: 5px;
149
+ }
150
+ canvas {
151
+ cursor: crosshair;
152
+ }
153
+ </style>
154
+ </head>
155
+ <body>
156
+ <div class="toolbar">
157
+ <button id="brush" onclick="setTool('brush')">Brush</button>
158
+ <button id="line" onclick="setTool('line')">Line</button>
159
+ <button id="eraser" onclick="setTool('eraser')">Eraser</button>
160
+ <button id="clear" onclick="clearCanvas()">Clear</button>
161
+ <input type="color" id="colorPicker" value="#000000">
162
+ <input type="range" id="brushSize" min="1" max="20" value="4">
163
+ </div>
164
+ <div class="canvas-container">
165
+ <canvas id="drawingCanvas" width="800" height="600"></canvas>
166
+ </div>
167
+ <script>
168
+ let canvas = document.getElementById('drawingCanvas');
169
+ let ctx = canvas.getContext('2d');
170
+ let drawing = false;
171
+ let tool = 'brush';
172
+ let lastX = 0, lastY = 0;
173
+
174
+ canvas.addEventListener('mousedown', (e) => {
175
+ drawing = true;
176
+ [lastX, lastY] = [e.offsetX, e.offsetY];
177
+ });
178
+
179
+ canvas.addEventListener('mousemove', draw);
180
+ canvas.addEventListener('mouseup', () => {
181
+ drawing = false;
182
+ sendDrawingToBackend();
183
+ });
184
+ canvas.addEventListener('mouseout', () => drawing = false);
185
+
186
+ function draw(e) {
187
+ if (!drawing) return;
188
+
189
+ ctx.strokeStyle = document.getElementById('colorPicker').value;
190
+ ctx.lineWidth = document.getElementById('brushSize').value;
191
+ ctx.lineJoin = 'round';
192
+ ctx.lineCap = 'round';
193
+
194
+ ctx.beginPath();
195
+ ctx.moveTo(lastX, lastY);
196
+ ctx.lineTo(e.offsetX, e.offsetY);
197
+ ctx.stroke();
198
+ [lastX, lastY] = [e.offsetX, e.offsetY];
199
+ }
200
+
201
+ function setTool(selectedTool) {
202
+ tool = selectedTool;
203
+ ctx.globalCompositeOperation = (tool === 'eraser') ? 'destination-out' : 'source-over';
204
+ }
205
+
206
+ function clearCanvas() {
207
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
208
+ }
209
+
210
+ function sendDrawingToBackend() {
211
+ let dataURL = canvas.toDataURL('image/png');
212
+ fetch('/process-image', {
213
+ method: 'POST',
214
+ headers: {
215
+ 'Content-Type': 'application/json',
216
+ },
217
+ body: JSON.stringify({ image: dataURL }),
218
+ })
219
+ .then(response => response.json())
220
+ .then(data => console.log('Image processed', data))
221
+ .catch(error => console.error('Error processing image:', error));
222
+ }
223
+ </script>
224
+ </body>
225
+ </html>
226
+ """
227
+ return render_template_string(html_template)
228
+
229
+ # HTML page for previewing the processed image
230
+ @app.route('/preview')
231
+ def preview_page():
232
+ html_template = """
233
+ <!doctype html>
234
+ <html lang="en">
235
+ <head>
236
+ <meta charset="utf-8">
237
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
238
+ <title>Preview Page</title>
239
+ <style>
240
+ body, html {
241
+ margin: 0;
242
+ padding: 0;
243
+ height: 100%;
244
+ background-color: black;
245
+ }
246
+ .full-screen-image {
247
+ width: 100%;
248
+ height: 100%;
249
+ object-fit: contain;
250
+ }
251
+ </style>
252
+ <script>
253
+ function refreshImage() {
254
+ var img = document.getElementById("output-image");
255
+ img.src = "/get_output?" + new Date().getTime();
256
+ }
257
+
258
+ // Auto-refresh every 2 seconds to show the latest image
259
+ setInterval(refreshImage, 2000);
260
+ </script>
261
+ </head>
262
+ <body>
263
+ <img id="output-image" src="/get_output" class="full-screen-image">
264
+ </body>
265
+ </html>
266
+ """
267
+ return render_template_string(html_template)
268
+
269
+ def signal_handler(sig, frame):
270
+ print("Ctrl+C pressed, shutting down.")
271
+ sys.exit(0)
272
+
273
+ # Register the signal handler for Ctrl+C
274
+ signal.signal(signal.SIGINT, signal_handler)
275
+
276
+ if __name__ == "__main__":
277
+ app.run(host='0.0.0.0', port=2073)
flask_sketch2imagehd.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from PIL import Image, ImageOps
4
+ import base64
5
+ from io import BytesIO
6
+ import torch
7
+ import torchvision.transforms.functional as F
8
+ from transformers import BlipProcessor, BlipForConditionalGeneration
9
+ from src.pix2pix_turbo import Pix2Pix_Turbo
10
+ import nltk
11
+ from nltk import pos_tag
12
+ from nltk.tokenize import word_tokenize
13
+ import re
14
+ import os
15
+ import threading
16
+ import hashlib
17
+ from flask import Flask, request, send_file, jsonify, render_template_string
18
+ from flask_cors import CORS
19
+ import signal
20
+ import sys
21
+ import logging
22
+ import json
23
+ import gc
24
+ from torch.cuda.amp import autocast
25
+
26
+ # Set environment variable for better memory management
27
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
28
+
29
+ # Function to clear CUDA cache and collect garbage
30
+ def clear_memory():
31
+ torch.cuda.empty_cache()
32
+ gc.collect()
33
+
34
+ # Load the configuration from config.json
35
+ with open('config.json', 'r') as config_file:
36
+ config = json.load(config_file)
37
+
38
+ # Setup logging as per config
39
+ logging.basicConfig(level=config["logging"]["level"], format=config["logging"]["format"])
40
+
41
+ # Ensure NLTK resources are downloaded
42
+ nltk.download('averaged_perceptron_tagger', quiet=True)
43
+ nltk.download('punkt', quiet=True)
44
+
45
+ # File paths for storing sketches and outputs
46
+ SKETCH_PATH = config["file_paths"]["sketch_path"]
47
+ OUTPUT_PATH = config["file_paths"]["output_path"]
48
+
49
+ # Processing queue
50
+ processing_queue = []
51
+
52
+ # Global Constants and Configuration
53
+ STYLE_LIST = config["style_list"]
54
+ STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
55
+ DEFAULT_STYLE_NAME = config["default_style_name"]
56
+ RANDOM_VALUES = config["random_values"]
57
+ PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"]
58
+ DEVICE = config["model_params"]["device"]
59
+ DEFAULT_SEED = config["model_params"]["default_seed"]
60
+ VAL_R_DEFAULT = config["model_params"]["val_r_default"]
61
+ MAX_SEED = config["model_params"]["max_seed"]
62
+
63
+ # Canvas configuration
64
+ CANVAS_WIDTH = config["canvas"]["width"]
65
+ CANVAS_HEIGHT = config["canvas"]["height"]
66
+ BACKGROUND_COLOR = config["canvas"]["background_color"]
67
+ DEFAULT_BRUSH_COLOR = config["canvas"]["default_brush_color"]
68
+ DEFAULT_BRUSH_SIZE = config["canvas"]["default_brush_size"]
69
+ ERASER_COLOR = config["canvas"]["eraser_color"]
70
+ MAX_BRUSH_SIZE = config["canvas"]["max_brush_size"]
71
+ MIN_BRUSH_SIZE = config["canvas"]["min_brush_size"]
72
+
73
+ # Preload Models
74
+ logging.debug("Loading BLIP and Pix2Pix models...")
75
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
76
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE).eval() # Set model to eval mode
77
+ pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME).to(DEVICE).eval() # Set model to eval mode
78
+ logging.debug("Models loaded.")
79
+
80
+ style_list = [
81
+ {
82
+ "name": "Cinematic",
83
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
84
+ },
85
+ # Other styles...
86
+ ]
87
+
88
+ styles = {k["name"]: k["prompt"] for k in style_list}
89
+ STYLE_NAMES = list(styles.keys())
90
+ DEFAULT_STYLE_NAME = "Fantasy art"
91
+ MAX_SEED = np.iinfo(np.int32).max
92
+
93
+ # Shared flag and thread for managing the current processing
94
+ current_thread = None
95
+ cancel_flag = threading.Event()
96
+
97
+ def pil_image_to_data_uri(img: Image, format="PNG") -> str:
98
+ """Converts a PIL image to a data URI."""
99
+ buffered = BytesIO()
100
+ img.save(buffered, format=format)
101
+ img_str = base64.b64encode(buffered.getvalue()).decode()
102
+ return f"data:image/{format.lower()};base64,{img_str}"
103
+
104
+ def generate_prompt_from_sketch(image: Image) -> str:
105
+ """Generates a text prompt based on a sketch using the BLIP model."""
106
+ logging.debug("Generating prompt from sketch...")
107
+
108
+ image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS)
109
+ inputs = processor(image, return_tensors="pt").to(DEVICE)
110
+
111
+ with torch.no_grad():
112
+ out = blip_model.generate(**inputs, max_new_tokens=50)
113
+
114
+ text_prompt = processor.decode(out[0], skip_special_tokens=True)
115
+ logging.debug(f"Generated prompt: {text_prompt}")
116
+
117
+ recognized_items = [extract_main_words(item) for item in text_prompt.split(', ') if item.strip()]
118
+ random_prefix = random.choice(RANDOM_VALUES)
119
+
120
+ prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}"
121
+ logging.debug(f"Final prompt: {prompt}")
122
+ return prompt
123
+
124
+ def extract_main_words(item: str) -> str:
125
+ """Extracts all nouns from a given text fragment and returns them as a space-separated string."""
126
+ words = word_tokenize(item.strip())
127
+ tagged = pos_tag(words)
128
+ nouns = [word.capitalize() for word, tag in tagged if tag in ('NN', 'NNP', 'NNPS', 'NNS')]
129
+ return ' '.join(nouns)
130
+
131
+ def run(image, prompt, prompt_template, style_name, seed, val_r):
132
+ """Runs the main image processing pipeline."""
133
+ logging.debug("Running model inference...")
134
+ if image is None:
135
+ blank_image = Image.new("L", (CANVAS_WIDTH, CANVAS_HEIGHT), 255)
136
+ blank_image.save(SKETCH_PATH) # Save blank image as sketch
137
+ logging.debug("No image provided. Saving blank image.")
138
+ return "", "", "", ""
139
+
140
+ if not prompt.strip():
141
+ prompt = generate_prompt_from_sketch(image)
142
+
143
+ # Save the sketch to a file
144
+ image.save(SKETCH_PATH)
145
+
146
+ # Show the original prompt before processing
147
+ original_prompt = f"Original Prompt: {prompt}"
148
+ logging.debug(original_prompt)
149
+
150
+ # Add the task to the processing queue
151
+ processing_queue.append({"prompt": prompt, "status": "processing"})
152
+
153
+ prompt = prompt_template.replace("{prompt}", prompt)
154
+ logging.debug(f"Processing with prompt: {prompt}")
155
+ image = image.convert("RGB")
156
+ image_tensor = F.to_tensor(image) * 2 - 1 # Normalize to [-1, 1]
157
+
158
+ clear_memory() # Clear memory before running the model
159
+
160
+ try:
161
+ with torch.no_grad():
162
+ c_t = image_tensor.unsqueeze(0).to(DEVICE).float()
163
+ torch.manual_seed(seed)
164
+ B, C, H, W = c_t.shape
165
+
166
+ noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
167
+ logging.debug("Calling Pix2Pix model...")
168
+
169
+ # Enable mixed precision
170
+ with autocast():
171
+ if cancel_flag.is_set():
172
+ logging.debug("Processing canceled.")
173
+ return "", "", "", original_prompt
174
+
175
+ output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
176
+
177
+ logging.debug("Model inference completed.")
178
+ except RuntimeError as e:
179
+ if "CUDA out of memory" in str(e):
180
+ logging.warning("CUDA out of memory error. Falling back to CPU.")
181
+ with torch.no_grad():
182
+ c_t = c_t.cpu()
183
+ noise = noise.cpu()
184
+ pix2pix_model_cpu = pix2pix_model.cpu() # Move the model to CPU
185
+ output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
186
+ else:
187
+ raise e
188
+
189
+ output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
190
+ output_pil.save(OUTPUT_PATH)
191
+ logging.debug("Output image saved.")
192
+
193
+ input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255 - np.array(image)))
194
+ output_image_uri = pil_image_to_data_uri(output_pil)
195
+ logging.debug(f"Generated output URI: {output_image_uri}")
196
+
197
+ clear_memory() # Clear memory after running the model
198
+
199
+ return output_image_uri, input_sketch_uri, output_image_uri, original_prompt
200
+
201
+ def process_image_task(image, prompt, style_name, seed, val_r):
202
+ try:
203
+ global cancel_flag
204
+ cancel_flag.clear() # Clear any previous cancellation flag
205
+
206
+ output_image_uri, _, _, _ = run(image, prompt, STYLES.get(style_name, DEFAULT_STYLE_NAME), style_name, seed, val_r)
207
+ logging.debug(f"Processed image URI: {output_image_uri}")
208
+
209
+ return jsonify({"image": output_image_uri})
210
+
211
+ except Exception as e:
212
+ logging.error(f"Error processing image: {e}")
213
+ return jsonify({"error": str(e)}), 500
214
+
215
+ # Flask Server Setup for Preview and JSON endpoint
216
+ app = Flask(__name__)
217
+ CORS(app) # Enable CORS
218
+
219
+ @app.route('/process-image', methods=['POST'])
220
+ def process_image():
221
+ global current_thread, cancel_flag
222
+
223
+ # Cancel any ongoing processing
224
+ if current_thread is not None and current_thread.is_alive():
225
+ logging.debug("Cancelling previous processing...")
226
+ cancel_flag.set()
227
+ current_thread.join() # Wait for the thread to finish
228
+
229
+ data = request.get_json()
230
+
231
+ # Extract and decode the base64 image
232
+ image_data = data.get("image", "").split(",")[1]
233
+ image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
234
+
235
+ prompt = data.get("prompt", "")
236
+ style_name = data.get("style_name", DEFAULT_STYLE_NAME)
237
+ seed = int(data.get("seed", DEFAULT_SEED))
238
+ val_r = float(data.get("val_r", VAL_R_DEFAULT))
239
+
240
+ # Start new processing in a separate thread
241
+ current_thread = threading.Thread(target=process_image_task, args=(image, prompt, style_name, seed, val_r))
242
+ current_thread.start()
243
+
244
+ return jsonify({"status": "processing_started"})
245
+
246
+ @app.route('/get_sketch', methods=['GET'])
247
+ def get_sketch():
248
+ if os.path.exists(SKETCH_PATH):
249
+ return send_file(SKETCH_PATH, mimetype='image/png')
250
+ return jsonify({"status": "error", "message": "Sketch not found."}), 404
251
+
252
+ @app.route('/get_output', methods=['GET'])
253
+ def get_output():
254
+ if os.path.exists(OUTPUT_PATH):
255
+ return send_file(OUTPUT_PATH, mimetype='image/png')
256
+ return jsonify({"status": "error", "message": "Output not found."}), 404
257
+
258
+ @app.route('/get_status', methods=['GET'])
259
+ def get_status():
260
+ """Returns a JSON with the last image base64 encoded, its checksum, and the processing queue."""
261
+ if os.path.exists(OUTPUT_PATH):
262
+ with open(OUTPUT_PATH, "rb") as f:
263
+ img_data = f.read()
264
+ base64_image = base64.b64encode(img_data).decode('utf-8')
265
+ checksum = hashlib.sha256(img_data).hexdigest()
266
+ else:
267
+ base64_image = ""
268
+ checksum = ""
269
+
270
+ return jsonify({
271
+ "image_base64": base64_image,
272
+ "checksum": checksum,
273
+ "processing_queue": processing_queue
274
+ })
275
+
276
+ @app.route('/')
277
+ def index():
278
+ # HTML template for the preview page
279
+ html_template = """
280
+ <!doctype html>
281
+ <html lang="en">
282
+ <head>
283
+ <meta charset="utf-8">
284
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
285
+ <title>Preview Page</title>
286
+ <style>
287
+ body, html {
288
+ margin: 0;
289
+ padding: 0;
290
+ height: 100%;
291
+ background-color: black;
292
+ }
293
+ .full-screen-image {
294
+ width: 100%;
295
+ height: 100%;
296
+ object-fit: contain;
297
+ }
298
+ </style>
299
+ <script>
300
+ function refreshImage() {
301
+ var img = document.getElementById("output-image");
302
+ img.src = "/get_output?" + new Date().getTime();
303
+ }
304
+
305
+ // Auto-refresh every 2 seconds to show the latest image
306
+ setInterval(refreshImage, 2000);
307
+ </script>
308
+ </head>
309
+ <body>
310
+ <img id="output-image" src="/get_output" class="full-screen-image">
311
+ </body>
312
+ </html>
313
+ """
314
+ return render_template_string(html_template)
315
+
316
+ @app.route('/draw')
317
+ def draw_page():
318
+ # HTML template for the drawing page at /draw
319
+ html_template = """
320
+ <!doctype html>
321
+ <html lang="en">
322
+ <head>
323
+ <meta charset="utf-8">
324
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
325
+ <title>Drawing Page</title>
326
+ <style>
327
+ body, html {
328
+ margin: 0;
329
+ padding: 0;
330
+ height: 100%;
331
+ display: flex;
332
+ justify-content: center;
333
+ align-items: center;
334
+ background-color: #f0f0f0;
335
+ }
336
+ .canvas-container {
337
+ border: none;
338
+ position: relative;
339
+ }
340
+ .toolbar {
341
+ display: flex;
342
+ justify-content: center;
343
+ margin-bottom: 10px;
344
+ }
345
+ button {
346
+ margin-right: 5px;
347
+ }
348
+ canvas {
349
+ cursor: crosshair;
350
+ }
351
+ </style>
352
+ </head>
353
+ <body>
354
+ <div style="position: fixed;
355
+ bottom: 0;
356
+ width: 100%;">
357
+ <div class="toolbar">
358
+ <button id="brush" onclick="setTool('brush')">Brush</button>
359
+ <button id="line" onclick="setTool('line')">Line</button>
360
+ <button id="eraser" onclick="setTool('eraser')">Eraser</button>
361
+ <button id="clear" onclick="clearCanvas()">Clear</button>
362
+ <input type="color" id="colorPicker" value="#000000">
363
+ <input type="range" id="brushSize" min="1" max="20" value="4">
364
+ </div>
365
+ </div>
366
+ <div class="canvas-container">
367
+ <canvas id="drawingCanvas" width="512" height="512"></canvas>
368
+ </div>
369
+ <script>
370
+ let canvas = document.getElementById('drawingCanvas');
371
+ let ctx = canvas.getContext('2d');
372
+ let drawing = false;
373
+ let tool = 'brush';
374
+ let lastX = 0, lastY = 0;
375
+
376
+ // Fill the canvas with white background
377
+ ctx.fillStyle = "#ffffff";
378
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
379
+
380
+ canvas.addEventListener('mousedown', (e) => {
381
+ drawing = true;
382
+ [lastX, lastY] = [e.offsetX, e.offsetY];
383
+ });
384
+
385
+ canvas.addEventListener('mousemove', draw);
386
+ canvas.addEventListener('mouseup', () => {
387
+ drawing = false;
388
+ sendDrawingToBackend();
389
+ });
390
+ canvas.addEventListener('mouseout', () => drawing = false);
391
+
392
+ function draw(e) {
393
+ if (!drawing) return;
394
+
395
+ ctx.strokeStyle = document.getElementById('colorPicker').value;
396
+ ctx.lineWidth = document.getElementById('brushSize').value;
397
+ ctx.lineJoin = 'round';
398
+ ctx.lineCap = 'round';
399
+
400
+ ctx.beginPath();
401
+ ctx.moveTo(lastX, lastY);
402
+ ctx.lineTo(e.offsetX, e.offsetY);
403
+ ctx.stroke();
404
+ [lastX, lastY] = [e.offsetX, e.offsetY];
405
+ }
406
+
407
+ function setTool(selectedTool) {
408
+ tool = selectedTool;
409
+ if (tool === 'eraser') {
410
+ ctx.strokeStyle = "#ffffff"; // Use white color for eraser
411
+ } else {
412
+ ctx.strokeStyle = document.getElementById('colorPicker').value;
413
+ }
414
+ ctx.globalCompositeOperation = 'source-over';
415
+ }
416
+
417
+ function clearCanvas() {
418
+ ctx.fillStyle = "#ffffff";
419
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
420
+ fetch('/clear_preview', { method: 'POST' })
421
+ .then(response => response.json())
422
+ .then(data => console.log('Cleared preview', data))
423
+ .catch(error => console.error('Error clearing preview:', error));
424
+ }
425
+
426
+ function sendDrawingToBackend() {
427
+ let dataURL = canvas.toDataURL('image/png');
428
+ fetch('/process-image', {
429
+ method: 'POST',
430
+ headers: {
431
+ 'Content-Type': 'application/json',
432
+ },
433
+ body: JSON.stringify({ image: dataURL }),
434
+ })
435
+ .then(response => response.json())
436
+ .then(data => console.log('Image processed', data))
437
+ .catch(error => console.error('Error processing image:', error));
438
+ }
439
+ </script>
440
+ </body>
441
+ </html>
442
+ """
443
+ return render_template_string(html_template)
444
+
445
+ @app.route('/clear_preview', methods=['POST'])
446
+ def clear_preview():
447
+ if os.path.exists(OUTPUT_PATH):
448
+ os.remove(OUTPUT_PATH)
449
+ return jsonify({"status": "cleared"})
450
+
451
+ def start_flask_app():
452
+ app.run(host=config["server"]["host"], port=config["server"]["port"], threaded=True)
453
+
454
+ def signal_handler(sig, frame):
455
+ print("Ctrl+C pressed, shutting down.")
456
+ sys.exit(0)
457
+
458
+ # Register the signal handler for Ctrl+C
459
+ signal.signal(signal.SIGINT, signal_handler)
460
+
461
+ if __name__ == "__main__":
462
+ start_flask_app()
gradio_sketch2imagehd.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from fastapi import FastAPI, UploadFile, File
3
+ from fastapi.responses import FileResponse, JSONResponse
4
+ import os
5
+ import random
6
+ import torch
7
+ from PIL import Image, ImageOps
8
+ from io import BytesIO
9
+ import base64
10
+ import json
11
+ import logging
12
+ import gc
13
+ from transformers import BlipProcessor, BlipForConditionalGeneration
14
+ import torchvision.transforms.functional as F
15
+ from src.pix2pix_turbo import Pix2Pix_Turbo # Aseg煤rate de que esta ruta de importaci贸n sea correcta
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+
18
+ # Configuraci贸n de logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
+
21
+ # Cargar la configuraci贸n desde config.json
22
+ logging.info("Cargando configuraci贸n desde config.json...")
23
+ with open('config.json', 'r') as config_file:
24
+ config = json.load(config_file)
25
+
26
+ # Variables Globales
27
+ OUTPUT_PATH = "result.jpg" # La imagen resultante se guardar谩 como result.jpg
28
+ INPUT_PATH = "draw.jpg" # La imagen recibida se guardar谩 como draw.jpg
29
+ STYLE_LIST = config["style_list"]
30
+ STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
31
+ DEVICE = config["model_params"]["device"]
32
+ DEFAULT_SEED = config["model_params"]["default_seed"]
33
+ VAL_R_DEFAULT = config["model_params"]["val_r_default"]
34
+ CANVAS_WIDTH = config["canvas"]["width"]
35
+ CANVAS_HEIGHT = config["canvas"]["height"]
36
+ PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"]
37
+
38
+ logging.info(f"Dispositivo seleccionado: {DEVICE}")
39
+ logging.info(f"Modelo Pix2Pix cargado: {PIX2PIX_MODEL_NAME}")
40
+
41
+ # Cargar y configurar los modelos
42
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
43
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE)
44
+ pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME)
45
+
46
+ def print_welcome_message(app):
47
+ for route in app.routes:
48
+ full_url = f"http://0.0.0.0:{app.server_port}{route.path}"
49
+ if hasattr(route, 'methods'):
50
+ route_info = f"URL: {full_url}, Methods: {route.methods}"
51
+ else:
52
+ route_info = f"URL: {full_url}, Methods: Not applicable"
53
+ print(route_info)
54
+
55
+ def clear_memory():
56
+ """Limpiar la memoria CUDA y recolectar basura si es necesario."""
57
+ logging.debug("Limpiando la memoria CUDA y recolectando basura...")
58
+ torch.cuda.empty_cache()
59
+ gc.collect()
60
+
61
+ def generate_prompt_from_sketch(image: Image) -> str:
62
+ """Generar un texto a partir del sketch usando BLIP."""
63
+ logging.debug("Generando el prompt desde el sketch...")
64
+ image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS)
65
+ inputs = processor(image, return_tensors="pt").to(DEVICE)
66
+
67
+ with torch.no_grad():
68
+ out = blip_model.generate(**inputs, max_new_tokens=50)
69
+ text_prompt = processor.decode(out[0], skip_special_tokens=True)
70
+ logging.debug(f"Prompt generado: {text_prompt}")
71
+
72
+ recognized_items = [item.strip() for item in text_prompt.split(', ') if item.strip()]
73
+ random_prefix = random.choice(config["random_values"])
74
+ prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}"
75
+ logging.debug(f"Prompt final: {prompt}")
76
+ return prompt
77
+
78
+ def normalize_image(image, range_from=(-1, 1)):
79
+ """Normalizar la imagen de entrada."""
80
+ logging.debug("Normalizando la imagen...")
81
+ image_t = F.to_tensor(image)
82
+ if range_from == (-1, 1):
83
+ image_t = image_t * 2 - 1
84
+ return image_t
85
+
86
+ def process_sketch(sketch_image, prompt=None, style_name=None, seed=DEFAULT_SEED, val_r=VAL_R_DEFAULT):
87
+ """Procesar el sketch y generar una imagen usando el modelo Pix2Pix."""
88
+ logging.debug("Iniciando el procesamiento del sketch...")
89
+
90
+ if not prompt:
91
+ logging.info("Prompt no proporcionado, generando uno a partir del sketch...")
92
+ prompt = generate_prompt_from_sketch(sketch_image)
93
+
94
+ prompt_template = STYLES.get(style_name, STYLES[config["default_style_name"]])
95
+ prompt = prompt_template.replace("{prompt}", prompt)
96
+ sketch_image = sketch_image.convert("RGB")
97
+ sketch_tensor = normalize_image(sketch_image, range_from=(-1, 1))
98
+
99
+ #image_t = F.to_tensor(sketch_image).unsqueeze(0).to(torch.float32)
100
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+ #clear_memory()
102
+
103
+ try:
104
+ with torch.no_grad():
105
+ logging.info("Iniciando la inferencia del modelo Pix2Pix...")
106
+ c_t = sketch_tensor.unsqueeze(0).to(DEVICE).float()
107
+ torch.manual_seed(seed)
108
+ B, C, H, W = c_t.shape
109
+ #noise = torch.randn((1, 4, c_t.shape[2] // 8, c_t.shape[3] // 8), device=c_t.device)
110
+ noise = torch.randn((1, 4, H // 8, W // 8), device=device)
111
+ with torch.cuda.amp.autocast():
112
+ output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
113
+
114
+ output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
115
+ output_pil.save(OUTPUT_PATH)
116
+ logging.info("Imagen generada y guardada correctamente.")
117
+ return output_pil
118
+
119
+ except RuntimeError as e:
120
+ logging.error(f"Error de runtime durante la inferencia: {str(e)}")
121
+ if "CUDA out of memory" in str(e):
122
+ logging.warning("Error de memoria CUDA. Cambiando a CPU.")
123
+ with torch.no_grad():
124
+ c_t = c_t.cpu()
125
+ noise = noise.cpu()
126
+ pix2pix_model_cpu = pix2pix_model.cpu()
127
+ output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
128
+ output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
129
+ output_pil.save(OUTPUT_PATH)
130
+ logging.info("Inferencia realizada en CPU y la imagen fue generada y guardada.")
131
+ return output_pil
132
+ else:
133
+ raise e
134
+
135
+ def get_image_as_base64(image_path):
136
+ """Convertir una imagen a cadena base64."""
137
+ logging.debug(f"Convirtiendo la imagen {image_path} a base64...")
138
+ with open(image_path, "rb") as image_file:
139
+ encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
140
+ return encoded_string
141
+
142
+ # Crear una instancia de FastAPI
143
+ app = FastAPI()
144
+
145
+ # Configurar el middleware de CORS
146
+ logging.info("Configurando el middleware de CORS...")
147
+ app.add_middleware(
148
+ CORSMiddleware,
149
+ allow_origins=["*"], # Permitir todas las or铆genes. Puedes especificar or铆genes espec铆ficos en lugar de "*"
150
+ allow_credentials=True,
151
+ allow_methods=["*"], # Permitir todos los m茅todos HTTP (GET, POST, etc.)
152
+ allow_headers=["*"], # Permitir todos los encabezados
153
+ )
154
+
155
+ @app.get("/")
156
+ def read_image():
157
+ """
158
+ Retorna el archivo 'result.jpg' si existe, o un mensaje de error si no.
159
+ """
160
+ logging.info("Petici贸n GET recibida en '/'. Verificando si existe una imagen procesada...")
161
+ if os.path.exists(OUTPUT_PATH):
162
+ logging.info(f"Retornando la imagen {OUTPUT_PATH}.")
163
+ return FileResponse(OUTPUT_PATH, media_type='image/jpeg', filename="result.jpg")
164
+ else:
165
+ logging.warning("No se ha procesado ninguna imagen a煤n.")
166
+ return {"error": "No image processed yet."}
167
+
168
+ @app.get("/image_base64")
169
+ def get_image_base64():
170
+ """
171
+ Retorna la imagen procesada como una cadena en formato base64 dentro de un objeto JSON.
172
+ """
173
+ if os.path.exists(OUTPUT_PATH):
174
+ # Convertir la imagen en base64
175
+ base64_str = get_image_as_base64(OUTPUT_PATH)
176
+ logging.info(f"Imagen convertida a base64 y enviada como respuesta JSON.")
177
+ return JSONResponse(content={"image_base64": base64_str})
178
+ else:
179
+ logging.error("No se encontr贸 ninguna imagen procesada.")
180
+ return JSONResponse(content={"error": "No image processed yet."})
181
+
182
+
183
+ @app.post("/process_image")
184
+ async def process_image(file: UploadFile = File(...)):
185
+ """
186
+ Procesa la imagen enviada y devuelve la imagen generada.
187
+ """
188
+ logging.info("Petici贸n POST recibida en '/process_image'. Procesando imagen...")
189
+ image = Image.open(BytesIO(await file.read()))
190
+
191
+ # Guardar la imagen recibida como 'draw.png'
192
+ image.save("draw.png") # Guardar en formato PNG
193
+ logging.info("Imagen recibida guardada como 'draw.png'.")
194
+
195
+ # Procesar la imagen y guardar el resultado
196
+ processed_image = process_sketch(image)
197
+ processed_image.save(OUTPUT_PATH) # Guardar la imagen procesada como 'result.jpg'
198
+ logging.info("Imagen procesada y guardada correctamente.")
199
+ return {"status": f"Image processed and saved as {OUTPUT_PATH}"}
200
+
201
+ # Montar la aplicaci贸n de Gradio en FastAPI
202
+ logging.info("Montando la interfaz de Gradio en la aplicaci贸n FastAPI...")
203
+ interface = gr.Interface(
204
+ fn=process_sketch,
205
+ inputs=[gr.Image(source="upload", type="pil", label="Sketch Image"),
206
+ gr.Textbox(label="Prompt (optional)"),
207
+ gr.Dropdown(choices=list(STYLES.keys()), label="Style"),
208
+ gr.Slider(minimum=0, maximum=100, step=1, value=DEFAULT_SEED, label="Seed"),
209
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=VAL_R_DEFAULT, label="Sketch Guidance")],
210
+ outputs=gr.Image(label="Generated Image"),
211
+ title="Sketch to Image HD",
212
+ description="Upload a sketch to generate an image."
213
+ )
214
+
215
+ app = gr.mount_gradio_app(app, interface, path="/gradio")
216
+
217
+
218
+ if __name__ == "__main__":
219
+ logging.info("Iniciando la aplicaci贸n en Uvicorn...")
220
+ import uvicorn
221
+ uvicorn.run(app, host="0.0.0.0", port=7860)
222
+ print_welcome_message(interface.app)
preview_server.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # Variable global para almacenar la 拢ltima imagen generada en base64
4
+ last_image_base64 = None
5
+
6
+ def get_last_image():
7
+ return last_image_base64 if last_image_base64 else "No image processed yet."
8
+
9
+ # Crear la interfaz de Gradio para la 拢ltima imagen y lanzarla en el puerto 7861
10
+ last_image_interface = gr.Interface(
11
+ fn=get_last_image,
12
+ inputs=[],
13
+ outputs="text",
14
+ title="Last Processed Image",
15
+ description="Retrieve the last processed image in base64 format."
16
+ )
17
+
18
+ if __name__ == "__main__":
19
+ last_image_interface.launch(server_name="0.0.0.0", server_port=7861, share=True)