EgoHackZero commited on
Commit
f71d4bc
·
1 Parent(s): 7d74b9b

solve problem with input third try

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -5,28 +5,44 @@ from PIL import Image
5
  import cv2
6
 
7
  # Загрузка модели
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
- midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small").to(device)
10
  midas.eval()
11
 
12
- midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms").to(device)
13
  transform = midas_transforms.small_transform
14
 
15
-
16
  def predict_depth(image):
17
- img = np.array(image)
18
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
19
- input_tensor = transform(img_rgb).to(device)
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  with torch.no_grad():
22
  prediction = midas(input_tensor)
23
  prediction = torch.nn.functional.interpolate(
24
  prediction.unsqueeze(1),
25
- size=img_rgb.shape[:2],
26
  mode="bicubic",
27
  align_corners=False,
28
  ).squeeze()
29
 
 
30
  depth_map = prediction.cpu().numpy()
31
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
32
  depth_map = (depth_map * 255).astype(np.uint8)
@@ -34,13 +50,13 @@ def predict_depth(image):
34
 
35
  return depth_img
36
 
37
- # Интерфейс Gradio
38
  iface = gr.Interface(
39
  fn=predict_depth,
40
  inputs=gr.Image(type="pil"),
41
  outputs=gr.Image(type="pil"),
42
  title="MiDaS Depth Estimation",
43
- description="Загрузите изображение и получите карту глубины."
44
  )
45
 
46
  if __name__ == "__main__":
 
5
  import cv2
6
 
7
  # Загрузка модели
8
+ midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
 
9
  midas.eval()
10
 
11
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
12
  transform = midas_transforms.small_transform
13
 
 
14
  def predict_depth(image):
15
+ # ======= 1. Проверка типа входных данных =======
16
+ if isinstance(image, torch.Tensor):
17
+ print(f"Пришёл Tensor с формой: {image.shape}")
18
+
19
+ if len(image.shape) == 4:
20
+ input_tensor = image # уже батч [1, 3, H, W]
21
+ elif len(image.shape) == 3:
22
+ input_tensor = image.unsqueeze(0) # сделаем батч
23
+ else:
24
+ raise ValueError(f"Неожиданный размер Tensor: {image.shape}")
25
+
26
+ else:
27
+ print("Пришёл PIL Image или numpy array")
28
+ # Если пришло обычное изображение (PIL или numpy)
29
+ if not isinstance(image, Image.Image):
30
+ image = Image.fromarray(image)
31
+ img = np.array(image)
32
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
33
+ input_tensor = transform(img_rgb).unsqueeze(0)
34
+
35
+ # ======= 2. Предсказание =======
36
  with torch.no_grad():
37
  prediction = midas(input_tensor)
38
  prediction = torch.nn.functional.interpolate(
39
  prediction.unsqueeze(1),
40
+ size=(input_tensor.shape[2], input_tensor.shape[3]),
41
  mode="bicubic",
42
  align_corners=False,
43
  ).squeeze()
44
 
45
+ # ======= 3. Нормализация карты глубины =======
46
  depth_map = prediction.cpu().numpy()
47
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
48
  depth_map = (depth_map * 255).astype(np.uint8)
 
50
 
51
  return depth_img
52
 
53
+ # Gradio интерфейс
54
  iface = gr.Interface(
55
  fn=predict_depth,
56
  inputs=gr.Image(type="pil"),
57
  outputs=gr.Image(type="pil"),
58
  title="MiDaS Depth Estimation",
59
+ description="Загрузите изображение или отправьте через API. Получите карту глубины."
60
  )
61
 
62
  if __name__ == "__main__":