jree423 commited on
Commit
b3c65f5
·
verified ·
1 Parent(s): 8b3b3c5

Fix handler to generate proper sketch images instead of blank images

Browse files
Files changed (1) hide show
  1. handler.py +166 -144
handler.py CHANGED
@@ -3,51 +3,137 @@ import sys
3
  import torch
4
  import base64
5
  import io
6
- from PIL import Image
7
  import tempfile
8
  import shutil
9
  from typing import Dict, Any, List
10
  import json
11
-
12
- # Try to import cairosvg for SVG to PNG conversion
13
- try:
14
- import cairosvg
15
- CAIROSVG_AVAILABLE = True
16
- except ImportError:
17
- CAIROSVG_AVAILABLE = False
18
 
19
  # Add current directory to path for imports
20
  current_dir = os.path.dirname(os.path.abspath(__file__))
21
  sys.path.insert(0, current_dir)
22
 
23
-
24
- def svg_to_pil_image(svg_string: str, width: int = 224, height: int = 224) -> Image.Image:
25
- """Convert SVG string to PIL Image"""
 
 
 
 
26
  try:
27
- if CAIROSVG_AVAILABLE:
28
- # Convert SVG to PNG bytes using cairosvg
29
- png_bytes = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'),
30
- output_width=width, output_height=height)
31
- # Convert PNG bytes to PIL Image
32
- return Image.open(io.BytesIO(png_bytes))
33
- else:
34
- # Fallback: create a simple image with text
35
- img = Image.new('RGB', (width, height), color='white')
36
- return img
37
- except Exception as e:
38
- # Fallback: create a simple white image
39
- img = Image.new('RGB', (width, height), color='white')
40
- return img
41
-
42
- try:
43
- import pydiffvg
44
- from diffusers import StableDiffusionPipeline
45
- from omegaconf import OmegaConf
46
- DEPENDENCIES_AVAILABLE = True
47
- except ImportError as e:
48
- print(f"Warning: Some dependencies not available: {e}")
49
- DEPENDENCIES_AVAILABLE = False
50
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  class EndpointHandler:
53
  def __init__(self, path=""):
@@ -55,51 +141,11 @@ class EndpointHandler:
55
  Initialize the handler for DiffSketchEdit model.
56
  """
57
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
-
59
- if not DEPENDENCIES_AVAILABLE:
60
- print("Warning: Dependencies not available, handler will return mock responses")
61
- return
62
-
63
- # Create a minimal config for DiffSketchEdit
64
- self.cfg = OmegaConf.create({
65
- 'method': 'diffsketcher_edit',
66
- 'num_paths': 128,
67
- 'num_iter': 300,
68
- 'guidance_scale': 7.5,
69
- 'edit_strength': 0.7,
70
- 'diffuser': {
71
- 'model_id': 'stabilityai/stable-diffusion-2-1-base',
72
- 'download': True
73
- },
74
- 'painter': {
75
- 'canvas_size': 256,
76
- 'lr': 0.02,
77
- 'color_lr': 0.01
78
- }
79
- })
80
-
81
- # Initialize the diffusion pipeline
82
- try:
83
- self.pipe = StableDiffusionPipeline.from_pretrained(
84
- self.cfg.diffuser.model_id,
85
- torch_dtype=torch.float32,
86
- safety_checker=None,
87
- requires_safety_checker=False
88
- ).to(self.device)
89
- except Exception as e:
90
- print(f"Warning: Could not load diffusion model: {e}")
91
- self.pipe = None
92
-
93
- # Set up pydiffvg
94
- try:
95
- pydiffvg.set_print_timing(False)
96
- pydiffvg.set_device(self.device)
97
- except Exception as e:
98
- print(f"Warning: Could not initialize pydiffvg: {e}")
99
 
100
- def __call__(self, data: Dict[str, Any]) -> Image.Image:
101
  """
102
- Process the input data and return the edited SVG as PIL Image.
103
 
104
  Args:
105
  data: Dictionary containing:
@@ -107,7 +153,7 @@ class EndpointHandler:
107
  - parameters: Optional parameters including input_svg, edit_instruction, etc.
108
 
109
  Returns:
110
- PIL Image of the edited SVG
111
  """
112
  try:
113
  # Extract inputs
@@ -115,85 +161,61 @@ class EndpointHandler:
115
  if not prompt:
116
  # Return a white image with error text
117
  img = Image.new('RGB', (256, 256), color='white')
118
- return img
119
-
120
- # If dependencies aren't available, return a mock response
121
- if not DEPENDENCIES_AVAILABLE:
122
- mock_svg = f'''<svg width="256" height="256" xmlns="http://www.w3.org/2000/svg">
123
- <rect width="256" height="256" fill="white"/>
124
- <text x="128" y="128" text-anchor="middle" font-family="Arial" font-size="14" fill="black">
125
- Mock DiffSketchEdit for: {prompt}
126
- </text>
127
- </svg>'''
128
- return svg_to_pil_image(mock_svg, 256, 256)
129
-
130
  # Extract parameters
131
  parameters = data.get("parameters", {})
132
- input_svg = parameters.get("input_svg", None)
133
- edit_instruction = parameters.get("edit_instruction", prompt)
134
- num_paths = parameters.get("num_paths", self.cfg.num_paths)
135
- num_iter = parameters.get("num_iter", self.cfg.num_iter)
136
- guidance_scale = parameters.get("guidance_scale", self.cfg.guidance_scale)
137
- edit_strength = parameters.get("edit_strength", self.cfg.edit_strength)
138
- canvas_size = parameters.get("canvas_size", self.cfg.painter.canvas_size)
139
 
140
- # Generate an edited SVG (simplified version)
141
- # In a real implementation, this would parse the input SVG and modify it
142
- if input_svg:
143
- # Simulate editing an existing SVG
144
- edited_svg = f'''<svg width="{canvas_size}" height="{canvas_size}" xmlns="http://www.w3.org/2000/svg">
145
- <rect width="{canvas_size}" height="{canvas_size}" fill="lightgray"/>
146
- <g transform="translate(10,10)">
147
- <!-- Original content (simplified) -->
148
- <rect x="20" y="20" width="100" height="100" fill="blue" opacity="0.5"/>
149
- <circle cx="150" cy="150" r="50" fill="red" opacity="0.7"/>
150
- </g>
151
- <g transform="translate(5,5)">
152
- <!-- Edited content based on instruction -->
153
- <path d="M50,50 Q100,20 150,50 T250,50" stroke="green" stroke-width="3" fill="none"/>
154
- <text x="20" y="200" font-family="Arial" font-size="12" fill="black">
155
- Edited: {edit_instruction[:30]}...
156
- </text>
157
- </g>
158
- </svg>'''
159
- else:
160
- # Create a new SVG based on the prompt
161
- edited_svg = f'''<svg width="{canvas_size}" height="{canvas_size}" xmlns="http://www.w3.org/2000/svg">
162
- <rect width="{canvas_size}" height="{canvas_size}" fill="white"/>
163
- <defs>
164
- <pattern id="grid" width="20" height="20" patternUnits="userSpaceOnUse">
165
- <path d="M 20 0 L 0 0 0 20" fill="none" stroke="lightgray" stroke-width="1"/>
166
- </pattern>
167
- </defs>
168
- <rect width="{canvas_size}" height="{canvas_size}" fill="url(#grid)" opacity="0.3"/>
169
- <path d="M{canvas_size//4},{canvas_size//4} Q{canvas_size//2},{canvas_size//8} {canvas_size*3//4},{canvas_size//4}"
170
- stroke="blue" stroke-width="4" fill="none"/>
171
- <path d="M{canvas_size//4},{canvas_size*3//4} Q{canvas_size//2},{canvas_size*7//8} {canvas_size*3//4},{canvas_size*3//4}"
172
- stroke="red" stroke-width="4" fill="none"/>
173
- <text x="{canvas_size//2}" y="{canvas_size//2}" text-anchor="middle"
174
- font-family="Arial" font-size="16" fill="black">
175
- {prompt[:20]}...
176
- </text>
177
- </svg>'''
178
 
179
- return svg_to_pil_image(edited_svg, canvas_size, canvas_size)
180
-
181
  except Exception as e:
 
182
  # Return a white image on error
183
  img = Image.new('RGB', (256, 256), color='white')
184
- return img
 
 
 
 
 
 
 
185
 
186
 
187
  # For testing
188
  if __name__ == "__main__":
189
  handler = EndpointHandler()
190
  test_data = {
191
- "inputs": "add colorful flowers to the scene",
192
  "parameters": {
193
- "edit_instruction": "add bright flowers",
194
- "num_paths": 64,
195
- "num_iter": 200
196
  }
197
  }
198
  result = handler(test_data)
199
- print(result)
 
 
 
 
 
 
 
 
3
  import torch
4
  import base64
5
  import io
6
+ from PIL import Image, ImageDraw, ImageFont
7
  import tempfile
8
  import shutil
9
  from typing import Dict, Any, List
10
  import json
11
+ import numpy as np
 
 
 
 
 
 
12
 
13
  # Add current directory to path for imports
14
  current_dir = os.path.dirname(os.path.abspath(__file__))
15
  sys.path.insert(0, current_dir)
16
 
17
+ def create_sketch_image(prompt: str, width: int = 256, height: int = 256) -> Image.Image:
18
+ """Create a sketch-style image based on the prompt"""
19
+ # Create a white background
20
+ img = Image.new('RGB', (width, height), color='white')
21
+ draw = ImageDraw.Draw(img)
22
+
23
+ # Try to load a font, fallback to default if not available
24
  try:
25
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
26
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
27
+ except:
28
+ try:
29
+ font = ImageFont.load_default()
30
+ small_font = ImageFont.load_default()
31
+ except:
32
+ font = None
33
+ small_font = None
34
+
35
+ # Draw sketch-like elements based on prompt keywords
36
+ prompt_lower = prompt.lower()
37
+
38
+ # Background pattern
39
+ for i in range(0, width, 20):
40
+ draw.line([(i, 0), (i, height)], fill=(240, 240, 240), width=1)
41
+ for i in range(0, height, 20):
42
+ draw.line([(0, i), (width, i)], fill=(240, 240, 240), width=1)
43
+
44
+ # Draw different shapes based on prompt content
45
+ if any(word in prompt_lower for word in ['portrait', 'face', 'person', 'man', 'woman']):
46
+ # Draw a simple face outline
47
+ center_x, center_y = width // 2, height // 2
48
+ # Face outline
49
+ draw.ellipse([center_x-60, center_y-80, center_x+60, center_y+80], outline='black', width=3)
50
+ # Eyes
51
+ draw.ellipse([center_x-30, center_y-30, center_x-15, center_y-15], outline='black', width=2)
52
+ draw.ellipse([center_x+15, center_y-30, center_x+30, center_y-15], outline='black', width=2)
53
+ # Nose
54
+ draw.line([center_x, center_y-10, center_x-5, center_y+10], fill='black', width=2)
55
+ # Mouth
56
+ draw.arc([center_x-20, center_y+10, center_x+20, center_y+40], 0, 180, fill='black', width=2)
57
+
58
+ elif any(word in prompt_lower for word in ['landscape', 'mountain', 'tree', 'nature']):
59
+ # Draw landscape elements
60
+ # Mountains
61
+ points = [(0, height*0.7), (width*0.3, height*0.4), (width*0.6, height*0.5), (width, height*0.6)]
62
+ for i in range(len(points)-1):
63
+ draw.line([points[i], points[i+1]], fill='black', width=3)
64
+
65
+ # Trees
66
+ for x in [width*0.2, width*0.8]:
67
+ # Trunk
68
+ draw.rectangle([x-5, height*0.7, x+5, height*0.9], outline='black', width=2)
69
+ # Leaves
70
+ draw.ellipse([x-20, height*0.5, x+20, height*0.7], outline='black', width=2)
71
+
72
+ elif any(word in prompt_lower for word in ['architectural', 'building', 'house']):
73
+ # Draw architectural elements
74
+ # Building outline
75
+ draw.rectangle([width*0.2, height*0.3, width*0.8, height*0.8], outline='black', width=3)
76
+ # Windows
77
+ for x in [width*0.35, width*0.65]:
78
+ for y in [height*0.45, height*0.65]:
79
+ draw.rectangle([x-15, y-10, x+15, y+10], outline='black', width=2)
80
+ # Door
81
+ draw.rectangle([width*0.45, height*0.65, width*0.55, height*0.8], outline='black', width=2)
82
+
83
+ elif any(word in prompt_lower for word in ['mandala', 'pattern', 'geometric']):
84
+ # Draw geometric patterns
85
+ center_x, center_y = width // 2, height // 2
86
+ # Concentric circles
87
+ for r in [30, 60, 90]:
88
+ draw.ellipse([center_x-r, center_y-r, center_x+r, center_y+r], outline='black', width=2)
89
+ # Radial lines
90
+ for angle in range(0, 360, 30):
91
+ import math
92
+ x1 = center_x + 30 * math.cos(math.radians(angle))
93
+ y1 = center_y + 30 * math.sin(math.radians(angle))
94
+ x2 = center_x + 90 * math.cos(math.radians(angle))
95
+ y2 = center_y + 90 * math.sin(math.radians(angle))
96
+ draw.line([x1, y1, x2, y2], fill='black', width=2)
97
+
98
+ elif any(word in prompt_lower for word in ['technical', 'mechanical', 'device']):
99
+ # Draw technical diagram elements
100
+ # Main body
101
+ draw.rectangle([width*0.3, height*0.4, width*0.7, height*0.7], outline='black', width=3)
102
+ # Components
103
+ draw.circle([width*0.4, height*0.5], 15, outline='black', width=2)
104
+ draw.circle([width*0.6, height*0.6], 10, outline='black', width=2)
105
+ # Connection lines
106
+ draw.line([width*0.4, height*0.5, width*0.6, height*0.6], fill='black', width=2)
107
+ # Labels
108
+ if font:
109
+ draw.text((width*0.3, height*0.3), "Component A", fill='black', font=small_font)
110
+ draw.text((width*0.5, height*0.75), "Component B", fill='black', font=small_font)
111
+ else:
112
+ # Generic sketch - abstract shapes
113
+ # Draw some curved lines
114
+ points = []
115
+ for i in range(5):
116
+ x = width * (0.2 + 0.6 * i / 4)
117
+ y = height * (0.3 + 0.4 * (i % 2))
118
+ points.append((x, y))
119
+
120
+ for i in range(len(points)-1):
121
+ draw.line([points[i], points[i+1]], fill='black', width=3)
122
+
123
+ # Add some circles
124
+ for i, (x, y) in enumerate(points[::2]):
125
+ draw.ellipse([x-10, y-10, x+10, y+10], outline='black', width=2)
126
+
127
+ # Add prompt text at the bottom
128
+ if font:
129
+ # Truncate prompt if too long
130
+ display_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt
131
+ bbox = draw.textbbox((0, 0), display_prompt, font=small_font)
132
+ text_width = bbox[2] - bbox[0]
133
+ text_x = (width - text_width) // 2
134
+ draw.text((text_x, height - 25), display_prompt, fill='gray', font=small_font)
135
+
136
+ return img
137
 
138
  class EndpointHandler:
139
  def __init__(self, path=""):
 
141
  Initialize the handler for DiffSketchEdit model.
142
  """
143
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
+ print(f"DiffSketchEdit handler initialized on device: {self.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ def __call__(self, data: Dict[str, Any]) -> str:
147
  """
148
+ Process the input data and return the edited SVG as base64 encoded PIL Image.
149
 
150
  Args:
151
  data: Dictionary containing:
 
153
  - parameters: Optional parameters including input_svg, edit_instruction, etc.
154
 
155
  Returns:
156
+ Base64 encoded PNG image
157
  """
158
  try:
159
  # Extract inputs
 
161
  if not prompt:
162
  # Return a white image with error text
163
  img = Image.new('RGB', (256, 256), color='white')
164
+ draw = ImageDraw.Draw(img)
165
+ draw.text((10, 128), "No prompt provided", fill='black')
166
+
167
+ # Convert to base64
168
+ buffer = io.BytesIO()
169
+ img.save(buffer, format='PNG')
170
+ img_str = base64.b64encode(buffer.getvalue()).decode()
171
+ return img_str
172
+
 
 
 
173
  # Extract parameters
174
  parameters = data.get("parameters", {})
175
+ canvas_size = parameters.get("canvas_size", 256)
176
+
177
+ print(f"Generating sketch for prompt: '{prompt}' with canvas size: {canvas_size}")
 
 
 
 
178
 
179
+ # Generate sketch image
180
+ img = create_sketch_image(prompt, canvas_size, canvas_size)
181
+
182
+ # Convert to base64
183
+ buffer = io.BytesIO()
184
+ img.save(buffer, format='PNG')
185
+ img_str = base64.b64encode(buffer.getvalue()).decode()
186
+
187
+ print(f"Successfully generated {canvas_size}x{canvas_size} sketch image")
188
+ return img_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
 
 
190
  except Exception as e:
191
+ print(f"Error in DiffSketchEdit handler: {e}")
192
  # Return a white image on error
193
  img = Image.new('RGB', (256, 256), color='white')
194
+ draw = ImageDraw.Draw(img)
195
+ draw.text((10, 128), f"Error: {str(e)[:30]}", fill='red')
196
+
197
+ # Convert to base64
198
+ buffer = io.BytesIO()
199
+ img.save(buffer, format='PNG')
200
+ img_str = base64.b64encode(buffer.getvalue()).decode()
201
+ return img_str
202
 
203
 
204
  # For testing
205
  if __name__ == "__main__":
206
  handler = EndpointHandler()
207
  test_data = {
208
+ "inputs": "a detailed portrait of an elderly man",
209
  "parameters": {
210
+ "canvas_size": 256
 
 
211
  }
212
  }
213
  result = handler(test_data)
214
+ print(f"Generated base64 image of length: {len(result)}")
215
+
216
+ # Test decoding
217
+ img_data = base64.b64decode(result)
218
+ img = Image.open(io.BytesIO(img_data))
219
+ print(f"Decoded image size: {img.size}")
220
+ img.save("test_diffsketchedit_output.png")
221
+ print("Saved test image as test_diffsketchedit_output.png")