EgoHackZero's picture
Update app.py
fa1476f verified
raw
history blame
2.3 kB
import torch
import gradio as gr
import numpy as np
import cv2
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Загрузка модели
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
midas.to(device)
midas.eval()
# Загрузка трансформаций
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform
def predict_depth(image):
# ======= 1. Преобразование в OpenCV формат =======
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image_np = np.array(image)
# OpenCV читает в BGR, но image_np скорее всего уже в RGB
img_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB) # На всякий случай двойная проверка
# ======= 2. Преобразование как в официальном туториале =======
input_tensor = transform(img_rgb).to(device) # shape: [3, H, W]
# ======= 3. Добавление batch размерности =======
if len(input_tensor.shape) == 3:
input_batch = input_tensor.unsqueeze(0) # shape: [1, 3, H, W]
else:
input_batch = input_tensor # Уже batch
# ======= 4. Предсказание =======
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=(img_rgb.shape[0], img_rgb.shape[1]), # (H, W)
mode="bicubic",
align_corners=False,
).squeeze()
# ======= 5. Нормализация и преобразование в изображение =======
depth_map = prediction.cpu().numpy()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_map = (depth_map * 255).astype(np.uint8)
depth_img = Image.fromarray(depth_map)
return depth_img
# Gradio интерфейс
iface = gr.Interface(
fn=predict_depth,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="MiDaS Depth Estimation",
description="Drop img -> depth map"
)
if __name__ == "__main__":
iface.launch()