File size: 5,733 Bytes
ad84c6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343e5a8
ad84c6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from flask import Flask, request, jsonify
from io import BytesIO
import base64
from PIL import Image
import torch
import torchvision.transforms.functional as F
from torch.cuda.amp import autocast
from flask_cors import CORS  # Import CORS

from src.pix2pix_turbo import Pix2Pix_Turbo

app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

# Configuration Variables
model_type = "sketch_to_image_stochastic"
output_format = "PNG"
desired_size = (768, 768)  # Increased resolution for better quality

# Load the model when the app starts
print("Loading model...")
model = Pix2Pix_Turbo(model_type)
print("Model loaded successfully.")

# Example styles list (update this with your actual styles)
style_list = [
    {"name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"},
    {"name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting"},
    {"name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime,  highly detailed"},
    {"name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed"},
    {"name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed"},
    {"name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics"},
    {"name": "Fantasy art", "prompt": "ethereal fantasy concept art of  {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"},
    {"name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"},
    {"name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style"},
]

styles = {k["name"]: k["prompt"] for k in style_list}

def process_image(image, prompt, prompt_template, style_name, seed, val_r):
    image = image.convert("RGB")
    
    # Convert image to tensor and threshold, then convert to float
    image_t = F.to_tensor(image) > 0.5
    image_t = image_t.float()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    with torch.no_grad(), autocast():
        # Move the tensor to the appropriate device
        c_t = image_t.unsqueeze(0).to(device).float()
        torch.manual_seed(seed)
        B, C, H, W = c_t.shape
        noise = torch.randn((1, 4, H // 8, W // 8), device=device)  # Ensure noise is on the same device
        logging.debug("Calling Pix2Pix model... ct: {}, prompt: {}, deterministic: False, r: {}, noise_map: {}".format(c_t.shape, prompt, val_r, noise.shape))  
        
        # Pass through the model
        output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
        
    output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
    return output_pil

@app.route('/process-image', methods=['POST'])
def process_image_route():
    data = request.get_json()

    # Debugging: Print the raw received data
    print("Received JSON data:", data)

    if not data or 'image' not in data:
        print("Error: No image provided")
        return jsonify({"error": "No image provided"}), 400

    # Decode the base64 image (remove the prefix 'data:image/png;base64,' if present)
    image_data = data['image']
    print("Received base64 image data (truncated):", image_data[:100])  # Print first 100 chars of base64 data

    if image_data.startswith('data:image/png;base64,'):
        image_data = image_data.split(",")[1]
    
    try:
        image_bytes = base64.b64decode(image_data)
        image = Image.open(BytesIO(image_bytes))
    except Exception as e:
        print("Error decoding base64 image:", str(e))
        return jsonify({"error": "Invalid image data"}), 400

    # Retrieve other parameters
    prompt = data.get('prompt', 'a cat')
    style_name = data.get('style_name', 'Fantasy art').strip()  # Strip any leading/trailing whitespace
    seed = int(data.get('seed', 42))
    val_r = float(data.get('val_r', 0.8))

    # Debug: print available styles
    print(f"Available styles: {list(styles.keys())}")
    print(f"Received style name: {style_name}")

    # Case-insensitive lookup
    style_name = next((key for key in styles if key.lower() == style_name.lower()), None)
    if not style_name:
        print(f"Error: Style '{data.get('style_name')}' not found")
        return jsonify({"error": f"Style '{data.get('style_name')}' not found"}), 400

    prompt_template = styles[style_name]

    print(f"Using style: {style_name} with prompt: {prompt}")

    # Process the image
    try:
        processed_image = process_image(image, prompt, prompt_template, style_name, seed, val_r)
    except Exception as e:
        print("Error processing image:", str(e))
        return jsonify({"error": "Failed to process image"}), 500

    # Convert the processed image to base64
    img_io = BytesIO()
    processed_image.save(img_io, format=output_format)
    img_io.seek(0)
    img_base64 = base64.b64encode(img_io.getvalue()).decode('utf-8')

    print("Processed image successfully, sending back to client")
    return jsonify({"image": img_base64})

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=5000)