Spaces:
Runtime error
Runtime error
import gradio as gr | |
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import FileResponse, JSONResponse | |
import os | |
import random | |
import torch | |
from PIL import Image, ImageOps | |
from io import BytesIO | |
import base64 | |
import json | |
import logging | |
import gc | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
import torchvision.transforms.functional as F | |
from src.pix2pix_turbo import Pix2Pix_Turbo # Aseg煤rate de que esta ruta de importaci贸n sea correcta | |
from fastapi.middleware.cors import CORSMiddleware | |
# Configuraci贸n de logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Cargar la configuraci贸n desde config.json | |
logging.info("Cargando configuraci贸n desde config.json...") | |
with open('config.json', 'r') as config_file: | |
config = json.load(config_file) | |
# Variables Globales | |
OUTPUT_PATH = "result.jpg" # La imagen resultante se guardar谩 como result.jpg | |
INPUT_PATH = "draw.jpg" # La imagen recibida se guardar谩 como draw.jpg | |
STYLE_LIST = config["style_list"] | |
STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST} | |
DEVICE = config["model_params"]["device"] | |
DEFAULT_SEED = config["model_params"]["default_seed"] | |
VAL_R_DEFAULT = config["model_params"]["val_r_default"] | |
CANVAS_WIDTH = config["canvas"]["width"] | |
CANVAS_HEIGHT = config["canvas"]["height"] | |
PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"] | |
logging.info(f"Dispositivo seleccionado: {DEVICE}") | |
logging.info(f"Modelo Pix2Pix cargado: {PIX2PIX_MODEL_NAME}") | |
# Cargar y configurar los modelos | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE) | |
pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME) | |
def print_welcome_message(app): | |
for route in app.routes: | |
full_url = f"http://0.0.0.0:{app.server_port}{route.path}" | |
if hasattr(route, 'methods'): | |
route_info = f"URL: {full_url}, Methods: {route.methods}" | |
else: | |
route_info = f"URL: {full_url}, Methods: Not applicable" | |
print(route_info) | |
def clear_memory(): | |
"""Limpiar la memoria CUDA y recolectar basura si es necesario.""" | |
logging.debug("Limpiando la memoria CUDA y recolectando basura...") | |
torch.cuda.empty_cache() | |
gc.collect() | |
def generate_prompt_from_sketch(image: Image) -> str: | |
"""Generar un texto a partir del sketch usando BLIP.""" | |
logging.debug("Generando el prompt desde el sketch...") | |
image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS) | |
inputs = processor(image, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
out = blip_model.generate(**inputs, max_new_tokens=50) | |
text_prompt = processor.decode(out[0], skip_special_tokens=True) | |
logging.debug(f"Prompt generado: {text_prompt}") | |
recognized_items = [item.strip() for item in text_prompt.split(', ') if item.strip()] | |
random_prefix = random.choice(config["random_values"]) | |
prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}" | |
logging.debug(f"Prompt final: {prompt}") | |
return prompt | |
def normalize_image(image, range_from=(-1, 1)): | |
"""Normalizar la imagen de entrada.""" | |
logging.debug("Normalizando la imagen...") | |
image_t = F.to_tensor(image) | |
if range_from == (-1, 1): | |
image_t = image_t * 2 - 1 | |
return image_t | |
def process_sketch(sketch_image, prompt=None, style_name=None, seed=DEFAULT_SEED, val_r=VAL_R_DEFAULT): | |
"""Procesar el sketch y generar una imagen usando el modelo Pix2Pix.""" | |
logging.debug("Iniciando el procesamiento del sketch...") | |
if not prompt: | |
logging.info("Prompt no proporcionado, generando uno a partir del sketch...") | |
prompt = generate_prompt_from_sketch(sketch_image) | |
prompt_template = STYLES.get(style_name, STYLES[config["default_style_name"]]) | |
prompt = prompt_template.replace("{prompt}", prompt) | |
sketch_image = sketch_image.convert("RGB") | |
sketch_tensor = normalize_image(sketch_image, range_from=(-1, 1)) | |
#image_t = F.to_tensor(sketch_image).unsqueeze(0).to(torch.float32) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
#clear_memory() | |
try: | |
with torch.no_grad(): | |
logging.info("Iniciando la inferencia del modelo Pix2Pix...") | |
c_t = sketch_tensor.unsqueeze(0).to(DEVICE).float() | |
torch.manual_seed(seed) | |
B, C, H, W = c_t.shape | |
#noise = torch.randn((1, 4, c_t.shape[2] // 8, c_t.shape[3] // 8), device=c_t.device) | |
noise = torch.randn((1, 4, H // 8, W // 8), device=device) | |
with torch.cuda.amp.autocast(): | |
output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) | |
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) | |
output_pil.save(OUTPUT_PATH) | |
logging.info("Imagen generada y guardada correctamente.") | |
return output_pil | |
except RuntimeError as e: | |
logging.error(f"Error de runtime durante la inferencia: {str(e)}") | |
if "CUDA out of memory" in str(e): | |
logging.warning("Error de memoria CUDA. Cambiando a CPU.") | |
with torch.no_grad(): | |
c_t = c_t.cpu() | |
noise = noise.cpu() | |
pix2pix_model_cpu = pix2pix_model.cpu() | |
output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) | |
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) | |
output_pil.save(OUTPUT_PATH) | |
logging.info("Inferencia realizada en CPU y la imagen fue generada y guardada.") | |
return output_pil | |
else: | |
raise e | |
def get_image_as_base64(image_path): | |
"""Convertir una imagen a cadena base64.""" | |
logging.debug(f"Convirtiendo la imagen {image_path} a base64...") | |
with open(image_path, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode("utf-8") | |
return encoded_string | |
# Crear una instancia de FastAPI | |
app = FastAPI() | |
# Configurar el middleware de CORS | |
logging.info("Configurando el middleware de CORS...") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Permitir todas las or铆genes. Puedes especificar or铆genes espec铆ficos en lugar de "*" | |
allow_credentials=True, | |
allow_methods=["*"], # Permitir todos los m茅todos HTTP (GET, POST, etc.) | |
allow_headers=["*"], # Permitir todos los encabezados | |
) | |
def read_image(): | |
""" | |
Retorna el archivo 'result.jpg' si existe, o un mensaje de error si no. | |
""" | |
logging.info("Petici贸n GET recibida en '/'. Verificando si existe una imagen procesada...") | |
if os.path.exists(OUTPUT_PATH): | |
logging.info(f"Retornando la imagen {OUTPUT_PATH}.") | |
return FileResponse(OUTPUT_PATH, media_type='image/jpeg', filename="result.jpg") | |
else: | |
logging.warning("No se ha procesado ninguna imagen a煤n.") | |
return {"error": "No image processed yet."} | |
def get_image_base64(): | |
""" | |
Retorna la imagen procesada como una cadena en formato base64 dentro de un objeto JSON. | |
""" | |
if os.path.exists(OUTPUT_PATH): | |
# Convertir la imagen en base64 | |
base64_str = get_image_as_base64(OUTPUT_PATH) | |
logging.info(f"Imagen convertida a base64 y enviada como respuesta JSON.") | |
return JSONResponse(content={"image_base64": base64_str}) | |
else: | |
logging.error("No se encontr贸 ninguna imagen procesada.") | |
return JSONResponse(content={"error": "No image processed yet."}) | |
async def process_image(file: UploadFile = File(...)): | |
""" | |
Procesa la imagen enviada y devuelve la imagen generada. | |
""" | |
logging.info("Petici贸n POST recibida en '/process_image'. Procesando imagen...") | |
image = Image.open(BytesIO(await file.read())) | |
# Guardar la imagen recibida como 'draw.png' | |
image.save("draw.png") # Guardar en formato PNG | |
logging.info("Imagen recibida guardada como 'draw.png'.") | |
# Procesar la imagen y guardar el resultado | |
processed_image = process_sketch(image) | |
processed_image.save(OUTPUT_PATH) # Guardar la imagen procesada como 'result.jpg' | |
logging.info("Imagen procesada y guardada correctamente.") | |
return {"status": f"Image processed and saved as {OUTPUT_PATH}"} | |
# Montar la aplicaci贸n de Gradio en FastAPI | |
logging.info("Montando la interfaz de Gradio en la aplicaci贸n FastAPI...") | |
interface = gr.Interface( | |
fn=process_sketch, | |
inputs=[gr.Image(source="upload", type="pil", label="Sketch Image"), | |
gr.Textbox(label="Prompt (optional)"), | |
gr.Dropdown(choices=list(STYLES.keys()), label="Style"), | |
gr.Slider(minimum=0, maximum=100, step=1, value=DEFAULT_SEED, label="Seed"), | |
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=VAL_R_DEFAULT, label="Sketch Guidance")], | |
outputs=gr.Image(label="Generated Image"), | |
title="Sketch to Image HD", | |
description="Upload a sketch to generate an image." | |
) | |
app = gr.mount_gradio_app(app, interface, path="/gradio") | |
if __name__ == "__main__": | |
logging.info("Iniciando la aplicaci贸n en Uvicorn...") | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |
print_welcome_message(interface.app) | |