Spaces:
Runtime error
Runtime error
from pytorch3d.renderer import ( | |
BlendParams, blending, look_at_view_transform, FoVOrthographicCameras, | |
PointLights, RasterizationSettings, PointsRasterizationSettings, | |
PointsRenderer, AlphaCompositor, PointsRasterizer, MeshRenderer, | |
MeshRasterizer, SoftPhongShader, SoftSilhouetteShader, TexturesVertex) | |
from pytorch3d.renderer.mesh import TexturesVertex | |
from pytorch3d.structures import Meshes, Pointclouds | |
import torch | |
import numpy as np | |
import math | |
import cv2 | |
class cleanShader(torch.nn.Module): | |
def __init__(self, device="cpu", cameras=None, blend_params=None): | |
super().__init__() | |
self.cameras = cameras | |
self.blend_params = blend_params if blend_params is not None else BlendParams( | |
) | |
def forward(self, fragments, meshes, **kwargs): | |
cameras = kwargs.get("cameras", self.cameras) | |
if cameras is None: | |
msg = "Cameras must be specified either at initialization \ | |
or in the forward pass of TexturedSoftPhongShader" | |
raise ValueError(msg) | |
# get renderer output | |
blend_params = kwargs.get("blend_params", self.blend_params) | |
texels = meshes.sample_textures(fragments) | |
images = blending.softmax_rgb_blend(texels, | |
fragments, | |
blend_params, | |
znear=-256, | |
zfar=256) | |
return images | |
class Render: | |
def __init__(self, size=512, device=torch.device("cuda:0")): | |
self.device = device | |
self.mesh_y_center = 100.0 | |
self.dis = 100.0 | |
self.scale = 1.0 | |
self.size = size | |
self.cam_pos = [(0, 100, 100)] | |
self.mesh = None | |
self.pcd = None | |
self.renderer = None | |
self.meshRas = None | |
def get_camera(self, cam_id): | |
# at | |
R, T = look_at_view_transform(eye=[self.cam_pos[cam_id]], | |
at=((0, self.mesh_y_center, 0),), | |
up=((0, 1, 0),)) | |
camera = FoVOrthographicCameras(device=self.device, | |
R=R, | |
T=T, | |
znear=100.0, | |
zfar=-100.0, | |
max_y=100.0, | |
min_y=-100.0, | |
max_x=100.0, | |
min_x=-100.0, | |
scale_xyz=(self.scale * np.ones(3),)) | |
return camera | |
def init_renderer(self, camera, type='clean_mesh', bg='gray'): | |
if 'mesh' in type: | |
# rasterizer | |
self.raster_settings_mesh = RasterizationSettings( | |
image_size=self.size, | |
blur_radius=np.log(1.0 / 1e-4) * 1e-7, | |
faces_per_pixel=30, | |
) | |
self.meshRas = MeshRasterizer(cameras=camera, | |
raster_settings=self.raster_settings_mesh) | |
if bg == 'black': | |
blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0)) | |
elif bg == 'white': | |
blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0)) | |
elif bg == 'gray': | |
blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5)) | |
if type == 'ori_mesh': | |
lights = PointLights(device=self.device, | |
ambient_color=((0.8, 0.8, 0.8),), | |
diffuse_color=((0.2, 0.2, 0.2),), | |
specular_color=((0.0, 0.0, 0.0),), | |
location=[[0.0, 200.0, 200.0]]) | |
self.renderer = MeshRenderer( | |
rasterizer=self.meshRas, | |
shader=SoftPhongShader( | |
device=self.device, | |
cameras=camera, | |
lights=lights, | |
blend_params=blendparam)) | |
if type == 'silhouette': | |
self.raster_settings_silhouette = RasterizationSettings( | |
image_size=self.size, | |
blur_radius=np.log(1. / 1e-4 - 1.) * 5e-5, | |
faces_per_pixel=50, | |
cull_backfaces=True, | |
) | |
self.silhouetteRas = MeshRasterizer( | |
cameras=camera, raster_settings=self.raster_settings_silhouette) | |
self.renderer = MeshRenderer(rasterizer=self.silhouetteRas, | |
shader=SoftSilhouetteShader()) | |
if type == 'pointcloud': | |
self.raster_settings_pcd = PointsRasterizationSettings( | |
image_size=self.size, | |
radius=0.006, | |
points_per_pixel=10) | |
self.pcdRas = PointsRasterizer(cameras=camera, | |
raster_settings=self.raster_settings_pcd) | |
self.renderer = PointsRenderer( | |
rasterizer=self.pcdRas, | |
compositor=AlphaCompositor(background_color=(0, 0, 0))) | |
if type == 'clean_mesh': | |
self.renderer = MeshRenderer( | |
rasterizer=self.meshRas, | |
shader=cleanShader( | |
device=self.device, | |
cameras=camera, | |
blend_params=blendparam)) | |
def set_camera(self, verts, normalize=False): | |
self.scale = 100 | |
self.mesh_y_center = 0 | |
if normalize: | |
y_max = verts.max(dim=1)[0][0, 1].item() | |
y_min = verts.min(dim=1)[0][0, 1].item() | |
self.scale *= 0.95 / ((y_max - y_min) * 0.5 + 1e-10) | |
self.mesh_y_center = (y_max + y_min) * 0.5 | |
self.cam_pos = [(0, self.mesh_y_center, self.dis), | |
(self.dis, self.mesh_y_center, 0), | |
(0, self.mesh_y_center, -self.dis), | |
(-self.dis, self.mesh_y_center, 0)] | |
def load_mesh(self, verts, faces, verts_rgb=None, normalize=False, use_normal=False): | |
"""load mesh into the pytorch3d renderer | |
Args: | |
verts ([N,3]): verts | |
faces ([N,3]): faces | |
verts_rgb ([N,3]): rgb | |
normalize: bool | |
""" | |
if not torch.is_tensor(verts): | |
verts = torch.tensor(verts) | |
if not torch.is_tensor(faces): | |
faces = torch.tensor(faces) | |
if verts.ndimension() == 2: | |
verts = verts.unsqueeze(0).float() | |
if faces.ndimension() == 2: | |
faces = faces.unsqueeze(0).long() | |
verts = verts.to(self.device) | |
faces = faces.to(self.device) | |
self.set_camera(verts, normalize) | |
self.mesh = Meshes(verts, faces).to(self.device) | |
if verts_rgb is not None: | |
if not torch.is_tensor(verts_rgb): | |
verts_rgb = torch.as_tensor(verts_rgb) | |
if verts_rgb.ndimension() == 2: | |
verts_rgb = verts_rgb.unsqueeze(0).float() | |
verts_rgb = verts_rgb.to(self.device) | |
elif use_normal: | |
verts_rgb = self.mesh.verts_normals_padded() | |
verts_rgb = (verts_rgb + 1.0) * 0.5 | |
else: | |
verts_rgb = self.mesh.verts_normals_padded()[..., 2:3].expand(-1, -1, 3) | |
verts_rgb = (verts_rgb + 1.0) * 0.5 | |
textures = TexturesVertex(verts_features=verts_rgb) | |
self.mesh.textures = textures | |
return self.mesh | |
def load_pcd(self, verts, verts_rgb, normalize=False): | |
"""load pointcloud into the pytorch3d renderer | |
Args: | |
verts ([B, N,3]): verts | |
verts_rgb ([B, N,3]): verts colors | |
normalize bool: render point cloud in center | |
""" | |
assert verts.shape == verts_rgb.shape and len(verts.shape) == 3 | |
# data format convert | |
if not torch.is_tensor(verts): | |
verts = torch.as_tensor(verts) | |
if not torch.is_tensor(verts_rgb): | |
verts_rgb = torch.as_tensor(verts_rgb) | |
verts = verts.float().to(self.device) | |
verts_rgb = verts_rgb.float().to(self.device) | |
# camera setting | |
self.set_camera(verts, normalize) | |
pcd = Pointclouds(points=verts, features=verts_rgb).to(self.device) | |
return pcd | |
def get_image(self, cam_ids=[0, 2], type='clean_mesh', bg='gray'): | |
images = [] | |
for cam_id in range(len(self.cam_pos)): | |
if cam_id in cam_ids: | |
self.init_renderer(self.get_camera(cam_id), type, bg) | |
rendered_img = self.renderer(self.mesh)[0, :, :, :3] | |
if cam_id == 2 and len(cam_ids) == 2: | |
rendered_img = torch.flip(rendered_img, dims=[1]) | |
images.append(rendered_img) | |
images = torch.cat(images, 1) | |
return images.detach().cpu().numpy() | |
def get_clean_image(self, cam_ids=[0, 2], type='clean_mesh', bg='gray'): | |
images = [] | |
for cam_id in range(len(self.cam_pos)): | |
if cam_id in cam_ids: | |
self.init_renderer(self.get_camera(cam_id), type, bg) | |
rendered_img = self.renderer(self.mesh)[0:1, :, :, :3] | |
if cam_id == 2 and len(cam_ids) == 2: | |
rendered_img = torch.flip(rendered_img, dims=[2]) | |
images.append(rendered_img) | |
return images | |
def get_silhouette_image(self, cam_ids=[0, 2]): | |
images = [] | |
for cam_id in range(len(self.cam_pos)): | |
if cam_id in cam_ids: | |
self.init_renderer(self.get_camera(cam_id), 'silhouette') | |
rendered_img = self.renderer(self.mesh)[0:1, :, :, 3] | |
if cam_id == 2 and len(cam_ids) == 2: | |
rendered_img = torch.flip(rendered_img, dims=[2]) | |
images.append(rendered_img) | |
return images | |
def get_image_pcd(self, pcd, cam_ids=[0, 1, 2, 3]): | |
images = torch.zeros((self.size, self.size * len(cam_ids), 3)).to(self.device) | |
for i, cam_id in enumerate(cam_ids): | |
self.init_renderer(self.get_camera(cam_id), 'pointcloud') | |
images[:, self.size * i:self.size * (i + 1), :] = self.renderer(pcd)[0, :, :, :3] | |
return images.cpu().numpy() | |
def get_rendered_video(self, save_path, num_angle=100, s=0): | |
self.cam_pos = [] | |
interval = 360. / num_angle | |
for i in range(num_angle): | |
# for angle in range(90, 90+360, ): | |
angle = (s + i * interval) % 360 | |
self.cam_pos.append( | |
(self.dis * math.cos(np.pi / 180 * angle), self.mesh_y_center, | |
self.dis * math.sin(np.pi / 180 * angle))) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video = cv2.VideoWriter(save_path, fourcc, 30, (self.size, self.size)) | |
for cam_id in range(len(self.cam_pos)): | |
self.init_renderer(self.get_camera(cam_id), 'clean_mesh', 'gray') | |
rendered_img = (self.renderer(self.mesh)[0, :, :, :3] * 255.0).detach().cpu().numpy().astype(np.uint8) | |
video.write(rendered_img) | |
video.release() | |