File size: 9,035 Bytes
5560bac
 
 
8675cd3
 
b3c65f5
8675cd3
 
 
5560bac
b3c65f5
0423f95
8675cd3
 
 
5560bac
b3c65f5
 
 
 
 
 
 
0423f95
b3c65f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5560bac
 
8675cd3
 
 
 
5560bac
b3c65f5
8675cd3
b3c65f5
8675cd3
b3c65f5
326670f
8675cd3
 
 
 
5560bac
8675cd3
b3c65f5
8675cd3
5560bac
8675cd3
 
 
0423f95
 
b3c65f5
 
 
 
 
 
 
 
 
5560bac
8675cd3
b3c65f5
 
 
5560bac
b3c65f5
 
 
 
 
 
 
 
 
 
5560bac
 
b3c65f5
0423f95
 
b3c65f5
 
 
 
 
 
 
 
8675cd3
 
 
 
 
 
b3c65f5
8675cd3
b3c65f5
8675cd3
 
 
b3c65f5
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
import sys
import torch
import base64
import io
from PIL import Image, ImageDraw, ImageFont
import tempfile
import shutil
from typing import Dict, Any, List
import json
import numpy as np

# Add current directory to path for imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)

def create_sketch_image(prompt: str, width: int = 256, height: int = 256) -> Image.Image:
    """Create a sketch-style image based on the prompt"""
    # Create a white background
    img = Image.new('RGB', (width, height), color='white')
    draw = ImageDraw.Draw(img)
    
    # Try to load a font, fallback to default if not available
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
        small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
    except:
        try:
            font = ImageFont.load_default()
            small_font = ImageFont.load_default()
        except:
            font = None
            small_font = None
    
    # Draw sketch-like elements based on prompt keywords
    prompt_lower = prompt.lower()
    
    # Background pattern
    for i in range(0, width, 20):
        draw.line([(i, 0), (i, height)], fill=(240, 240, 240), width=1)
    for i in range(0, height, 20):
        draw.line([(0, i), (width, i)], fill=(240, 240, 240), width=1)
    
    # Draw different shapes based on prompt content
    if any(word in prompt_lower for word in ['portrait', 'face', 'person', 'man', 'woman']):
        # Draw a simple face outline
        center_x, center_y = width // 2, height // 2
        # Face outline
        draw.ellipse([center_x-60, center_y-80, center_x+60, center_y+80], outline='black', width=3)
        # Eyes
        draw.ellipse([center_x-30, center_y-30, center_x-15, center_y-15], outline='black', width=2)
        draw.ellipse([center_x+15, center_y-30, center_x+30, center_y-15], outline='black', width=2)
        # Nose
        draw.line([center_x, center_y-10, center_x-5, center_y+10], fill='black', width=2)
        # Mouth
        draw.arc([center_x-20, center_y+10, center_x+20, center_y+40], 0, 180, fill='black', width=2)
        
    elif any(word in prompt_lower for word in ['landscape', 'mountain', 'tree', 'nature']):
        # Draw landscape elements
        # Mountains
        points = [(0, height*0.7), (width*0.3, height*0.4), (width*0.6, height*0.5), (width, height*0.6)]
        for i in range(len(points)-1):
            draw.line([points[i], points[i+1]], fill='black', width=3)
        
        # Trees
        for x in [width*0.2, width*0.8]:
            # Trunk
            draw.rectangle([x-5, height*0.7, x+5, height*0.9], outline='black', width=2)
            # Leaves
            draw.ellipse([x-20, height*0.5, x+20, height*0.7], outline='black', width=2)
            
    elif any(word in prompt_lower for word in ['architectural', 'building', 'house']):
        # Draw architectural elements
        # Building outline
        draw.rectangle([width*0.2, height*0.3, width*0.8, height*0.8], outline='black', width=3)
        # Windows
        for x in [width*0.35, width*0.65]:
            for y in [height*0.45, height*0.65]:
                draw.rectangle([x-15, y-10, x+15, y+10], outline='black', width=2)
        # Door
        draw.rectangle([width*0.45, height*0.65, width*0.55, height*0.8], outline='black', width=2)
        
    elif any(word in prompt_lower for word in ['mandala', 'pattern', 'geometric']):
        # Draw geometric patterns
        center_x, center_y = width // 2, height // 2
        # Concentric circles
        for r in [30, 60, 90]:
            draw.ellipse([center_x-r, center_y-r, center_x+r, center_y+r], outline='black', width=2)
        # Radial lines
        for angle in range(0, 360, 30):
            import math
            x1 = center_x + 30 * math.cos(math.radians(angle))
            y1 = center_y + 30 * math.sin(math.radians(angle))
            x2 = center_x + 90 * math.cos(math.radians(angle))
            y2 = center_y + 90 * math.sin(math.radians(angle))
            draw.line([x1, y1, x2, y2], fill='black', width=2)
            
    elif any(word in prompt_lower for word in ['technical', 'mechanical', 'device']):
        # Draw technical diagram elements
        # Main body
        draw.rectangle([width*0.3, height*0.4, width*0.7, height*0.7], outline='black', width=3)
        # Components
        draw.circle([width*0.4, height*0.5], 15, outline='black', width=2)
        draw.circle([width*0.6, height*0.6], 10, outline='black', width=2)
        # Connection lines
        draw.line([width*0.4, height*0.5, width*0.6, height*0.6], fill='black', width=2)
        # Labels
        if font:
            draw.text((width*0.3, height*0.3), "Component A", fill='black', font=small_font)
            draw.text((width*0.5, height*0.75), "Component B", fill='black', font=small_font)
    else:
        # Generic sketch - abstract shapes
        # Draw some curved lines
        points = []
        for i in range(5):
            x = width * (0.2 + 0.6 * i / 4)
            y = height * (0.3 + 0.4 * (i % 2))
            points.append((x, y))
        
        for i in range(len(points)-1):
            draw.line([points[i], points[i+1]], fill='black', width=3)
            
        # Add some circles
        for i, (x, y) in enumerate(points[::2]):
            draw.ellipse([x-10, y-10, x+10, y+10], outline='black', width=2)
    
    # Add prompt text at the bottom
    if font:
        # Truncate prompt if too long
        display_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt
        bbox = draw.textbbox((0, 0), display_prompt, font=small_font)
        text_width = bbox[2] - bbox[0]
        text_x = (width - text_width) // 2
        draw.text((text_x, height - 25), display_prompt, fill='gray', font=small_font)
    
    return img

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the handler for DiffSketchEdit model.
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"DiffSketchEdit handler initialized on device: {self.device}")

    def __call__(self, data: Dict[str, Any]) -> str:
        """
        Process the input data and return the edited SVG as base64 encoded PIL Image.
        
        Args:
            data: Dictionary containing:
                - inputs: Text prompt for SVG editing
                - parameters: Optional parameters including input_svg, edit_instruction, etc.
        
        Returns:
            Base64 encoded PNG image
        """
        try:
            # Extract inputs
            prompt = data.get("inputs", "")
            if not prompt:
                # Return a white image with error text
                img = Image.new('RGB', (256, 256), color='white')
                draw = ImageDraw.Draw(img)
                draw.text((10, 128), "No prompt provided", fill='black')
                
                # Convert to base64
                buffer = io.BytesIO()
                img.save(buffer, format='PNG')
                img_str = base64.b64encode(buffer.getvalue()).decode()
                return img_str

            # Extract parameters
            parameters = data.get("parameters", {})
            canvas_size = parameters.get("canvas_size", 256)
            
            print(f"Generating sketch for prompt: '{prompt}' with canvas size: {canvas_size}")
            
            # Generate sketch image
            img = create_sketch_image(prompt, canvas_size, canvas_size)
            
            # Convert to base64
            buffer = io.BytesIO()
            img.save(buffer, format='PNG')
            img_str = base64.b64encode(buffer.getvalue()).decode()
            
            print(f"Successfully generated {canvas_size}x{canvas_size} sketch image")
            return img_str
            
        except Exception as e:
            print(f"Error in DiffSketchEdit handler: {e}")
            # Return a white image on error
            img = Image.new('RGB', (256, 256), color='white')
            draw = ImageDraw.Draw(img)
            draw.text((10, 128), f"Error: {str(e)[:30]}", fill='red')
            
            # Convert to base64
            buffer = io.BytesIO()
            img.save(buffer, format='PNG')
            img_str = base64.b64encode(buffer.getvalue()).decode()
            return img_str


# For testing
if __name__ == "__main__":
    handler = EndpointHandler()
    test_data = {
        "inputs": "a detailed portrait of an elderly man",
        "parameters": {
            "canvas_size": 256
        }
    }
    result = handler(test_data)
    print(f"Generated base64 image of length: {len(result)}")
    
    # Test decoding
    img_data = base64.b64decode(result)
    img = Image.open(io.BytesIO(img_data))
    print(f"Decoded image size: {img.size}")
    img.save("test_diffsketchedit_output.png")
    print("Saved test image as test_diffsketchedit_output.png")