Spaces:
Paused
Paused
from flask import Flask, request, jsonify | |
from io import BytesIO | |
import base64 | |
from PIL import Image | |
import torch | |
import torchvision.transforms.functional as F | |
from torch.cuda.amp import autocast | |
from flask_cors import CORS # Import CORS | |
from src.pix2pix_turbo import Pix2Pix_Turbo | |
app = Flask(__name__) | |
CORS(app) # Enable CORS for all routes | |
# Configuration Variables | |
model_type = "sketch_to_image_stochastic" | |
output_format = "PNG" | |
desired_size = (768, 768) # Increased resolution for better quality | |
# Load the model when the app starts | |
print("Loading model...") | |
model = Pix2Pix_Turbo(model_type) | |
print("Model loaded successfully.") | |
# Example styles list (update this with your actual styles) | |
style_list = [ | |
{"name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"}, | |
{"name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting"}, | |
{"name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed"}, | |
{"name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed"}, | |
{"name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed"}, | |
{"name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics"}, | |
{"name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"}, | |
{"name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"}, | |
{"name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style"}, | |
] | |
styles = {k["name"]: k["prompt"] for k in style_list} | |
def process_image(image, prompt, prompt_template, style_name, seed, val_r): | |
image = image.convert("RGB") | |
# Convert image to tensor and threshold, then convert to float | |
image_t = F.to_tensor(image) > 0.5 | |
image_t = image_t.float() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
with torch.no_grad(), autocast(): | |
# Move the tensor to the appropriate device | |
c_t = image_t.unsqueeze(0).to(device).float() | |
torch.manual_seed(seed) | |
B, C, H, W = c_t.shape | |
noise = torch.randn((1, 4, H // 8, W // 8), device=device) # Ensure noise is on the same device | |
logging.debug("Calling Pix2Pix model... ct: {}, prompt: {}, deterministic: False, r: {}, noise_map: {}".format(c_t.shape, prompt, val_r, noise.shape)) | |
# Pass through the model | |
output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) | |
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) | |
return output_pil | |
def process_image_route(): | |
data = request.get_json() | |
# Debugging: Print the raw received data | |
print("Received JSON data:", data) | |
if not data or 'image' not in data: | |
print("Error: No image provided") | |
return jsonify({"error": "No image provided"}), 400 | |
# Decode the base64 image (remove the prefix 'data:image/png;base64,' if present) | |
image_data = data['image'] | |
print("Received base64 image data (truncated):", image_data[:100]) # Print first 100 chars of base64 data | |
if image_data.startswith('data:image/png;base64,'): | |
image_data = image_data.split(",")[1] | |
try: | |
image_bytes = base64.b64decode(image_data) | |
image = Image.open(BytesIO(image_bytes)) | |
except Exception as e: | |
print("Error decoding base64 image:", str(e)) | |
return jsonify({"error": "Invalid image data"}), 400 | |
# Retrieve other parameters | |
prompt = data.get('prompt', 'a cat') | |
style_name = data.get('style_name', 'Fantasy art').strip() # Strip any leading/trailing whitespace | |
seed = int(data.get('seed', 42)) | |
val_r = float(data.get('val_r', 0.8)) | |
# Debug: print available styles | |
print(f"Available styles: {list(styles.keys())}") | |
print(f"Received style name: {style_name}") | |
# Case-insensitive lookup | |
style_name = next((key for key in styles if key.lower() == style_name.lower()), None) | |
if not style_name: | |
print(f"Error: Style '{data.get('style_name')}' not found") | |
return jsonify({"error": f"Style '{data.get('style_name')}' not found"}), 400 | |
prompt_template = styles[style_name] | |
print(f"Using style: {style_name} with prompt: {prompt}") | |
# Process the image | |
try: | |
processed_image = process_image(image, prompt, prompt_template, style_name, seed, val_r) | |
except Exception as e: | |
print("Error processing image:", str(e)) | |
return jsonify({"error": "Failed to process image"}), 500 | |
# Convert the processed image to base64 | |
img_io = BytesIO() | |
processed_image.save(img_io, format=output_format) | |
img_io.seek(0) | |
img_base64 = base64.b64encode(img_io.getvalue()).decode('utf-8') | |
print("Processed image successfully, sending back to client") | |
return jsonify({"image": img_base64}) | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=5000) | |