RelightVid / misc_utils /flow_utils.py
aleafy's picture
Start fresh
0a63786
'''
Usage:
from misc_utils.flow_utils import RAFTFlow, load_image_as_tensor, warp_image, MyRandomPerspective, generate_sample
image = load_image_as_tensor('hamburger_pic.jpeg', image_size)
flow_estimator = RAFTFlow()
res = generate_sample(
image,
flow_estimator,
distortion_scale=distortion_scale,
)
f1 = res['input'][None]
f2 = res['target'][None]
flow = res['flow'][None]
f1_warp = warp_image(f1, flow)
show_image(f1_warp[0])
show_image(f2[0])
'''
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
import numpy as np
def warp_image(image, flow, mode='bilinear'):
""" Warp an image using optical flow.
Args:
image (torch.Tensor): Input image tensor with shape (N, C, H, W).
flow (torch.Tensor): Optical flow tensor with shape (N, 2, H, W).
Returns:
warped_image (torch.Tensor): Warped image tensor with shape (N, C, H, W).
"""
# check shape
if len(image.shape) == 3:
image = image.unsqueeze(0)
if len(flow.shape) == 3:
flow = flow.unsqueeze(0)
if image.device != flow.device:
flow = flow.to(image.device)
assert image.shape[0] == flow.shape[0], f'Batch size of image and flow must be the same. Got {image.shape[0]} and {flow.shape[0]}.'
assert image.shape[2:] == flow.shape[2:], f'Height and width of image and flow must be the same. Got {image.shape[2:]} and {flow.shape[2:]}.'
# Generate a grid of sampling points
grid = torch.tensor(
np.array(np.meshgrid(range(image.shape[3]), range(image.shape[2]), indexing='xy')),
dtype=torch.float32, device=image.device
)[None]
grid = grid.permute(0, 2, 3, 1).repeat(image.shape[0], 1, 1, 1) # (N, H, W, 2)
grid += flow.permute(0, 2, 3, 1) # add optical flow to grid
# Normalize grid to [-1, 1]
grid[:, :, :, 0] = 2 * (grid[:, :, :, 0] / (image.shape[3] - 1) - 0.5)
grid[:, :, :, 1] = 2 * (grid[:, :, :, 1] / (image.shape[2] - 1) - 0.5)
# Sample input image using the grid
warped_image = F.grid_sample(image, grid, mode=mode, align_corners=True)
return warped_image
def resize_flow(flow, size):
"""
Resize optical flow tensor to a new size.
Args:
flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W).
size (tuple[int, int]): Target size as a tuple (H, W).
Returns:
flow_resized (torch.Tensor): Resized optical flow tensor with shape (B, 2, H, W).
"""
# Unpack the target size
H, W = size
# Compute the scaling factors
h, w = flow.shape[2:]
scale_x = W / w
scale_y = H / h
# Scale the optical flow by the resizing factors
flow_scaled = flow.clone()
flow_scaled[:, 0] *= scale_x
flow_scaled[:, 1] *= scale_y
# Resize the optical flow to the new size (H, W)
flow_resized = F.interpolate(flow_scaled, size=(H, W), mode='bilinear', align_corners=False)
return flow_resized
def check_consistency(flow1: torch.Tensor, flow2: torch.Tensor) -> torch.Tensor:
"""
Check the consistency of two optical flows.
flow1: (B, 2, H, W)
flow2: (B, 2, H, W)
if want the output to be forward flow, then flow1 is the forward flow and flow2 is the backward flow
return: (H, W)
"""
device = flow1.device
height, width = flow1.shape[2:]
kernel_x = torch.tensor([[0.5, 0, -0.5]]).unsqueeze(0).unsqueeze(0).to(device)
kernel_y = torch.tensor([[0.5], [0], [-0.5]]).unsqueeze(0).unsqueeze(0).to(device)
grad_x = torch.nn.functional.conv2d(flow1[:, :1], kernel_x, padding=(0, 1))
grad_y = torch.nn.functional.conv2d(flow1[:, 1:], kernel_y, padding=(1, 0))
motion_edge = (grad_x * grad_x + grad_y * grad_y).sum(dim=1).squeeze(0)
ax, ay = torch.meshgrid(torch.arange(width, device=device), torch.arange(height, device=device), indexing='xy')
bx, by = ax + flow1[:, 0], ay + flow1[:, 1]
x1, y1 = torch.floor(bx).long(), torch.floor(by).long()
x2, y2 = x1 + 1, y1 + 1
x1 = torch.clamp(x1, 0, width - 1)
x2 = torch.clamp(x2, 0, width - 1)
y1 = torch.clamp(y1, 0, height - 1)
y2 = torch.clamp(y2, 0, height - 1)
alpha_x, alpha_y = bx - x1.float(), by - y1.float()
a = (1.0 - alpha_x) * flow2[:, 0, y1, x1] + alpha_x * flow2[:, 0, y1, x2]
b = (1.0 - alpha_x) * flow2[:, 0, y2, x1] + alpha_x * flow2[:, 0, y2, x2]
u = (1.0 - alpha_y) * a + alpha_y * b
a = (1.0 - alpha_x) * flow2[:, 1, y1, x1] + alpha_x * flow2[:, 1, y1, x2]
b = (1.0 - alpha_x) * flow2[:, 1, y2, x1] + alpha_x * flow2[:, 1, y2, x2]
v = (1.0 - alpha_y) * a + alpha_y * b
cx, cy = bx + u, by + v
u2, v2 = flow1[:, 0], flow1[:, 1]
reliable = ((((cx - ax) ** 2 + (cy - ay) ** 2) < (0.01 * (u2 ** 2 + v2 ** 2 + u ** 2 + v ** 2) + 0.5)) & (motion_edge <= 0.01 * (u2 ** 2 + v2 ** 2) + 0.002)).float()
return reliable # (B, 1, H, W)
class RAFTFlow(torch.nn.Module):
'''
# Instantiate the RAFTFlow class
raft_flow = RAFTFlow(device='cuda')
# Load a pair of image frames as PyTorch tensors
img1 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32)
img2 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32)
# Compute optical flow between the two frames
(optional) image_size = (256, 256) or None
flow = raft_flow.compute_flow(img1, img2, image_size) # flow will be computed at the original image size if image_size is None
# this flow can be used to warp the second image to the first image
# Warp the second image using the flow
warped_img = warp_image(img2, flow)
'''
def __init__(self, *args):
"""
Args:
device (str): Device to run the model on ("cpu" or "cuda").
"""
super().__init__(*args)
weights = Raft_Large_Weights.DEFAULT
self.model = raft_large(weights=weights, progress=False)
self.model_transform = weights.transforms()
def forward(self, img1, img2, img_size=None):
"""
Compute optical flow between two frames using RAFT model.
Args:
img1 (torch.Tensor): First frame tensor with shape (B, C, H, W).
img2 (torch.Tensor): Second frame tensor with shape (B, C, H, W).
img_size (tuple): Size of the input images to be processed.
Returns:
flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W).
"""
original_size = img1.shape[2:]
# Preprocess the input frames
if img_size is not None:
img1 = TF.resize(img1, size=img_size, antialias=False)
img2 = TF.resize(img2, size=img_size, antialias=False)
img1, img2 = self.model_transform(img1, img2)
# Compute the optical flow using the RAFT model
with torch.no_grad():
list_of_flows = self.model(img1, img2)
flow = list_of_flows[-1]
if img_size is not None:
flow = resize_flow(flow, original_size)
return flow