img2img-turbo / i2i_sk.py
Inmental's picture
Upload folder using huggingface_hub
343e5a8 verified
from flask import Flask, request, jsonify
from io import BytesIO
import base64
from PIL import Image
import torch
import torchvision.transforms.functional as F
from torch.cuda.amp import autocast
from flask_cors import CORS # Import CORS
from src.pix2pix_turbo import Pix2Pix_Turbo
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Configuration Variables
model_type = "sketch_to_image_stochastic"
output_format = "PNG"
desired_size = (768, 768) # Increased resolution for better quality
# Load the model when the app starts
print("Loading model...")
model = Pix2Pix_Turbo(model_type)
print("Model loaded successfully.")
# Example styles list (update this with your actual styles)
style_list = [
{"name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"},
{"name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting"},
{"name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed"},
{"name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed"},
{"name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed"},
{"name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics"},
{"name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"},
{"name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"},
{"name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style"},
]
styles = {k["name"]: k["prompt"] for k in style_list}
def process_image(image, prompt, prompt_template, style_name, seed, val_r):
image = image.convert("RGB")
# Convert image to tensor and threshold, then convert to float
image_t = F.to_tensor(image) > 0.5
image_t = image_t.float()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad(), autocast():
# Move the tensor to the appropriate device
c_t = image_t.unsqueeze(0).to(device).float()
torch.manual_seed(seed)
B, C, H, W = c_t.shape
noise = torch.randn((1, 4, H // 8, W // 8), device=device) # Ensure noise is on the same device
logging.debug("Calling Pix2Pix model... ct: {}, prompt: {}, deterministic: False, r: {}, noise_map: {}".format(c_t.shape, prompt, val_r, noise.shape))
# Pass through the model
output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
return output_pil
@app.route('/process-image', methods=['POST'])
def process_image_route():
data = request.get_json()
# Debugging: Print the raw received data
print("Received JSON data:", data)
if not data or 'image' not in data:
print("Error: No image provided")
return jsonify({"error": "No image provided"}), 400
# Decode the base64 image (remove the prefix 'data:image/png;base64,' if present)
image_data = data['image']
print("Received base64 image data (truncated):", image_data[:100]) # Print first 100 chars of base64 data
if image_data.startswith('data:image/png;base64,'):
image_data = image_data.split(",")[1]
try:
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
except Exception as e:
print("Error decoding base64 image:", str(e))
return jsonify({"error": "Invalid image data"}), 400
# Retrieve other parameters
prompt = data.get('prompt', 'a cat')
style_name = data.get('style_name', 'Fantasy art').strip() # Strip any leading/trailing whitespace
seed = int(data.get('seed', 42))
val_r = float(data.get('val_r', 0.8))
# Debug: print available styles
print(f"Available styles: {list(styles.keys())}")
print(f"Received style name: {style_name}")
# Case-insensitive lookup
style_name = next((key for key in styles if key.lower() == style_name.lower()), None)
if not style_name:
print(f"Error: Style '{data.get('style_name')}' not found")
return jsonify({"error": f"Style '{data.get('style_name')}' not found"}), 400
prompt_template = styles[style_name]
print(f"Using style: {style_name} with prompt: {prompt}")
# Process the image
try:
processed_image = process_image(image, prompt, prompt_template, style_name, seed, val_r)
except Exception as e:
print("Error processing image:", str(e))
return jsonify({"error": "Failed to process image"}), 500
# Convert the processed image to base64
img_io = BytesIO()
processed_image.save(img_io, format=output_format)
img_io.seek(0)
img_base64 = base64.b64encode(img_io.getvalue()).decode('utf-8')
print("Processed image successfully, sending back to client")
return jsonify({"image": img_base64})
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000)