Spaces:
Runtime error
Runtime error
import os | |
import random | |
import json | |
import pickle as pkl | |
import cv2 | |
import numpy as np | |
import imageio | |
import torch | |
from packaging import version as pver | |
from yacs.config import CfgNode as CN | |
def load_config(path, default_path=None): | |
cfg = CN(new_allowed=True) | |
if default_path is not None: | |
cfg.merge_from_file(default_path) | |
cfg.merge_from_file(path) | |
return cfg | |
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
return torch.sum(x * y, -1, keepdim=True) | |
def custom_meshgrid(*args): | |
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid | |
if pver.parse(torch.__version__) < pver.parse('1.10'): | |
return torch.meshgrid(*args) | |
else: | |
return torch.meshgrid(*args, indexing='ij') | |
def plot_grid_images(images, row, col, save_path=None): | |
""" | |
Args: | |
images: np.array [B, H, W, 3] | |
row: | |
col: | |
save_path: | |
Returns: | |
""" | |
assert row * col == images.shape[0] | |
images = np.vstack([np.hstack(images[r * col:(r + 1) * col]) for r in range(row)]) | |
if save_path: | |
cv2.imwrite(save_path, images * 255) | |
return images | |
def safe_normalize(x, eps=1e-20): | |
return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) | |
def seed_everything(seed): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
# torch.backends.cudnn.deterministic = True | |
# torch.backends.cudnn.benchmark = True | |
def torch_vis_2d(x, renormalize=False): | |
# x: [3, H, W], [H, W, 3] or [1, H, W] or [H, W] | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
if isinstance(x, torch.Tensor): | |
if len(x.shape) == 3 and x.shape[0] == 3: | |
x = x.permute(1, 2, 0).squeeze() | |
x = x.detach().cpu().numpy() | |
print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}') | |
x = x.astype(np.float32) | |
# renormalize | |
if renormalize: | |
x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8) | |
plt.imshow(x) | |
plt.show() | |
def get_rays(poses, intrinsics, H, W, N=-1, error_map=None): | |
''' get rays | |
Args: | |
poses: [B, 4, 4], cam2world | |
intrinsics: [4] | |
H, W, N: int | |
error_map: [B, 128 * 128], sample probability based on training error | |
Returns: | |
rays_o, rays_d: [B, N, 3] | |
inds: [B, N] | |
''' | |
device = poses.device | |
B = poses.shape[0] | |
fx, fy, cx, cy = intrinsics | |
i, j = custom_meshgrid(torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device)) | |
i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 | |
j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 | |
results = {} | |
if N > 0: | |
N = min(N, H * W) | |
if error_map is None: | |
inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate | |
inds = inds.expand([B, N]) | |
else: | |
# weighted sample on a low-reso grid | |
inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128) | |
# map to the original resolution with random perturb. | |
inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway. | |
sx, sy = H / 128, W / 128 | |
inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1) | |
inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1) | |
inds = inds_x * W + inds_y | |
results['inds_coarse'] = inds_coarse # need this when updating error_map | |
i = torch.gather(i, -1, inds) | |
j = torch.gather(j, -1, inds) | |
results['inds'] = inds | |
else: | |
inds = torch.arange(H * W, device=device).expand([B, H * W]) | |
zs = - torch.ones_like(i) | |
xs = - (i - cx) / fx * zs | |
ys = (j - cy) / fy * zs | |
directions = torch.stack((xs, ys, zs), dim=-1) | |
# directions = safe_normalize(directions) | |
rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) | |
rays_o = poses[..., :3, 3] # [B, 3] | |
rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] | |
results['rays_o'] = rays_o | |
results['rays_d'] = rays_d | |
return rays_o, rays_d | |
def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'): | |
assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[ | |
1]), "Trying to magnify image in one dimension and minify in the other" | |
y = x.permute(0, 3, 1, 2) # NHWC -> NCHW | |
if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger | |
y = torch.nn.functional.interpolate(y, size, mode=min) | |
else: # Magnification | |
if mag == 'bilinear' or mag == 'bicubic': | |
y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) | |
else: | |
y = torch.nn.functional.interpolate(y, size, mode=mag) | |
return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC | |
def scale_img_hwc(x, size, mag='bilinear', min='bilinear'): | |
return scale_img_nhwc(x[None, ...], size, mag, min)[0] | |
def scale_img_nhw(x, size, mag='bilinear', min='bilinear'): | |
return scale_img_nhwc(x[..., None], size, mag, min)[..., 0] | |
def scale_img_hw(x, size, mag='bilinear', min='bilinear'): | |
return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0] | |
def trunc_rev_sigmoid(x, eps=1e-6): | |
x = x.clamp(eps, 1 - eps) | |
return torch.log(x / (1 - x)) | |
def save_image(fn, x: np.ndarray): | |
try: | |
if os.path.splitext(fn)[1] == ".png": | |
imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), | |
compress_level=3) # Low compression for faster saving | |
else: | |
imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) | |
except: | |
print("WARNING: FAILED to save image %s" % fn) | |
# Reworked so this matches gluPerspective / glm::perspective, using fovy | |
def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): | |
y = np.tan(fovy / 2) | |
return torch.tensor([[1 / (y * aspect), 0, 0, 0], | |
[0, 1 / -y, 0, 0], | |
[0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)], | |
[0, 0, -1, 0]], dtype=torch.float32, device=device) | |
def translate(x, y, z, device=None): | |
return torch.tensor([[1, 0, 0, x], | |
[0, 1, 0, y], | |
[0, 0, 1, z], | |
[0, 0, 0, 1]], dtype=torch.float32, device=device) | |
def rotate_x(a, device=None): | |
s, c = np.sin(a), np.cos(a) | |
return torch.tensor([[1, 0, 0, 0], | |
[0, c, s, 0], | |
[0, -s, c, 0], | |
[0, 0, 0, 1]], dtype=torch.float32, device=device) | |
def rotate_y(a, device=None): | |
s, c = np.sin(a), np.cos(a) | |
return torch.tensor([[c, 0, s, 0], | |
[0, 1, 0, 0], | |
[-s, 0, c, 0], | |
[0, 0, 0, 1]], dtype=torch.float32, device=device) | |
def random_rotation_translation(t, device=None): | |
m = np.random.normal(size=[3, 3]) | |
m[1] = np.cross(m[0], m[2]) | |
m[2] = np.cross(m[0], m[1]) | |
m = m / np.linalg.norm(m, axis=1, keepdims=True) | |
m = np.pad(m, [[0, 1], [0, 1]], mode='constant') | |
m[3, 3] = 1.0 | |
m[:3, 3] = np.random.uniform(-t, t, size=[3]) | |
return torch.tensor(m, dtype=torch.float32, device=device) | |
def make_rotate(rx, ry, rz): | |
sinX = np.sin(rx) | |
sinY = np.sin(ry) | |
sinZ = np.sin(rz) | |
cosX = np.cos(rx) | |
cosY = np.cos(ry) | |
cosZ = np.cos(rz) | |
Rx = np.zeros((3, 3)) | |
Rx[0, 0] = 1.0 | |
Rx[1, 1] = cosX | |
Rx[1, 2] = -sinX | |
Rx[2, 1] = sinX | |
Rx[2, 2] = cosX | |
Ry = np.zeros((3, 3)) | |
Ry[0, 0] = cosY | |
Ry[0, 2] = sinY | |
Ry[1, 1] = 1.0 | |
Ry[2, 0] = -sinY | |
Ry[2, 2] = cosY | |
Rz = np.zeros((3, 3)) | |
Rz[0, 0] = cosZ | |
Rz[0, 1] = -sinZ | |
Rz[1, 0] = sinZ | |
Rz[1, 1] = cosZ | |
Rz[2, 2] = 1.0 | |
R = np.matmul(np.matmul(Rz, Ry), Rx) | |
return R | |
class SMPLXSeg: | |
def __init__(self, base_dir): | |
smplx_dir = os.path.join(base_dir, "smplx") | |
smplx_segs = json.load(open(f"{smplx_dir}/smplx_vert_segementation.json")) | |
flame_segs = pkl.load(open(f"{smplx_dir}/FLAME_masks.pkl", "rb"), encoding='latin1') | |
smplx_face = np.load(f"{smplx_dir}/smplx_faces.npy") | |
smplx_flame_vid = np.load(f"{smplx_dir}/FLAME_SMPLX_vertex_ids.npy", allow_pickle=True) | |
self.eyeball_ids = smplx_segs["leftEye"] + smplx_segs["rightEye"] | |
self.hands_ids = smplx_segs["leftHand"] + smplx_segs["rightHand"] + \ | |
smplx_segs["leftHandIndex1"] + smplx_segs["rightHandIndex1"] | |
self.neck_ids = smplx_segs["neck"] | |
self.head_ids = smplx_segs["head"] | |
self.front_face_ids = list(smplx_flame_vid[flame_segs["face"]]) | |
self.ears_ids = list(smplx_flame_vid[flame_segs["left_ear"]]) + list(smplx_flame_vid[flame_segs["right_ear"]]) | |
self.forehead_ids = list(smplx_flame_vid[flame_segs["forehead"]]) | |
self.lips_ids = list(smplx_flame_vid[flame_segs["lips"]]) | |
self.nose_ids = list(smplx_flame_vid[flame_segs["nose"]]) | |
self.eyes_ids = list(smplx_flame_vid[flame_segs["right_eye_region"]]) + list( | |
smplx_flame_vid[flame_segs["left_eye_region"]]) | |
# re-mesh mask | |
remesh_ids = list(set(self.front_face_ids) - set(self.forehead_ids)) + self.ears_ids + self.eyeball_ids + self.hands_ids | |
remesh_mask = ~np.isin(np.arange(10475), remesh_ids) | |
self.remesh_mask = remesh_mask[smplx_face].all(axis=1) | |
def create_checkerboard(h, w, c, grid_size): | |
num_grid_row = h // grid_size | |
num_grid_col = w // grid_size | |
grid_ones = np.ones((grid_size, grid_size, c)) | |
grid_zeros = np.zeros((grid_size, grid_size, c)) | |
checkerboard = np.vstack([ | |
np.hstack([grid_ones if (c + r) % 2 == 1 else grid_zeros for c in range(num_grid_col)]) | |
for r in range(num_grid_row) | |
]) | |
# pad | |
cx, cy, _ = checkerboard.shape | |
out = np.ones((h, w, c)) | |
dx = (h - cx) // 2 | |
dy = (w - cy) // 2 | |
out[dx:dx + cx, dy:dy + cy] = checkerboard | |
return out | |
if __name__ == '__main__': | |
out = create_checkerboard(512, 512, 3, 64) | |
import cv2 | |
cv2.imwrite("ck.png", out * 255) | |