File size: 4,299 Bytes
f8b3886
 
 
d5bc6d9
e8486cb
3548ace
 
4f3e058
e687f81
f8b3886
 
893be2d
d5bc6d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
893be2d
d5bc6d9
 
 
 
fd26002
d5bc6d9
 
 
 
893be2d
7b83683
f8b3886
a42d79c
edd2a5b
99bbe3e
 
a42d79c
dcf5ae7
e687f81
dcf5ae7
e687f81
 
 
 
 
 
 
 
dcf5ae7
e687f81
 
 
dcf5ae7
e687f81
 
 
 
 
 
 
772e909
e687f81
 
 
 
 
 
dcf5ae7
2505cbf
e687f81
2505cbf
dcf5ae7
 
 
 
 
 
3548ace
79684c1
d6a18b3
1f906f0
 
cafea28
d5bc6d9
40334e7
d5bc6d9
40334e7
d5bc6d9
dcf5ae7
edd2a5b
dcf5ae7
7bc8ed0
4f1fd81
 
d6a18b3
4f1fd81
 
 
f8b3886
f170544
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import cv2
import torch
import numpy as np
from transformers import DPTImageProcessor
import gradio as gr
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch.nn as nn
from scipy.interpolate import interp2d

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load your custom trained model
class CompressedStudentModel(nn.Module):
    def __init__(self):
        super(CompressedStudentModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
        )

    def forward(self, x):
        features = self.encoder(x)
        depth = self.decoder(features)
        return depth

# Initialize and load weights into the student model
model = CompressedStudentModel().to(device)
model.load_state_dict(torch.load("huntrezz_depth_v2.pt", map_location=device))
model.eval()

processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")

def preprocess_image(image):
    image = cv2.resize(image, (128, 72))
    image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
    return image / 255.0

def plot_depth_map(depth_map, original_image):
    fig = plt.figure(figsize=(32, 9))
    
    # Increase resolution of the meshgrid
    x, y = np.meshgrid(np.linspace(0, depth_map.shape[1]-1, 255), np.linspace(0, depth_map.shape[0]-1, 255))
    
    # Interpolate depth map
    depth_interp = interp2d(np.arange(depth_map.shape[1]), np.arange(depth_map.shape[0]), depth_map)
    z = depth_interp(np.linspace(0, depth_map.shape[1]-1, 255), np.linspace(0, depth_map.shape[0]-1, 255))
    
    # Interpolate colors
    original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
    colors = original_image_resized.reshape(-1, original_image_resized.shape[1], 3) / 255.0
    colors_interp = interp2d(np.arange(colors.shape[1]), np.arange(colors.shape[0]), colors.reshape(-1, colors.shape[1]), kind='linear')
    new_colors = colors_interp(np.linspace(0, colors.shape[1]-1, 255), np.linspace(0, colors.shape[0]-1, 255))
    
    # Plot with depth map color
    ax1 = fig.add_subplot(121, projection='3d')
    surf1 = ax1.plot_surface(x, y, z, facecolors=plt.cm.viridis(z), shade=False)
    ax1.set_zlim(0, 1)
    ax1.view_init(elev=150, azim=90)
    ax1.set_title("Depth Map Color")
    plt.axis('off')
    
    # Plot with RGB color
    ax2 = fig.add_subplot(122, projection='3d')
    surf2 = ax2.plot_surface(x, y, z, facecolors=new_colors, shade=False)
    ax2.set_zlim(0, 1)
    ax2.view_init(elev=150, azim=90)
    ax2.set_title("RGB Color")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    fig.canvas.draw()
    img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    
    return img

@torch.inference_mode()
def process_frame(image):
    if image is None:
        return None
    preprocessed = preprocess_image(image)
    predicted_depth = model(preprocessed).squeeze().cpu().numpy()
    
    depth_map = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
    
    if image.shape[2] == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    return plot_depth_map(depth_map, image)

interface = gr.Interface(
    fn=process_frame,
    inputs=gr.Image(sources="webcam", streaming=True),
    outputs="image",
    live=True
)

interface.launch()