focuzz's picture
Update app.py
b5d3526 verified
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download
import os
# Настройки
use_custom_weights = True
custom_weights_path = hf_hub_download(
repo_id="focuzz/depth-estimation",
filename="unet_weights.pth"
)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Загрузка пайплайна
pipe = DiffusionPipeline.from_pretrained(
"prs-eth/marigold-v1-0",
custom_pipeline="marigold_depth_estimation",
torch_dtype=dtype
).to(device)
# Загрузка дообученных весов
if use_custom_weights:
state_dict = torch.load(custom_weights_path, map_location=device)
prefix = "unet.conv_in." if any(k.startswith("unet.conv_in.") for k in state_dict) else "conv_in."
conv_in_dict = {
k.replace(prefix, ""): v
for k, v in state_dict.items()
if k.startswith(prefix)
}
pipe.unet.conv_in.load_state_dict(conv_in_dict)
print("Загружены дообученные веса conv_in из:", custom_weights_path)
# Добавление overlay-текста
def add_overlay(image: Image.Image, label: str) -> Image.Image:
image = image.copy()
draw = ImageDraw.Draw(image)
try:
font = ImageFont.load_default()
except:
font = None
draw.text((10, 10), label, fill="white", font=font)
return image
# Генерация галереи из примеров
TARGET_SIZE = (768, 768)
def normalize_depth(depth_np):
d = np.copy(depth_np)
d_min = np.percentile(d, 1)
d_max = np.percentile(d, 99)
d = np.clip((d - d_min) / (d_max - d_min), 0, 1)
return (d * 255).astype(np.uint8)
def generate_gallery():
example_files = ["example1.jpg", "example2.jpg", "example3.jpg", "example4.jpg"]
rgbs = []
depths_gray = []
depths_color = []
for path in example_files:
if not os.path.exists(path):
continue
rgb = Image.open(path).convert("RGB").resize(TARGET_SIZE)
with torch.no_grad():
output = pipe(
rgb,
denoising_steps=4,
ensemble_size=5,
processing_res=768,
match_input_res=True,
batch_size=0,
color_map="Spectral",
show_progress_bar=False,
)
depth_np = output.depth_np
gray_normalized = normalize_depth(depth_np)
depth_gray = Image.fromarray(gray_normalized).convert("RGB").resize(TARGET_SIZE, Image.BILINEAR)
depth_color = output.depth_colored.resize(TARGET_SIZE, Image.BILINEAR)
rgbs.append(add_overlay(rgb, "RGB"))
depths_gray.append(add_overlay(depth_gray, "Глубина (серая)"))
depths_color.append(add_overlay(depth_color, "Глубина (цветная)"))
return rgbs + depths_color + depths_gray
# Интерфейс Blocks с галереей и инференсом
with gr.Blocks() as demo:
gr.Markdown("## Генерация карт глубины")
gr.Markdown(
"Модель основана на Marigold (ETH), дообучена на indoor-сценах из NYUv2. "
"Сохраняет способность обрабатывать произвольные изображения благодаря наличию оригинальных U-Net весов."
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Загрузите RGB изображение", type="pil")
denoise = gr.Slider(1, 50, value=4, step=1, label="Шаги денойзинга")
ensemble = gr.Slider(1, 10, value=5, step=1, label="Размер ансамбля (количество запусков для одной картинки)")
resolution = gr.Slider(256, 1024, value=768, step=64, label="Разрешение обработки изображений")
match_res = gr.Checkbox(value=True, label="Сохранять исходное разрешение")
with gr.Column(scale=1):
output_image = gr.Image(label="Карта глубины")
def predict_depth(image, denoising_steps, ensemble_size, processing_res, match_input_res):
with torch.no_grad():
output = pipe(
image,
denoising_steps=denoising_steps,
ensemble_size=ensemble_size,
processing_res=processing_res,
match_input_res=match_input_res,
batch_size=0,
color_map="Spectral",
show_progress_bar=False,
)
return output.depth_colored
submit_btn = gr.Button("Выполнить предсказание")
submit_btn.click(
predict_depth,
inputs=[input_image, denoise, ensemble, resolution, match_res],
outputs=output_image
)
gr.Markdown("### Примеры:")
gallery = gr.Gallery(label="Сравнение RGB и Глубины", columns=4)
demo.load(fn=generate_gallery, outputs=gallery)
demo.launch(ssr_mode=False)