from pytorch3d.renderer import ( BlendParams, blending, look_at_view_transform, FoVOrthographicCameras, PointLights, RasterizationSettings, PointsRasterizationSettings, PointsRenderer, AlphaCompositor, PointsRasterizer, MeshRenderer, MeshRasterizer, SoftPhongShader, SoftSilhouetteShader, TexturesVertex) from pytorch3d.renderer.mesh import TexturesVertex from pytorch3d.structures import Meshes, Pointclouds import torch import numpy as np import math import cv2 class cleanShader(torch.nn.Module): def __init__(self, device="cpu", cameras=None, blend_params=None): super().__init__() self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams( ) def forward(self, fragments, meshes, **kwargs): cameras = kwargs.get("cameras", self.cameras) if cameras is None: msg = "Cameras must be specified either at initialization \ or in the forward pass of TexturedSoftPhongShader" raise ValueError(msg) # get renderer output blend_params = kwargs.get("blend_params", self.blend_params) texels = meshes.sample_textures(fragments) images = blending.softmax_rgb_blend(texels, fragments, blend_params, znear=-256, zfar=256) return images class Render: def __init__(self, size=512, device=torch.device("cuda:0")): self.device = device self.mesh_y_center = 100.0 self.dis = 100.0 self.scale = 1.0 self.size = size self.cam_pos = [(0, 100, 100)] self.mesh = None self.pcd = None self.renderer = None self.meshRas = None def get_camera(self, cam_id): # at R, T = look_at_view_transform(eye=[self.cam_pos[cam_id]], at=((0, self.mesh_y_center, 0),), up=((0, 1, 0),)) camera = FoVOrthographicCameras(device=self.device, R=R, T=T, znear=100.0, zfar=-100.0, max_y=100.0, min_y=-100.0, max_x=100.0, min_x=-100.0, scale_xyz=(self.scale * np.ones(3),)) return camera def init_renderer(self, camera, type='clean_mesh', bg='gray'): if 'mesh' in type: # rasterizer self.raster_settings_mesh = RasterizationSettings( image_size=self.size, blur_radius=np.log(1.0 / 1e-4) * 1e-7, faces_per_pixel=30, ) self.meshRas = MeshRasterizer(cameras=camera, raster_settings=self.raster_settings_mesh) if bg == 'black': blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0)) elif bg == 'white': blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0)) elif bg == 'gray': blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5)) if type == 'ori_mesh': lights = PointLights(device=self.device, ambient_color=((0.8, 0.8, 0.8),), diffuse_color=((0.2, 0.2, 0.2),), specular_color=((0.0, 0.0, 0.0),), location=[[0.0, 200.0, 200.0]]) self.renderer = MeshRenderer( rasterizer=self.meshRas, shader=SoftPhongShader( device=self.device, cameras=camera, lights=lights, blend_params=blendparam)) if type == 'silhouette': self.raster_settings_silhouette = RasterizationSettings( image_size=self.size, blur_radius=np.log(1. / 1e-4 - 1.) * 5e-5, faces_per_pixel=50, cull_backfaces=True, ) self.silhouetteRas = MeshRasterizer( cameras=camera, raster_settings=self.raster_settings_silhouette) self.renderer = MeshRenderer(rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()) if type == 'pointcloud': self.raster_settings_pcd = PointsRasterizationSettings( image_size=self.size, radius=0.006, points_per_pixel=10) self.pcdRas = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings_pcd) self.renderer = PointsRenderer( rasterizer=self.pcdRas, compositor=AlphaCompositor(background_color=(0, 0, 0))) if type == 'clean_mesh': self.renderer = MeshRenderer( rasterizer=self.meshRas, shader=cleanShader( device=self.device, cameras=camera, blend_params=blendparam)) def set_camera(self, verts, normalize=False): self.scale = 100 self.mesh_y_center = 0 if normalize: y_max = verts.max(dim=1)[0][0, 1].item() y_min = verts.min(dim=1)[0][0, 1].item() self.scale *= 0.95 / ((y_max - y_min) * 0.5 + 1e-10) self.mesh_y_center = (y_max + y_min) * 0.5 self.cam_pos = [(0, self.mesh_y_center, self.dis), (self.dis, self.mesh_y_center, 0), (0, self.mesh_y_center, -self.dis), (-self.dis, self.mesh_y_center, 0)] def load_mesh(self, verts, faces, verts_rgb=None, normalize=False, use_normal=False): """load mesh into the pytorch3d renderer Args: verts ([N,3]): verts faces ([N,3]): faces verts_rgb ([N,3]): rgb normalize: bool """ if not torch.is_tensor(verts): verts = torch.tensor(verts) if not torch.is_tensor(faces): faces = torch.tensor(faces) if verts.ndimension() == 2: verts = verts.unsqueeze(0).float() if faces.ndimension() == 2: faces = faces.unsqueeze(0).long() verts = verts.to(self.device) faces = faces.to(self.device) self.set_camera(verts, normalize) self.mesh = Meshes(verts, faces).to(self.device) if verts_rgb is not None: if not torch.is_tensor(verts_rgb): verts_rgb = torch.as_tensor(verts_rgb) if verts_rgb.ndimension() == 2: verts_rgb = verts_rgb.unsqueeze(0).float() verts_rgb = verts_rgb.to(self.device) elif use_normal: verts_rgb = self.mesh.verts_normals_padded() verts_rgb = (verts_rgb + 1.0) * 0.5 else: verts_rgb = self.mesh.verts_normals_padded()[..., 2:3].expand(-1, -1, 3) verts_rgb = (verts_rgb + 1.0) * 0.5 textures = TexturesVertex(verts_features=verts_rgb) self.mesh.textures = textures return self.mesh def load_pcd(self, verts, verts_rgb, normalize=False): """load pointcloud into the pytorch3d renderer Args: verts ([B, N,3]): verts verts_rgb ([B, N,3]): verts colors normalize bool: render point cloud in center """ assert verts.shape == verts_rgb.shape and len(verts.shape) == 3 # data format convert if not torch.is_tensor(verts): verts = torch.as_tensor(verts) if not torch.is_tensor(verts_rgb): verts_rgb = torch.as_tensor(verts_rgb) verts = verts.float().to(self.device) verts_rgb = verts_rgb.float().to(self.device) # camera setting self.set_camera(verts, normalize) pcd = Pointclouds(points=verts, features=verts_rgb).to(self.device) return pcd def get_image(self, cam_ids=[0, 2], type='clean_mesh', bg='gray'): images = [] for cam_id in range(len(self.cam_pos)): if cam_id in cam_ids: self.init_renderer(self.get_camera(cam_id), type, bg) rendered_img = self.renderer(self.mesh)[0, :, :, :3] if cam_id == 2 and len(cam_ids) == 2: rendered_img = torch.flip(rendered_img, dims=[1]) images.append(rendered_img) images = torch.cat(images, 1) return images.detach().cpu().numpy() def get_clean_image(self, cam_ids=[0, 2], type='clean_mesh', bg='gray'): images = [] for cam_id in range(len(self.cam_pos)): if cam_id in cam_ids: self.init_renderer(self.get_camera(cam_id), type, bg) rendered_img = self.renderer(self.mesh)[0:1, :, :, :3] if cam_id == 2 and len(cam_ids) == 2: rendered_img = torch.flip(rendered_img, dims=[2]) images.append(rendered_img) return images def get_silhouette_image(self, cam_ids=[0, 2]): images = [] for cam_id in range(len(self.cam_pos)): if cam_id in cam_ids: self.init_renderer(self.get_camera(cam_id), 'silhouette') rendered_img = self.renderer(self.mesh)[0:1, :, :, 3] if cam_id == 2 and len(cam_ids) == 2: rendered_img = torch.flip(rendered_img, dims=[2]) images.append(rendered_img) return images def get_image_pcd(self, pcd, cam_ids=[0, 1, 2, 3]): images = torch.zeros((self.size, self.size * len(cam_ids), 3)).to(self.device) for i, cam_id in enumerate(cam_ids): self.init_renderer(self.get_camera(cam_id), 'pointcloud') images[:, self.size * i:self.size * (i + 1), :] = self.renderer(pcd)[0, :, :, :3] return images.cpu().numpy() def get_rendered_video(self, save_path, num_angle=100, s=0): self.cam_pos = [] interval = 360. / num_angle for i in range(num_angle): # for angle in range(90, 90+360, ): angle = (s + i * interval) % 360 self.cam_pos.append( (self.dis * math.cos(np.pi / 180 * angle), self.mesh_y_center, self.dis * math.sin(np.pi / 180 * angle))) fourcc = cv2.VideoWriter_fourcc(*'mp4v') video = cv2.VideoWriter(save_path, fourcc, 30, (self.size, self.size)) for cam_id in range(len(self.cam_pos)): self.init_renderer(self.get_camera(cam_id), 'clean_mesh', 'gray') rendered_img = (self.renderer(self.mesh)[0, :, :, :3] * 255.0).detach().cpu().numpy().astype(np.uint8) video.write(rendered_img) video.release()