File size: 7,050 Bytes
0a63786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
'''
Usage:

from misc_utils.flow_utils import RAFTFlow, load_image_as_tensor, warp_image, MyRandomPerspective, generate_sample
image = load_image_as_tensor('hamburger_pic.jpeg', image_size)
flow_estimator = RAFTFlow()
res = generate_sample(
    image, 
    flow_estimator, 
    distortion_scale=distortion_scale,
)
f1 = res['input'][None]
f2 = res['target'][None]
flow = res['flow'][None]
f1_warp = warp_image(f1, flow)
show_image(f1_warp[0])
show_image(f2[0])
'''
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
import numpy as np

def warp_image(image, flow, mode='bilinear'):
    """ Warp an image using optical flow.
    Args:
        image (torch.Tensor): Input image tensor with shape (N, C, H, W).
        flow (torch.Tensor): Optical flow tensor with shape (N, 2, H, W).
    Returns:
        warped_image (torch.Tensor): Warped image tensor with shape (N, C, H, W).
    """
    # check shape
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    if len(flow.shape) == 3:
        flow = flow.unsqueeze(0)
    if image.device != flow.device:
        flow = flow.to(image.device)
    assert image.shape[0] == flow.shape[0], f'Batch size of image and flow must be the same. Got {image.shape[0]} and {flow.shape[0]}.'
    assert image.shape[2:] == flow.shape[2:], f'Height and width of image and flow must be the same. Got {image.shape[2:]} and {flow.shape[2:]}.'
    # Generate a grid of sampling points
    grid = torch.tensor(
        np.array(np.meshgrid(range(image.shape[3]), range(image.shape[2]), indexing='xy')), 
        dtype=torch.float32, device=image.device
    )[None]
    grid = grid.permute(0, 2, 3, 1).repeat(image.shape[0], 1, 1, 1)  # (N, H, W, 2)
    grid += flow.permute(0, 2, 3, 1)  # add optical flow to grid

    # Normalize grid to [-1, 1]
    grid[:, :, :, 0] = 2 * (grid[:, :, :, 0] / (image.shape[3] - 1) - 0.5)
    grid[:, :, :, 1] = 2 * (grid[:, :, :, 1] / (image.shape[2] - 1) - 0.5)

    # Sample input image using the grid
    warped_image = F.grid_sample(image, grid, mode=mode, align_corners=True)

    return warped_image

def resize_flow(flow, size):
    """
    Resize optical flow tensor to a new size.

    Args:
        flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W).
        size (tuple[int, int]): Target size as a tuple (H, W).

    Returns:
        flow_resized (torch.Tensor): Resized optical flow tensor with shape (B, 2, H, W).
    """
    # Unpack the target size
    H, W = size

    # Compute the scaling factors
    h, w = flow.shape[2:]
    scale_x = W / w
    scale_y = H / h

    # Scale the optical flow by the resizing factors
    flow_scaled = flow.clone()
    flow_scaled[:, 0] *= scale_x
    flow_scaled[:, 1] *= scale_y

    # Resize the optical flow to the new size (H, W)
    flow_resized = F.interpolate(flow_scaled, size=(H, W), mode='bilinear', align_corners=False)

    return flow_resized

def check_consistency(flow1: torch.Tensor, flow2: torch.Tensor) -> torch.Tensor:
    """
    Check the consistency of two optical flows.
    flow1: (B, 2, H, W)
    flow2: (B, 2, H, W)
    if want the output to be forward flow, then flow1 is the forward flow and flow2 is the backward flow
    return: (H, W)
    """
    device = flow1.device
    height, width = flow1.shape[2:]

    kernel_x = torch.tensor([[0.5, 0, -0.5]]).unsqueeze(0).unsqueeze(0).to(device)
    kernel_y = torch.tensor([[0.5], [0], [-0.5]]).unsqueeze(0).unsqueeze(0).to(device)
    grad_x = torch.nn.functional.conv2d(flow1[:, :1], kernel_x, padding=(0, 1))
    grad_y = torch.nn.functional.conv2d(flow1[:, 1:], kernel_y, padding=(1, 0))

    motion_edge = (grad_x * grad_x + grad_y * grad_y).sum(dim=1).squeeze(0)

    ax, ay = torch.meshgrid(torch.arange(width, device=device), torch.arange(height, device=device), indexing='xy')
    bx, by = ax + flow1[:, 0], ay + flow1[:, 1]

    x1, y1 = torch.floor(bx).long(), torch.floor(by).long()
    x2, y2 = x1 + 1, y1 + 1
    x1 = torch.clamp(x1, 0, width - 1)
    x2 = torch.clamp(x2, 0, width - 1)
    y1 = torch.clamp(y1, 0, height - 1)
    y2 = torch.clamp(y2, 0, height - 1)

    alpha_x, alpha_y = bx - x1.float(), by - y1.float()

    a = (1.0 - alpha_x) * flow2[:, 0, y1, x1] + alpha_x * flow2[:, 0, y1, x2]
    b = (1.0 - alpha_x) * flow2[:, 0, y2, x1] + alpha_x * flow2[:, 0, y2, x2]
    u = (1.0 - alpha_y) * a + alpha_y * b

    a = (1.0 - alpha_x) * flow2[:, 1, y1, x1] + alpha_x * flow2[:, 1, y1, x2]
    b = (1.0 - alpha_x) * flow2[:, 1, y2, x1] + alpha_x * flow2[:, 1, y2, x2]
    v = (1.0 - alpha_y) * a + alpha_y * b

    cx, cy = bx + u, by + v
    u2, v2 = flow1[:, 0], flow1[:, 1]

    reliable = ((((cx - ax) ** 2 + (cy - ay) ** 2) < (0.01 * (u2 ** 2 + v2 ** 2 + u ** 2 + v ** 2) + 0.5)) & (motion_edge <= 0.01 * (u2 ** 2 + v2 ** 2) + 0.002)).float()

    return reliable # (B, 1, H, W)


class RAFTFlow(torch.nn.Module):
    '''
    # Instantiate the RAFTFlow class
    raft_flow = RAFTFlow(device='cuda')

    # Load a pair of image frames as PyTorch tensors
    img1 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32)
    img2 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32)

    # Compute optical flow between the two frames
    (optional) image_size = (256, 256) or None
    flow = raft_flow.compute_flow(img1, img2, image_size) # flow will be computed at the original image size if image_size is None
    # this flow can be used to warp the second image to the first image

    # Warp the second image using the flow
    warped_img = warp_image(img2, flow)
    '''
    def __init__(self, *args):
        """
        Args:
            device (str): Device to run the model on ("cpu" or "cuda").
        """
        super().__init__(*args)
        weights = Raft_Large_Weights.DEFAULT
        self.model = raft_large(weights=weights, progress=False)
        self.model_transform = weights.transforms()

    def forward(self, img1, img2, img_size=None):
        """
        Compute optical flow between two frames using RAFT model.

        Args:
            img1 (torch.Tensor): First frame tensor with shape (B, C, H, W).
            img2 (torch.Tensor): Second frame tensor with shape (B, C, H, W).
            img_size (tuple): Size of the input images to be processed.

        Returns:
            flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W).
        """
        original_size = img1.shape[2:]
        # Preprocess the input frames
        if img_size is not None:
            img1 = TF.resize(img1, size=img_size, antialias=False)
            img2 = TF.resize(img2, size=img_size, antialias=False)

        img1, img2 = self.model_transform(img1, img2)

        # Compute the optical flow using the RAFT model
        with torch.no_grad():
            list_of_flows = self.model(img1, img2)
        flow = list_of_flows[-1]

        if img_size is not None:
            flow = resize_flow(flow, original_size)

        return flow