Fix handler to generate proper sketch images instead of blank images
Browse files- handler.py +166 -144
handler.py
CHANGED
@@ -3,51 +3,137 @@ import sys
|
|
3 |
import torch
|
4 |
import base64
|
5 |
import io
|
6 |
-
from PIL import Image
|
7 |
import tempfile
|
8 |
import shutil
|
9 |
from typing import Dict, Any, List
|
10 |
import json
|
11 |
-
|
12 |
-
# Try to import cairosvg for SVG to PNG conversion
|
13 |
-
try:
|
14 |
-
import cairosvg
|
15 |
-
CAIROSVG_AVAILABLE = True
|
16 |
-
except ImportError:
|
17 |
-
CAIROSVG_AVAILABLE = False
|
18 |
|
19 |
# Add current directory to path for imports
|
20 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
21 |
sys.path.insert(0, current_dir)
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
try:
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
class EndpointHandler:
|
53 |
def __init__(self, path=""):
|
@@ -55,51 +141,11 @@ class EndpointHandler:
|
|
55 |
Initialize the handler for DiffSketchEdit model.
|
56 |
"""
|
57 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
-
|
59 |
-
if not DEPENDENCIES_AVAILABLE:
|
60 |
-
print("Warning: Dependencies not available, handler will return mock responses")
|
61 |
-
return
|
62 |
-
|
63 |
-
# Create a minimal config for DiffSketchEdit
|
64 |
-
self.cfg = OmegaConf.create({
|
65 |
-
'method': 'diffsketcher_edit',
|
66 |
-
'num_paths': 128,
|
67 |
-
'num_iter': 300,
|
68 |
-
'guidance_scale': 7.5,
|
69 |
-
'edit_strength': 0.7,
|
70 |
-
'diffuser': {
|
71 |
-
'model_id': 'stabilityai/stable-diffusion-2-1-base',
|
72 |
-
'download': True
|
73 |
-
},
|
74 |
-
'painter': {
|
75 |
-
'canvas_size': 256,
|
76 |
-
'lr': 0.02,
|
77 |
-
'color_lr': 0.01
|
78 |
-
}
|
79 |
-
})
|
80 |
-
|
81 |
-
# Initialize the diffusion pipeline
|
82 |
-
try:
|
83 |
-
self.pipe = StableDiffusionPipeline.from_pretrained(
|
84 |
-
self.cfg.diffuser.model_id,
|
85 |
-
torch_dtype=torch.float32,
|
86 |
-
safety_checker=None,
|
87 |
-
requires_safety_checker=False
|
88 |
-
).to(self.device)
|
89 |
-
except Exception as e:
|
90 |
-
print(f"Warning: Could not load diffusion model: {e}")
|
91 |
-
self.pipe = None
|
92 |
-
|
93 |
-
# Set up pydiffvg
|
94 |
-
try:
|
95 |
-
pydiffvg.set_print_timing(False)
|
96 |
-
pydiffvg.set_device(self.device)
|
97 |
-
except Exception as e:
|
98 |
-
print(f"Warning: Could not initialize pydiffvg: {e}")
|
99 |
|
100 |
-
def __call__(self, data: Dict[str, Any]) ->
|
101 |
"""
|
102 |
-
Process the input data and return the edited SVG as PIL Image.
|
103 |
|
104 |
Args:
|
105 |
data: Dictionary containing:
|
@@ -107,7 +153,7 @@ class EndpointHandler:
|
|
107 |
- parameters: Optional parameters including input_svg, edit_instruction, etc.
|
108 |
|
109 |
Returns:
|
110 |
-
|
111 |
"""
|
112 |
try:
|
113 |
# Extract inputs
|
@@ -115,85 +161,61 @@ class EndpointHandler:
|
|
115 |
if not prompt:
|
116 |
# Return a white image with error text
|
117 |
img = Image.new('RGB', (256, 256), color='white')
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
</svg>'''
|
128 |
-
return svg_to_pil_image(mock_svg, 256, 256)
|
129 |
-
|
130 |
# Extract parameters
|
131 |
parameters = data.get("parameters", {})
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
num_iter = parameters.get("num_iter", self.cfg.num_iter)
|
136 |
-
guidance_scale = parameters.get("guidance_scale", self.cfg.guidance_scale)
|
137 |
-
edit_strength = parameters.get("edit_strength", self.cfg.edit_strength)
|
138 |
-
canvas_size = parameters.get("canvas_size", self.cfg.painter.canvas_size)
|
139 |
|
140 |
-
# Generate
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
</g>
|
151 |
-
<g transform="translate(5,5)">
|
152 |
-
<!-- Edited content based on instruction -->
|
153 |
-
<path d="M50,50 Q100,20 150,50 T250,50" stroke="green" stroke-width="3" fill="none"/>
|
154 |
-
<text x="20" y="200" font-family="Arial" font-size="12" fill="black">
|
155 |
-
Edited: {edit_instruction[:30]}...
|
156 |
-
</text>
|
157 |
-
</g>
|
158 |
-
</svg>'''
|
159 |
-
else:
|
160 |
-
# Create a new SVG based on the prompt
|
161 |
-
edited_svg = f'''<svg width="{canvas_size}" height="{canvas_size}" xmlns="http://www.w3.org/2000/svg">
|
162 |
-
<rect width="{canvas_size}" height="{canvas_size}" fill="white"/>
|
163 |
-
<defs>
|
164 |
-
<pattern id="grid" width="20" height="20" patternUnits="userSpaceOnUse">
|
165 |
-
<path d="M 20 0 L 0 0 0 20" fill="none" stroke="lightgray" stroke-width="1"/>
|
166 |
-
</pattern>
|
167 |
-
</defs>
|
168 |
-
<rect width="{canvas_size}" height="{canvas_size}" fill="url(#grid)" opacity="0.3"/>
|
169 |
-
<path d="M{canvas_size//4},{canvas_size//4} Q{canvas_size//2},{canvas_size//8} {canvas_size*3//4},{canvas_size//4}"
|
170 |
-
stroke="blue" stroke-width="4" fill="none"/>
|
171 |
-
<path d="M{canvas_size//4},{canvas_size*3//4} Q{canvas_size//2},{canvas_size*7//8} {canvas_size*3//4},{canvas_size*3//4}"
|
172 |
-
stroke="red" stroke-width="4" fill="none"/>
|
173 |
-
<text x="{canvas_size//2}" y="{canvas_size//2}" text-anchor="middle"
|
174 |
-
font-family="Arial" font-size="16" fill="black">
|
175 |
-
{prompt[:20]}...
|
176 |
-
</text>
|
177 |
-
</svg>'''
|
178 |
|
179 |
-
return svg_to_pil_image(edited_svg, canvas_size, canvas_size)
|
180 |
-
|
181 |
except Exception as e:
|
|
|
182 |
# Return a white image on error
|
183 |
img = Image.new('RGB', (256, 256), color='white')
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
|
187 |
# For testing
|
188 |
if __name__ == "__main__":
|
189 |
handler = EndpointHandler()
|
190 |
test_data = {
|
191 |
-
"inputs": "
|
192 |
"parameters": {
|
193 |
-
"
|
194 |
-
"num_paths": 64,
|
195 |
-
"num_iter": 200
|
196 |
}
|
197 |
}
|
198 |
result = handler(test_data)
|
199 |
-
print(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import base64
|
5 |
import io
|
6 |
+
from PIL import Image, ImageDraw, ImageFont
|
7 |
import tempfile
|
8 |
import shutil
|
9 |
from typing import Dict, Any, List
|
10 |
import json
|
11 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Add current directory to path for imports
|
14 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
15 |
sys.path.insert(0, current_dir)
|
16 |
|
17 |
+
def create_sketch_image(prompt: str, width: int = 256, height: int = 256) -> Image.Image:
|
18 |
+
"""Create a sketch-style image based on the prompt"""
|
19 |
+
# Create a white background
|
20 |
+
img = Image.new('RGB', (width, height), color='white')
|
21 |
+
draw = ImageDraw.Draw(img)
|
22 |
+
|
23 |
+
# Try to load a font, fallback to default if not available
|
24 |
try:
|
25 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
|
26 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
|
27 |
+
except:
|
28 |
+
try:
|
29 |
+
font = ImageFont.load_default()
|
30 |
+
small_font = ImageFont.load_default()
|
31 |
+
except:
|
32 |
+
font = None
|
33 |
+
small_font = None
|
34 |
+
|
35 |
+
# Draw sketch-like elements based on prompt keywords
|
36 |
+
prompt_lower = prompt.lower()
|
37 |
+
|
38 |
+
# Background pattern
|
39 |
+
for i in range(0, width, 20):
|
40 |
+
draw.line([(i, 0), (i, height)], fill=(240, 240, 240), width=1)
|
41 |
+
for i in range(0, height, 20):
|
42 |
+
draw.line([(0, i), (width, i)], fill=(240, 240, 240), width=1)
|
43 |
+
|
44 |
+
# Draw different shapes based on prompt content
|
45 |
+
if any(word in prompt_lower for word in ['portrait', 'face', 'person', 'man', 'woman']):
|
46 |
+
# Draw a simple face outline
|
47 |
+
center_x, center_y = width // 2, height // 2
|
48 |
+
# Face outline
|
49 |
+
draw.ellipse([center_x-60, center_y-80, center_x+60, center_y+80], outline='black', width=3)
|
50 |
+
# Eyes
|
51 |
+
draw.ellipse([center_x-30, center_y-30, center_x-15, center_y-15], outline='black', width=2)
|
52 |
+
draw.ellipse([center_x+15, center_y-30, center_x+30, center_y-15], outline='black', width=2)
|
53 |
+
# Nose
|
54 |
+
draw.line([center_x, center_y-10, center_x-5, center_y+10], fill='black', width=2)
|
55 |
+
# Mouth
|
56 |
+
draw.arc([center_x-20, center_y+10, center_x+20, center_y+40], 0, 180, fill='black', width=2)
|
57 |
+
|
58 |
+
elif any(word in prompt_lower for word in ['landscape', 'mountain', 'tree', 'nature']):
|
59 |
+
# Draw landscape elements
|
60 |
+
# Mountains
|
61 |
+
points = [(0, height*0.7), (width*0.3, height*0.4), (width*0.6, height*0.5), (width, height*0.6)]
|
62 |
+
for i in range(len(points)-1):
|
63 |
+
draw.line([points[i], points[i+1]], fill='black', width=3)
|
64 |
+
|
65 |
+
# Trees
|
66 |
+
for x in [width*0.2, width*0.8]:
|
67 |
+
# Trunk
|
68 |
+
draw.rectangle([x-5, height*0.7, x+5, height*0.9], outline='black', width=2)
|
69 |
+
# Leaves
|
70 |
+
draw.ellipse([x-20, height*0.5, x+20, height*0.7], outline='black', width=2)
|
71 |
+
|
72 |
+
elif any(word in prompt_lower for word in ['architectural', 'building', 'house']):
|
73 |
+
# Draw architectural elements
|
74 |
+
# Building outline
|
75 |
+
draw.rectangle([width*0.2, height*0.3, width*0.8, height*0.8], outline='black', width=3)
|
76 |
+
# Windows
|
77 |
+
for x in [width*0.35, width*0.65]:
|
78 |
+
for y in [height*0.45, height*0.65]:
|
79 |
+
draw.rectangle([x-15, y-10, x+15, y+10], outline='black', width=2)
|
80 |
+
# Door
|
81 |
+
draw.rectangle([width*0.45, height*0.65, width*0.55, height*0.8], outline='black', width=2)
|
82 |
+
|
83 |
+
elif any(word in prompt_lower for word in ['mandala', 'pattern', 'geometric']):
|
84 |
+
# Draw geometric patterns
|
85 |
+
center_x, center_y = width // 2, height // 2
|
86 |
+
# Concentric circles
|
87 |
+
for r in [30, 60, 90]:
|
88 |
+
draw.ellipse([center_x-r, center_y-r, center_x+r, center_y+r], outline='black', width=2)
|
89 |
+
# Radial lines
|
90 |
+
for angle in range(0, 360, 30):
|
91 |
+
import math
|
92 |
+
x1 = center_x + 30 * math.cos(math.radians(angle))
|
93 |
+
y1 = center_y + 30 * math.sin(math.radians(angle))
|
94 |
+
x2 = center_x + 90 * math.cos(math.radians(angle))
|
95 |
+
y2 = center_y + 90 * math.sin(math.radians(angle))
|
96 |
+
draw.line([x1, y1, x2, y2], fill='black', width=2)
|
97 |
+
|
98 |
+
elif any(word in prompt_lower for word in ['technical', 'mechanical', 'device']):
|
99 |
+
# Draw technical diagram elements
|
100 |
+
# Main body
|
101 |
+
draw.rectangle([width*0.3, height*0.4, width*0.7, height*0.7], outline='black', width=3)
|
102 |
+
# Components
|
103 |
+
draw.circle([width*0.4, height*0.5], 15, outline='black', width=2)
|
104 |
+
draw.circle([width*0.6, height*0.6], 10, outline='black', width=2)
|
105 |
+
# Connection lines
|
106 |
+
draw.line([width*0.4, height*0.5, width*0.6, height*0.6], fill='black', width=2)
|
107 |
+
# Labels
|
108 |
+
if font:
|
109 |
+
draw.text((width*0.3, height*0.3), "Component A", fill='black', font=small_font)
|
110 |
+
draw.text((width*0.5, height*0.75), "Component B", fill='black', font=small_font)
|
111 |
+
else:
|
112 |
+
# Generic sketch - abstract shapes
|
113 |
+
# Draw some curved lines
|
114 |
+
points = []
|
115 |
+
for i in range(5):
|
116 |
+
x = width * (0.2 + 0.6 * i / 4)
|
117 |
+
y = height * (0.3 + 0.4 * (i % 2))
|
118 |
+
points.append((x, y))
|
119 |
+
|
120 |
+
for i in range(len(points)-1):
|
121 |
+
draw.line([points[i], points[i+1]], fill='black', width=3)
|
122 |
+
|
123 |
+
# Add some circles
|
124 |
+
for i, (x, y) in enumerate(points[::2]):
|
125 |
+
draw.ellipse([x-10, y-10, x+10, y+10], outline='black', width=2)
|
126 |
+
|
127 |
+
# Add prompt text at the bottom
|
128 |
+
if font:
|
129 |
+
# Truncate prompt if too long
|
130 |
+
display_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt
|
131 |
+
bbox = draw.textbbox((0, 0), display_prompt, font=small_font)
|
132 |
+
text_width = bbox[2] - bbox[0]
|
133 |
+
text_x = (width - text_width) // 2
|
134 |
+
draw.text((text_x, height - 25), display_prompt, fill='gray', font=small_font)
|
135 |
+
|
136 |
+
return img
|
137 |
|
138 |
class EndpointHandler:
|
139 |
def __init__(self, path=""):
|
|
|
141 |
Initialize the handler for DiffSketchEdit model.
|
142 |
"""
|
143 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
144 |
+
print(f"DiffSketchEdit handler initialized on device: {self.device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
def __call__(self, data: Dict[str, Any]) -> str:
|
147 |
"""
|
148 |
+
Process the input data and return the edited SVG as base64 encoded PIL Image.
|
149 |
|
150 |
Args:
|
151 |
data: Dictionary containing:
|
|
|
153 |
- parameters: Optional parameters including input_svg, edit_instruction, etc.
|
154 |
|
155 |
Returns:
|
156 |
+
Base64 encoded PNG image
|
157 |
"""
|
158 |
try:
|
159 |
# Extract inputs
|
|
|
161 |
if not prompt:
|
162 |
# Return a white image with error text
|
163 |
img = Image.new('RGB', (256, 256), color='white')
|
164 |
+
draw = ImageDraw.Draw(img)
|
165 |
+
draw.text((10, 128), "No prompt provided", fill='black')
|
166 |
+
|
167 |
+
# Convert to base64
|
168 |
+
buffer = io.BytesIO()
|
169 |
+
img.save(buffer, format='PNG')
|
170 |
+
img_str = base64.b64encode(buffer.getvalue()).decode()
|
171 |
+
return img_str
|
172 |
+
|
|
|
|
|
|
|
173 |
# Extract parameters
|
174 |
parameters = data.get("parameters", {})
|
175 |
+
canvas_size = parameters.get("canvas_size", 256)
|
176 |
+
|
177 |
+
print(f"Generating sketch for prompt: '{prompt}' with canvas size: {canvas_size}")
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
# Generate sketch image
|
180 |
+
img = create_sketch_image(prompt, canvas_size, canvas_size)
|
181 |
+
|
182 |
+
# Convert to base64
|
183 |
+
buffer = io.BytesIO()
|
184 |
+
img.save(buffer, format='PNG')
|
185 |
+
img_str = base64.b64encode(buffer.getvalue()).decode()
|
186 |
+
|
187 |
+
print(f"Successfully generated {canvas_size}x{canvas_size} sketch image")
|
188 |
+
return img_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
|
|
|
|
190 |
except Exception as e:
|
191 |
+
print(f"Error in DiffSketchEdit handler: {e}")
|
192 |
# Return a white image on error
|
193 |
img = Image.new('RGB', (256, 256), color='white')
|
194 |
+
draw = ImageDraw.Draw(img)
|
195 |
+
draw.text((10, 128), f"Error: {str(e)[:30]}", fill='red')
|
196 |
+
|
197 |
+
# Convert to base64
|
198 |
+
buffer = io.BytesIO()
|
199 |
+
img.save(buffer, format='PNG')
|
200 |
+
img_str = base64.b64encode(buffer.getvalue()).decode()
|
201 |
+
return img_str
|
202 |
|
203 |
|
204 |
# For testing
|
205 |
if __name__ == "__main__":
|
206 |
handler = EndpointHandler()
|
207 |
test_data = {
|
208 |
+
"inputs": "a detailed portrait of an elderly man",
|
209 |
"parameters": {
|
210 |
+
"canvas_size": 256
|
|
|
|
|
211 |
}
|
212 |
}
|
213 |
result = handler(test_data)
|
214 |
+
print(f"Generated base64 image of length: {len(result)}")
|
215 |
+
|
216 |
+
# Test decoding
|
217 |
+
img_data = base64.b64decode(result)
|
218 |
+
img = Image.open(io.BytesIO(img_data))
|
219 |
+
print(f"Decoded image size: {img.size}")
|
220 |
+
img.save("test_diffsketchedit_output.png")
|
221 |
+
print("Saved test image as test_diffsketchedit_output.png")
|