|
import sys |
|
from pathlib import Path |
|
|
|
import argparse |
|
from pytube import YouTube |
|
import os.path as osp |
|
from utils.torch_utils import select_device, time_sync |
|
from utils.general import check_img_size |
|
from utils.datasets import LoadImages |
|
from models.experimental import attempt_load |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
import yaml |
|
from tqdm import tqdm |
|
import imageio |
|
from val import run_nms, post_process_batch |
|
|
|
|
|
VIDEO_NAME = 'Crazy Uptown Funk Flashmob in Sydney for sydney domains campaign.mp4' |
|
URL = 'https://youtu.be/1WLMahXDnuI' |
|
COLOR = (255, 0, 255) |
|
ALPHA = 0.5 |
|
SEG_THICK = 3 |
|
FPS_TEXT_SIZE = 2 |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--data', type=str, default='data/coco-kp.yaml') |
|
parser.add_argument('--imgsz', type=int, default=448) |
|
parser.add_argument('--vid', type=str, default='') |
|
parser.add_argument('--weights', default='kapao_s_coco.pt') |
|
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or cpu') |
|
parser.add_argument('--half', action='store_true') |
|
parser.add_argument('--conf-thres', type=float, default=0.5, help='confidence threshold') |
|
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold') |
|
parser.add_argument('--no-kp-dets', action='store_true', help='do not use keypoint objects') |
|
parser.add_argument('--conf-thres-kp', type=float, default=0.5) |
|
parser.add_argument('--conf-thres-kp-person', type=float, default=0.2) |
|
parser.add_argument('--iou-thres-kp', type=float, default=0.45) |
|
parser.add_argument('--overwrite-tol', type=int, default=50) |
|
parser.add_argument('--scales', type=float, nargs='+', default=[1]) |
|
parser.add_argument('--flips', type=int, nargs='+', default=[-1]) |
|
parser.add_argument('--display', action='store_true', help='display inference results') |
|
parser.add_argument('--fps', action='store_true', help='display fps') |
|
parser.add_argument('--gif', action='store_true', help='create fig') |
|
parser.add_argument('--start', type=int, default=68, help='start time (s)') |
|
parser.add_argument('--end', type=int, default=98, help='end time (s)') |
|
args = parser.parse_args() |
|
|
|
with open(args.data) as f: |
|
data = yaml.safe_load(f) |
|
|
|
|
|
data['imgsz'] = args.imgsz |
|
data['conf_thres'] = args.conf_thres |
|
data['iou_thres'] = args.iou_thres |
|
data['use_kp_dets'] = not args.no_kp_dets |
|
data['conf_thres_kp'] = args.conf_thres_kp |
|
data['iou_thres_kp'] = args.iou_thres_kp |
|
data['conf_thres_kp_person'] = args.conf_thres_kp_person |
|
data['overwrite_tol'] = args.overwrite_tol |
|
data['scales'] = args.scales |
|
data['flips'] = [None if f == -1 else f for f in args.flips] |
|
|
|
|
|
|
|
device = select_device(args.device, batch_size=1) |
|
print('Using device: {}'.format(device)) |
|
|
|
model = attempt_load(args.weights, map_location=device) |
|
half = args.half & (device.type != 'cpu') |
|
if half: |
|
model.half() |
|
stride = int(model.stride.max()) |
|
|
|
imgsz = check_img_size(args.imgsz, s=stride) |
|
dataset = LoadImages(args.vid, img_size=imgsz, stride=stride, auto=True) |
|
|
|
if device.type != 'cpu': |
|
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) |
|
|
|
cap = dataset.cap |
|
cap.set(cv2.CAP_PROP_POS_MSEC, args.start * 1000) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
n = int(fps * (args.end - args.start)) |
|
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
gif_frames = [] |
|
video_name = 'flash_mob_inference_{}'.format(osp.splitext(args.weights)[0]) |
|
|
|
if not args.display: |
|
writer = cv2.VideoWriter(video_name + '.mp4', |
|
cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) |
|
if not args.fps: |
|
dataset = tqdm(dataset, desc='Writing inference video', total=n) |
|
|
|
t0 = time_sync() |
|
for i, (path, img, im0, _) in enumerate(dataset): |
|
img = torch.from_numpy(img).to(device) |
|
img = img.half() if half else img.float() |
|
img = img / 255.0 |
|
if len(img.shape) == 3: |
|
img = img[None] |
|
|
|
out = model(img, augment=True, kp_flip=data['kp_flip'], scales=data['scales'], flips=data['flips'])[0] |
|
person_dets, kp_dets = run_nms(data, out) |
|
bboxes, poses, _, _, _ = post_process_batch(data, img, [], [[im0.shape[:2]]], person_dets, kp_dets) |
|
|
|
im0_copy = im0.copy() |
|
|
|
|
|
for j, (bbox, pose) in enumerate(zip(bboxes, poses)): |
|
x1, y1, x2, y2 = bbox |
|
size = ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5 |
|
|
|
cv2.rectangle(im0_copy, (int(x1), int(y1)), (int(x2), int(y2)), COLOR, thickness=2) |
|
for seg in data['segments'].values(): |
|
pt1 = (int(pose[seg[0], 0]), int(pose[seg[0], 1])) |
|
pt2 = (int(pose[seg[1], 0]), int(pose[seg[1], 1])) |
|
cv2.line(im0_copy, pt1, pt2, COLOR, SEG_THICK) |
|
im0 = cv2.addWeighted(im0, ALPHA, im0_copy, 1 - ALPHA, gamma=0) |
|
|
|
if i == 0: |
|
t = time_sync() - t0 |
|
else: |
|
t = time_sync() - t1 |
|
|
|
if args.fps: |
|
s = FPS_TEXT_SIZE |
|
cv2.putText(im0, '{:.1f} FPS'.format(1 / t), (5*s, 25*s), |
|
cv2.FONT_HERSHEY_SIMPLEX, s, (255, 255, 255), thickness=2*s) |
|
|
|
if args.gif: |
|
gif_frames.append(cv2.resize(im0, dsize=None, fx=0.375, fy=0.375)[:, :, [2, 1, 0]]) |
|
elif not args.display: |
|
writer.write(im0) |
|
else: |
|
cv2.imshow('', im0) |
|
cv2.waitKey(1) |
|
|
|
t1 = time_sync() |
|
if i == n - 1: |
|
break |
|
|
|
cv2.destroyAllWindows() |
|
cap.release() |
|
if not args.display: |
|
writer.release() |
|
|
|
if args.gif: |
|
print('Saving GIF...') |
|
with imageio.get_writer(video_name + '.gif', mode="I", fps=fps) as writer: |
|
for idx, frame in tqdm(enumerate(gif_frames)): |
|
writer.append_data(frame) |
|
|
|
|
|
|
|
|