huntrezz's picture
Update app.py
40334e7 verified
raw
history blame
2.96 kB
import cv2
import torch
import numpy as np
from transformers import DPTForDepthEstimation, DPTImageProcessor
import gradio as gr
import torch.nn.utils.prune as prune
import open3d as o3d
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float32)
model.eval()
# Apply global unstructured pruning
parameters_to_prune = [
(module, "weight") for module in filter(lambda m: isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)), model.modules())
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4, # Prune 40% of weights
)
for module, _ in parameters_to_prune:
prune.remove(module, "weight")
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
)
model = model.to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
color_map = torch.from_numpy(color_map).to(device)
def preprocess_image(image):
image = cv2.resize(image, (128, 128))
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
return image / 255.0
def create_point_cloud(depth_map, color_image):
rows, cols = depth_map.shape
c, r = np.meshgrid(np.arange(cols), np.arange(rows), sparse=True)
valid = (depth_map > 0) & (depth_map < 1000)
z = np.where(valid, depth_map, 0)
x = np.where(valid, z * (c - cols / 2) / cols, 0)
y = np.where(valid, z * (r - rows / 2) / rows, 0)
points = np.dstack((x, y, z)).reshape(-1, 3)
colors = color_image.reshape(-1, 3)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(colors / 255.0)
return pcd
@torch.inference_mode()
def process_frame(image):
if image is None:
return None
preprocessed = preprocess_image(image)
predicted_depth = model(preprocessed).predicted_depth
depth_map = predicted_depth.squeeze().cpu().numpy()
# Normalize depth map
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
# Create point cloud
pcd = create_point_cloud(depth_map, image)
# Visualize point cloud
vis = o3d.visualization.Visualizer()
vis.create_window()
vis.add_geometry(pcd)
vis.poll_events()
vis.update_renderer()
# Capture the visualization as an image
image = vis.capture_screen_float_buffer(False)
vis.destroy_window()
# Convert the image to numpy array
point_cloud_image = (np.asarray(image) * 255).astype(np.uint8)
return point_cloud_image
interface = gr.Interface(
fn=process_frame,
inputs=gr.Image(sources="webcam", streaming=True),
outputs="image",
live=True
)
interface.launch()