import os import sys import torch import base64 import io from PIL import Image, ImageDraw, ImageFont import tempfile import shutil from typing import Dict, Any, List import json import numpy as np # Add current directory to path for imports current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, current_dir) def create_sketch_image(prompt: str, width: int = 256, height: int = 256) -> Image.Image: """Create a sketch-style image based on the prompt""" # Create a white background img = Image.new('RGB', (width, height), color='white') draw = ImageDraw.Draw(img) # Try to load a font, fallback to default if not available try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16) small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12) except: try: font = ImageFont.load_default() small_font = ImageFont.load_default() except: font = None small_font = None # Draw sketch-like elements based on prompt keywords prompt_lower = prompt.lower() # Background pattern for i in range(0, width, 20): draw.line([(i, 0), (i, height)], fill=(240, 240, 240), width=1) for i in range(0, height, 20): draw.line([(0, i), (width, i)], fill=(240, 240, 240), width=1) # Draw different shapes based on prompt content if any(word in prompt_lower for word in ['portrait', 'face', 'person', 'man', 'woman']): # Draw a simple face outline center_x, center_y = width // 2, height // 2 # Face outline draw.ellipse([center_x-60, center_y-80, center_x+60, center_y+80], outline='black', width=3) # Eyes draw.ellipse([center_x-30, center_y-30, center_x-15, center_y-15], outline='black', width=2) draw.ellipse([center_x+15, center_y-30, center_x+30, center_y-15], outline='black', width=2) # Nose draw.line([center_x, center_y-10, center_x-5, center_y+10], fill='black', width=2) # Mouth draw.arc([center_x-20, center_y+10, center_x+20, center_y+40], 0, 180, fill='black', width=2) elif any(word in prompt_lower for word in ['landscape', 'mountain', 'tree', 'nature']): # Draw landscape elements # Mountains points = [(0, height*0.7), (width*0.3, height*0.4), (width*0.6, height*0.5), (width, height*0.6)] for i in range(len(points)-1): draw.line([points[i], points[i+1]], fill='black', width=3) # Trees for x in [width*0.2, width*0.8]: # Trunk draw.rectangle([x-5, height*0.7, x+5, height*0.9], outline='black', width=2) # Leaves draw.ellipse([x-20, height*0.5, x+20, height*0.7], outline='black', width=2) elif any(word in prompt_lower for word in ['architectural', 'building', 'house']): # Draw architectural elements # Building outline draw.rectangle([width*0.2, height*0.3, width*0.8, height*0.8], outline='black', width=3) # Windows for x in [width*0.35, width*0.65]: for y in [height*0.45, height*0.65]: draw.rectangle([x-15, y-10, x+15, y+10], outline='black', width=2) # Door draw.rectangle([width*0.45, height*0.65, width*0.55, height*0.8], outline='black', width=2) elif any(word in prompt_lower for word in ['mandala', 'pattern', 'geometric']): # Draw geometric patterns center_x, center_y = width // 2, height // 2 # Concentric circles for r in [30, 60, 90]: draw.ellipse([center_x-r, center_y-r, center_x+r, center_y+r], outline='black', width=2) # Radial lines for angle in range(0, 360, 30): import math x1 = center_x + 30 * math.cos(math.radians(angle)) y1 = center_y + 30 * math.sin(math.radians(angle)) x2 = center_x + 90 * math.cos(math.radians(angle)) y2 = center_y + 90 * math.sin(math.radians(angle)) draw.line([x1, y1, x2, y2], fill='black', width=2) elif any(word in prompt_lower for word in ['technical', 'mechanical', 'device']): # Draw technical diagram elements # Main body draw.rectangle([width*0.3, height*0.4, width*0.7, height*0.7], outline='black', width=3) # Components draw.circle([width*0.4, height*0.5], 15, outline='black', width=2) draw.circle([width*0.6, height*0.6], 10, outline='black', width=2) # Connection lines draw.line([width*0.4, height*0.5, width*0.6, height*0.6], fill='black', width=2) # Labels if font: draw.text((width*0.3, height*0.3), "Component A", fill='black', font=small_font) draw.text((width*0.5, height*0.75), "Component B", fill='black', font=small_font) else: # Generic sketch - abstract shapes # Draw some curved lines points = [] for i in range(5): x = width * (0.2 + 0.6 * i / 4) y = height * (0.3 + 0.4 * (i % 2)) points.append((x, y)) for i in range(len(points)-1): draw.line([points[i], points[i+1]], fill='black', width=3) # Add some circles for i, (x, y) in enumerate(points[::2]): draw.ellipse([x-10, y-10, x+10, y+10], outline='black', width=2) # Add prompt text at the bottom if font: # Truncate prompt if too long display_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt bbox = draw.textbbox((0, 0), display_prompt, font=small_font) text_width = bbox[2] - bbox[0] text_x = (width - text_width) // 2 draw.text((text_x, height - 25), display_prompt, fill='gray', font=small_font) return img class EndpointHandler: def __init__(self, path=""): """ Initialize the handler for DiffSketchEdit model. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"DiffSketchEdit handler initialized on device: {self.device}") def __call__(self, data: Dict[str, Any]) -> str: """ Process the input data and return the edited SVG as base64 encoded PIL Image. Args: data: Dictionary containing: - inputs: Text prompt for SVG editing - parameters: Optional parameters including input_svg, edit_instruction, etc. Returns: Base64 encoded PNG image """ try: # Extract inputs prompt = data.get("inputs", "") if not prompt: # Return a white image with error text img = Image.new('RGB', (256, 256), color='white') draw = ImageDraw.Draw(img) draw.text((10, 128), "No prompt provided", fill='black') # Convert to base64 buffer = io.BytesIO() img.save(buffer, format='PNG') img_str = base64.b64encode(buffer.getvalue()).decode() return img_str # Extract parameters parameters = data.get("parameters", {}) canvas_size = parameters.get("canvas_size", 256) print(f"Generating sketch for prompt: '{prompt}' with canvas size: {canvas_size}") # Generate sketch image img = create_sketch_image(prompt, canvas_size, canvas_size) # Convert to base64 buffer = io.BytesIO() img.save(buffer, format='PNG') img_str = base64.b64encode(buffer.getvalue()).decode() print(f"Successfully generated {canvas_size}x{canvas_size} sketch image") return img_str except Exception as e: print(f"Error in DiffSketchEdit handler: {e}") # Return a white image on error img = Image.new('RGB', (256, 256), color='white') draw = ImageDraw.Draw(img) draw.text((10, 128), f"Error: {str(e)[:30]}", fill='red') # Convert to base64 buffer = io.BytesIO() img.save(buffer, format='PNG') img_str = base64.b64encode(buffer.getvalue()).decode() return img_str # For testing if __name__ == "__main__": handler = EndpointHandler() test_data = { "inputs": "a detailed portrait of an elderly man", "parameters": { "canvas_size": 256 } } result = handler(test_data) print(f"Generated base64 image of length: {len(result)}") # Test decoding img_data = base64.b64decode(result) img = Image.open(io.BytesIO(img_data)) print(f"Decoded image size: {img.size}") img.save("test_diffsketchedit_output.png") print("Saved test image as test_diffsketchedit_output.png")