SemanticBoost / TADA /lib /common /pytorch3d_renderer.py
kleinhe
init
c3d0293
from pytorch3d.renderer import (
BlendParams, blending, look_at_view_transform, FoVOrthographicCameras,
PointLights, RasterizationSettings, PointsRasterizationSettings,
PointsRenderer, AlphaCompositor, PointsRasterizer, MeshRenderer,
MeshRasterizer, SoftPhongShader, SoftSilhouetteShader, TexturesVertex)
from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds
import torch
import numpy as np
import math
import cv2
class cleanShader(torch.nn.Module):
def __init__(self, device="cpu", cameras=None, blend_params=None):
super().__init__()
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams(
)
def forward(self, fragments, meshes, **kwargs):
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of TexturedSoftPhongShader"
raise ValueError(msg)
# get renderer output
blend_params = kwargs.get("blend_params", self.blend_params)
texels = meshes.sample_textures(fragments)
images = blending.softmax_rgb_blend(texels,
fragments,
blend_params,
znear=-256,
zfar=256)
return images
class Render:
def __init__(self, size=512, device=torch.device("cuda:0")):
self.device = device
self.mesh_y_center = 100.0
self.dis = 100.0
self.scale = 1.0
self.size = size
self.cam_pos = [(0, 100, 100)]
self.mesh = None
self.pcd = None
self.renderer = None
self.meshRas = None
def get_camera(self, cam_id):
# at
R, T = look_at_view_transform(eye=[self.cam_pos[cam_id]],
at=((0, self.mesh_y_center, 0),),
up=((0, 1, 0),))
camera = FoVOrthographicCameras(device=self.device,
R=R,
T=T,
znear=100.0,
zfar=-100.0,
max_y=100.0,
min_y=-100.0,
max_x=100.0,
min_x=-100.0,
scale_xyz=(self.scale * np.ones(3),))
return camera
def init_renderer(self, camera, type='clean_mesh', bg='gray'):
if 'mesh' in type:
# rasterizer
self.raster_settings_mesh = RasterizationSettings(
image_size=self.size,
blur_radius=np.log(1.0 / 1e-4) * 1e-7,
faces_per_pixel=30,
)
self.meshRas = MeshRasterizer(cameras=camera,
raster_settings=self.raster_settings_mesh)
if bg == 'black':
blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0))
elif bg == 'white':
blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0))
elif bg == 'gray':
blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5))
if type == 'ori_mesh':
lights = PointLights(device=self.device,
ambient_color=((0.8, 0.8, 0.8),),
diffuse_color=((0.2, 0.2, 0.2),),
specular_color=((0.0, 0.0, 0.0),),
location=[[0.0, 200.0, 200.0]])
self.renderer = MeshRenderer(
rasterizer=self.meshRas,
shader=SoftPhongShader(
device=self.device,
cameras=camera,
lights=lights,
blend_params=blendparam))
if type == 'silhouette':
self.raster_settings_silhouette = RasterizationSettings(
image_size=self.size,
blur_radius=np.log(1. / 1e-4 - 1.) * 5e-5,
faces_per_pixel=50,
cull_backfaces=True,
)
self.silhouetteRas = MeshRasterizer(
cameras=camera, raster_settings=self.raster_settings_silhouette)
self.renderer = MeshRenderer(rasterizer=self.silhouetteRas,
shader=SoftSilhouetteShader())
if type == 'pointcloud':
self.raster_settings_pcd = PointsRasterizationSettings(
image_size=self.size,
radius=0.006,
points_per_pixel=10)
self.pcdRas = PointsRasterizer(cameras=camera,
raster_settings=self.raster_settings_pcd)
self.renderer = PointsRenderer(
rasterizer=self.pcdRas,
compositor=AlphaCompositor(background_color=(0, 0, 0)))
if type == 'clean_mesh':
self.renderer = MeshRenderer(
rasterizer=self.meshRas,
shader=cleanShader(
device=self.device,
cameras=camera,
blend_params=blendparam))
def set_camera(self, verts, normalize=False):
self.scale = 100
self.mesh_y_center = 0
if normalize:
y_max = verts.max(dim=1)[0][0, 1].item()
y_min = verts.min(dim=1)[0][0, 1].item()
self.scale *= 0.95 / ((y_max - y_min) * 0.5 + 1e-10)
self.mesh_y_center = (y_max + y_min) * 0.5
self.cam_pos = [(0, self.mesh_y_center, self.dis),
(self.dis, self.mesh_y_center, 0),
(0, self.mesh_y_center, -self.dis),
(-self.dis, self.mesh_y_center, 0)]
def load_mesh(self, verts, faces, verts_rgb=None, normalize=False, use_normal=False):
"""load mesh into the pytorch3d renderer
Args:
verts ([N,3]): verts
faces ([N,3]): faces
verts_rgb ([N,3]): rgb
normalize: bool
"""
if not torch.is_tensor(verts):
verts = torch.tensor(verts)
if not torch.is_tensor(faces):
faces = torch.tensor(faces)
if verts.ndimension() == 2:
verts = verts.unsqueeze(0).float()
if faces.ndimension() == 2:
faces = faces.unsqueeze(0).long()
verts = verts.to(self.device)
faces = faces.to(self.device)
self.set_camera(verts, normalize)
self.mesh = Meshes(verts, faces).to(self.device)
if verts_rgb is not None:
if not torch.is_tensor(verts_rgb):
verts_rgb = torch.as_tensor(verts_rgb)
if verts_rgb.ndimension() == 2:
verts_rgb = verts_rgb.unsqueeze(0).float()
verts_rgb = verts_rgb.to(self.device)
elif use_normal:
verts_rgb = self.mesh.verts_normals_padded()
verts_rgb = (verts_rgb + 1.0) * 0.5
else:
verts_rgb = self.mesh.verts_normals_padded()[..., 2:3].expand(-1, -1, 3)
verts_rgb = (verts_rgb + 1.0) * 0.5
textures = TexturesVertex(verts_features=verts_rgb)
self.mesh.textures = textures
return self.mesh
def load_pcd(self, verts, verts_rgb, normalize=False):
"""load pointcloud into the pytorch3d renderer
Args:
verts ([B, N,3]): verts
verts_rgb ([B, N,3]): verts colors
normalize bool: render point cloud in center
"""
assert verts.shape == verts_rgb.shape and len(verts.shape) == 3
# data format convert
if not torch.is_tensor(verts):
verts = torch.as_tensor(verts)
if not torch.is_tensor(verts_rgb):
verts_rgb = torch.as_tensor(verts_rgb)
verts = verts.float().to(self.device)
verts_rgb = verts_rgb.float().to(self.device)
# camera setting
self.set_camera(verts, normalize)
pcd = Pointclouds(points=verts, features=verts_rgb).to(self.device)
return pcd
def get_image(self, cam_ids=[0, 2], type='clean_mesh', bg='gray'):
images = []
for cam_id in range(len(self.cam_pos)):
if cam_id in cam_ids:
self.init_renderer(self.get_camera(cam_id), type, bg)
rendered_img = self.renderer(self.mesh)[0, :, :, :3]
if cam_id == 2 and len(cam_ids) == 2:
rendered_img = torch.flip(rendered_img, dims=[1])
images.append(rendered_img)
images = torch.cat(images, 1)
return images.detach().cpu().numpy()
def get_clean_image(self, cam_ids=[0, 2], type='clean_mesh', bg='gray'):
images = []
for cam_id in range(len(self.cam_pos)):
if cam_id in cam_ids:
self.init_renderer(self.get_camera(cam_id), type, bg)
rendered_img = self.renderer(self.mesh)[0:1, :, :, :3]
if cam_id == 2 and len(cam_ids) == 2:
rendered_img = torch.flip(rendered_img, dims=[2])
images.append(rendered_img)
return images
def get_silhouette_image(self, cam_ids=[0, 2]):
images = []
for cam_id in range(len(self.cam_pos)):
if cam_id in cam_ids:
self.init_renderer(self.get_camera(cam_id), 'silhouette')
rendered_img = self.renderer(self.mesh)[0:1, :, :, 3]
if cam_id == 2 and len(cam_ids) == 2:
rendered_img = torch.flip(rendered_img, dims=[2])
images.append(rendered_img)
return images
def get_image_pcd(self, pcd, cam_ids=[0, 1, 2, 3]):
images = torch.zeros((self.size, self.size * len(cam_ids), 3)).to(self.device)
for i, cam_id in enumerate(cam_ids):
self.init_renderer(self.get_camera(cam_id), 'pointcloud')
images[:, self.size * i:self.size * (i + 1), :] = self.renderer(pcd)[0, :, :, :3]
return images.cpu().numpy()
def get_rendered_video(self, save_path, num_angle=100, s=0):
self.cam_pos = []
interval = 360. / num_angle
for i in range(num_angle):
# for angle in range(90, 90+360, ):
angle = (s + i * interval) % 360
self.cam_pos.append(
(self.dis * math.cos(np.pi / 180 * angle), self.mesh_y_center,
self.dis * math.sin(np.pi / 180 * angle)))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(save_path, fourcc, 30, (self.size, self.size))
for cam_id in range(len(self.cam_pos)):
self.init_renderer(self.get_camera(cam_id), 'clean_mesh', 'gray')
rendered_img = (self.renderer(self.mesh)[0, :, :, :3] * 255.0).detach().cpu().numpy().astype(np.uint8)
video.write(rendered_img)
video.release()