import torch import numpy as np from PIL import Image, ImageDraw, ImageFont import gradio as gr from diffusers import DiffusionPipeline from huggingface_hub import hf_hub_download import os # Настройки use_custom_weights = True custom_weights_path = hf_hub_download( repo_id="focuzz/depth-estimation", filename="unet_weights.pth" ) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 # Загрузка пайплайна pipe = DiffusionPipeline.from_pretrained( "prs-eth/marigold-v1-0", custom_pipeline="marigold_depth_estimation", torch_dtype=dtype ).to(device) # Загрузка дообученных весов if use_custom_weights: state_dict = torch.load(custom_weights_path, map_location=device) prefix = "unet.conv_in." if any(k.startswith("unet.conv_in.") for k in state_dict) else "conv_in." conv_in_dict = { k.replace(prefix, ""): v for k, v in state_dict.items() if k.startswith(prefix) } pipe.unet.conv_in.load_state_dict(conv_in_dict) print("Загружены дообученные веса conv_in из:", custom_weights_path) # Добавление overlay-текста def add_overlay(image: Image.Image, label: str) -> Image.Image: image = image.copy() draw = ImageDraw.Draw(image) try: font = ImageFont.load_default() except: font = None draw.text((10, 10), label, fill="white", font=font) return image # Генерация галереи из примеров TARGET_SIZE = (768, 768) def normalize_depth(depth_np): d = np.copy(depth_np) d_min = np.percentile(d, 1) d_max = np.percentile(d, 99) d = np.clip((d - d_min) / (d_max - d_min), 0, 1) return (d * 255).astype(np.uint8) def generate_gallery(): example_files = ["example1.jpg", "example2.jpg", "example3.jpg", "example4.jpg"] rgbs = [] depths_gray = [] depths_color = [] for path in example_files: if not os.path.exists(path): continue rgb = Image.open(path).convert("RGB").resize(TARGET_SIZE) with torch.no_grad(): output = pipe( rgb, denoising_steps=4, ensemble_size=5, processing_res=768, match_input_res=True, batch_size=0, color_map="Spectral", show_progress_bar=False, ) depth_np = output.depth_np gray_normalized = normalize_depth(depth_np) depth_gray = Image.fromarray(gray_normalized).convert("RGB").resize(TARGET_SIZE, Image.BILINEAR) depth_color = output.depth_colored.resize(TARGET_SIZE, Image.BILINEAR) rgbs.append(add_overlay(rgb, "RGB")) depths_gray.append(add_overlay(depth_gray, "Глубина (серая)")) depths_color.append(add_overlay(depth_color, "Глубина (цветная)")) return rgbs + depths_color + depths_gray # Интерфейс Blocks с галереей и инференсом with gr.Blocks() as demo: gr.Markdown("## Генерация карт глубины") gr.Markdown( "Модель основана на Marigold (ETH), дообучена на indoor-сценах из NYUv2. " "Сохраняет способность обрабатывать произвольные изображения благодаря наличию оригинальных U-Net весов." ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Загрузите RGB изображение", type="pil") denoise = gr.Slider(1, 50, value=4, step=1, label="Шаги денойзинга") ensemble = gr.Slider(1, 10, value=5, step=1, label="Размер ансамбля (количество запусков для одной картинки)") resolution = gr.Slider(256, 1024, value=768, step=64, label="Разрешение обработки изображений") match_res = gr.Checkbox(value=True, label="Сохранять исходное разрешение") with gr.Column(scale=1): output_image = gr.Image(label="Карта глубины") def predict_depth(image, denoising_steps, ensemble_size, processing_res, match_input_res): with torch.no_grad(): output = pipe( image, denoising_steps=denoising_steps, ensemble_size=ensemble_size, processing_res=processing_res, match_input_res=match_input_res, batch_size=0, color_map="Spectral", show_progress_bar=False, ) return output.depth_colored submit_btn = gr.Button("Выполнить предсказание") submit_btn.click( predict_depth, inputs=[input_image, denoise, ensemble, resolution, match_res], outputs=output_image ) gr.Markdown("### Примеры:") gallery = gr.Gallery(label="Сравнение RGB и Глубины", columns=4) demo.load(fn=generate_gallery, outputs=gallery) demo.launch(ssr_mode=False)