EgoHackZero commited on
Commit
afae8e5
·
1 Parent(s): f71d4bc

solve problem with input third try

Browse files
Files changed (2) hide show
  1. app.py +26 -24
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,48 +1,50 @@
1
  import torch
2
  import gradio as gr
3
  import numpy as np
4
- from PIL import Image
5
  import cv2
 
 
 
6
 
7
  # Загрузка модели
8
  midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
 
9
  midas.eval()
10
 
 
11
  midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
12
  transform = midas_transforms.small_transform
13
 
14
  def predict_depth(image):
15
- # ======= 1. Проверка типа входных данных =======
16
- if isinstance(image, torch.Tensor):
17
- print(f"Пришёл Tensor с формой: {image.shape}")
 
18
 
19
- if len(image.shape) == 4:
20
- input_tensor = image # уже батч [1, 3, H, W]
21
- elif len(image.shape) == 3:
22
- input_tensor = image.unsqueeze(0) # сделаем батч
23
- else:
24
- raise ValueError(f"Неожиданный размер Tensor: {image.shape}")
25
 
 
 
 
 
 
 
26
  else:
27
- print("Пришёл PIL Image или numpy array")
28
- # Если пришло обычное изображение (PIL или numpy)
29
- if not isinstance(image, Image.Image):
30
- image = Image.fromarray(image)
31
- img = np.array(image)
32
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
33
- input_tensor = transform(img_rgb).unsqueeze(0)
34
-
35
- # ======= 2. Предсказание =======
36
  with torch.no_grad():
37
- prediction = midas(input_tensor)
38
  prediction = torch.nn.functional.interpolate(
39
  prediction.unsqueeze(1),
40
- size=(input_tensor.shape[2], input_tensor.shape[3]),
41
  mode="bicubic",
42
  align_corners=False,
43
  ).squeeze()
44
 
45
- # ======= 3. Нормализация карты глубины =======
46
  depth_map = prediction.cpu().numpy()
47
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
48
  depth_map = (depth_map * 255).astype(np.uint8)
@@ -55,8 +57,8 @@ iface = gr.Interface(
55
  fn=predict_depth,
56
  inputs=gr.Image(type="pil"),
57
  outputs=gr.Image(type="pil"),
58
- title="MiDaS Depth Estimation",
59
- description="Загрузите изображение или отправьте через API. Получите карту глубины."
60
  )
61
 
62
  if __name__ == "__main__":
 
1
  import torch
2
  import gradio as gr
3
  import numpy as np
 
4
  import cv2
5
+ from PIL import Image
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
  # Загрузка модели
10
  midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
11
+ midas.to(device)
12
  midas.eval()
13
 
14
+ # Загрузка трансформаций
15
  midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
16
  transform = midas_transforms.small_transform
17
 
18
  def predict_depth(image):
19
+ # ======= 1. Преобразование в OpenCV формат =======
20
+ if not isinstance(image, Image.Image):
21
+ image = Image.fromarray(image)
22
+ image_np = np.array(image)
23
 
24
+ # OpenCV читает в BGR, но image_np скорее всего уже в RGB
25
+ img_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
26
+ img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB) # На всякий случай двойная проверка
 
 
 
27
 
28
+ # ======= 2. Преобразование как в официальном туториале =======
29
+ input_tensor = transform(img_rgb).to(device) # shape: [3, H, W]
30
+
31
+ # ======= 3. Добавление batch размерности =======
32
+ if len(input_tensor.shape) == 3:
33
+ input_batch = input_tensor.unsqueeze(0) # shape: [1, 3, H, W]
34
  else:
35
+ input_batch = input_tensor # Уже batch
36
+
37
+ # ======= 4. Предсказание =======
 
 
 
 
 
 
38
  with torch.no_grad():
39
+ prediction = midas(input_batch)
40
  prediction = torch.nn.functional.interpolate(
41
  prediction.unsqueeze(1),
42
+ size=(img_rgb.shape[0], img_rgb.shape[1]), # (H, W)
43
  mode="bicubic",
44
  align_corners=False,
45
  ).squeeze()
46
 
47
+ # ======= 5. Нормализация и преобразование в изображение =======
48
  depth_map = prediction.cpu().numpy()
49
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
50
  depth_map = (depth_map * 255).astype(np.uint8)
 
57
  fn=predict_depth,
58
  inputs=gr.Image(type="pil"),
59
  outputs=gr.Image(type="pil"),
60
+ title="MiDaS Depth Estimation (Official Tutorial Style)",
61
+ description="Загрузите изображение. Возвращается карта глубины."
62
  )
63
 
64
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -2,4 +2,5 @@ torch
2
  torchvision
3
  timm
4
  opencv-python
 
5
  gradio
 
2
  torchvision
3
  timm
4
  opencv-python
5
+ pillow
6
  gradio