RishitJavia commited on
Commit
ec24258
·
1 Parent(s): 1a1a15e
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/movenet/movenet_tf/1/variables/variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gradio as gr
4
+ from src import run_chain
5
+
6
+
7
+ def video_process(ref_video, test_video, crop_method):
8
+
9
+ run_chain.main(
10
+ ref_video,
11
+ test_video,
12
+ 'output_video.mp4',
13
+ crop_method=crop_method
14
+ )
15
+
16
+ return 'output_video.mp4'
17
+
18
+
19
+ demo = gr.Interface(video_process,
20
+ inputs = [gr.Video(label='Reference Video'), gr.Video(label='Test Video'), gr.Radio(["YOLO", "Tracker"], label="Crop Method")],
21
+ outputs = [gr.PlayableVideo()]
22
+ )
23
+
24
+ if __name__ == "__main__":
25
+ demo.launch()
26
+
models/movenet/movenet_tf/1/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea126ce61a8d11321a5523a5f0b07bb22184a848d193df7a5d0a42068f0ac15f
3
+ size 10954887
models/movenet/movenet_tf/1/variables/variables.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bfcfc4f2267cbc32327ad9433ba93ee37b58cbf561b5b45bf9ed37f98215295
3
+ size 25390986
models/movenet/movenet_tf/1/variables/variables.index ADDED
Binary file (6.09 kB). View file
 
models/yolo/yolov5x.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f27a794fa0308e2606f90565164571d6f0a0ba18a3ce2e5e5d323b71c157859
3
+ size 174114333
requirements.txt ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ aiofiles==23.2.1
3
+ altair
4
+ annotated-types==0.5.0
5
+ anyio==4.0.0
6
+ astunparse==1.6.3
7
+ attrs==23.1.0
8
+ boto3==1.28.39
9
+ botocore==1.31.39
10
+ cachetools==5.3.1
11
+ certifi==2022.12.7
12
+ chardet==4.0.0
13
+ charset-normalizer==3.2.0
14
+ click==8.1.7
15
+ cmake==3.27.2
16
+ contourpy==1.1.0
17
+ cycler==0.10.0
18
+ decorator==4.4.2
19
+ exceptiongroup==1.1.3
20
+ fastapi
21
+ ffmpy==0.3.1
22
+ filelock
23
+ fire==0.5.0
24
+ flatbuffers==23.5.26
25
+ fonttools==4.42.1
26
+ fsspec==2023.6.0
27
+ gast==0.4.0
28
+ gitdb==4.0.10
29
+ GitPython==3.1.33
30
+ google-auth==2.22.0
31
+ google-auth-oauthlib==1.0.0
32
+ google-pasta==0.2.0
33
+ gradio
34
+ gradio_client
35
+ grpcio==1.57.0
36
+ h11==0.14.0
37
+ h5py==3.9.0
38
+ httpcore==0.17.3
39
+ httpx==0.24.1
40
+ huggingface-hub
41
+ idna==2.10
42
+ imageio==2.31.2
43
+ imageio-ffmpeg==0.4.8
44
+ importlib-resources==6.0.1
45
+ Jinja2==3.1.2
46
+ jmespath==1.0.1
47
+ joblib==1.3.2
48
+ jsonschema==4.19.0
49
+ jsonschema-specifications==2023.7.1
50
+ keras==2.13.1
51
+ kiwisolver==1.4.5
52
+ libclang==16.0.6
53
+ lit==16.0.6
54
+ Markdown==3.4.4
55
+ MarkupSafe==2.1.3
56
+ matplotlib==3.7.2
57
+ moviepy==1.0.3
58
+ mpmath==1.3.0
59
+ networkx==3.1
60
+ numpy==1.24.3
61
+ nvidia-cublas-cu11==11.10.3.66
62
+ nvidia-cuda-cupti-cu11==11.7.101
63
+ nvidia-cuda-nvrtc-cu11==11.7.99
64
+ nvidia-cuda-runtime-cu11==11.7.99
65
+ nvidia-cudnn-cu11==8.5.0.96
66
+ nvidia-cufft-cu11==10.9.0.58
67
+ nvidia-curand-cu11==10.2.10.91
68
+ nvidia-cusolver-cu11==11.4.0.1
69
+ nvidia-cusparse-cu11==11.7.4.91
70
+ nvidia-nccl-cu11==2.14.3
71
+ nvidia-nvtx-cu11==11.7.91
72
+ oauthlib==3.2.2
73
+ opencv-python==4.8.0.76
74
+ opencv-python-headless==4.8.0.76
75
+ opt-einsum==3.3.0
76
+ orjson==3.9.5
77
+ packaging==23.1
78
+ pandas==2.1.0
79
+ Pillow==9.5.0
80
+ proglog==0.1.10
81
+ protobuf==4.24.2
82
+ psutil==5.9.5
83
+ py-cpuinfo==9.0.0
84
+ pyasn1==0.5.0
85
+ pyasn1-modules==0.3.0
86
+ pybboxes==0.1.6
87
+ pydantic
88
+ pydantic_core
89
+ pydub==0.25.1
90
+ pyparsing==2.4.7
91
+ python-dateutil==2.8.2
92
+ python-dotenv==1.0.0
93
+ python-multipart==0.0.6
94
+ pytz==2023.3
95
+ PyYAML==6.0.1
96
+ referencing==0.30.2
97
+ requests==2.31.0
98
+ requests-oauthlib==1.3.1
99
+ requests-toolbelt==1.0.0
100
+ roboflow==1.1.4
101
+ rpds-py==0.10.0
102
+ rsa==4.9
103
+ s3transfer==0.6.2
104
+ sahi==0.11.14
105
+ scikit-learn==1.3.0
106
+ scipy==1.11.2
107
+ seaborn==0.12.2
108
+ semantic-version==2.10.0
109
+ shapely==2.0.1
110
+ six==1.16.0
111
+ smmap==5.0.0
112
+ sniffio==1.3.0
113
+ starlette==0.27.0
114
+ supervision==0.14.0
115
+ sympy==1.12
116
+ tensorboard==2.13.0
117
+ tensorboard-data-server==0.7.1
118
+ tensorflow
119
+ tensorflow-estimator==2.13.0
120
+ tensorflow-hub==0.14.0
121
+ tensorflow-io-gcs-filesystem==0.33.0
122
+ termcolor==2.3.0
123
+ terminaltables==3.1.10
124
+ thop==0.1.1.post2209072238
125
+ threadpoolctl==3.2.0
126
+ toolz==0.12.0
127
+ torch==2.0.1
128
+ torchvision==0.15.2
129
+ tqdm==4.66.1
130
+ triton==2.0.0
131
+ typing_extensions
132
+ tzdata==2023.3
133
+ ultralytics==8.0.168
134
+ urllib3==1.26.16
135
+ uvicorn==0.23.2
136
+ websockets==11.0.3
137
+ Werkzeug==2.3.7
138
+ wget==3.2
139
+ wrapt==1.15.0
140
+ yolov5==7.0.12
src/correspondence.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from __future__ import absolute_import, division
2
+ from collections import defaultdict
3
+
4
+ import numpy as np
5
+
6
+ from src.evaluation import AngleMAE
7
+
8
+
9
+ class DTW:
10
+ """Class for correspondence between two sequence of keypoints detected
11
+ from videos
12
+
13
+ Parameters:
14
+ cost_weightage : dictionary containing weightage of MAE and AngleMAE
15
+ to compute cost for DTW
16
+
17
+ """
18
+
19
+ def __init__(self, cost_weightage=None):
20
+ self.anglemae_calculator = AngleMAE()
21
+
22
+ self.cost_weightage = cost_weightage
23
+
24
+ def cost(self, x, y):
25
+ """computes cost for a set of keypoints through MAE or AngleMAE
26
+
27
+ Args:
28
+ x: A numpy array representing the keypoints of a reference frame
29
+ y: A numpy array representing the keypoints of a test frame
30
+
31
+ Returns:
32
+ A float value representing the mae/angle mae score between
33
+ reference and test pose
34
+ """
35
+ cost = 0
36
+ if self.cost_weightage['angle_mae']:
37
+ mae_angle = self.anglemae_calculator.mean_absolute_error(x, y) * \
38
+ self.cost_weightage['angle_mae']
39
+ cost += mae_angle
40
+
41
+ if self.cost_weightage['mae']:
42
+ mae = self.MAE.mean_absolute_error(x, y) * self.cost_weightage[
43
+ 'mae']
44
+ cost += mae
45
+
46
+ return cost
47
+
48
+ def find_correspondence(self, x, y):
49
+ """applies Dynamic Time Warping algorithm to find correspondence
50
+ between reference video and test video
51
+
52
+ Args:
53
+ x: A numpy array representing the keypoints of a reference video
54
+ y: A numpy array representing the keypoints of a test video
55
+
56
+ Returns:
57
+ A tuple containing ref_frame_indices, test_frame_indices and costs
58
+ where
59
+ ref_frame_indices: A list of indices for reference video frames
60
+ test_frame_indices : A list of indices for test video frames
61
+ costs : A list of cost between reference and test keypoint
62
+ """
63
+
64
+ x = np.asanyarray(x, dtype='float')
65
+ y = np.asanyarray(y, dtype='float')
66
+ len_x, len_y = len(x), len(y)
67
+
68
+ dtw_mapping = defaultdict(lambda: (float('inf'),))
69
+ similarity_score = defaultdict(lambda: float('inf'), )
70
+
71
+ dtw_mapping[0, 0] = (0, 0, 0)
72
+
73
+ similarity_score[0, 0] = 0
74
+
75
+ for i in range(1, len_x + 1):
76
+ for j in range(1, len_y + 1):
77
+ dt = self.cost(x[i - 1], y[j - 1])
78
+
79
+ dtw_mapping[i, j] = min(
80
+ (dtw_mapping[i - 1, j][0] + dt, i - 1, j),
81
+ (dtw_mapping[i, j - 1][0] + dt, i, j - 1),
82
+ (dtw_mapping[i - 1, j - 1][0] + dt, i - 1, j - 1),
83
+ key=lambda a: a[0]
84
+ )
85
+ similarity_score[i, j] = dt
86
+
87
+ path = []
88
+ i, j = len_x, len_y
89
+ while not (i == j == 0):
90
+ path.append(
91
+ (i - 1, j - 1, dtw_mapping[i - 1, j - 1][0],
92
+ similarity_score[i - 1, j - 1])
93
+ )
94
+ i, j = dtw_mapping[i, j][1], dtw_mapping[i, j][2]
95
+ path.reverse()
96
+
97
+ ref_frame_idx, test_frame_idx, _, costs = DTW.get_ref_test_mapping(path)
98
+ return (ref_frame_idx, test_frame_idx, costs)
99
+
100
+ @staticmethod
101
+ def get_ref_test_mapping(paths):
102
+ """applies Dynamic Time Warping algorithm to find correspondence
103
+ between reference video and test video
104
+
105
+ Args:
106
+ paths : list of lists which consists of [i, j, ps, c] where i and
107
+ j are index of x and y time series respectively which have the
108
+ corespondence, ps is cummulative cost and c is cost between these
109
+ two instances
110
+
111
+ Returns:
112
+ A tuple containing ref_frame_indices, test_frame_indices,path_score
113
+ and costs where
114
+ ref_frame_indices: A list of indices for reference video frames
115
+ test_frame_indices : A list of indices for test video frames
116
+ path_score : A list of path score calculated by DTW between
117
+ reference and test keypoints
118
+ costs : A list of cost between reference and test keypoint
119
+ """
120
+
121
+ path = np.array(paths)
122
+ ref_2_test = {}
123
+ for i in range(path.shape[0]):
124
+ ref_2_test_val = ref_2_test.get(path[i][0], [])
125
+ ref_2_test_val.append([path[i][1], path[i][2], path[i][3]])
126
+ ref_2_test[path[i][0]] = ref_2_test_val
127
+ ref_frames = []
128
+ test_frames = []
129
+ path_score = []
130
+ costs = []
131
+
132
+ for ref_frame, test_frame_list in ref_2_test.items():
133
+ ref_frames.append(int(ref_frame))
134
+ test_frame_list.sort(key=lambda x: x[2])
135
+ test_frames.append(int(test_frame_list[0][0]))
136
+ path_score.append(test_frame_list[0][1])
137
+ costs.append(test_frame_list[0][2])
138
+
139
+ return ref_frames, test_frames, path_score, costs
src/crop_video.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import yolov5
4
+
5
+
6
+ class CropVideo:
7
+ """Base class for cropping a video frame-by-frame using various object
8
+ detection method such as YOLO or cv2.Tracker
9
+
10
+ Warning: This class should not be used directly.
11
+ Use derived classes instead.
12
+
13
+ Parameters:
14
+ method : name of the object detection method
15
+ model_path : path to object detection model
16
+
17
+ """
18
+
19
+ def __init__(self, method=None):
20
+ self.method = method
21
+
22
+ def video_crop(self, video_frames):
23
+ """Crops given list of frames by detecting object using different
24
+ methods such as YOLO or cv2.Tracker.
25
+
26
+ Args:
27
+ video_frames: A list of numpy arrays representing the input images
28
+
29
+ Returns:
30
+ A numpy array containing cropped frames
31
+ """
32
+ raise NotImplementedError
33
+
34
+
35
+ class YOLOCrop(CropVideo):
36
+
37
+ """Class for cropping a video frame-by-frame using YOLO object detection
38
+ method
39
+
40
+
41
+ Parameters :
42
+ cropping_model_path : path to object detection model
43
+
44
+ """
45
+
46
+ def __init__(self, method=None, model_path=None):
47
+ super().__init__('yolo')
48
+ self.model_path = model_path or '../data/models/yolo/yolov5x.pt'
49
+ self.load_model(self.model_path)
50
+
51
+ def load_model(self, model_path):
52
+ """Loads object detection model.
53
+ """
54
+ self.model = yolov5.load(model_path)
55
+ self.model.classes = 0
56
+
57
+ def get_yolo_bbox(self, frame):
58
+ """Runs YOLO object detection on an input image.
59
+
60
+ Args:
61
+ frame: A [height, width, 3] numpy array representing the input image
62
+
63
+ Returns:
64
+ A list conating boundig box parameters [x_min, y_min, x_max, y_max]
65
+ """
66
+
67
+ results = self.model(frame)
68
+ predictions = results.pred[0]
69
+
70
+ boxes = predictions[:, :4].numpy().astype(np.int32)
71
+ if len(boxes) == 0:
72
+ return []
73
+ elif len(boxes) == 1:
74
+ return list(boxes[0])
75
+ else:
76
+ area = []
77
+ for i in boxes:
78
+ area.append(cv2.contourArea(np.array([[i[:2]], [i[2:]]])))
79
+ largest_bbox = boxes[np.argmax(np.array(area))]
80
+ return list(largest_bbox)
81
+
82
+ def video_crop(self, video_frames):
83
+ """Crops given list of frames by detecting object using YOLO
84
+
85
+ Args:
86
+ video_frames: A list of numpy arrays representing the input images
87
+
88
+ Returns:
89
+ A numpy array containing cropped frames
90
+ """
91
+
92
+ x_width_start = []
93
+ y_height_start = []
94
+ x_width_end = []
95
+ y_height_end = []
96
+ frame_height, frame_width = 0, 0
97
+
98
+ widths = []
99
+ heights = []
100
+ for frame in video_frames:
101
+ frame_height, frame_width, _ = frame.shape
102
+ bbox = self.get_yolo_bbox(frame)
103
+
104
+ if len(bbox) == 0:
105
+ continue
106
+ else:
107
+ x_width_start.append(int(max(bbox[0] - 100, 0)))
108
+ y_height_start.append(int(max(bbox[1] - 100, 0)))
109
+ x_width_end.append(int(min(bbox[2] + 100, frame.shape[1])))
110
+ y_height_end.append(int(min(bbox[3] + 100, frame.shape[0])))
111
+
112
+ widths.append(x_width_end[-1] - x_width_start[-1])
113
+ heights.append(y_height_end[-1] - y_height_start[-1])
114
+
115
+ width = np.percentile(np.array(widths), 95)
116
+ height = np.percentile(np.array(heights), 95)
117
+ box_len = int(max(width, height))
118
+
119
+ cropped_frames = []
120
+
121
+ for i in range(len(widths)):
122
+ frame = video_frames[i]
123
+ xs = x_width_start[i]
124
+ xe = x_width_start[i] + box_len
125
+ ys = y_height_start[i]
126
+ ye = y_height_start[i] + box_len
127
+
128
+ if ye > frame_height:
129
+ ye = frame_height
130
+ ys = max(0, ye - box_len)
131
+
132
+ if xe > frame_width:
133
+ xe = frame_width
134
+ xs = max(0, xe - box_len)
135
+
136
+ cropped = frame[int(ys): int(ye), int(xs): int(xe), :]
137
+ cropped_frames.append(np.array(cropped))
138
+
139
+ return np.array(cropped_frames)
140
+
141
+
142
+ class TrackerCrop(YOLOCrop):
143
+ def __init__(self, model_path=None):
144
+ super().__init__(method='yolo')
145
+ self.tracker = cv2.TrackerMIL.create()
146
+
147
+ @staticmethod
148
+ def expand_bbox(bbox, frame_shape):
149
+ """Expands given bounding box by 50 pixels
150
+
151
+ Args:
152
+ bbox: A list [x,y, width, height] consits of bounding box
153
+ parameters of
154
+ object
155
+ frame_shape: (height, width) of a frame
156
+
157
+ """
158
+ bbox[0] = max(bbox[0] - 50, 0)
159
+ bbox[1] = max(bbox[1] - 50, 0)
160
+ bbox[2] = min(bbox[3] + 50, frame_shape[1] - bbox[0] - 1)
161
+ bbox[3] = min(bbox[3] + 50, frame_shape[0] - bbox[1] - 1)
162
+
163
+ @staticmethod
164
+ def pad_bbox(crop_frame, box_len):
165
+ """Pads given cropped frame
166
+
167
+ Args:
168
+ crop_frame: A numpy array representing the cropped frame
169
+ box_len: An integer value representing maximum out of width and height
170
+
171
+ Returns:
172
+ A numpy array containing cropped frame with padding
173
+ """
174
+ if box_len > crop_frame.shape[0] or box_len > crop_frame.shape[1]:
175
+ crop_frame = np.pad(
176
+ crop_frame, pad_width=(
177
+ (0, box_len - crop_frame.shape[0]),
178
+ (0, box_len - crop_frame.shape[1]), (0, 0))
179
+ )
180
+ return crop_frame
181
+
182
+ @staticmethod
183
+ def clip_coordinates(x, y, box_len, frame_shape):
184
+ """Clips (x,y) coordinates representing the centre of bounding box
185
+
186
+ Args:
187
+ x: x-coordinate of the centre of bounding box
188
+ y: y-coordinate of the centre of bounding box
189
+ box_len: An integer value representing maximum out of width and height
190
+ frame_shape: (height, width) of a frame
191
+
192
+ Returns:
193
+ (x,y) clipped coordinates
194
+ """
195
+ if x + box_len > frame_shape[1]:
196
+ diff = x + box_len - frame_shape[1]
197
+ x = max(0, x - diff)
198
+ if y + box_len > frame_shape[0]:
199
+ diff = y + box_len - frame_shape[0]
200
+ y = max(0, y - diff)
201
+
202
+ return (x, y)
203
+
204
+ def video_crop(self, video_frames):
205
+ """Crops given list of frames by detecting object using cv2.Tracker
206
+
207
+ Args:
208
+ video_frames: A list of numpy arrays representing the input images
209
+
210
+ Returns:
211
+ A numpy array containing cropped frames
212
+ """
213
+
214
+ frame = video_frames[0]
215
+ bbox = self.get_yolo_bbox(frame)
216
+ TrackerCrop.expand_bbox(bbox, frame.shape)
217
+ self.tracker.init(frame, bbox)
218
+ output_frame_list = []
219
+ for frame in video_frames:
220
+ _, bbox = self.tracker.update(frame)
221
+ x, y, w, h = bbox
222
+ box_len = max(w, h)
223
+ x, y = TrackerCrop.clip_coordinates(x, y, box_len, frame.shape)
224
+ crop_frame = np.array(frame[y:y + box_len, x:x + box_len, :])
225
+ crop_frame = TrackerCrop.pad_bbox(crop_frame, box_len)
226
+ output_frame_list.append(crop_frame)
227
+
228
+ output_frame_array = np.array(output_frame_list)
229
+
230
+ return output_frame_array
src/evaluation.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn import metrics
3
+
4
+
5
+ class AngleMAE:
6
+ """Class for evaluation metrics AngleMAE which computes mean absolute
7
+ error between set of joints angle of reference and test video keypoints
8
+
9
+ Parameters:
10
+ ref_keypoits : A numpy array [n, 1, 17, 3] containing keypoints
11
+ representing poses in reference video
12
+ test_keypoits : A numpy array [n, 1, 17, 3] containing keypoints
13
+ representing poses in test video
14
+ """
15
+
16
+ def __init__(self):
17
+
18
+ self.joints_array = np.array(
19
+ [[11, 5, 7],
20
+ [12, 6, 8],
21
+ [6, 8, 10],
22
+ [5, 7, 9],
23
+ [11, 12, 14],
24
+ [12, 11, 13],
25
+ [12, 14, 16],
26
+ [11, 13, 15],
27
+ [5, 11, 13]]
28
+ )
29
+
30
+ self.joints_dict = {
31
+ 'left_shoulder_joint': ['left_hip', 'left_shoulder', 'left_elbow'],
32
+ 'right_shoulder_joint': ['right_hip', 'right_shoulder',
33
+ 'right_elbow'],
34
+ 'right_elbow_joint': ['right_shoulder', 'right_elbow',
35
+ 'right_wrist'],
36
+ 'left_elbow_joint': ['left_shoulder', 'left_elbow', 'left_wrist'],
37
+ 'right_hip_joint': ['left_hip', 'right_hip', 'right_knee'],
38
+ 'left_hip_joint': ['right_hip', 'left_hip', 'left_knee'],
39
+ 'right_knee_joint': ['right_hip', 'right_knee', 'right_ankle'],
40
+ 'left_knee_joint': ['left_hip', 'left_knee', 'left_ankle'],
41
+ 'waist_joint': ['left_shoulder', 'left_hip', 'left_knee']
42
+ }
43
+
44
+ self.angle_mae_joints_weightage_array = ([1, 1, 1, 1, 1, 1, 1, 1, 1])
45
+
46
+ def mean_absolute_error(self, ref_keypoints, test_keypoints) -> object:
47
+ """
48
+ Calcultes MAE of given joints via index between reference and test
49
+ frames
50
+
51
+ Args:
52
+ ref_keypoints: ndarray of shape (17,2) containing reference frame
53
+ x, y coordinates
54
+ test_keypoints: ndarray of shape (17,2) containing test frame x,
55
+ y coordinates
56
+
57
+ Returns:
58
+ MAE: A float value representing angle based MAE
59
+ """
60
+
61
+ ref_angle = self.calculate_angle_atan2(ref_keypoints)
62
+ test_angle = self.calculate_angle_atan2(test_keypoints)
63
+
64
+ diff = np.abs(ref_angle - test_angle)
65
+
66
+ mae = np.sum(diff * self.angle_mae_joints_weightage_array) / sum(
67
+ self.angle_mae_joints_weightage_array
68
+ )
69
+
70
+ return mae
71
+
72
+ def calculate_angle_atan2(self, kpts):
73
+ """
74
+ Calcultes angle of given joint
75
+
76
+ Args:
77
+ kpts: ndarray of shape (17,2) containing x, y coordinates
78
+
79
+ Returns:
80
+ angle: A float value representing angle in degrees
81
+ """
82
+ a = np.zeros((9, 2))
83
+ b = np.zeros((9, 2))
84
+ c = np.zeros((9, 2))
85
+
86
+ for i, j in enumerate(self.joints_array):
87
+
88
+ a[i] = kpts[j[0]]
89
+ b[i] = kpts[j[1]]
90
+ c[i] = kpts[j[2]]
91
+
92
+ vector_b_a = b - a
93
+ vector_b_c = b - c
94
+
95
+ angle_0 = np.arctan2(
96
+ vector_b_a[:, 1],
97
+ vector_b_a[:, 0]
98
+ )
99
+
100
+ angle_2 = np.arctan2(
101
+ vector_b_c[:, 1],
102
+ vector_b_c[:, 0]
103
+ )
104
+
105
+ determinant = vector_b_a[:, 0] * vector_b_c[:, 1] - vector_b_a[:,
106
+ 1] * vector_b_c[:,
107
+ 0]
108
+
109
+ angle_diff = (angle_0 - angle_2)
110
+
111
+ angle = np.degrees(angle_diff)
112
+ joints_angle_array = angle * (determinant < 0) + (360 + angle) * (
113
+ determinant > 0)
114
+
115
+ return joints_angle_array % 360
116
+
117
+
118
+ class MAE:
119
+ def __init__(self):
120
+ pass
121
+
122
+ @staticmethod
123
+ def mean_absolute_error(ref_keypoints, test_keypoints):
124
+ """
125
+ Calcultes MAE of given keypoints between reference and test frames
126
+
127
+ Args:
128
+ ref_keypoints: ndarray of shape (17,2) containing reference frame
129
+ x, y coordinates
130
+ test_keypoints: ndarray of shape (17,2) containing test frame x,
131
+ y coordinates
132
+
133
+ Returns:
134
+ MAE: A float value representing MAE
135
+ """
136
+ return metrics.mean_absolute_error(
137
+ ref_keypoints.flatten(),
138
+ test_keypoints.flatten(),
139
+ )
src/fastdtw_file.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division
2
+
3
+ from collections import defaultdict
4
+
5
+ import numpy as np
6
+ from pose_estimation import joints_angle
7
+
8
+
9
+ def fastdtw(x, y, radius=1, dist=None, method={}, angle_comp_method=''):
10
+ ''' return the approximate distance between 2 time series with O(N)
11
+ time and memory complexity
12
+
13
+ Parameters
14
+ ----------
15
+ x : array_like
16
+ input array 1
17
+ y : array_like
18
+ input array 2
19
+ radius : int
20
+ size of neighborhood when expanding the path. A higher value will
21
+ increase the accuracy of the calculation but also increase time
22
+ and memory consumption. A radius equal to the size of x and y
23
+ will yield an exact dynamic time warping calculation.
24
+ dist : function or int
25
+ The method for calculating the distance between x[i] and y[j]. If
26
+ dist is an int of value p > 0, then the p-norm will be used. If
27
+ dist is a function then dist(x[i], y[j]) will be used. If dist is
28
+ None then abs(x[i] - y[j]) will be used.
29
+
30
+ Returns
31
+ -------
32
+ distance : float
33
+ the approximate distance between the 2 time series
34
+ path : list
35
+ list of indexes for the inputs x and y
36
+
37
+ '''
38
+ x = np.asanyarray(x, dtype='float')
39
+ y = np.asanyarray(y, dtype='float')
40
+
41
+ return __dtw(x, y, None, angle_comp_method=angle_comp_method)
42
+
43
+
44
+ def cost(x, y, method, angle_comp_method=''):
45
+
46
+ mae_angle = joints_angle.mean_absolute_error(
47
+ x, y
48
+ )
49
+ cost = mae_angle
50
+
51
+ return cost
52
+
53
+
54
+ def __dtw(x, y, method, angle_comp_method=''):
55
+ len_x, len_y = len(x), len(y)
56
+
57
+ dtw_mapping = defaultdict(lambda: (float('inf'),))
58
+ similarity_score = defaultdict(lambda: float('inf'), )
59
+
60
+ dtw_mapping[0, 0] = (0, 0, 0)
61
+
62
+ similarity_score[0, 0] = 0
63
+
64
+ for i in range(1, len_x + 1):
65
+ for j in range(1, len_y + 1):
66
+ dt = cost(
67
+ x[i - 1], y[j - 1], method, angle_comp_method=angle_comp_method
68
+ )
69
+ dtw_mapping[i, j] = min(
70
+ (dtw_mapping[i - 1, j][0] + dt, i - 1, j),
71
+ (dtw_mapping[i, j - 1][0] + dt, i, j - 1),
72
+ (dtw_mapping[i - 1, j - 1][0] + dt, i - 1, j - 1),
73
+ key=lambda a: a[0]
74
+ )
75
+ similarity_score[i, j] = dt
76
+
77
+ path = []
78
+ i, j = len_x, len_y
79
+ while not (i == j == 0):
80
+ path.append(
81
+ (i - 1, j - 1, dtw_mapping[i - 1, j - 1][0],
82
+ similarity_score[i - 1, j - 1])
83
+ )
84
+ i, j = dtw_mapping[i, j][1], dtw_mapping[i, j][2]
85
+ path.reverse()
86
+ return (dtw_mapping[len_x, len_y][0], path)
src/pose_detection.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import tensorflow_hub as hub
4
+ import os
5
+
6
+
7
+ class PoseDetection:
8
+ """Base class for pose detection in images using various algorithm such
9
+ as Movenet
10
+
11
+ Warning: This class should not be used directly.
12
+ Use derived classes instead.
13
+
14
+ Parameters:
15
+ model_name : name of the pose detection method
16
+ input_size : image size of input required for model
17
+ model_path : path to pose detection model
18
+
19
+ """
20
+
21
+ def __init__(self, model_name=None, input_size=None, model_path=None):
22
+ self.model_name = model_name
23
+ self.input_size = input_size
24
+ self.model_path = model_path
25
+ self.load_model()
26
+
27
+ def load_model(self):
28
+ """Absrtact method to loads pose detection model.
29
+ """
30
+ raise NotImplementedError
31
+
32
+ def preprocess_image(self, frame):
33
+ """Absrtact method to preprocess a image before running pose
34
+ detection inference on it.
35
+ """
36
+ raise NotImplementedError
37
+
38
+ def run_inference(self, frames):
39
+ """Absrtact method to run pose detection inference on it.
40
+ """
41
+ raise NotImplementedError
42
+
43
+
44
+ class MovenetPoseDetection(PoseDetection):
45
+ """Class for pose detection in images using Movenet model
46
+
47
+ Parameters:
48
+ model_name : name of the pose detection method
49
+ input_size : image size of input required for model
50
+ model_path : path to pose detection model
51
+
52
+ """
53
+
54
+ def __init__(self, model_name=None, input_size=None, model_path=None):
55
+ model_name = model_name or 'movenet'
56
+ input_size = input_size or 256
57
+ model_path = model_path or 'models/movenet/movenet_tf/1'
58
+ super().__init__(model_name, input_size, model_path)
59
+
60
+ def load_model(self):
61
+ """Loads the pose detection model.
62
+
63
+ """
64
+ if self.model_name == "movenet":
65
+ module = hub.load(self.model_path)
66
+ self.model = module.signatures['serving_default']
67
+
68
+ def preprocess_image(self, frame):
69
+ """Preprocesses an image to transform it into required format for
70
+ pose detection
71
+
72
+ Args:
73
+ frame: A numpy array representing the input image
74
+
75
+ Returns:
76
+ A tensor of 'int32' data type and resized input image.
77
+ """
78
+ input_image = tf.expand_dims(frame, axis=0)
79
+ input_image = tf.image.resize_with_pad(
80
+ input_image, self.input_size, self.input_size
81
+ )
82
+ input_image = tf.cast(input_image, dtype=tf.int32)
83
+
84
+ return input_image
85
+
86
+ def run_inference(self, frames):
87
+ """Appllies pose dection model on a frame
88
+
89
+ Args:
90
+ frames: A list of numpy arrays representing the input images
91
+
92
+ Returns:
93
+ A [n, 1, 17, 3] float numpy array representing the keypoint
94
+ coordinates
95
+ and scores predicted by Movenet Model of list of images
96
+ """
97
+ keypoints = np.zeros((len(frames), 17, 2))
98
+ for i, frame in enumerate(frames):
99
+ frame = self.preprocess_image(frame)
100
+ keypoints[i, :, :] = self.run_model(frame)
101
+ return keypoints
102
+
103
+ def run_model(self, input_image):
104
+ """Runs detection on an input image.
105
+
106
+ Args:
107
+ input_image: A [1, height, width, 3] tensor represents the input image
108
+ pixels. Note that the height/width should already be resized and
109
+ match the
110
+ expected input resolution of the model before passing into this
111
+ function.
112
+
113
+ Returns:
114
+ A [1, 1, 17, 3] float numpy array representing the keypoint
115
+ coordinates
116
+ and scores predicted by Movenet Model
117
+ """
118
+
119
+ outputs = self.model(input_image)
120
+ outputs = outputs['output_0'].numpy()
121
+ outputs[:, :, :, [0, 1]] = outputs[:, :, :, [1, 0]]
122
+
123
+ return outputs[0, 0, :, :2]
src/run_chain.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
4
+
5
+ import cv2
6
+ from src.correspondence import DTW
7
+ from src.crop_video import YOLOCrop, TrackerCrop
8
+ from src.pose_detection import MovenetPoseDetection
9
+ from src import utils
10
+
11
+
12
+ def main(ref_video_path, test_video_path, output_video_path, crop_method="yolo"):
13
+
14
+ ref_frames = utils.get_video_frames(ref_video_path)
15
+ test_frames = utils.get_video_frames(test_video_path)
16
+
17
+ if crop_method == 'Tracker':
18
+ crop_object = TrackerCrop()
19
+ else:
20
+ crop_object = YOLOCrop()
21
+
22
+ ref_crop_frames = crop_object.video_crop(ref_frames)
23
+ test_crop_frames = crop_object.video_crop(test_frames)
24
+
25
+ movenet = MovenetPoseDetection()
26
+ ref_keypoints = movenet.run_inference(ref_crop_frames)
27
+ test_keypoints = movenet.run_inference(test_crop_frames)
28
+
29
+ dtw = DTW(cost_weightage={'mae' : 0, 'angle_mae' : 1})
30
+ ref_frame_idx, test_frame_idx, costs = dtw.find_correspondence(ref_keypoints, test_keypoints)
31
+
32
+ utils.Plot.plot_matching(ref_crop_frames, test_crop_frames, ref_keypoints, test_keypoints,ref_frame_idx, test_frame_idx, costs,output_video_path)
33
+
34
+
35
+ if __name__ == "__main__":
36
+
37
+ parser = argparse.ArgumentParser(
38
+ description="run chained process",
39
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
40
+ )
41
+ parser.add_argument('--Ref_video', type=str, required=True)
42
+ parser.add_argument('--Test_video', type=str, required=True)
43
+ parser.add_argument('--Output_path', type=str, required=True)
44
+
45
+ args = parser.parse_args()
46
+
47
+ try:
48
+ main(args.Ref_video, args.Test_video, args.Output_path)
49
+ except NameError:
50
+ print("Video file is not appropriate.")
51
+ except ValueError:
52
+ print("YOLO couldn't detect bounding box for given video.")
53
+ except cv2.error:
54
+ print(
55
+ "Can not convert color from BGR to RGB. Please check the input frame."
56
+ )
src/utils.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from moviepy.editor import ImageSequenceClip
6
+
7
+
8
+ def get_video_frames(video_path):
9
+ """Reads a video frame by frame
10
+
11
+ Args:
12
+ video_path: A video path
13
+
14
+ Returns:
15
+ A list of numpy arrays representing video frames
16
+ """
17
+
18
+ vid_cap = cv2.VideoCapture(video_path)
19
+ frames = []
20
+ while True:
21
+ ret, frame = vid_cap.read()
22
+ if ret:
23
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
24
+ frames.append(frame)
25
+ else:
26
+ break
27
+
28
+ return frames
29
+
30
+
31
+ class Plot:
32
+ """Class for plotting keypoints on video frames and creating a video
33
+
34
+ """
35
+ KEYPOINT_EDGE_INDS_TO_COLOR = {
36
+ (0, 1): "m",
37
+ (0, 2): "c",
38
+ (1, 3): "m",
39
+ (2, 4): "c",
40
+ (0, 5): "m",
41
+ (0, 6): "c",
42
+ (5, 7): "m",
43
+ (7, 9): "m",
44
+ (6, 8): "c",
45
+ (8, 10): "c",
46
+ (5, 6): "y",
47
+ (5, 11): "m",
48
+ (6, 12): "c",
49
+ (11, 12): "y",
50
+ (11, 13): "m",
51
+ (13, 15): "m",
52
+ (12, 14): "c",
53
+ (14, 16): "c",
54
+ }
55
+
56
+ @staticmethod
57
+ def add_keypoints_to_image(
58
+ images_array, keypoints_list
59
+ ):
60
+ """
61
+ Adds keypoints to the image
62
+
63
+ Args:
64
+ images_array: list of images to represent the keypoints
65
+ keypoints_list : list of keypoints
66
+ model: name of the model used to detect the keypoints
67
+ Returns:
68
+ None
69
+ """
70
+
71
+ output_overlay_array = images_array.astype(np.int32)
72
+ output_overlay_list = Plot.draw_prediction_on_image(
73
+ output_overlay_array, keypoints_list
74
+ )
75
+ return np.array(output_overlay_list)
76
+
77
+ @staticmethod
78
+ def draw_prediction_on_image(
79
+ image_list, keypoints_list
80
+ ):
81
+
82
+ """Draws the keypoint predictions on image.
83
+
84
+ Args:
85
+ image_list: A numpy array with shape [n, height, width, channel]
86
+ representing the
87
+ pixel values of the input image where n is number of images.
88
+ keypoints_list: A numpy array with shape [n, 17, 2] representing the
89
+ coordinates of 17 keypoints where n is number of images.
90
+
91
+ Returns:
92
+ A numpy array with shape [n, out_height, out_width, channel]
93
+ representing
94
+ the list of
95
+ images overlaid with keypoint predictions.
96
+ """
97
+ height, width, channel = image_list[0].shape
98
+
99
+ keypoint_locs, keypoint_edges = Plot._keypoints_and_edges_for_display(
100
+ keypoints_list,
101
+ height, width
102
+ )
103
+ for img_i in range(keypoint_locs.shape[0]):
104
+ for edge in keypoint_edges[img_i]:
105
+ image = cv2.line(
106
+ image_list[img_i], (int(edge[0]), int(edge[1])),
107
+ (int(edge[2]), int(edge[3])), color=(0, 0, 255), thickness=3
108
+ )
109
+
110
+ for center_x, center_y in keypoint_locs[img_i]:
111
+ image = cv2.circle(
112
+ image_list[img_i], (int(center_x), int(center_y)), radius=5,
113
+ color=(255, 0, 0), thickness=-1
114
+ )
115
+
116
+ image_list[img_i] = image
117
+
118
+ return image_list
119
+
120
+ @staticmethod
121
+ def _keypoints_and_edges_for_display(
122
+ keypoints_list, height, width,
123
+ ):
124
+ """Returns high confidence keypoints and edges for visualization.
125
+
126
+ Args:
127
+ keypoints_list: A numpy array with shape [1, 1, 17, 3] representing
128
+ the keypoint coordinates and scores returned from the MoveNet model.
129
+ height: height of the image in pixels.
130
+ width: width of the image in pixels.
131
+ keypoint_threshold: minimum confidence score for keypoint to be
132
+ visualized.
133
+
134
+ Returns:
135
+ A (kpts_absolute_xy, edges_xy) containing:
136
+ * array with shape [n, 17, 2] representing the coordinates of all
137
+ keypoints of all detected entities in n images;
138
+ * array with shape [n, 18, 4] representing the coordinates of all
139
+ skeleton edges of all detected entities in n images;
140
+ """
141
+ kpts_x = width * keypoints_list[:, :, 0]
142
+ kpts_y = height * keypoints_list[:, :, 1]
143
+
144
+ edge_pair = np.array(list(Plot.KEYPOINT_EDGE_INDS_TO_COLOR.keys()))
145
+ kpts_absolute_xy = np.stack(
146
+ [kpts_x, kpts_y], axis=-1
147
+ )
148
+
149
+ x_start = kpts_x[:, edge_pair[:, 0]]
150
+ y_start = kpts_y[:, edge_pair[:, 0]]
151
+ x_end = kpts_x[:, edge_pair[:, 1]]
152
+ y_end = kpts_y[:, edge_pair[:, 1]]
153
+
154
+ edges = np.stack([x_start, y_start, x_end, y_end], axis=2)
155
+ return kpts_absolute_xy, edges
156
+
157
+ @staticmethod
158
+ def resize_and_concat(ref_image_list, test_image_list):
159
+ """Resizes either of reference frames list or test frames list to
160
+ make both list of equal shape and merges both frames side by side
161
+
162
+ Args:
163
+ ref_image_list: A list of numpy array representing reference video
164
+ frames
165
+ test_image_list: A list of numpy array representing test video frames
166
+
167
+ Returns:
168
+ concat_img_list: A list of numpy array representing merged video
169
+ frames
170
+
171
+ """
172
+
173
+ def pad_image(image_list, pad_axis, pad_len, odd_len_diff):
174
+ """pads given number of pixels to image_list on given axis
175
+
176
+ Args:
177
+ image_list: A list of numpy array representing video frames
178
+ pad_axis: A list of numpy array representing test video frames
179
+ pad_len: number of pixels to pad on either side of frame
180
+ odd_len_diff: 1 if difference between reference and test is odd
181
+ else 0
182
+
183
+ Returns:
184
+ padded video frame
185
+
186
+ """
187
+ if pad_axis == 0:
188
+ return np.pad(
189
+ image_list, (
190
+ (0, 0), (pad_len, pad_len + odd_len_diff), (0, 0),
191
+ (0, 0)), 'constant',
192
+ constant_values=(0)
193
+ )
194
+ elif pad_axis == 1:
195
+ return np.pad(
196
+ image_list, (
197
+ (0, 0), (0, 0), (pad_len, pad_len + odd_len_diff),
198
+ (0, 0)), 'constant',
199
+ constant_values=(0)
200
+ )
201
+
202
+ ref_height, ref_width, _ = ref_image_list[0].shape
203
+ test_height, test_width, _ = test_image_list[0].shape
204
+
205
+ pad_height = abs(test_height - ref_height) // 2
206
+ odd_height_diff = (test_height - ref_height) % 2
207
+
208
+ if ref_height < test_height:
209
+ ref_image_list = pad_image(
210
+ ref_image_list, 0, pad_height, odd_height_diff
211
+ )
212
+ elif ref_height > test_height:
213
+ test_image_list = pad_image(
214
+ test_image_list, 0, pad_height, odd_height_diff
215
+ )
216
+
217
+ pad_width = abs(test_width - ref_width) // 2
218
+ odd_width_diff = (test_width - ref_width) % 2
219
+
220
+ if ref_width < test_width:
221
+ ref_image_list = pad_image(
222
+ ref_image_list, 1, pad_width, odd_width_diff
223
+ )
224
+ elif ref_width > test_width:
225
+ test_image_list = pad_image(
226
+ test_image_list, 1, pad_width, odd_width_diff
227
+ )
228
+
229
+ concat_img_list = np.concatenate(
230
+ (ref_image_list, test_image_list), axis=2
231
+ )
232
+ return concat_img_list
233
+
234
+ @staticmethod
235
+ def overlay_score_on_images(image_list, scores):
236
+ """writes score on given image list
237
+
238
+ Args:
239
+ image_list: A list of numpy array representing video frames
240
+ scores: A list of score between reference and test keypoints
241
+
242
+ Returns:
243
+ A list of numpy array with score overlayed on it
244
+
245
+ """
246
+ for i in range(len(image_list)):
247
+ image = image_list[i, :, :, :]
248
+
249
+ txt = f"Score : {scores[i]}"
250
+
251
+ image = cv2.putText(
252
+ image, txt, (5, 10), cv2.FONT_HERSHEY_SIMPLEX,
253
+ 0.4, (255, 0, 0), 1, cv2.LINE_AA
254
+ )
255
+
256
+ image_list[i, :, :, :] = image
257
+
258
+ return image_list
259
+
260
+ def plot_matching(
261
+ ref_frames, test_frames, ref_keypoints, test_keypoints, ref_frames_idx,
262
+ test_frames_idx, costs, output_path
263
+ ):
264
+ """creates a video of reference and test video frames with keypoints
265
+ overlayed on them
266
+
267
+ Args:
268
+ ref_frames: A list of numpy array representing reference video frames
269
+ test_frames: A list of numpy array representing test video frames
270
+ ref_keypoints: A list of numpy array representing reference video
271
+ keypoints
272
+ test_keypoints: A list of numpy array representing reference video
273
+ keypoints
274
+ ref_frames_idx: A list of reference frame indices
275
+ test_frames_idx: A list of test frame indices
276
+ costs: A list of score between reference and test keypoints
277
+ output_path: path at which output video to be stored
278
+
279
+ """
280
+ if (ref_frames_idx is not None) and (
281
+ test_frames_idx is not None) and len(ref_frames_idx) == len(
282
+ test_frames_idx
283
+ ):
284
+
285
+ ref_frames = ref_frames[ref_frames_idx]
286
+ test_frames = test_frames[test_frames_idx]
287
+
288
+ ref_keypoints = ref_keypoints[ref_frames_idx]
289
+ test_keypoints = test_keypoints[test_frames_idx]
290
+
291
+ if costs is None:
292
+ costs = ['N/A'] * len(ref_frames)
293
+
294
+ display_line = [
295
+ f"Cost: {costs[i]} Ref frame: {ref_frames_idx[i]} Test " \
296
+ f"frame: {test_frames_idx[i]}"
297
+ for i in range(len(ref_frames_idx))]
298
+
299
+ ref_image_list = Plot.add_keypoints_to_image(
300
+ ref_frames, ref_keypoints
301
+ )
302
+ test_image_list = Plot.add_keypoints_to_image(
303
+ test_frames, test_keypoints
304
+ )
305
+
306
+ comparison_img_list = Plot.resize_and_concat(
307
+ ref_image_list, test_image_list
308
+ )
309
+
310
+ comparison_img_list = Plot.overlay_score_on_images(
311
+ comparison_img_list, display_line
312
+ )
313
+
314
+ video = ImageSequenceClip(list(comparison_img_list), fps=5)
315
+ video.write_videofile(output_path, fps=5)