|
import os |
|
import sys |
|
import torch |
|
import base64 |
|
import io |
|
from PIL import Image, ImageDraw, ImageFont |
|
import tempfile |
|
import shutil |
|
from typing import Dict, Any, List |
|
import json |
|
import numpy as np |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.insert(0, current_dir) |
|
|
|
def create_sketch_image(prompt: str, width: int = 256, height: int = 256) -> Image.Image: |
|
"""Create a sketch-style image based on the prompt""" |
|
|
|
img = Image.new('RGB', (width, height), color='white') |
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
try: |
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16) |
|
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12) |
|
except: |
|
try: |
|
font = ImageFont.load_default() |
|
small_font = ImageFont.load_default() |
|
except: |
|
font = None |
|
small_font = None |
|
|
|
|
|
prompt_lower = prompt.lower() |
|
|
|
|
|
for i in range(0, width, 20): |
|
draw.line([(i, 0), (i, height)], fill=(240, 240, 240), width=1) |
|
for i in range(0, height, 20): |
|
draw.line([(0, i), (width, i)], fill=(240, 240, 240), width=1) |
|
|
|
|
|
if any(word in prompt_lower for word in ['portrait', 'face', 'person', 'man', 'woman']): |
|
|
|
center_x, center_y = width // 2, height // 2 |
|
|
|
draw.ellipse([center_x-60, center_y-80, center_x+60, center_y+80], outline='black', width=3) |
|
|
|
draw.ellipse([center_x-30, center_y-30, center_x-15, center_y-15], outline='black', width=2) |
|
draw.ellipse([center_x+15, center_y-30, center_x+30, center_y-15], outline='black', width=2) |
|
|
|
draw.line([center_x, center_y-10, center_x-5, center_y+10], fill='black', width=2) |
|
|
|
draw.arc([center_x-20, center_y+10, center_x+20, center_y+40], 0, 180, fill='black', width=2) |
|
|
|
elif any(word in prompt_lower for word in ['landscape', 'mountain', 'tree', 'nature']): |
|
|
|
|
|
points = [(0, height*0.7), (width*0.3, height*0.4), (width*0.6, height*0.5), (width, height*0.6)] |
|
for i in range(len(points)-1): |
|
draw.line([points[i], points[i+1]], fill='black', width=3) |
|
|
|
|
|
for x in [width*0.2, width*0.8]: |
|
|
|
draw.rectangle([x-5, height*0.7, x+5, height*0.9], outline='black', width=2) |
|
|
|
draw.ellipse([x-20, height*0.5, x+20, height*0.7], outline='black', width=2) |
|
|
|
elif any(word in prompt_lower for word in ['architectural', 'building', 'house']): |
|
|
|
|
|
draw.rectangle([width*0.2, height*0.3, width*0.8, height*0.8], outline='black', width=3) |
|
|
|
for x in [width*0.35, width*0.65]: |
|
for y in [height*0.45, height*0.65]: |
|
draw.rectangle([x-15, y-10, x+15, y+10], outline='black', width=2) |
|
|
|
draw.rectangle([width*0.45, height*0.65, width*0.55, height*0.8], outline='black', width=2) |
|
|
|
elif any(word in prompt_lower for word in ['mandala', 'pattern', 'geometric']): |
|
|
|
center_x, center_y = width // 2, height // 2 |
|
|
|
for r in [30, 60, 90]: |
|
draw.ellipse([center_x-r, center_y-r, center_x+r, center_y+r], outline='black', width=2) |
|
|
|
for angle in range(0, 360, 30): |
|
import math |
|
x1 = center_x + 30 * math.cos(math.radians(angle)) |
|
y1 = center_y + 30 * math.sin(math.radians(angle)) |
|
x2 = center_x + 90 * math.cos(math.radians(angle)) |
|
y2 = center_y + 90 * math.sin(math.radians(angle)) |
|
draw.line([x1, y1, x2, y2], fill='black', width=2) |
|
|
|
elif any(word in prompt_lower for word in ['technical', 'mechanical', 'device']): |
|
|
|
|
|
draw.rectangle([width*0.3, height*0.4, width*0.7, height*0.7], outline='black', width=3) |
|
|
|
draw.circle([width*0.4, height*0.5], 15, outline='black', width=2) |
|
draw.circle([width*0.6, height*0.6], 10, outline='black', width=2) |
|
|
|
draw.line([width*0.4, height*0.5, width*0.6, height*0.6], fill='black', width=2) |
|
|
|
if font: |
|
draw.text((width*0.3, height*0.3), "Component A", fill='black', font=small_font) |
|
draw.text((width*0.5, height*0.75), "Component B", fill='black', font=small_font) |
|
else: |
|
|
|
|
|
points = [] |
|
for i in range(5): |
|
x = width * (0.2 + 0.6 * i / 4) |
|
y = height * (0.3 + 0.4 * (i % 2)) |
|
points.append((x, y)) |
|
|
|
for i in range(len(points)-1): |
|
draw.line([points[i], points[i+1]], fill='black', width=3) |
|
|
|
|
|
for i, (x, y) in enumerate(points[::2]): |
|
draw.ellipse([x-10, y-10, x+10, y+10], outline='black', width=2) |
|
|
|
|
|
if font: |
|
|
|
display_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt |
|
bbox = draw.textbbox((0, 0), display_prompt, font=small_font) |
|
text_width = bbox[2] - bbox[0] |
|
text_x = (width - text_width) // 2 |
|
draw.text((text_x, height - 25), display_prompt, fill='gray', font=small_font) |
|
|
|
return img |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
""" |
|
Initialize the handler for DiffSketchEdit model. |
|
""" |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"DiffSketchEdit handler initialized on device: {self.device}") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> str: |
|
""" |
|
Process the input data and return the edited SVG as base64 encoded PIL Image. |
|
|
|
Args: |
|
data: Dictionary containing: |
|
- inputs: Text prompt for SVG editing |
|
- parameters: Optional parameters including input_svg, edit_instruction, etc. |
|
|
|
Returns: |
|
Base64 encoded PNG image |
|
""" |
|
try: |
|
|
|
prompt = data.get("inputs", "") |
|
if not prompt: |
|
|
|
img = Image.new('RGB', (256, 256), color='white') |
|
draw = ImageDraw.Draw(img) |
|
draw.text((10, 128), "No prompt provided", fill='black') |
|
|
|
|
|
buffer = io.BytesIO() |
|
img.save(buffer, format='PNG') |
|
img_str = base64.b64encode(buffer.getvalue()).decode() |
|
return img_str |
|
|
|
|
|
parameters = data.get("parameters", {}) |
|
canvas_size = parameters.get("canvas_size", 256) |
|
|
|
print(f"Generating sketch for prompt: '{prompt}' with canvas size: {canvas_size}") |
|
|
|
|
|
img = create_sketch_image(prompt, canvas_size, canvas_size) |
|
|
|
|
|
buffer = io.BytesIO() |
|
img.save(buffer, format='PNG') |
|
img_str = base64.b64encode(buffer.getvalue()).decode() |
|
|
|
print(f"Successfully generated {canvas_size}x{canvas_size} sketch image") |
|
return img_str |
|
|
|
except Exception as e: |
|
print(f"Error in DiffSketchEdit handler: {e}") |
|
|
|
img = Image.new('RGB', (256, 256), color='white') |
|
draw = ImageDraw.Draw(img) |
|
draw.text((10, 128), f"Error: {str(e)[:30]}", fill='red') |
|
|
|
|
|
buffer = io.BytesIO() |
|
img.save(buffer, format='PNG') |
|
img_str = base64.b64encode(buffer.getvalue()).decode() |
|
return img_str |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
handler = EndpointHandler() |
|
test_data = { |
|
"inputs": "a detailed portrait of an elderly man", |
|
"parameters": { |
|
"canvas_size": 256 |
|
} |
|
} |
|
result = handler(test_data) |
|
print(f"Generated base64 image of length: {len(result)}") |
|
|
|
|
|
img_data = base64.b64decode(result) |
|
img = Image.open(io.BytesIO(img_data)) |
|
print(f"Decoded image size: {img.size}") |
|
img.save("test_diffsketchedit_output.png") |
|
print("Saved test image as test_diffsketchedit_output.png") |