File size: 5,258 Bytes
36ab632
 
 
 
 
 
 
 
 
 
 
59fb189
36ab632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cf1c1f
 
 
 
36ab632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53771b4
 
36ab632
 
 
 
 
 
 
 
 
 
 
 
 
 
53771b4
 
 
 
36ab632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53771b4
36ab632
 
b5d3526
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download
import os

# Настройки
use_custom_weights = True
custom_weights_path = hf_hub_download(
    repo_id="focuzz/depth-estimation",
    filename="unet_weights.pth"
)

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# Загрузка пайплайна
pipe = DiffusionPipeline.from_pretrained(
    "prs-eth/marigold-v1-0",
    custom_pipeline="marigold_depth_estimation",
    torch_dtype=dtype
).to(device)

# Загрузка дообученных весов
if use_custom_weights:
    state_dict = torch.load(custom_weights_path, map_location=device)
    prefix = "unet.conv_in." if any(k.startswith("unet.conv_in.") for k in state_dict) else "conv_in."
    conv_in_dict = {
        k.replace(prefix, ""): v
        for k, v in state_dict.items()
        if k.startswith(prefix)
    }
    pipe.unet.conv_in.load_state_dict(conv_in_dict)
    print("Загружены дообученные веса conv_in из:", custom_weights_path)

# Добавление overlay-текста
def add_overlay(image: Image.Image, label: str) -> Image.Image:
    image = image.copy()
    draw = ImageDraw.Draw(image)
    try:
        font = ImageFont.load_default()
    except:
        font = None
    draw.text((10, 10), label, fill="white", font=font)
    return image

# Генерация галереи из примеров
TARGET_SIZE = (768, 768)
def normalize_depth(depth_np):
    d = np.copy(depth_np)
    d_min = np.percentile(d, 1)
    d_max = np.percentile(d, 99)
    d = np.clip((d - d_min) / (d_max - d_min), 0, 1)
    return (d * 255).astype(np.uint8)

def generate_gallery():
    example_files = ["example1.jpg", "example2.jpg", "example3.jpg", "example4.jpg"]
    rgbs = []
    depths_gray = []
    depths_color = []

    for path in example_files:
        if not os.path.exists(path):
            continue

        rgb = Image.open(path).convert("RGB").resize(TARGET_SIZE)

        with torch.no_grad():
            output = pipe(
                rgb,
                denoising_steps=4,
                ensemble_size=5,
                processing_res=768,
                match_input_res=True,
                batch_size=0,
                color_map="Spectral",
                show_progress_bar=False,
            )

        depth_np = output.depth_np
        gray_normalized = normalize_depth(depth_np)
        depth_gray = Image.fromarray(gray_normalized).convert("RGB").resize(TARGET_SIZE, Image.BILINEAR)
        depth_color = output.depth_colored.resize(TARGET_SIZE, Image.BILINEAR)

        rgbs.append(add_overlay(rgb, "RGB"))
        depths_gray.append(add_overlay(depth_gray, "Глубина (серая)"))
        depths_color.append(add_overlay(depth_color, "Глубина (цветная)"))

    return rgbs + depths_color + depths_gray

# Интерфейс Blocks с галереей и инференсом
with gr.Blocks() as demo:
    gr.Markdown("## Генерация карт глубины")
    gr.Markdown(
        "Модель основана на Marigold (ETH), дообучена на indoor-сценах из NYUv2. "
        "Сохраняет способность обрабатывать произвольные изображения благодаря наличию оригинальных U-Net весов."
    )

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(label="Загрузите RGB изображение", type="pil")
            denoise = gr.Slider(1, 50, value=4, step=1, label="Шаги денойзинга")
            ensemble = gr.Slider(1, 10, value=5, step=1, label="Размер ансамбля (количество запусков для одной картинки)")
            resolution = gr.Slider(256, 1024, value=768, step=64, label="Разрешение обработки изображений")
            match_res = gr.Checkbox(value=True, label="Сохранять исходное разрешение")
        with gr.Column(scale=1):
            output_image = gr.Image(label="Карта глубины")

    def predict_depth(image, denoising_steps, ensemble_size, processing_res, match_input_res):
        with torch.no_grad():
            output = pipe(
                image,
                denoising_steps=denoising_steps,
                ensemble_size=ensemble_size,
                processing_res=processing_res,
                match_input_res=match_input_res,
                batch_size=0,
                color_map="Spectral",
                show_progress_bar=False,
            )
        return output.depth_colored

    submit_btn = gr.Button("Выполнить предсказание")
    submit_btn.click(
        predict_depth,
        inputs=[input_image, denoise, ensemble, resolution, match_res],
        outputs=output_image
    )

    gr.Markdown("### Примеры:")
    gallery = gr.Gallery(label="Сравнение RGB и Глубины", columns=4)
    demo.load(fn=generate_gallery, outputs=gallery)

demo.launch(ssr_mode=False)