img2img-turbo-sketch / gradio_sketch2imagehd.py
Inmental's picture
Upload 4 files
f59de63 verified
import gradio as gr
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse, JSONResponse
import os
import random
import torch
from PIL import Image, ImageOps
from io import BytesIO
import base64
import json
import logging
import gc
from transformers import BlipProcessor, BlipForConditionalGeneration
import torchvision.transforms.functional as F
from src.pix2pix_turbo import Pix2Pix_Turbo # Aseg煤rate de que esta ruta de importaci贸n sea correcta
from fastapi.middleware.cors import CORSMiddleware
# Configuraci贸n de logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Cargar la configuraci贸n desde config.json
logging.info("Cargando configuraci贸n desde config.json...")
with open('config.json', 'r') as config_file:
config = json.load(config_file)
# Variables Globales
OUTPUT_PATH = "result.jpg" # La imagen resultante se guardar谩 como result.jpg
INPUT_PATH = "draw.jpg" # La imagen recibida se guardar谩 como draw.jpg
STYLE_LIST = config["style_list"]
STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
DEVICE = config["model_params"]["device"]
DEFAULT_SEED = config["model_params"]["default_seed"]
VAL_R_DEFAULT = config["model_params"]["val_r_default"]
CANVAS_WIDTH = config["canvas"]["width"]
CANVAS_HEIGHT = config["canvas"]["height"]
PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"]
logging.info(f"Dispositivo seleccionado: {DEVICE}")
logging.info(f"Modelo Pix2Pix cargado: {PIX2PIX_MODEL_NAME}")
# Cargar y configurar los modelos
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE)
pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME)
def print_welcome_message(app):
for route in app.routes:
full_url = f"http://0.0.0.0:{app.server_port}{route.path}"
if hasattr(route, 'methods'):
route_info = f"URL: {full_url}, Methods: {route.methods}"
else:
route_info = f"URL: {full_url}, Methods: Not applicable"
print(route_info)
def clear_memory():
"""Limpiar la memoria CUDA y recolectar basura si es necesario."""
logging.debug("Limpiando la memoria CUDA y recolectando basura...")
torch.cuda.empty_cache()
gc.collect()
def generate_prompt_from_sketch(image: Image) -> str:
"""Generar un texto a partir del sketch usando BLIP."""
logging.debug("Generando el prompt desde el sketch...")
image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS)
inputs = processor(image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = blip_model.generate(**inputs, max_new_tokens=50)
text_prompt = processor.decode(out[0], skip_special_tokens=True)
logging.debug(f"Prompt generado: {text_prompt}")
recognized_items = [item.strip() for item in text_prompt.split(', ') if item.strip()]
random_prefix = random.choice(config["random_values"])
prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}"
logging.debug(f"Prompt final: {prompt}")
return prompt
def normalize_image(image, range_from=(-1, 1)):
"""Normalizar la imagen de entrada."""
logging.debug("Normalizando la imagen...")
image_t = F.to_tensor(image)
if range_from == (-1, 1):
image_t = image_t * 2 - 1
return image_t
def process_sketch(sketch_image, prompt=None, style_name=None, seed=DEFAULT_SEED, val_r=VAL_R_DEFAULT):
"""Procesar el sketch y generar una imagen usando el modelo Pix2Pix."""
logging.debug("Iniciando el procesamiento del sketch...")
if not prompt:
logging.info("Prompt no proporcionado, generando uno a partir del sketch...")
prompt = generate_prompt_from_sketch(sketch_image)
prompt_template = STYLES.get(style_name, STYLES[config["default_style_name"]])
prompt = prompt_template.replace("{prompt}", prompt)
sketch_image = sketch_image.convert("RGB")
sketch_tensor = normalize_image(sketch_image, range_from=(-1, 1))
#image_t = F.to_tensor(sketch_image).unsqueeze(0).to(torch.float32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#clear_memory()
try:
with torch.no_grad():
logging.info("Iniciando la inferencia del modelo Pix2Pix...")
c_t = sketch_tensor.unsqueeze(0).to(DEVICE).float()
torch.manual_seed(seed)
B, C, H, W = c_t.shape
#noise = torch.randn((1, 4, c_t.shape[2] // 8, c_t.shape[3] // 8), device=c_t.device)
noise = torch.randn((1, 4, H // 8, W // 8), device=device)
with torch.cuda.amp.autocast():
output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
output_pil.save(OUTPUT_PATH)
logging.info("Imagen generada y guardada correctamente.")
return output_pil
except RuntimeError as e:
logging.error(f"Error de runtime durante la inferencia: {str(e)}")
if "CUDA out of memory" in str(e):
logging.warning("Error de memoria CUDA. Cambiando a CPU.")
with torch.no_grad():
c_t = c_t.cpu()
noise = noise.cpu()
pix2pix_model_cpu = pix2pix_model.cpu()
output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
output_pil.save(OUTPUT_PATH)
logging.info("Inferencia realizada en CPU y la imagen fue generada y guardada.")
return output_pil
else:
raise e
def get_image_as_base64(image_path):
"""Convertir una imagen a cadena base64."""
logging.debug(f"Convirtiendo la imagen {image_path} a base64...")
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return encoded_string
# Crear una instancia de FastAPI
app = FastAPI()
# Configurar el middleware de CORS
logging.info("Configurando el middleware de CORS...")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Permitir todas las or铆genes. Puedes especificar or铆genes espec铆ficos en lugar de "*"
allow_credentials=True,
allow_methods=["*"], # Permitir todos los m茅todos HTTP (GET, POST, etc.)
allow_headers=["*"], # Permitir todos los encabezados
)
@app.get("/")
def read_image():
"""
Retorna el archivo 'result.jpg' si existe, o un mensaje de error si no.
"""
logging.info("Petici贸n GET recibida en '/'. Verificando si existe una imagen procesada...")
if os.path.exists(OUTPUT_PATH):
logging.info(f"Retornando la imagen {OUTPUT_PATH}.")
return FileResponse(OUTPUT_PATH, media_type='image/jpeg', filename="result.jpg")
else:
logging.warning("No se ha procesado ninguna imagen a煤n.")
return {"error": "No image processed yet."}
@app.get("/image_base64")
def get_image_base64():
"""
Retorna la imagen procesada como una cadena en formato base64 dentro de un objeto JSON.
"""
if os.path.exists(OUTPUT_PATH):
# Convertir la imagen en base64
base64_str = get_image_as_base64(OUTPUT_PATH)
logging.info(f"Imagen convertida a base64 y enviada como respuesta JSON.")
return JSONResponse(content={"image_base64": base64_str})
else:
logging.error("No se encontr贸 ninguna imagen procesada.")
return JSONResponse(content={"error": "No image processed yet."})
@app.post("/process_image")
async def process_image(file: UploadFile = File(...)):
"""
Procesa la imagen enviada y devuelve la imagen generada.
"""
logging.info("Petici贸n POST recibida en '/process_image'. Procesando imagen...")
image = Image.open(BytesIO(await file.read()))
# Guardar la imagen recibida como 'draw.png'
image.save("draw.png") # Guardar en formato PNG
logging.info("Imagen recibida guardada como 'draw.png'.")
# Procesar la imagen y guardar el resultado
processed_image = process_sketch(image)
processed_image.save(OUTPUT_PATH) # Guardar la imagen procesada como 'result.jpg'
logging.info("Imagen procesada y guardada correctamente.")
return {"status": f"Image processed and saved as {OUTPUT_PATH}"}
# Montar la aplicaci贸n de Gradio en FastAPI
logging.info("Montando la interfaz de Gradio en la aplicaci贸n FastAPI...")
interface = gr.Interface(
fn=process_sketch,
inputs=[gr.Image(source="upload", type="pil", label="Sketch Image"),
gr.Textbox(label="Prompt (optional)"),
gr.Dropdown(choices=list(STYLES.keys()), label="Style"),
gr.Slider(minimum=0, maximum=100, step=1, value=DEFAULT_SEED, label="Seed"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=VAL_R_DEFAULT, label="Sketch Guidance")],
outputs=gr.Image(label="Generated Image"),
title="Sketch to Image HD",
description="Upload a sketch to generate an image."
)
app = gr.mount_gradio_app(app, interface, path="/gradio")
if __name__ == "__main__":
logging.info("Iniciando la aplicaci贸n en Uvicorn...")
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
print_welcome_message(interface.app)