Spaces:
Runtime error
Runtime error
File size: 9,466 Bytes
f59de63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import gradio as gr
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse, JSONResponse
import os
import random
import torch
from PIL import Image, ImageOps
from io import BytesIO
import base64
import json
import logging
import gc
from transformers import BlipProcessor, BlipForConditionalGeneration
import torchvision.transforms.functional as F
from src.pix2pix_turbo import Pix2Pix_Turbo # Aseg煤rate de que esta ruta de importaci贸n sea correcta
from fastapi.middleware.cors import CORSMiddleware
# Configuraci贸n de logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Cargar la configuraci贸n desde config.json
logging.info("Cargando configuraci贸n desde config.json...")
with open('config.json', 'r') as config_file:
config = json.load(config_file)
# Variables Globales
OUTPUT_PATH = "result.jpg" # La imagen resultante se guardar谩 como result.jpg
INPUT_PATH = "draw.jpg" # La imagen recibida se guardar谩 como draw.jpg
STYLE_LIST = config["style_list"]
STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
DEVICE = config["model_params"]["device"]
DEFAULT_SEED = config["model_params"]["default_seed"]
VAL_R_DEFAULT = config["model_params"]["val_r_default"]
CANVAS_WIDTH = config["canvas"]["width"]
CANVAS_HEIGHT = config["canvas"]["height"]
PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"]
logging.info(f"Dispositivo seleccionado: {DEVICE}")
logging.info(f"Modelo Pix2Pix cargado: {PIX2PIX_MODEL_NAME}")
# Cargar y configurar los modelos
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE)
pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME)
def print_welcome_message(app):
for route in app.routes:
full_url = f"http://0.0.0.0:{app.server_port}{route.path}"
if hasattr(route, 'methods'):
route_info = f"URL: {full_url}, Methods: {route.methods}"
else:
route_info = f"URL: {full_url}, Methods: Not applicable"
print(route_info)
def clear_memory():
"""Limpiar la memoria CUDA y recolectar basura si es necesario."""
logging.debug("Limpiando la memoria CUDA y recolectando basura...")
torch.cuda.empty_cache()
gc.collect()
def generate_prompt_from_sketch(image: Image) -> str:
"""Generar un texto a partir del sketch usando BLIP."""
logging.debug("Generando el prompt desde el sketch...")
image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS)
inputs = processor(image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = blip_model.generate(**inputs, max_new_tokens=50)
text_prompt = processor.decode(out[0], skip_special_tokens=True)
logging.debug(f"Prompt generado: {text_prompt}")
recognized_items = [item.strip() for item in text_prompt.split(', ') if item.strip()]
random_prefix = random.choice(config["random_values"])
prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}"
logging.debug(f"Prompt final: {prompt}")
return prompt
def normalize_image(image, range_from=(-1, 1)):
"""Normalizar la imagen de entrada."""
logging.debug("Normalizando la imagen...")
image_t = F.to_tensor(image)
if range_from == (-1, 1):
image_t = image_t * 2 - 1
return image_t
def process_sketch(sketch_image, prompt=None, style_name=None, seed=DEFAULT_SEED, val_r=VAL_R_DEFAULT):
"""Procesar el sketch y generar una imagen usando el modelo Pix2Pix."""
logging.debug("Iniciando el procesamiento del sketch...")
if not prompt:
logging.info("Prompt no proporcionado, generando uno a partir del sketch...")
prompt = generate_prompt_from_sketch(sketch_image)
prompt_template = STYLES.get(style_name, STYLES[config["default_style_name"]])
prompt = prompt_template.replace("{prompt}", prompt)
sketch_image = sketch_image.convert("RGB")
sketch_tensor = normalize_image(sketch_image, range_from=(-1, 1))
#image_t = F.to_tensor(sketch_image).unsqueeze(0).to(torch.float32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#clear_memory()
try:
with torch.no_grad():
logging.info("Iniciando la inferencia del modelo Pix2Pix...")
c_t = sketch_tensor.unsqueeze(0).to(DEVICE).float()
torch.manual_seed(seed)
B, C, H, W = c_t.shape
#noise = torch.randn((1, 4, c_t.shape[2] // 8, c_t.shape[3] // 8), device=c_t.device)
noise = torch.randn((1, 4, H // 8, W // 8), device=device)
with torch.cuda.amp.autocast():
output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
output_pil.save(OUTPUT_PATH)
logging.info("Imagen generada y guardada correctamente.")
return output_pil
except RuntimeError as e:
logging.error(f"Error de runtime durante la inferencia: {str(e)}")
if "CUDA out of memory" in str(e):
logging.warning("Error de memoria CUDA. Cambiando a CPU.")
with torch.no_grad():
c_t = c_t.cpu()
noise = noise.cpu()
pix2pix_model_cpu = pix2pix_model.cpu()
output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
output_pil.save(OUTPUT_PATH)
logging.info("Inferencia realizada en CPU y la imagen fue generada y guardada.")
return output_pil
else:
raise e
def get_image_as_base64(image_path):
"""Convertir una imagen a cadena base64."""
logging.debug(f"Convirtiendo la imagen {image_path} a base64...")
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return encoded_string
# Crear una instancia de FastAPI
app = FastAPI()
# Configurar el middleware de CORS
logging.info("Configurando el middleware de CORS...")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Permitir todas las or铆genes. Puedes especificar or铆genes espec铆ficos en lugar de "*"
allow_credentials=True,
allow_methods=["*"], # Permitir todos los m茅todos HTTP (GET, POST, etc.)
allow_headers=["*"], # Permitir todos los encabezados
)
@app.get("/")
def read_image():
"""
Retorna el archivo 'result.jpg' si existe, o un mensaje de error si no.
"""
logging.info("Petici贸n GET recibida en '/'. Verificando si existe una imagen procesada...")
if os.path.exists(OUTPUT_PATH):
logging.info(f"Retornando la imagen {OUTPUT_PATH}.")
return FileResponse(OUTPUT_PATH, media_type='image/jpeg', filename="result.jpg")
else:
logging.warning("No se ha procesado ninguna imagen a煤n.")
return {"error": "No image processed yet."}
@app.get("/image_base64")
def get_image_base64():
"""
Retorna la imagen procesada como una cadena en formato base64 dentro de un objeto JSON.
"""
if os.path.exists(OUTPUT_PATH):
# Convertir la imagen en base64
base64_str = get_image_as_base64(OUTPUT_PATH)
logging.info(f"Imagen convertida a base64 y enviada como respuesta JSON.")
return JSONResponse(content={"image_base64": base64_str})
else:
logging.error("No se encontr贸 ninguna imagen procesada.")
return JSONResponse(content={"error": "No image processed yet."})
@app.post("/process_image")
async def process_image(file: UploadFile = File(...)):
"""
Procesa la imagen enviada y devuelve la imagen generada.
"""
logging.info("Petici贸n POST recibida en '/process_image'. Procesando imagen...")
image = Image.open(BytesIO(await file.read()))
# Guardar la imagen recibida como 'draw.png'
image.save("draw.png") # Guardar en formato PNG
logging.info("Imagen recibida guardada como 'draw.png'.")
# Procesar la imagen y guardar el resultado
processed_image = process_sketch(image)
processed_image.save(OUTPUT_PATH) # Guardar la imagen procesada como 'result.jpg'
logging.info("Imagen procesada y guardada correctamente.")
return {"status": f"Image processed and saved as {OUTPUT_PATH}"}
# Montar la aplicaci贸n de Gradio en FastAPI
logging.info("Montando la interfaz de Gradio en la aplicaci贸n FastAPI...")
interface = gr.Interface(
fn=process_sketch,
inputs=[gr.Image(source="upload", type="pil", label="Sketch Image"),
gr.Textbox(label="Prompt (optional)"),
gr.Dropdown(choices=list(STYLES.keys()), label="Style"),
gr.Slider(minimum=0, maximum=100, step=1, value=DEFAULT_SEED, label="Seed"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=VAL_R_DEFAULT, label="Sketch Guidance")],
outputs=gr.Image(label="Generated Image"),
title="Sketch to Image HD",
description="Upload a sketch to generate an image."
)
app = gr.mount_gradio_app(app, interface, path="/gradio")
if __name__ == "__main__":
logging.info("Iniciando la aplicaci贸n en Uvicorn...")
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
print_welcome_message(interface.app)
|