Image-to-Video
zzwustc's picture
Upload folder using huggingface_hub
ef296aa verified
import os
import argparse
from PIL import Image
from glob import glob
import numpy as np
import json
import torch
import torchvision
from torch.nn import functional as F
from matplotlib import colormaps
import math
import scipy
def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True):
H, W = height, width
S = shape if shape else []
if align_corners:
x = torch.linspace(0, 1, W, device=device)
y = torch.linspace(0, 1, H, device=device)
if not normalize:
x = x * (W - 1)
y = y * (H - 1)
else:
x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device)
y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device)
if not normalize:
x = x * W
y = y * H
x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W]
x = x.view(*x_view).expand(*exp)
y = y.view(*y_view).expand(*exp)
grid = torch.stack([x, y], dim=-1)
if dtype == "numpy":
grid = grid.numpy()
return grid
def translation(frame, dx, dy, pad_value):
C, H, W = frame.shape
grid = get_grid(H, W, device=frame.device)
grid[..., 0] = grid[..., 0] - (dx / (W - 1))
grid[..., 1] = grid[..., 1] - (dy / (H - 1))
frame = frame - pad_value
frame = torch.nn.functional.grid_sample(frame[None], grid[None] * 2 - 1, mode='bilinear', align_corners=True)[0]
frame = frame + pad_value
return frame
def project(pos, t, time_steps, heigh, width):
T, H, W = time_steps, heigh, width
pos = torch.stack([pos[..., 0] / (W - 1), pos[..., 1] / (H - 1)], dim=-1)
pos = pos - 0.5
pos = pos * 0.25
t = 1 - torch.ones_like(pos[..., :1]) * t / (T - 1)
pos = torch.cat([pos, t], dim=-1)
M = torch.tensor([
[0.8, 0, 0.5],
[-0.2, 1.0, 0.1],
[0.0, 0.0, 0.0]
])
pos = pos @ M.t().to(pos.device)
pos = pos[..., :2]
pos[..., 0] += 0.25
pos[..., 1] += 0.45
pos[..., 0] *= (W - 1)
pos[..., 1] *= (H - 1)
return pos
def draw(pos, vis, col, height, width, radius=1):
H, W = height, width
frame = torch.zeros(H * W, 4, device=pos.device)
pos = pos[vis.bool()]
col = col[vis.bool()]
if radius > 1:
pos, col = get_radius_neighbors(pos, col, radius)
else:
pos, col = get_cardinal_neighbors(pos, col)
inbound = (pos[:, 0] >= 0) & (pos[:, 0] <= W - 1) & (pos[:, 1] >= 0) & (pos[:, 1] <= H - 1)
pos = pos[inbound]
col = col[inbound]
pos = pos.round().long()
idx = pos[:, 1] * W + pos[:, 0]
idx = idx.view(-1, 1).expand(-1, 4)
frame.scatter_add_(0, idx, col)
frame = frame.view(H, W, 4)
frame, alpha = frame[..., :3], frame[..., 3]
nonzero = alpha > 0
frame[nonzero] /= alpha[nonzero][..., None]
alpha = nonzero[..., None].float()
return frame, alpha
def get_cardinal_neighbors(pos, col, eps=0.01):
pos_nw = torch.stack([pos[:, 0].floor(), pos[:, 1].floor()], dim=-1)
pos_sw = torch.stack([pos[:, 0].floor(), pos[:, 1].floor() + 1], dim=-1)
pos_ne = torch.stack([pos[:, 0].floor() + 1, pos[:, 1].floor()], dim=-1)
pos_se = torch.stack([pos[:, 0].floor() + 1, pos[:, 1].floor() + 1], dim=-1)
w_n = pos[:, 1].floor() + 1 - pos[:, 1] + eps
w_s = pos[:, 1] - pos[:, 1].floor() + eps
w_w = pos[:, 0].floor() + 1 - pos[:, 0] + eps
w_e = pos[:, 0] - pos[:, 0].floor() + eps
w_nw = (w_n * w_w)[:, None]
w_sw = (w_s * w_w)[:, None]
w_ne = (w_n * w_e)[:, None]
w_se = (w_s * w_e)[:, None]
col_nw = torch.cat([w_nw * col, w_nw], dim=-1)
col_sw = torch.cat([w_sw * col, w_sw], dim=-1)
col_ne = torch.cat([w_ne * col, w_ne], dim=-1)
col_se = torch.cat([w_se * col, w_se], dim=-1)
pos = torch.cat([pos_nw, pos_sw, pos_ne, pos_se], dim=0)
col = torch.cat([col_nw, col_sw, col_ne, col_se], dim=0)
return pos, col
def get_radius_neighbors(pos, col, radius):
R = math.ceil(radius)
center = torch.stack([pos[:, 0].round(), pos[:, 1].round()], dim=-1)
nn = torch.arange(-R, R + 1)
nn = torch.stack([nn[None, :].expand(2 * R + 1, -1), nn[:, None].expand(-1, 2 * R + 1)], dim=-1)
nn = nn.view(-1, 2).cuda()
in_radius = nn[:, 0] ** 2 + nn[:, 1] ** 2 <= radius ** 2
nn = nn[in_radius]
w = 1 - nn.pow(2).sum(-1).sqrt() / radius + 0.01
w = w[None].expand(pos.size(0), -1).reshape(-1)
pos = (center.view(-1, 1, 2) + nn.view(1, -1, 2)).view(-1, 2)
col = col.view(-1, 1, 3).repeat(1, nn.size(0), 1)
col = col.view(-1, 3)
col = torch.cat([col * w[:, None], w[:, None]], dim=-1)
return pos, col
def get_rainbow_colors(size):
col_map = colormaps["jet"]
col_range = np.array(range(size)) / (size - 1)
col = torch.from_numpy(col_map(col_range)[..., :3]).float()
col = col.view(-1, 3)
return col
def spline_interpolation(x, length=10):
if length != 1:
T, N, C = x.shape
x = x.view(T, -1).cpu().numpy()
original_time = np.arange(T)
cs = scipy.interpolate.CubicSpline(original_time, x)
new_time = np.linspace(original_time[0], original_time[-1], T * length)
x = torch.from_numpy(cs(new_time)).view(-1, N, C).float().cuda()
return x
def create_folder(path, verbose=False, exist_ok=True, safe=True):
if os.path.exists(path) and not exist_ok:
if not safe:
raise OSError
return False
try:
os.makedirs(path)
except:
if not safe:
raise OSError
return False
if verbose:
print(f"Created folder: {path}")
return True
def write_video_to_file(video, path, channels):
create_folder(os.path.dirname(path))
if channels == "first":
video = video.permute(0, 2, 3, 1)
video = (video.cpu() * 255.).to(torch.uint8)
torchvision.io.write_video(path, video, 8, "h264", options={"pix_fmt": "yuv420p", "crf": "23"})
return video
def write_frame(frame, path, channels="first"):
create_folder(os.path.dirname(path))
frame = frame.cpu().numpy()
if channels == "first":
frame = np.transpose(frame, (1, 2, 0))
frame = np.clip(np.round(frame * 255), 0, 255).astype(np.uint8)
frame = Image.fromarray(frame)
frame.save(path)
def write_video_to_folder(video, path, channels, zero_padded, ext):
create_folder(path)
time_steps = video.shape[0]
for step in range(time_steps):
pad = "0" * (len(str(time_steps)) - len(str(step))) if zero_padded else ""
frame_path = os.path.join(path, f"{pad}{step}.{ext}")
write_frame(video[step], frame_path, channels)
def write_video(video, path, channels="first", zero_padded=True, ext="png", dtype="torch"):
if dtype == "numpy":
video = torch.from_numpy(video)
if path.endswith(".mp4"):
write_video_to_file(video, path, channels)
else:
write_video_to_folder(video, path, channels, zero_padded, ext)