Spaces:
Running
on
T4
Running
on
T4
import torch | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
import gradio as gr | |
from diffusers import DiffusionPipeline | |
from huggingface_hub import hf_hub_download | |
import os | |
# Настройки | |
use_custom_weights = True | |
custom_weights_path = hf_hub_download( | |
repo_id="focuzz/depth-estimation", | |
filename="unet_weights.pth" | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
# Загрузка пайплайна | |
pipe = DiffusionPipeline.from_pretrained( | |
"prs-eth/marigold-v1-0", | |
custom_pipeline="marigold_depth_estimation", | |
torch_dtype=dtype | |
).to(device) | |
# Загрузка дообученных весов | |
if use_custom_weights: | |
state_dict = torch.load(custom_weights_path, map_location=device) | |
prefix = "unet.conv_in." if any(k.startswith("unet.conv_in.") for k in state_dict) else "conv_in." | |
conv_in_dict = { | |
k.replace(prefix, ""): v | |
for k, v in state_dict.items() | |
if k.startswith(prefix) | |
} | |
pipe.unet.conv_in.load_state_dict(conv_in_dict) | |
print("Загружены дообученные веса conv_in из:", custom_weights_path) | |
# Добавление overlay-текста | |
def add_overlay(image: Image.Image, label: str) -> Image.Image: | |
image = image.copy() | |
draw = ImageDraw.Draw(image) | |
try: | |
font = ImageFont.load_default() | |
except: | |
font = None | |
draw.text((10, 10), label, fill="white", font=font) | |
return image | |
# Генерация галереи из примеров | |
TARGET_SIZE = (768, 768) | |
def normalize_depth(depth_np): | |
d = np.copy(depth_np) | |
d_min = np.percentile(d, 1) | |
d_max = np.percentile(d, 99) | |
d = np.clip((d - d_min) / (d_max - d_min), 0, 1) | |
return (d * 255).astype(np.uint8) | |
def generate_gallery(): | |
example_files = ["example1.jpg", "example2.jpg", "example3.jpg", "example4.jpg"] | |
rgbs = [] | |
depths_gray = [] | |
depths_color = [] | |
for path in example_files: | |
if not os.path.exists(path): | |
continue | |
rgb = Image.open(path).convert("RGB").resize(TARGET_SIZE) | |
with torch.no_grad(): | |
output = pipe( | |
rgb, | |
denoising_steps=4, | |
ensemble_size=5, | |
processing_res=768, | |
match_input_res=True, | |
batch_size=0, | |
color_map="Spectral", | |
show_progress_bar=False, | |
) | |
depth_np = output.depth_np | |
gray_normalized = normalize_depth(depth_np) | |
depth_gray = Image.fromarray(gray_normalized).convert("RGB").resize(TARGET_SIZE, Image.BILINEAR) | |
depth_color = output.depth_colored.resize(TARGET_SIZE, Image.BILINEAR) | |
rgbs.append(add_overlay(rgb, "RGB")) | |
depths_gray.append(add_overlay(depth_gray, "Глубина (серая)")) | |
depths_color.append(add_overlay(depth_color, "Глубина (цветная)")) | |
return rgbs + depths_color + depths_gray | |
# Интерфейс Blocks с галереей и инференсом | |
with gr.Blocks() as demo: | |
gr.Markdown("## Генерация карт глубины") | |
gr.Markdown( | |
"Модель основана на Marigold (ETH), дообучена на indoor-сценах из NYUv2. " | |
"Сохраняет способность обрабатывать произвольные изображения благодаря наличию оригинальных U-Net весов." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(label="Загрузите RGB изображение", type="pil") | |
denoise = gr.Slider(1, 50, value=4, step=1, label="Шаги денойзинга") | |
ensemble = gr.Slider(1, 10, value=5, step=1, label="Размер ансамбля (количество запусков для одной картинки)") | |
resolution = gr.Slider(256, 1024, value=768, step=64, label="Разрешение обработки изображений") | |
match_res = gr.Checkbox(value=True, label="Сохранять исходное разрешение") | |
with gr.Column(scale=1): | |
output_image = gr.Image(label="Карта глубины") | |
def predict_depth(image, denoising_steps, ensemble_size, processing_res, match_input_res): | |
with torch.no_grad(): | |
output = pipe( | |
image, | |
denoising_steps=denoising_steps, | |
ensemble_size=ensemble_size, | |
processing_res=processing_res, | |
match_input_res=match_input_res, | |
batch_size=0, | |
color_map="Spectral", | |
show_progress_bar=False, | |
) | |
return output.depth_colored | |
submit_btn = gr.Button("Выполнить предсказание") | |
submit_btn.click( | |
predict_depth, | |
inputs=[input_image, denoise, ensemble, resolution, match_res], | |
outputs=output_image | |
) | |
gr.Markdown("### Примеры:") | |
gallery = gr.Gallery(label="Сравнение RGB и Глубины", columns=4) | |
demo.load(fn=generate_gallery, outputs=gallery) | |
demo.launch(ssr_mode=False) |