File size: 9,466 Bytes
f59de63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import gradio as gr
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse, JSONResponse
import os
import random
import torch
from PIL import Image, ImageOps
from io import BytesIO
import base64
import json
import logging
import gc
from transformers import BlipProcessor, BlipForConditionalGeneration
import torchvision.transforms.functional as F
from src.pix2pix_turbo import Pix2Pix_Turbo  # Aseg煤rate de que esta ruta de importaci贸n sea correcta
from fastapi.middleware.cors import CORSMiddleware

# Configuraci贸n de logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Cargar la configuraci贸n desde config.json
logging.info("Cargando configuraci贸n desde config.json...")
with open('config.json', 'r') as config_file:
    config = json.load(config_file)

# Variables Globales
OUTPUT_PATH = "result.jpg"  # La imagen resultante se guardar谩 como result.jpg
INPUT_PATH = "draw.jpg"     # La imagen recibida se guardar谩 como draw.jpg
STYLE_LIST = config["style_list"]
STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
DEVICE = config["model_params"]["device"]
DEFAULT_SEED = config["model_params"]["default_seed"]
VAL_R_DEFAULT = config["model_params"]["val_r_default"]
CANVAS_WIDTH = config["canvas"]["width"]
CANVAS_HEIGHT = config["canvas"]["height"]
PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"]

logging.info(f"Dispositivo seleccionado: {DEVICE}")
logging.info(f"Modelo Pix2Pix cargado: {PIX2PIX_MODEL_NAME}")

# Cargar y configurar los modelos
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE)
pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME)

def print_welcome_message(app):
    for route in app.routes:
        full_url = f"http://0.0.0.0:{app.server_port}{route.path}"
        if hasattr(route, 'methods'):
            route_info = f"URL: {full_url}, Methods: {route.methods}"
        else:
            route_info = f"URL: {full_url}, Methods: Not applicable"
        print(route_info)

def clear_memory():
    """Limpiar la memoria CUDA y recolectar basura si es necesario."""
    logging.debug("Limpiando la memoria CUDA y recolectando basura...")
    torch.cuda.empty_cache()
    gc.collect()

def generate_prompt_from_sketch(image: Image) -> str:
    """Generar un texto a partir del sketch usando BLIP."""
    logging.debug("Generando el prompt desde el sketch...")
    image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS)
    inputs = processor(image, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        out = blip_model.generate(**inputs, max_new_tokens=50)
    text_prompt = processor.decode(out[0], skip_special_tokens=True)
    logging.debug(f"Prompt generado: {text_prompt}")

    recognized_items = [item.strip() for item in text_prompt.split(', ') if item.strip()]
    random_prefix = random.choice(config["random_values"])
    prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}"
    logging.debug(f"Prompt final: {prompt}")
    return prompt

def normalize_image(image, range_from=(-1, 1)):
    """Normalizar la imagen de entrada."""
    logging.debug("Normalizando la imagen...")
    image_t = F.to_tensor(image)
    if range_from == (-1, 1):
        image_t = image_t * 2 - 1
    return image_t

def process_sketch(sketch_image, prompt=None, style_name=None, seed=DEFAULT_SEED, val_r=VAL_R_DEFAULT):
    """Procesar el sketch y generar una imagen usando el modelo Pix2Pix."""
    logging.debug("Iniciando el procesamiento del sketch...")

    if not prompt:
        logging.info("Prompt no proporcionado, generando uno a partir del sketch...")
        prompt = generate_prompt_from_sketch(sketch_image)

    prompt_template = STYLES.get(style_name, STYLES[config["default_style_name"]])
    prompt = prompt_template.replace("{prompt}", prompt)
    sketch_image = sketch_image.convert("RGB")
    sketch_tensor = normalize_image(sketch_image, range_from=(-1, 1))

    #image_t = F.to_tensor(sketch_image).unsqueeze(0).to(torch.float32)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #clear_memory()

    try:
        with torch.no_grad():
            logging.info("Iniciando la inferencia del modelo Pix2Pix...")
            c_t = sketch_tensor.unsqueeze(0).to(DEVICE).float()
            torch.manual_seed(seed)
            B, C, H, W = c_t.shape
            #noise = torch.randn((1, 4, c_t.shape[2] // 8, c_t.shape[3] // 8), device=c_t.device)
            noise = torch.randn((1, 4, H // 8, W // 8), device=device)
            with torch.cuda.amp.autocast():
                output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)

            output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
            output_pil.save(OUTPUT_PATH)
            logging.info("Imagen generada y guardada correctamente.")
            return output_pil

    except RuntimeError as e:
        logging.error(f"Error de runtime durante la inferencia: {str(e)}")
        if "CUDA out of memory" in str(e):
            logging.warning("Error de memoria CUDA. Cambiando a CPU.")
            with torch.no_grad():
                c_t = c_t.cpu()
                noise = noise.cpu()
                pix2pix_model_cpu = pix2pix_model.cpu()
                output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
                output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
                output_pil.save(OUTPUT_PATH)
                logging.info("Inferencia realizada en CPU y la imagen fue generada y guardada.")
                return output_pil
        else:
            raise e

def get_image_as_base64(image_path):
    """Convertir una imagen a cadena base64."""
    logging.debug(f"Convirtiendo la imagen {image_path} a base64...")
    with open(image_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
    return encoded_string

# Crear una instancia de FastAPI
app = FastAPI()

# Configurar el middleware de CORS
logging.info("Configurando el middleware de CORS...")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Permitir todas las or铆genes. Puedes especificar or铆genes espec铆ficos en lugar de "*"
    allow_credentials=True,
    allow_methods=["*"],  # Permitir todos los m茅todos HTTP (GET, POST, etc.)
    allow_headers=["*"],  # Permitir todos los encabezados
)

@app.get("/")
def read_image():
    """
    Retorna el archivo 'result.jpg' si existe, o un mensaje de error si no.
    """
    logging.info("Petici贸n GET recibida en '/'. Verificando si existe una imagen procesada...")
    if os.path.exists(OUTPUT_PATH):
        logging.info(f"Retornando la imagen {OUTPUT_PATH}.")
        return FileResponse(OUTPUT_PATH, media_type='image/jpeg', filename="result.jpg")
    else:
        logging.warning("No se ha procesado ninguna imagen a煤n.")
        return {"error": "No image processed yet."}

@app.get("/image_base64")
def get_image_base64():
    """
    Retorna la imagen procesada como una cadena en formato base64 dentro de un objeto JSON.
    """
    if os.path.exists(OUTPUT_PATH):
        # Convertir la imagen en base64
        base64_str = get_image_as_base64(OUTPUT_PATH)
        logging.info(f"Imagen convertida a base64 y enviada como respuesta JSON.")
        return JSONResponse(content={"image_base64": base64_str})
    else:
        logging.error("No se encontr贸 ninguna imagen procesada.")
        return JSONResponse(content={"error": "No image processed yet."})


@app.post("/process_image")
async def process_image(file: UploadFile = File(...)):
    """
    Procesa la imagen enviada y devuelve la imagen generada.
    """
    logging.info("Petici贸n POST recibida en '/process_image'. Procesando imagen...")
    image = Image.open(BytesIO(await file.read()))

    # Guardar la imagen recibida como 'draw.png'
    image.save("draw.png")  # Guardar en formato PNG
    logging.info("Imagen recibida guardada como 'draw.png'.")

    # Procesar la imagen y guardar el resultado
    processed_image = process_sketch(image)
    processed_image.save(OUTPUT_PATH)  # Guardar la imagen procesada como 'result.jpg'
    logging.info("Imagen procesada y guardada correctamente.")
    return {"status": f"Image processed and saved as {OUTPUT_PATH}"}

# Montar la aplicaci贸n de Gradio en FastAPI
logging.info("Montando la interfaz de Gradio en la aplicaci贸n FastAPI...")
interface = gr.Interface(
    fn=process_sketch,
    inputs=[gr.Image(source="upload", type="pil", label="Sketch Image"),
            gr.Textbox(label="Prompt (optional)"),
            gr.Dropdown(choices=list(STYLES.keys()), label="Style"),
            gr.Slider(minimum=0, maximum=100, step=1, value=DEFAULT_SEED, label="Seed"),
            gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=VAL_R_DEFAULT, label="Sketch Guidance")],
    outputs=gr.Image(label="Generated Image"),
    title="Sketch to Image HD",
    description="Upload a sketch to generate an image."
)

app = gr.mount_gradio_app(app, interface, path="/gradio")


if __name__ == "__main__":
    logging.info("Iniciando la aplicaci贸n en Uvicorn...")
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)
    print_welcome_message(interface.app)