|
import os |
|
import argparse |
|
from PIL import Image |
|
from glob import glob |
|
import numpy as np |
|
import json |
|
import torch |
|
import torchvision |
|
from torch.nn import functional as F |
|
from matplotlib import colormaps |
|
import math |
|
import scipy |
|
|
|
|
|
def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True): |
|
H, W = height, width |
|
S = shape if shape else [] |
|
if align_corners: |
|
x = torch.linspace(0, 1, W, device=device) |
|
y = torch.linspace(0, 1, H, device=device) |
|
if not normalize: |
|
x = x * (W - 1) |
|
y = y * (H - 1) |
|
else: |
|
x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device) |
|
y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device) |
|
if not normalize: |
|
x = x * W |
|
y = y * H |
|
x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W] |
|
x = x.view(*x_view).expand(*exp) |
|
y = y.view(*y_view).expand(*exp) |
|
grid = torch.stack([x, y], dim=-1) |
|
if dtype == "numpy": |
|
grid = grid.numpy() |
|
return grid |
|
|
|
def translation(frame, dx, dy, pad_value): |
|
C, H, W = frame.shape |
|
grid = get_grid(H, W, device=frame.device) |
|
grid[..., 0] = grid[..., 0] - (dx / (W - 1)) |
|
grid[..., 1] = grid[..., 1] - (dy / (H - 1)) |
|
frame = frame - pad_value |
|
frame = torch.nn.functional.grid_sample(frame[None], grid[None] * 2 - 1, mode='bilinear', align_corners=True)[0] |
|
frame = frame + pad_value |
|
return frame |
|
|
|
|
|
def project(pos, t, time_steps, heigh, width): |
|
T, H, W = time_steps, heigh, width |
|
pos = torch.stack([pos[..., 0] / (W - 1), pos[..., 1] / (H - 1)], dim=-1) |
|
pos = pos - 0.5 |
|
pos = pos * 0.25 |
|
t = 1 - torch.ones_like(pos[..., :1]) * t / (T - 1) |
|
pos = torch.cat([pos, t], dim=-1) |
|
M = torch.tensor([ |
|
[0.8, 0, 0.5], |
|
[-0.2, 1.0, 0.1], |
|
[0.0, 0.0, 0.0] |
|
]) |
|
pos = pos @ M.t().to(pos.device) |
|
pos = pos[..., :2] |
|
pos[..., 0] += 0.25 |
|
pos[..., 1] += 0.45 |
|
pos[..., 0] *= (W - 1) |
|
pos[..., 1] *= (H - 1) |
|
return pos |
|
|
|
def draw(pos, vis, col, height, width, radius=1): |
|
H, W = height, width |
|
frame = torch.zeros(H * W, 4, device=pos.device) |
|
pos = pos[vis.bool()] |
|
col = col[vis.bool()] |
|
if radius > 1: |
|
pos, col = get_radius_neighbors(pos, col, radius) |
|
else: |
|
pos, col = get_cardinal_neighbors(pos, col) |
|
inbound = (pos[:, 0] >= 0) & (pos[:, 0] <= W - 1) & (pos[:, 1] >= 0) & (pos[:, 1] <= H - 1) |
|
pos = pos[inbound] |
|
col = col[inbound] |
|
pos = pos.round().long() |
|
idx = pos[:, 1] * W + pos[:, 0] |
|
idx = idx.view(-1, 1).expand(-1, 4) |
|
frame.scatter_add_(0, idx, col) |
|
frame = frame.view(H, W, 4) |
|
frame, alpha = frame[..., :3], frame[..., 3] |
|
nonzero = alpha > 0 |
|
frame[nonzero] /= alpha[nonzero][..., None] |
|
alpha = nonzero[..., None].float() |
|
return frame, alpha |
|
|
|
def get_cardinal_neighbors(pos, col, eps=0.01): |
|
pos_nw = torch.stack([pos[:, 0].floor(), pos[:, 1].floor()], dim=-1) |
|
pos_sw = torch.stack([pos[:, 0].floor(), pos[:, 1].floor() + 1], dim=-1) |
|
pos_ne = torch.stack([pos[:, 0].floor() + 1, pos[:, 1].floor()], dim=-1) |
|
pos_se = torch.stack([pos[:, 0].floor() + 1, pos[:, 1].floor() + 1], dim=-1) |
|
w_n = pos[:, 1].floor() + 1 - pos[:, 1] + eps |
|
w_s = pos[:, 1] - pos[:, 1].floor() + eps |
|
w_w = pos[:, 0].floor() + 1 - pos[:, 0] + eps |
|
w_e = pos[:, 0] - pos[:, 0].floor() + eps |
|
w_nw = (w_n * w_w)[:, None] |
|
w_sw = (w_s * w_w)[:, None] |
|
w_ne = (w_n * w_e)[:, None] |
|
w_se = (w_s * w_e)[:, None] |
|
col_nw = torch.cat([w_nw * col, w_nw], dim=-1) |
|
col_sw = torch.cat([w_sw * col, w_sw], dim=-1) |
|
col_ne = torch.cat([w_ne * col, w_ne], dim=-1) |
|
col_se = torch.cat([w_se * col, w_se], dim=-1) |
|
pos = torch.cat([pos_nw, pos_sw, pos_ne, pos_se], dim=0) |
|
col = torch.cat([col_nw, col_sw, col_ne, col_se], dim=0) |
|
return pos, col |
|
|
|
|
|
def get_radius_neighbors(pos, col, radius): |
|
R = math.ceil(radius) |
|
center = torch.stack([pos[:, 0].round(), pos[:, 1].round()], dim=-1) |
|
nn = torch.arange(-R, R + 1) |
|
nn = torch.stack([nn[None, :].expand(2 * R + 1, -1), nn[:, None].expand(-1, 2 * R + 1)], dim=-1) |
|
nn = nn.view(-1, 2).cuda() |
|
in_radius = nn[:, 0] ** 2 + nn[:, 1] ** 2 <= radius ** 2 |
|
nn = nn[in_radius] |
|
w = 1 - nn.pow(2).sum(-1).sqrt() / radius + 0.01 |
|
w = w[None].expand(pos.size(0), -1).reshape(-1) |
|
pos = (center.view(-1, 1, 2) + nn.view(1, -1, 2)).view(-1, 2) |
|
col = col.view(-1, 1, 3).repeat(1, nn.size(0), 1) |
|
col = col.view(-1, 3) |
|
col = torch.cat([col * w[:, None], w[:, None]], dim=-1) |
|
return pos, col |
|
|
|
|
|
def get_rainbow_colors(size): |
|
col_map = colormaps["jet"] |
|
col_range = np.array(range(size)) / (size - 1) |
|
col = torch.from_numpy(col_map(col_range)[..., :3]).float() |
|
col = col.view(-1, 3) |
|
return col |
|
|
|
|
|
def spline_interpolation(x, length=10): |
|
if length != 1: |
|
T, N, C = x.shape |
|
x = x.view(T, -1).cpu().numpy() |
|
original_time = np.arange(T) |
|
cs = scipy.interpolate.CubicSpline(original_time, x) |
|
new_time = np.linspace(original_time[0], original_time[-1], T * length) |
|
x = torch.from_numpy(cs(new_time)).view(-1, N, C).float().cuda() |
|
return x |
|
|
|
def create_folder(path, verbose=False, exist_ok=True, safe=True): |
|
if os.path.exists(path) and not exist_ok: |
|
if not safe: |
|
raise OSError |
|
return False |
|
try: |
|
os.makedirs(path) |
|
except: |
|
if not safe: |
|
raise OSError |
|
return False |
|
if verbose: |
|
print(f"Created folder: {path}") |
|
return True |
|
|
|
|
|
def write_video_to_file(video, path, channels): |
|
create_folder(os.path.dirname(path)) |
|
if channels == "first": |
|
video = video.permute(0, 2, 3, 1) |
|
video = (video.cpu() * 255.).to(torch.uint8) |
|
torchvision.io.write_video(path, video, 8, "h264", options={"pix_fmt": "yuv420p", "crf": "23"}) |
|
return video |
|
|
|
|
|
def write_frame(frame, path, channels="first"): |
|
create_folder(os.path.dirname(path)) |
|
frame = frame.cpu().numpy() |
|
if channels == "first": |
|
frame = np.transpose(frame, (1, 2, 0)) |
|
frame = np.clip(np.round(frame * 255), 0, 255).astype(np.uint8) |
|
frame = Image.fromarray(frame) |
|
frame.save(path) |
|
|
|
|
|
def write_video_to_folder(video, path, channels, zero_padded, ext): |
|
create_folder(path) |
|
time_steps = video.shape[0] |
|
for step in range(time_steps): |
|
pad = "0" * (len(str(time_steps)) - len(str(step))) if zero_padded else "" |
|
frame_path = os.path.join(path, f"{pad}{step}.{ext}") |
|
write_frame(video[step], frame_path, channels) |
|
|
|
|
|
|
|
def write_video(video, path, channels="first", zero_padded=True, ext="png", dtype="torch"): |
|
if dtype == "numpy": |
|
video = torch.from_numpy(video) |
|
if path.endswith(".mp4"): |
|
write_video_to_file(video, path, channels) |
|
else: |
|
write_video_to_folder(video, path, channels, zero_padded, ext) |
|
|